package com.itcast.wang.test_dl;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.deeplearning4j.eval.Evaluation; import org.nd4j.linalg.api.ndarray.INDArray; /**A Simple MLP applied to digit classification for MNIST. */ public class MLPMnistSingleLayerExample { private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class); public static void main(String[] args) throws Exception { final int numRows = 28;//图像宽 final int numColumns = 28;//图像长 int outputNum = 10;//输出的类别数 int batchSize = 128;//没128个样本参加训练 int rngSeed = 123;// int numEpochs = 15;//训练集样本迭代的次数 //Get the DataSetIterators: DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); DataSetIterator mnist = new MnistDataSetIterator(batchSize, false, rngSeed); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngSeed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)//随机梯度下降 .iterations(1) .learningRate(0.006)//学习率 .updater(Updater.NESTEROVS).momentum(0.9)//运动惯量 .regularization(true).l2(1e-4)//是否使用正则化 .list() .layer(0, new DenseLayer.Builder()//第一层网络配置 .nIn(numRows * numColumns)//输入数目 .nOut(1000)//输出数目 .activation("relu")//激活函数 relu .weightInit(WeightInit.XAVIER)//权值初始化 .build()) //输出层指定误差函数 .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)//误差函数 .nIn(1000)//输入 .nOut(outputNum)//输出 .activation("softmax")//激活函数 .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); log.info("Train model...."); for(int i=0;i<4;i++){ model.fit(mnistTrain); System.out.println(" Completed epoch is :" + i); System.out.println("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); while(mnist.hasNext()){ DataSet ds = mnist.next(); INDArray output = model.output(ds.getFeatureMatrix(), false); eval.eval(ds.getLabels(), output); } System.out.println(eval.stats()); mnist.reset(); } System.out.println("model finish"); } }