Integrating TensorFlow with Spring Boot for Image Detection

1. What is TensorFlow?

The name TensorFlow comes from the flow (Flow) of tensors (Tensor) in a computational graph (Computational Graph), as shown in the diagram. Its foundation is based on the automatic differentiation introduced earlier based on computational graphs. In addition to automatically calculating gradients, it also provides various common operations (op, which are the nodes of the computational graph), common loss functions, and optimization algorithms.Integrating TensorFlow with Spring Boot for Image Detection

  • TensorFlow is an open-source software library for high-performance numerical computation. With its flexible architecture, users can easily deploy computation across a variety of platforms (CPU, GPU, TPU) and devices (desktop devices, server clusters, mobile devices, edge devices, etc.). www.tensorflow.org/tutorials/?hl=zh-cn(opens new window)

  • TensorFlow is an open-source machine learning library for research and production. TensorFlow provides various APIs for beginners and experts to develop in desktop, mobile, web, and cloud environments.

  • TensorFlow computes using data flow graphs, so we first need to create a data flow graph, and then feed our data (in the form of tensors) into the data flow graph for computation. Nodes in the graph represent mathematical operations, while edges in the graph represent multidimensional data arrays that are interconnected, namely tensors. During model training, tensors continuously flow from one node to another in the data flow graph, which is the origin of the name TensorFlow. Tensors: There are various types of tensors. A zero-order tensor is a scalar (scalar), which is a single numerical value. For example, [1] is a first-order tensor, which is a vector (vector), such as a one-dimensional array [1, 2, 3]. A second-order tensor is a matrix (matrix), for example, a two-dimensional array [[1, 2, 3], [4, 5, 6], [7, 8, 9]], and so on, including three-order three-dimensional tensors… The process of tensors flowing from one end of the graph to the other vividly describes the flow, transmission, analysis, and processing patterns of complex data structures in artificial neural networks.

In machine learning, numerical values are usually composed of four types: (1) Scalar: a single value, the smallest unit of computation, such as “1” or “3.2”. (2) Vector: a one-dimensional array composed of several scalars, such as [1, 3.2, 4.6]. (3) Matrix: a two-dimensional array composed of scalars. (4) Tensor: a collection of data composed of multi-dimensional (usually) arrays, which can be understood as high-dimensional matrices.

Basic Concepts of TensorFlow

  • Graph: Describes the computation process; TensorFlow uses graphs to represent the computation process.

  • Tensor: TensorFlow uses tensors to represent data; each tensor is a multi-dimensional array.

  • Operation: The nodes in the graph are ops; an op receives 0 or more tensors as input, executes and computes, producing 0 or more tensors.

  • Session: TensorFlow’s execution requires running within a session.

Code Flow in TensorFlow

  • Define variable placeholders.

  • Write equations based on mathematical principles.

  • Define loss function cost.

  • Define optimization gradient descent (GradientDescentOptimizer).

  • Train within a session using a for loop.

  • Save the model using saver.

2. Environment Preparation

Integration Steps

  1. Model Construction: First, we need to define and train a deep learning model in TensorFlow. This may involve selecting the appropriate network architecture, optimizer, and loss function.

  2. Training Data Preparation: Next, we need to prepare the data for training and validating the model. This may include data cleaning, labeling, and preprocessing steps.

  3. REST API Design: To interact with the TensorFlow model, we need to create a REST API in Spring Boot. This can be achieved using Spring Boot’s built-in features, such as Spring MVC or Spring WebFlux.

  4. Model Deployment: After the model training is complete, we need to deploy it in the Spring Boot application. For this, we can use TensorFlow’s Java API to export the model as ONNX or SavedModel format and then load and use it in the Spring Boot application.

During the integration process, there are several key points to note. First, firewall settings may affect network communication during TensorFlow training. Ensure that your firewall allows TensorFlow to access the necessary network resources to avoid interruptions in training or degradation in model performance. Second, pay attention to version compatibility. Spring Boot and TensorFlow each have their own version update cycles; ensuring compatible versions during integration can avoid many unnecessary troubles.

Model Download

Model construction and training involve Python code, which we will skip here. For those interested, you can download the source code to train the model yourself; we will directly download the pre-trained model.

  • https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz

After downloading, extract it to the /resources/inception_v3 directory.

3. Code Project

Experiment Purpose

Implement image detection.

pom.xml

