同じプログラムコードを用いて、線形クラスター分類のトレーニングデータと評価データをアプライしてみる。グラフの表示範囲は、拡大しておく。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
package org.deeplearning4j.examples.feedforward.classification; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import java.io.File; /** * "Saturn" Data Classification Example * * Based on the data from Jason Baldridge: * https://github.com/jasonbaldridge/try-tf/tree/master/simdata * * @author Josh Patterson * @author Alex Black (added plots) * */ public class MLPClassifierSaturn { public static void main(String[] args) throws Exception { Nd4j.ENFORCE_NUMERICAL_STABILITY = true; int batchSize = 50; int seed = 123; double learningRate = 0.005; //Number of epochs (full passes of the data) int nEpochs =30; int numInputs = 2; int numOutputs = 2; int numHiddenNodes = 20; //final String filenameTrain = new ClassPathResource("/classification/saturn_data_train.csv").getFile().getPath(); //final String filenameTest = new ClassPathResource("/classification/saturn_data_eval.csv").getFile().getPath(); final String filenameTrain = new ClassPathResource("/classification/linear_data_train.csv").getFile().getPath(); final String filenameTest = new ClassPathResource("/classification/linear_data_eval.csv").getFile().getPath(); //Load the training data: RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new File(filenameTrain))); DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2); //Load the test/evaluation data: RecordReader rrTest = new CSVRecordReader(); rrTest.initialize(new FileSplit(new File(filenameTest))); DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2); //log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .updater(new Nesterovs(learningRate, 0.9)) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation(Activation.RELU) .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX) .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); System.out.println(model.paramTable()); System.out.println(model.params()); //model.setListeners(new ScoreIterationListener(1)); //Print score every 10 parameter updates for ( int n = 0; n < nEpochs; n++) { model.fit( trainIter ); System.out.println(model.params()); } System.out.println(model.summary()); System.out.println(model.paramTable()); System.out.println("Evaluate model...."); Evaluation eval = new Evaluation(numOutputs); while(testIter.hasNext()){ DataSet t = testIter.next(); INDArray features = t.getFeatures(); INDArray lables = t.getLabels(); INDArray predicted = model.output(features,false); eval.eval(lables, predicted); } System.out.println(eval.stats()); //------------------------------------------------------------------------------------ //Training is complete. Code that follows is for plotting the data & predictions only //double xMin = -15; //double xMax = 15; //double yMin = -15; //double yMax = 15; double xMin = 0; double xMax = 1.0; double yMin = -0.2; double yMax = 0.8; //Let's evaluate the predictions at every point in the x/y input space, and plot this in the background int nPointsPerAxis = 100; double[][] evalPoints = new double[nPointsPerAxis*nPointsPerAxis][2]; int count = 0; for( int i=0; i<nPointsPerAxis; i++ ){ for( int j=0; j<nPointsPerAxis; j++ ){ double x = i * (xMax-xMin)/(nPointsPerAxis-1) + xMin; double y = j * (yMax-yMin)/(nPointsPerAxis-1) + yMin; evalPoints[count][0] = x; evalPoints[count][1] = y; count++; } } INDArray allXYPoints = Nd4j.create(evalPoints); INDArray predictionsAtXYPoints = model.output(allXYPoints); //Get all of the training data in a single array, and plot it: rr.initialize(new FileSplit(new File(filenameTrain))); rr.reset(); int nTrainPoints = 1000; trainIter = new RecordReaderDataSetIterator(rr,nTrainPoints,0,2); DataSet ds = trainIter.next(); PlotUtil.plotTrainingData(ds.getFeatures(), ds.getLabels(), allXYPoints, predictionsAtXYPoints, nPointsPerAxis); //Get test data, run the test data through the network to generate predictions, and plot those predictions: rrTest.initialize(new FileSplit(new File(filenameTest))); rrTest.reset(); int nTestPoints = 500; testIter = new RecordReaderDataSetIterator(rrTest,nTestPoints,0,2); ds = testIter.next(); INDArray testPredicted = model.output(ds.getFeatures()); PlotUtil.plotTestData(ds.getFeatures(), ds.getLabels(), testPredicted, allXYPoints, predictionsAtXYPoints, nPointsPerAxis); System.out.println("****************Example finished********************"); } } |
と、いとも簡単に1 […]