deeplearning4j实现多感知器的手写数字识别

xiaoxiao2021-02-28  84

package com.itcast.wang.test_dl;

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.deeplearning4j.eval.Evaluation;   import org.nd4j.linalg.api.ndarray.INDArray; /**A Simple MLP applied to digit classification for MNIST.  */ public class MLPMnistSingleLayerExample {       private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);       public static void main(String[] args) throws Exception {           final int numRows = 28;//图像宽         final int numColumns = 28;//图像长         int outputNum = 10;//输出的类别数         int batchSize = 128;//没128个样本参加训练         int rngSeed = 123;//         int numEpochs = 15;//训练集样本迭代的次数           //Get the DataSetIterators:         DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);         DataSetIterator mnist = new MnistDataSetIterator(batchSize, false, rngSeed);         log.info("Build model....");         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                 .seed(rngSeed)                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)//随机梯度下降                 .iterations(1)                 .learningRate(0.006)//学习率                 .updater(Updater.NESTEROVS).momentum(0.9)//运动惯量                 .regularization(true).l2(1e-4)//是否使用正则化                 .list()                 .layer(0, new DenseLayer.Builder()//第一层网络配置                         .nIn(numRows * numColumns)//输入数目                         .nOut(1000)//输出数目                         .activation("relu")//激活函数 relu                         .weightInit(WeightInit.XAVIER)//权值初始化                         .build())                 //输出层指定误差函数                 .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)//误差函数                         .nIn(1000)//输入                         .nOut(outputNum)//输出                         .activation("softmax")//激活函数                         .weightInit(WeightInit.XAVIER)                         .build())                 .pretrain(false).backprop(true)                 .build();         MultiLayerNetwork model = new MultiLayerNetwork(conf);         model.init();         model.setListeners(new ScoreIterationListener(1));         log.info("Train model....");         for(int i=0;i<4;i++){               model.fit(mnistTrain);               System.out.println(" Completed epoch is :" + i);                      System.out.println("Evaluate model....");                   Evaluation eval = new Evaluation(outputNum);                   while(mnist.hasNext()){                       DataSet ds = mnist.next();                       INDArray output = model.output(ds.getFeatureMatrix(), false);                       eval.eval(ds.getLabels(), output);                   }                   System.out.println(eval.stats());                   mnist.reset();                          }           System.out.println("model finish");       }     }        
转载请注明原文地址: https://www.6miu.com/read-58704.html

最新回复(0)