Integrating DeepLearning4J (DL4J) into a Spring Boot Application for building and using a neural network
Below is a step-by-step guide to create a Spring Boot application that integrates DeepLearning4J (DL4J) for building and using a neural network. We'll create a simple example of training a neural network to perform binary classification.
Prerequisites
1. Java 8 or higher
2. Maven or Gradle
3. IDE like IntelliJ IDEA or Eclipse
1. Create a Spring Boot Project
1. Use Spring Initializr to generate a Spring Boot project.
2. Add dependencies:
- Spring Web
- Maven/Gradle for dependency management
2. Add DL4J Dependencies
Add the following dependencies to your pom.xml for Maven:
<dependencies> <!-- DL4J Core --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-M2.1</version> </dependency> <!-- ND4J Backend --> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-M2.1</version> </dependency> <!-- Spring Boot --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> </dependencies>
3. Create a Neural Network Configuration
Create a simple neural network to perform binary classification.
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; public class NeuralNetExample { public static MultiLayerNetwork buildModel() { MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() .updater(new Sgd(0.1)) .list() .layer(new DenseLayer.Builder().nIn(2).nOut(4).activation(Activation.RELU).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE) .nIn(4).nOut(1).activation(Activation.SIGMOID).build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(configuration); model.init(); return model; } }
4. Train the Model
Add a service class to handle model training.
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @Service public class TrainingService { public void trainModel() { MultiLayerNetwork model = NeuralNetExample.buildModel(); // Example training data (XOR logic gate) double[][] input = { {0, 0}, {0, 1}, {1, 0}, {1, 1} }; double[][] labels = { {0}, {1}, {1}, {0} }; DataSet dataSet = new DataSet(Nd4j.create(input), Nd4j.create(labels)); // Normalize data NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler(); scaler.fit(dataSet); scaler.transform(dataSet); // Train model for (int i = 0; i < 1000; i++) { model.fit(dataSet); } System.out.println("Training complete."); } }
5. Expose an API Endpoint
Create a REST controller to trigger training and prediction.
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @RestController @RequestMapping("/api/dl4j") public class NeuralNetController { @Autowired private TrainingService trainingService; private MultiLayerNetwork model; @PostMapping("/train") public String trainModel() { trainingService.trainModel(); return "Training completed!"; } @PostMapping("/predict") public double predict(@RequestBody double[] input) { return model.output(Nd4j.create(input)).getDouble(0); } }
6. Run the Application
- Start the Spring Boot application.
- Use tools like Postman or curl to hit the /api/dl4j/train endpoint to train the model.
- Use the /api/dl4j/predict endpoint with input data to make predictions.
Example Request
Training:
curl -X POST http://localhost:8080/api/dl4j/train
Prediction:
curl -X POST http://localhost:8080/api/dl4j/predict \ -H "Content-Type: application/json" \ -d "[1.0, 0.0]"
This example shows a basic implementation. For production use, consider persisting the trained model to disk and adding proper validation for input data.

Comments
Post a Comment