<?xml version="1.0" encoding="UTF-8"?><project xmlns="http://maven.apache.org/POM/4.0.0"         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">    <parent>        <artifactId>springboot-demo</artifactId>        <groupId>com.et</groupId>        <version>1.0-SNAPSHOT</version>    </parent>    <modelVersion>4.0.0</modelVersion>    <artifactId>Tensorflow</artifactId>    <properties>        <maven.compiler.source>11</maven.compiler.source>        <maven.compiler.target>11</maven.compiler.target>    </properties>    <dependencies>        <dependency>            <groupId>org.springframework.boot</groupId>            <artifactId>spring-boot-starter-web</artifactId>        </dependency>        <dependency>            <groupId>org.springframework.boot</groupId>            <artifactId>spring-boot-autoconfigure</artifactId>        </dependency>        <dependency>            <groupId>org.springframework.boot</groupId>            <artifactId>spring-boot-starter-test</artifactId>            <scope>test</scope>        </dependency>        <dependency>            <groupId>org.tensorflow</groupId>            <artifactId>tensorflow-core-platform</artifactId>            <version>0.5.0</version>        </dependency>        <dependency>            <groupId>org.projectlombok</groupId>            <artifactId>lombok</artifactId>        </dependency>        <dependency>            <groupId>jmimemagic</groupId>            <artifactId>jmimemagic</artifactId>            <version>0.1.2</version>        </dependency>        <dependency>            <groupId>jakarta.platform</groupId>            <artifactId>jakarta.jakartaee-api</artifactId>            <version>9.0.0</version>        </dependency>        <dependency>            <groupId>commons-io</groupId>            <artifactId>commons-io</artifactId>            <version>2.16.1</version>        </dependency>        <dependency>            <groupId>org.springframework.restdocs</groupId>            <artifactId>spring-restdocs-mockmvc</artifactId>            <scope>test</scope>        </dependency>    </dependencies></project>

Controller

package com.et.tf.api;import java.io.IOException;import com.et.tf.service.ClassifyImageService;import net.sf.jmimemagic.Magic;import net.sf.jmimemagic.MagicMatch;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.web.bind.annotation.CrossOrigin;import org.springframework.web.bind.annotation.PostMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;import org.springframework.web.multipart.MultipartFile;@RestController@RequestMapping("/api")public class AppController {    @Autowired    ClassifyImageService classifyImageService;    @PostMapping(value = "/classify")    @CrossOrigin(origins = "*")    public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException {        checkImageContents(file);        return classifyImageService.classifyImage(file.getBytes());    }    @RequestMapping(value = "/")    public String index() {        return "index";    }    private void checkImageContents(MultipartFile file) {        MagicMatch match;        try {            match = Magic.getMagicMatch(file.getBytes());        } catch (Exception e) {            throw new RuntimeException(e);        }        String mimeType = match.getMimeType();        if (!mimeType.startsWith("image")) {            throw new IllegalArgumentException("Not an image type: " + mimeType);        }    }}

Service

