LeNet Convolution Neural Networkの構築
全体の構造の特徴
1)入力側から出力側へ、畳み込み層(convolution)とプーリング層(pooling)がペアで複数回繰り返される。
2) 畳み込み層とプーリング層のあとに、隣接層間のユニットが全結合する全結合層(fully-connected layer)を配置する。
ここでは、以下の5つの層をシンプルに構成させる。
第一畳み込み層(convolution):
この層では、行列化したイメージデータに対して、サイズの小さいイメージ様の行列フィルタを掛け合わせる。今回、機械学習にアプライさせるMNISTの数字画像は28 x 28 = 784 数値で構成されている。
ここでは、Filter size = 5×5 のフィルター(重みに相当)とする。このフィルターを20セット
Input data channelsは、入力画像のチャンネルを指すが、カラー画像だったらRBGの3xチャンネルだけど、MNISTの白黒画像では、1チャンネル。
この小さなフィルター行列を画像行列に掛け合わせる際、フィルター移動の条件として「Stride ストライド」を決める。xy方向にひとつづつ移動させるので、(1,1)
ここでのactivation functionは y=f(x)=x: Identityとする。
第0層(Layer 0)を書き上げると、
1 2 3 4 5 6 |
.layer(0, new ConvolutionLayer.Builder(5, 5) .nIn(nChannels) //int nChannels = 1 .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) |
第一プーリング層(pooling):
(2×2)のサイズで、最もの大きな画素値(max pooling)を代表値として取り出し、ダウンサイジング(その歳、ストライドは2,2)を図る。
第1層(Layer 1)を書き上げると、
1 2 3 4 |
.layer(1, new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) |
第ニ畳み込み層(convolution):
Filter size = 5×5 のフィルター、ストライド(1,1)で、このフィルターを50セットとして、第2層(Layer 2)を書き上げると、
1 2 3 4 5 6 |
.layer(2, new ConvolutionLayer.Builder(5, 5) .nIn(nChannels) //int nChannels = 1 .stride(1, 1) .nOut(50) .activation(Activation.IDENTITY) .build()) |
第ニプーリング層(pooling):
(2×2)のサイズで、最もの大きな画素値(max pooling)を代表値として取り出し、ダウンサイジング(その歳、ストライドは2,2)を図る。
第3層(Layer 3)を書き上げると、
1 2 3 4 |
.layer(3, new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) |
第4層(デンス層):
ここでの入力ニューロン数は、20×50=1000となっている。
ReLU関数でのアクチベーションで、outputを500とする。
1 2 |
.layer(4, new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) |
第5層(出力層): 活性化関数としてSoftmax関数を利用し、10個のアウトプットとする。
1 2 3 4 |
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) // int outputNum = 10 .activation(Activation.SOFTMAX) .build()) |
6層をまとめて、ハイパーパラメータを以下のように定義して、
1 2 3 4 5 6 7 |
.seed(seed) .iteration(interations) .regulatization(true).l2(0.0005) .learningTate(.01) .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROV).momentum(0.9)) |
MultiLayrConfiguration Confを上げると、
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iteration(interations) .regulatization(true).l2(0.0005) .learningTate(.01) .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROV).momentum(0.9)) .list() .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) .layer(1, new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) .layer(2, new ConvolutionLayer.Builder(5, 5) //Note that nIn need not be specified in later layers .stride(1, 1) .nOut(50) .activation(Activation.IDENTITY) .build()) .layer(3, new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) .layer(4, new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below .backprop(true).pretrain(false).build(); |
NMISTのデータのロードには、DL4Jの専用クラスMnistDataSetInteratorを用いる。
1 2 |
DataSetIteator mnistTrain = new MnistDataSetIterator(batchSize,true,12345); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345); |
重要なハイパーパラメータも定義しておく。
1 2 3 4 5 |
int nChannels = 1; // Number of input channels int outputNum = 10; // The number of possible outcomes int batchSize = 64; // Test batch size int nEpochs = 1; // Number of training epochs int seed = 123; // |
最後にモデルに対してデータをアプライして、トレーニングとテスト評価を行う。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); log.info("Train model...."); model.setListeners(new ScoreIterationListener(10)); //Print score every 10 iterations for( int i=0; i<nEpochs; i++ ) { model.fit(mnistTrain); log.info("*** Completed epoch {} ***", i); log.info("Evaluate model...."); Evaluation eval = model.evaluate(mnistTest); log.info(eval.stats()); mnistTest.reset(); } log.info("****************Example finished********************"); |
全部コードをまとめると
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
package org.deeplearning4j.examples.convolution; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class LenetMnistExample { private static final Logger log = LoggerFactory.getLogger(LenetMnistExample.class); public static void main(String[] args) throws Exception { int nChannels = 1; // Number of input channels int outputNum = 10; // The number of possible outcomes int batchSize = 64; // Test batch size int nEpochs = 1; // Number of training epochs int iterations = 1; // Number of training iterations int seed = 123; // log.info("Load data...."); DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize,true,12345); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) // Training iterations as above .regularization(true).l2(0.0005) .learningRate(.01)//.biasLearningRate(0.02) .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS).momentum(0.9) .list() .layer(0, new ConvolutionLayer.Builder(5, 5) .nIn(nChannels) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) .layer(2, new ConvolutionLayer.Builder(5, 5) //Note that nIn need not be specified in later layers .stride(1, 1) .nOut(50) .activation(Activation.IDENTITY) .build()) .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) .layer(4, new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below .backprop(true).pretrain(false).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); log.info("Train model...."); model.setListeners(new ScoreIterationListener(1)); for( int i=0; i<nEpochs; i++ ) { model.fit(mnistTrain); log.info("*** Completed epoch {} ***", i); log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); while(mnistTest.hasNext()){ DataSet ds = mnistTest.next(); INDArray output = model.output(ds.getFeatureMatrix(), false); eval.eval(ds.getLabels(), output); } log.info(eval.stats()); mnistTest.reset(); } log.info("****************Example finished********************"); } } |
プログラムを実行してみると、
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
...... Examples labeled as 0 classified by model as 0: 967 times Examples labeled as 0 classified by model as 1: 1 times Examples labeled as 0 classified by model as 2: 2 times Examples labeled as 0 classified by model as 5: 3 times Examples labeled as 0 classified by model as 6: 2 times Examples labeled as 0 classified by model as 7: 4 times Examples labeled as 0 classified by model as 8: 1 times Examples labeled as 1 classified by model as 1: 1122 times Examples labeled as 1 classified by model as 2: 4 times Examples labeled as 1 classified by model as 3: 1 times Examples labeled as 1 classified by model as 5: 1 times Examples labeled as 1 classified by model as 6: 2 times Examples labeled as 1 classified by model as 8: 5 times Examples labeled as 2 classified by model as 2: 1025 times Examples labeled as 2 classified by model as 3: 1 times Examples labeled as 2 classified by model as 7: 5 times Examples labeled as 2 classified by model as 8: 1 times Examples labeled as 3 classified by model as 2: 10 times Examples labeled as 3 classified by model as 3: 964 times Examples labeled as 3 classified by model as 5: 24 times Examples labeled as 3 classified by model as 7: 8 times Examples labeled as 3 classified by model as 8: 3 times Examples labeled as 3 classified by model as 9: 1 times Examples labeled as 4 classified by model as 2: 2 times Examples labeled as 4 classified by model as 4: 977 times Examples labeled as 4 classified by model as 6: 1 times Examples labeled as 4 classified by model as 7: 1 times Examples labeled as 4 classified by model as 9: 1 times Examples labeled as 5 classified by model as 0: 1 times Examples labeled as 5 classified by model as 2: 1 times Examples labeled as 5 classified by model as 3: 2 times Examples labeled as 5 classified by model as 5: 886 times Examples labeled as 5 classified by model as 6: 1 times Examples labeled as 5 classified by model as 7: 1 times Examples labeled as 6 classified by model as 0: 5 times Examples labeled as 6 classified by model as 1: 4 times Examples labeled as 6 classified by model as 2: 4 times Examples labeled as 6 classified by model as 3: 1 times Examples labeled as 6 classified by model as 4: 3 times Examples labeled as 6 classified by model as 5: 13 times Examples labeled as 6 classified by model as 6: 928 times Examples labeled as 7 classified by model as 1: 4 times Examples labeled as 7 classified by model as 2: 20 times Examples labeled as 7 classified by model as 7: 1004 times Examples labeled as 8 classified by model as 0: 2 times Examples labeled as 8 classified by model as 2: 7 times Examples labeled as 8 classified by model as 3: 2 times Examples labeled as 8 classified by model as 4: 3 times Examples labeled as 8 classified by model as 5: 12 times Examples labeled as 8 classified by model as 6: 1 times Examples labeled as 8 classified by model as 7: 12 times Examples labeled as 8 classified by model as 8: 925 times Examples labeled as 8 classified by model as 9: 10 times Examples labeled as 9 classified by model as 0: 2 times Examples labeled as 9 classified by model as 1: 4 times Examples labeled as 9 classified by model as 2: 3 times Examples labeled as 9 classified by model as 3: 2 times Examples labeled as 9 classified by model as 4: 20 times Examples labeled as 9 classified by model as 5: 12 times Examples labeled as 9 classified by model as 7: 12 times Examples labeled as 9 classified by model as 8: 1 times Examples labeled as 9 classified by model as 9: 953 times ==========================Scores======================================== Accuracy: 0.9751 Precision: 0.9753 Recall: 0.9751 F1 Score: 0.9752 ======================================================================== o.d.e.c.LenetMnistExample - ****************Example finished******************** |