MnistAnomalyの構造について、詳しく確認してみよう。ここでは、とくにJAVA 8で強化されたコレクションクラスを駆使した行列データの操作への理解が重要となる。
オートエンコーダーでは、28×28の数字の画像を入れて、同じ画像を出力するニューラルネットワークであり、データを表現する特徴を獲得するためのニューラルネットワーク。
https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24
MnistAnomaly.javaの少し長いコードの大まかな以下の8つの構造から構成されていることをまず把握しておく。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
public class MNISTAnomalyExample { public static void main(String[] args) throws Exception { // [構造1:MutiLayerConfiguration],多層ニューラルネットワークの設定. // Mnist画像は縦28ピクセルx横28ピクセル=784ピクセル画像 // ニューラルネットは4層構造で、入力層784ニューロン -> 第0層250ニューロン -> 第1層10ニューロン -> 第2層250ニューロン -> 第3層784ニューロン(出力層) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() ........ MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(Collections.singletonList(new ScoreIterationListener(1))); // [構造2:データ読み込みと訓練・評価データへの分割],40,000個の訓練データと10,000個の評価データへの分割. DataSetIterator iter = new MnistDataSetIterator(100,50000,false); List<INDArray> featuresTrain = new ArrayList<>(); List<INDArray> featuresTest = new ArrayList<>(); List<INDArray> labelsTest = new ArrayList<>(); ............... ............... // [構造3:モデルへの訓練データの適用] 30エポック繰り返される. int nEpochs = 30; for( int epoch=0; epoch<nEpochs; epoch++ ){ for(INDArray data : featuresTrain){ net.fit(data,data); } System.out.println("Epoch " + epoch + " complete"); } // [構造4:訓練データの適用] // 各画像のスコア評価と数値データをペアにしたMap配列を作成する. Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>(); for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>()); ............... ............... // [構造5:スコアによるソート] Comparatorクラスが利用される。 Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() { ................ ................ // [構造6:Best, Worstスコアデータの選定] List<INDArray> best = new ArrayList<>(50); List<INDArray> worst = new ArrayList<>(50); ................ // [構造7:Best, Worstスコアデータの可視化] 後述のMNISTVisualizerクラスを用いる。 MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0,best,"Best (Low Rec. Error)"); bestVisualizer.visualize(); MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0,worst,"Worst (High Rec. Error)"); worstVisualizer.visualize(); } // [構造8:データの可視化のためのMNISTVisualizerクラス(構造7から呼び出される)] // メソットとして、MNISTVIsualizer(), visualize(), getComponents()を持つ. public static class MNISTVisualizer { public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) { ....... } public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) { ....... } public void visualize(){ ....... } private List<JLabel> getComponents(){ ....... } return images; } } } |
では、まず第1構造MutiLayerConfigurationで設定される多層ニューラルネットワーク構成から詳しく見てみる。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) //乱数発生の初期値 .weightInit(WeightInit.XAVIER) //重み初期設定はサビエル法を選択 .updater(new AdaGrad(0.05)) //学習率調整はAdaGrad法を選択 .activation(Activation.RELU) //活性化関数はReLU法を選択 .l2(0.0001) //正規化係数2設定は0.0001 .list() //List.Builder .layer(0, new DenseLayer.Builder().nIn(784).nOut(250) .build()) //第0層:784>250 .layer(1, new DenseLayer.Builder().nIn(250).nOut(10) .build()) //第1層:250>10 .layer(2, new DenseLayer.Builder().nIn(10).nOut(250) .build()) //第2層:10>250 .layer(3, new OutputLayer.Builder().nIn(250).nOut(784) .lossFunction(LossFunctions.LossFunction.MSE) .build()) //第3層:250>784 .pretrain(false).backprop(true) //逆誤差伝搬法 .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); //MultiLayerNetworkのインスタンnetを上記のconfを初期条件として生成 net.setListeners(Collections.singletonList(new ScoreIterationListener(1))); //トレーニングリスナーのセット. |
4層構造のニューラルネットが設定されて、入力層では28×28=784ピクセルで構成されるMnist画像のそれぞれのピクセルが第0層784個のニューロンに接続されることと、最終的な出力層である第3層も同数の784個のニューロンで構成されることがわかる。
次に第2構造であるデータ読み込みと訓練・評価データへの分割について見る。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
DataSetIterator iter = new MnistDataSetIterator(100,50000,false); //DateSetIteratorクラスのインスタンスiterを生成する。MnistDataSetIterator(int batchSize, boolean train, int seed) では、バッチサイズ100個で、まずは訓練データが50,000個がセットされる。後ほど評価用データが分割される.) List<INDArray> featuresTrain = new ArrayList<>(); List<INDArray> featuresTest = new ArrayList<>(); List<INDArray> labelsTest = new ArrayList<>(); //INDArrayで構成されるArrayListとして訓練用、評価用、評価ラベル用リストが設定される. Random r = new Random(12345); //乱数r設定 while(iter.hasNext()){ //iterがデータを保持している間、データ更新を繰り返す DataSet ds = iter.next(); //DataSetクラスdsにデータをセットする. SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //訓練データ80と評価データ20を分離 (from miniBatch = 100) featuresTrain.add(split.getTrain().getFeatures()); //訓練データをfeaturesTrainへスプリット DataSet dsTest = split.getTest(); //DataSetクラスdsTestに評価データをセットする. featuresTest.add(dsTest.getFeatures()); //評価データをfeaturesTestへ格納する. INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1);//one-hotで得られた最大値をindexesに収める. labelsTest.add(indexes); //indexesをlabelsTestへ加える. } |
次は、構造3であるモデルへの訓練データの適用を分析する。
1 2 3 4 5 6 7 |
int nEpochs = 30; //以下のforループを通じて、30エポック繰り返される. for( int epoch=0; epoch<nEpochs; epoch++ ){ for(INDArray data : featuresTrain){ //訓練データ数繰り返される net.fit(data,data); } System.out.println("Epoch " + epoch + " complete"); } |
ここでは、30エポック繰り返される.
訓練用データは80個/バッチなので、40,000データの訓練データをカバーするためには、500回の繰り返しが必要ということから、出力結果は、1エポック500の繰り返しを30回で、以下のように0回目を含めて、15,000回まで繰り返された。
1 2 3 4 5 6 7 8 9 10 |
........ o.d.o.l.ScoreIterationListener - Score at iteration 2998 is 0.04381205167132551 o.d.o.l.ScoreIterationListener - Score at iteration 2999 is 0.04151316669116635 Epoch 5 complete o.d.o.l.ScoreIterationListener - Score at iteration 3000 is 0.0389922661273628 o.d.o.l.ScoreIterationListener - Score at iteration 3001 is 0.040439206650119595 ......... o.d.o.l.ScoreIterationListener - Score at iteration 14998 is 0.039607348121309865 o.d.o.l.ScoreIterationListener - Score at iteration 14999 is 0.03742313742759901 Epoch 29 complete |
次は、構造4の訓練データの適用。各画像のスコア評価と正解数値データをペアにしたMap配列を作成する.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>(); //Map配列listsByDigitを(整数、リスト(ダブル、INDArray配列))として設定. //ここではHashMap関数を使って初期設定. for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>()); //0〜9までを、Map配列listsByDigitをセット. for( int i=0; i<featuresTest.size(); i++ ){ //評価用データサイズ繰り返す. //featuresTest.size()は、500個. 1バッチ20で、評価データ10,000なので、500個 INDArray testData = featuresTest.get(i); //featuresTest(List<INDArray>)の中身を順番にINDArray testDataへ取り出す。 INDArray labels = labelsTest.get(i); // List<INDArray> labelsTestの中身を順番にINDArray labelsへ取り出す。 int nRows = testData.rows(); // INDArray testDataの行数をnRowsへ 20行:テストデータは20件。 for( int j=0; j<nRows; j++){ // INDArray testDataの20回、繰り返す。 INDArray example = testData.getRow(j); //1行づつINDArray exampleへ移す。 int digit = (int)labels.getDouble(j); // INDArray labelsの中身をdigitへ移す。 double score = net.score(new DataSet(example,example)); //評価スコアを算出し、スコアへ配置 List digitAllPairs = listsByDigit.get(digit); //digitをキーにPair<Double,INDArray>を取り出し、リストdigitAllPairsへ digitAllPairs.add(new ImmutablePair<>(score, example)); // リストdigitAllPairsへ、scoreとexampleを配置する。 } } |
参考までにdigitAllPairsの中身を出力させると、以下の通り scoreと784の配列データ.
1 |
digitAllPairs:[(0.06809146115858107,[[ 0, 0, ......]])] |
次は、構造5、スコアによるソート(Comparatorクラスが利用される)
1 2 3 4 5 6 7 8 9 10 11 12 13 |
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() { //Comparator<T> オブジェクトのコレクションで全体順序付けを行う比較関数. @Override //スーパークラスのメソッドをサブクラスで記述 public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) { return Double.compare(o1.getLeft(),o2.getLeft()); // compare(T o1,T o2) 順序付けのために2つの引数を比較する. } }; for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){ Collections.sort(digitAllPairs, c); // Collections.sort(List<T> list, Comparator<? super T> c) 指定されたコンパレータが示す順序に従って、指定されたリストをソートする. digitAllPairがscoreが低い順にソートされる. } |
構造6では、Best, Worstスコアデータの選定を行う。
1 2 3 4 5 6 7 8 9 10 |
List<INDArray> best = new ArrayList<>(50); //べストを収めるArrayList生成 List<INDArray> worst = new ArrayList<>(50);//ワーストを収めるArrayList生成 for( int i=0; i<10; i++ ){ //0〜9繰り返す。 List<Pair<Double,INDArray>> list = listsByDigit.get(i); //listByDigitから0〜9をキーとして、順番にPair<Double,INDArray>をlistへ取り出す。 for( int j=0; j<5; j++ ){ listの上位下位5番のbestとworstへ画像配列を取り出す. best.add(list.get(j).getRight()); worst.add(list.get(list.size()-j-1).getRight()); } } |
最後に構造7で、Best, Worstスコアデータの可視化する。この歳、後述のMNISTVisualizerクラスを用いる。
1 2 3 4 5 6 |
MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0,best,"Best (Low Rec. Error)"); bestVisualizer.visualize(); MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0,worst,"Worst (High Rec. Error)"); worstVisualizer.visualize(); } |
構造7で用いるMNISTVisualizer関数を構造8で定義する。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
public static class MNISTVisualizer { private double imageScale; private List<INDArray> digits; //Digits (as row vectors), one per INDArray private String title; private int gridWidth; public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) { this(imageScale, digits, title, 5); } public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) { this.imageScale = imageScale; this.digits = digits; this.title = title; this.gridWidth = gridWidth; } public void visualize(){ JFrame frame = new JFrame(); frame.setTitle(title); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); JPanel panel = new JPanel(); panel.setLayout(new GridLayout(0,gridWidth)); List<JLabel> list = getComponents(); for(JLabel image : list){ panel.add(image); } frame.add(panel); frame.setVisible(true); frame.pack(); } private List<JLabel> getComponents(){ List<JLabel> images = new ArrayList<>(); for( INDArray arr : digits ){ BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY); for( int i=0; i<784; i++ ){ bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i))); } ImageIcon orig = new ImageIcon(bi); Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE); ImageIcon scaled = new ImageIcon(imageScaled); images.add(new JLabel(scaled)); } return images; } } |
分かりにくいのは、Map配列listsByDigitとListのdigitAllPairsでしょう。
仮に、出力を減らすために、アプライデータを1000に減らして、Epochを1に減らして見ると、訓練データ800、評価データは200となる。それぞれの数字ごとには20個程度。1バッチ100とすれば、10サイクルで1エポック終了する。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 2 o.n.n.Nd4jBlas - Number of threads used for BLAS: 2 o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Mac OS X] o.n.l.a.o.e.DefaultOpExecutioner - Cores: [4]; Memory: [3.6GB]; o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [MKL] o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE] o.d.o.l.ScoreIterationListener - Score at iteration 0 is 0.10551188530375336 o.d.o.l.ScoreIterationListener - Score at iteration 1 is 0.10076131940393633 o.d.o.l.ScoreIterationListener - Score at iteration 2 is 0.10534758334790928 o.d.o.l.ScoreIterationListener - Score at iteration 3 is 0.09853258587852988 o.d.o.l.ScoreIterationListener - Score at iteration 4 is 0.09353333141320336 o.d.o.l.ScoreIterationListener - Score at iteration 5 is 0.0901039258482666 o.d.o.l.ScoreIterationListener - Score at iteration 6 is 0.09359006410566302 o.d.o.l.ScoreIterationListener - Score at iteration 7 is 0.09574343548946489 o.d.o.l.ScoreIterationListener - Score at iteration 8 is 0.0939110892056225 o.d.o.l.ScoreIterationListener - Score at iteration 9 is 0.08888372561510648 Epoch 0 complete |
listsByDigitは、生成段階では、
1 |
listsByDigit_pre:{0=[], 1=[], 2=[], 3=[], 4=[], 5=[], 6=[], 7=[], 8=[], 9=[]} |
となっているが、その後、digiAllPairsを取り込んで、
1 |
listsByDigit_post:{0=[(0.15421215545258268,[[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1098, 0.7647, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 1.0000, 0.2392, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0235, 0.7490, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1020, 0.7451, 0.9922, 0.9922, 0.9922, 0.9922, 0.9412, 0.7490, 0.9490, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0588, 0.7333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7843, 0, 0.8275, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0863, 0.2588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9451, 0.8196, 0.1725, 0.0902, 0.8549, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.4863, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7137, 0, 0, 0.5137, 0.9922, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1490, 0.8510, 0.9922, 0.9922, 0.9569, 0.4353, 0.1451, 0, 0, 0.5137, 0.9922, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.4863, 0.9922, 0.9922, 0.9922, 0.6471, 0, 0, 0, 0.0863, 0.7137, 0.9922, 0.9922, 0.2353, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.4863, 0.9922, 0.9922, 0.9412, 0.1765, 0, 0, 0, 0.2078, 0.9922, 0.9922, 0.9765, 0.2275, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0627, 0.6588, 0.9922, 0.8471, 0.1765, 0, 0, 0, 0, 0.2078, 0.9922, 0.9922, 0.5412, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.6235, 0.9922, 0.9922, 0.5765, 0, 0, 0, 0, 0, 0.2078, 0.9922, 0.9922, 0.5412, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5333, 0.9882, 0.9922, 0.8902, 0.0196, 0, 0, 0, 0, 0, 0.2078, 0.9922, 0.9529, 0.3961, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5490, 0.9922, 0.9922, 0.4863, 0, 0, 0, 0, 0, 0, 0.6118, 0.9922, 0.8549, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0510, 0.6431, 0.9922, 0.5569, 0.0196, 0, 0, 0, 0, 0, 0.1255, 0.9137, 0.9922, 0.8549, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2431, 0.9922, 0.9922, 0.5098, 0, 0, 0, 0, 0, 0.1451, 0.7961, 0.9922, 0.9922, 0.4980, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2431, 0.9922, 0.9922, 0.5765, 0.1412, 0.1412, 0.1412, 0.1412, 0.5922, 0.8706, 0.9922, 0.9608, 0.4980, 0.0314, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1333, 0.7922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7843, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5490, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9725, 0.9216, 0.2549, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3412, 0.6784, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7137, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0549, 0.3059, 0.3765, 0.9922, 0.9922, 0.9922, 0.5373, 0.2196, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), (0.12060000635023545,[[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1098, 0.6431, 0.9961, 0.9137, 0.5804, 0.0431, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0118, 0.6431, 0.9961, 0.9176, 0.8824, 0.9961, 0.8000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3569, 0.9961, 0.9216, 0.1882, 0.1255, 0.6510, 0.9843, 0.3608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1294, 0.4353, 0.8392, 0.8039, 0.1922, 0, 0, 0.0941, 0.8471, 0.8235, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1333, 0.8510, 0.9961, 0.9961, 0.8275, 0, 0, 0, 0, 0.3412, 0.9294, 0.1686, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1333, 0.8471, 0.9961, 0.9961, 0.9882, 0.9529, 0.2392, 0, 0, 0, 0.1490, 0.9725, 0.7137, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.6706, 0.9961, 0.7216, 0.8039, 0.6863, 0.1412, 0, 0, 0, 0, 0, 0.6706, 0.8902, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1098, 0.9176, 0.7451, 0.0510, 0.7569, 0.6157, 0, 0, 0, 0, 0, 0, 0.4863, 0.9333, 0.1020, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5490, 0.9961, 0.5137, 0, 0.5059, 0.6157, 0, 0, 0, 0, 0, 0, 0.4863, 0.9961, 0.3725, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.7882, 0.9333, 0.2196, 0, 0.2745, 0.4039, 0, 0, 0, 0, 0, 0, 0.4863, 0.9961, 0.5804, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2431, 1.0000, 0.8235, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5882, 0.9961, 0.4784, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3373, 0.9961, 0.7882, 0.0588, 0, 0, 0, 0, 0, 0, 0, 0, 0.1098, 0.9294, 0.9647, 0 |
データ挿入後のlistsByDigitをエクセルで覗いてみると以下のように、ゼロからスタートして、ゼロが18個入っていた。
いくつかの疑問:以下の通り:
double score = net.score(new DataSet(example,example));
scoreは何を見ているのか:Class MultiLayerNetwork
double score(DataSet data, boolean training)
Calculate the score (loss function) of the prediction with respect to the true labels