package com.et.tf.service;import jakarta.annotation.PreDestroy;import java.util.Arrays;import java.util.List;import lombok.AllArgsConstructor;import lombok.Data;import lombok.NoArgsConstructor;import lombok.extern.slf4j.Slf4j;import org.springframework.beans.factory.annotation.Value;import org.springframework.stereotype.Service;import org.tensorflow.Graph;import org.tensorflow.Output;import org.tensorflow.Session;import org.tensorflow.Tensor;import org.tensorflow.ndarray.NdArrays;import org.tensorflow.ndarray.Shape;import org.tensorflow.ndarray.buffer.FloatDataBuffer;import org.tensorflow.op.OpScope;import org.tensorflow.op.Scope;import org.tensorflow.proto.framework.DataType;import org.tensorflow.types.TFloat32;import org.tensorflow.types.TInt32;import org.tensorflow.types.TString;import org.tensorflow.types.family.TType;@Service@Slf4jpublic class ClassifyImageService {    private final Session session;    private final List<string> labels;    private final String outputLayer;    private final int W;    private final int H;    private final float mean;    private final float scale;    public ClassifyImageService(        Graph inceptionGraph, List<string> labels, @Value("${tf.outputLayer}") String outputLayer,        @Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH,        @Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale    ) {        this.labels = labels;        this.outputLayer = outputLayer;        this.H = imageH;        this.W = imageW;        this.mean = mean;        this.scale = scale;        this.session = new Session(inceptionGraph);    }    public LabelWithProbability classifyImage(byte[] imageBytes) {        long start = System.currentTimeMillis();        try (Tensor image = normalizedImageToTensor(imageBytes)) {            float[] labelProbabilities = classifyImageProbabilities(image);            int bestLabelIdx = maxIndex(labelProbabilities);            LabelWithProbability labelWithProbability =                new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start);            log.debug(String.format(                    "Image classification [%s %.2f%%] took %d ms",                    labelWithProbability.getLabel(),                    labelWithProbability.getProbability(),                    labelWithProbability.getElapsed()                )            );            return labelWithProbability;        }    }    private float[] classifyImageProbabilities(Tensor image) {        try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) {            final Shape resultShape = result.shape();            final long[] rShape = resultShape.asArray();            if (resultShape.numDimensions() != 2 || rShape[0] != 1) {                throw new RuntimeException(                    String.format(                        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",                        Arrays.toString(rShape)                    ));            }            int nlabels = (int) rShape[1];            FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats();            float[] dst = new float[nlabels];            resultFloatBuffer.read(dst);            return dst;        }    }    private int maxIndex(float[] probabilities) {        int best = 0;        for (int i = 1; i < probabilities.length; ++i) {            if (probabilities[i] > probabilities[best]) {                best = i;            }        }        return best;    }    private Tensor normalizedImageToTensor(byte[] imageBytes) {        try (Graph g = new Graph();             TInt32 batchTensor = TInt32.scalarOf(0);             TInt32 sizeTensor = TInt32.vectorOf(H, W);             TFloat32 meanTensor = TFloat32.scalarOf(mean);             TFloat32 scaleTensor = TFloat32.scalarOf(scale);        ) {            GraphBuilder b = new GraphBuilder(g);            final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)));            final Output output =                b.div(                    b.sub(                        b.resizeBilinear(                            b.expandDims(                                b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT),                                b.constant("make_batch", batchTensor)                            ),                            b.constant("size", sizeTensor)                        ),                        b.constant("mean", meanTensor)                    ),                    b.constant("scale", scaleTensor)                );            try (Session s = new Session(g)) {                return s.runner().fetch(output.op().name()).run().get(0);            }        }    }    static class GraphBuilder {        final Scope scope;        GraphBuilder(Graph g) {            this.g = g;            this.scope = new OpScope(g);        }        Output div(Output x, Output y) {            return binaryOp("Div", x, y);        }        Output sub(Output x, Output y) {            return binaryOp("Sub", x, y);        }        Output resizeBilinear(Output images, Output size) {            return binaryOp("ResizeBilinear", images, size);        }        Output expandDims(Output input, Output dim) {            return binaryOp("ExpandDims", input, dim);        }        Output cast(Output value, DataType dtype) {            return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0);        }        Output decodeJpeg(Output contents, long channels) {            return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope)                .addInput(contents)                .setAttr("channels", channels)                .build()                .output(0);        }        Output<? extends TType> constant(String name, Tensor t) {            return g.opBuilder("Const", name, scope)                .setAttr("dtype", t.dataType())                .setAttr("value", t)                .build()                .output(0);        }        private Output binaryOp(String type, Output in1, Output in2) {            return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0);        }        private final Graph g;    }    @PreDestroy    public void close() {        session.close();    }    @Data    @NoArgsConstructor    @AllArgsConstructor    public static class LabelWithProbability {        private String label;        private float probability;        private long elapsed;    }}</string></string>

application.yaml

tf:    frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pb    labelsPath: inception-v3/imagenet_slim_labels.txt    outputLayer: InceptionV3/Predictions/Reshape_1    image:        width: 299        height: 299        mean: 0        scale: 255logging.level.net.sf.jmimemagic: WARNspring:  servlet:    multipart:      max-file-size: 5MB

Application.java

package com.et.tf;import java.io.IOException;import java.nio.charset.StandardCharsets;import java.util.List;import java.util.stream.Collectors;import lombok.extern.slf4j.Slf4j;import org.apache.commons.io.IOUtils;import org.springframework.beans.factory.annotation.Value;import org.springframework.boot.SpringApplication;import org.springframework.boot.autoconfigure.SpringBootApplication;import org.springframework.context.annotation.Bean;import org.springframework.core.io.ClassPathResource;import org.springframework.core.io.FileSystemResource;import org.springframework.core.io.Resource;import org.tensorflow.Graph;import org.tensorflow.proto.framework.GraphDef;@SpringBootApplication@Slf4jpublic class Application {    public static void main(String[] args) {        SpringApplication.run(Application.class, args);    }    @Bean    public Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException {        Resource graphResource = getResource(tfFrozenModelPath);        Graph graph = new Graph();        graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream()));        log.info("Loaded TensorFlow model");        return graph;    }    private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) {        Resource graphResource = new FileSystemResource(tfFrozenModelPath);        if (!graphResource.exists()) {            graphResource = new ClassPathResource(tfFrozenModelPath);        }        if (!graphResource.exists()) {            throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath));        }        return graphResource;    }    @Bean    public List<string> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException {        Resource labelsRes = getResource(labelsPath);        log.info("Loaded model labels");        return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream()            .map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList());    }}</string>

The above is just some key code; for all the code, please refer to the code repository below.

Code Repository

4. Testing

Start the Spring Boot application.

Test Image Classification

Visit http://127.0.0.1:8080/, upload an image, and click classify.

Integrating TensorFlow with Spring Boot for Image Detection

5. References

Leave a Comment