Integrating DeepLearning4J (DL4J) into a Spring Boot Application for building and using a neural network

Below is a step-by-step guide to create a Spring Boot application that integrates DeepLearning4J (DL4J) for building and using a neural network. We'll create a simple example of training a neural network to perform binary classification.

Prerequisites

1. Java 8 or higher

2. Maven or Gradle

3. IDE like IntelliJ IDEA or Eclipse


1. Create a Spring Boot Project

1. Use Spring Initializr to generate a Spring Boot project.

2. Add dependencies:

  • Spring Web
  • Maven/Gradle for dependency management

2. Add DL4J Dependencies

Add the following dependencies to your pom.xml for Maven:

<dependencies>
    <!-- DL4J Core -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-core</artifactId>
        <version>1.0.0-M2.1</version>
    </dependency>

    <!-- ND4J Backend -->
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>nd4j-native-platform</artifactId>
        <version>1.0.0-M2.1</version>
    </dependency>

    <!-- Spring Boot -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
</dependencies>


3. Create a Neural Network Configuration

Create a simple neural network to perform binary classification.

import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class NeuralNetExample {

    public static MultiLayerNetwork buildModel() {
        MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
                .updater(new Sgd(0.1))
                .list()
                .layer(new DenseLayer.Builder().nIn(2).nOut(4).activation(Activation.RELU).build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .nIn(4).nOut(1).activation(Activation.SIGMOID).build())
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
        model.init();
        return model;
    }
}


4. Train the Model

Add a service class to handle model training.

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

@Service
public class TrainingService {

    public void trainModel() {
        MultiLayerNetwork model = NeuralNetExample.buildModel();

        // Example training data (XOR logic gate)
        double[][] input = {
                {0, 0},
                {0, 1},
                {1, 0},
                {1, 1}
        };
        double[][] labels = {
                {0},
                {1},
                {1},
                {0}
        };

        DataSet dataSet = new DataSet(Nd4j.create(input), Nd4j.create(labels));

        // Normalize data
        NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler();
        scaler.fit(dataSet);
        scaler.transform(dataSet);

        // Train model
        for (int i = 0; i < 1000; i++) {
            model.fit(dataSet);
        }

        System.out.println("Training complete.");
    }
}


5. Expose an API Endpoint

Create a REST controller to trigger training and prediction.

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;

@RestController
@RequestMapping("/api/dl4j")
public class NeuralNetController {

    @Autowired
    private TrainingService trainingService;

    private MultiLayerNetwork model;

    @PostMapping("/train")
    public String trainModel() {
        trainingService.trainModel();
        return "Training completed!";
    }

    @PostMapping("/predict")
    public double predict(@RequestBody double[] input) {
        return model.output(Nd4j.create(input)).getDouble(0);
    }
}


6. Run the Application

  • Start the Spring Boot application.
  • Use tools like Postman or curl to hit the /api/dl4j/train endpoint to train the model.
  • Use the /api/dl4j/predict endpoint with input data to make predictions.


Example Request

Training:

curl -X POST http://localhost:8080/api/dl4j/train


Prediction:

curl -X POST http://localhost:8080/api/dl4j/predict \
  -H "Content-Type: application/json" \
  -d "[1.0, 0.0]"

This example shows a basic implementation. For production use, consider persisting the trained model to disk and adding proper validation for input data.

Comments

Popular posts from this blog

Spring Boot OpenAI Integration: Step-by-Step Guide

Orchestration-Based Saga Architecture and Spring Boot Microservices Implementation Guide

Spring Boot 3 + Angular 15 + Material - Full Stack CRUD Application Example