もう少しDL4JのExampleを調査する。
MNISTAnomalyExampleは、
http://yann.lecun.com/exdb/mnist/
にある50,000枚の手書き数字学習用画像ライブラリについて、機械学習オートエンコーダで判断で最善のスコア5件と最悪のスコア5件をリストアップする
さて、このアルゴリズムがどのようにコードに埋め込まれているのか確認していこう。
入力データアルゴリズム:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
//Load data and split into training and testing sets. 40000 train, 10000 test: 40,000の訓練データと10,000の評価データ DataSetIterator iter = new MnistDataSetIterator(100,50000,false); //バッチサイズ100、 50,000の訓練データ // MnistDataSetIterator(int batchSize, boolean train, int seed) List<INDArray> featuresTrain = new ArrayList<>(); //INDArrayデータを複数含むリストである。 List<INDArray> featuresTest = new ArrayList<>(); //INDArrayデータを複数含むリストである。 List<INDArray> labelsTest = new ArrayList<>(); //INDArrayデータを複数含むリストである。 Random r = new Random(12345); while(iter.hasNext()){ DataSet ds = iter.next(); //dsにデータ・セット SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //80/20 split (from miniBatch = 100) 訓練データ80と評価データ20を分離 featuresTrain.add(split.getTrain().getFeatures()); //訓練データをfeaturesTrainへスプリット DataSet dsTest = split.getTest(); featuresTest.add(dsTest.getFeatures()); //評価データをfeaturesTestへスプリット INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //Convert from one-hot representation -> index labelsTest.add(indexes); //one-hotで得られた最大値をindexesにおさめて、labelsTestへ加える } |
ニューラルネットワーク構造:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
//Load data and split into training and testing sets. 40000 train, 10000 test: 40,000の訓練データと10,000の評価データ DataSetIterator iter = new MnistDataSetIterator(100,50000,false); //バッチサイズ100、 50,000の訓練データ // MnistDataSetIterator(int batchSize, boolean train, int seed) List<INDArray> featuresTrain = new ArrayList<>(); //INDArrayデータを複数含むリストである。 List<INDArray> featuresTest = new ArrayList<>(); //INDArrayデータを複数含むリストである。 List<INDArray> labelsTest = new ArrayList<>(); //INDArrayデータを複数含むリストである。 Random r = new Random(12345); while(iter.hasNext()){ DataSet ds = iter.next(); //dsにデータ・セット SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //80/20 split (from miniBatch = 100) 訓練データ80と評価データ20を分離 featuresTrain.add(split.getTrain().getFeatures()); //訓練データをfeaturesTrainへスプリット DataSet dsTest = split.getTest(); featuresTest.add(dsTest.getFeatures()); //評価データをfeaturesTestへスプリット INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //Convert from one-hot representation -> index labelsTest.add(indexes); //one-hotで得られた最大値をindexesにおさめて、labelsTestへ加える |
訓練モデル:
1 2 3 4 5 6 7 8 |
//Train model: int nEpochs = 30; //30エポック for( int epoch=0; epoch<nEpochs; epoch++ ){ for(INDArray data : featuresTrain){ net.fit(data,data); //訓練 } System.out.println("Epoch " + epoch + " complete"); } |
テストデータでのモデル評価
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
//Evaluate the model on the test data テストデータでのモデル評価 //Score each example in the test set separately //Compose a map that relates each digit to a list of (score, example) pairs スコアと例のリスト作成 //Then find N best and N worst scores per digit ベストとワーストを見つける Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>(); //listsByDifit Integer,List<Pair<Double,INDArray>>というMapをHashMap関数を使って実装する。 // Pair(K key, V value) for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>()); // 0から10まで、Map実装listsByDigitにdigitとArrayList<>を配置。 for( int i=0; i<featuresTest.size(); i++ ){ //featuresTest(List<INDArray>)のサイズ繰り返す。500個 INDArray testData = featuresTest.get(i); //featuresTest(List<INDArray>)の中身を順番にINDArray testDataへ取り出す。 INDArray labels = labelsTest.get(i); // List<INDArray> labelsTestの中身を順番にINDArray labelsへ取り出す。 int nRows = testData.rows(); // INDArray testDataの行数をnRowsへ 20行:テストデータは20件。 for( int j=0; j<nRows; j++){ // INDArray testDataの20行数繰り返す。 INDArray example = testData.getRow(j); //行づつINDArray exampleへ移す。 int digit = (int)labels.getDouble(j); // INDArray labelsの中身を移す。 double score = net.score(new DataSet(example,example)); // Add (score, example) pair to the appropriate list List digitAllPairs = listsByDigit.get(digit); digitAllPairs.add(new ImmutablePair<>(score, example)); } } |
スコアリング結果でのソート:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
//Sort each list in the map by score Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() { // public interface Comparator<T> オブジェクトのコレクションで全体順序付けを行う比較関数です。 //Doubleクラスは、プリミティブ型doubleの値をオブジェクトにラップします。 @Override //スーパークラスのメソッドをサブクラスで記述 public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) { return Double.compare(o1.getLeft(),o2.getLeft()); } // compare(T o1,T o2) 順序付けのために2つの引数を比較します。 }; for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){ Collections.sort(digitAllPairs, c); } // Collections.sort(List<T> list, Comparator<? super T> c) 指定されたコンパレータが示す順序に従って、指定されたリストをソートします。 |
Best,Worst画像の表示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
//After sorting, select N best and N worst scores (by reconstruction error) for each digit, where N=5 List<INDArray> best = new ArrayList<>(50); List<INDArray> worst = new ArrayList<>(50); for( int i=0; i<10; i++ ){ List<Pair<Double,INDArray>> list = listsByDigit.get(i); for( int j=0; j<5; j++ ){ best.add(list.get(j).getRight()); worst.add(list.get(list.size()-j-1).getRight()); } } //Visualize the best and worst digits MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0,best,"Best (Low Rec. Error)"); bestVisualizer.visualize(); MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0,worst,"Worst (High Rec. Error)"); worstVisualizer.visualize(); } |
数字画像を表示させる関数 MNISTVisualizer()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
public static class MNISTVisualizer { private double imageScale; private List<INDArray> digits; //Digits (as row vectors), one per INDArray private String title; private int gridWidth; public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) { this(imageScale, digits, title, 5); } public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) { this.imageScale = imageScale; this.digits = digits; this.title = title; this.gridWidth = gridWidth; } public void visualize(){ JFrame frame = new JFrame(); frame.setTitle(title); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); JPanel panel = new JPanel(); panel.setLayout(new GridLayout(0,gridWidth)); List<JLabel> list = getComponents(); for(JLabel image : list){ panel.add(image); } frame.add(panel); frame.setVisible(true); frame.pack(); } private List<JLabel> getComponents(){ List<JLabel> images = new ArrayList<>(); for( INDArray arr : digits ){ // List<INDArray> digits 繰り返し BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY); for( int i=0; i<784; i++ ){ //28x28=784 データ値をBufferedImageへ bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i))); } ImageIcon orig = new ImageIcon(bi); Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE); ImageIcon scaled = new ImageIcon(imageScaled); images.add(new JLabel(scaled)); } return images; } } } |
ここまででアルゴリズムの概略が把握できたが、詳細なデータ構造は、多重のリストや配列、マップでちょっと複雑ですね。そこで、変数の内容をファイルに書き込んで、データを分析してみよう。
標準出力をファイルへのデータ書き込みへ変更するコードを埋め込む。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>(); for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>()); PrintStream sysOut = System.out; FileOutputStream fos = new FileOutputStream("out_MNISTAnom.txt"); PrintStream ps = new PrintStream(fos); System.setOut(ps); for( int i=0; i<featuresTest.size(); i++ ){ INDArray testData = featuresTest.get(i); INDArray labels = labelsTest.get(i); int nRows = testData.rows(); //System.out.println("featuresTest.size:"+ featuresTest.size()); //System.out.println("testData:"+ testData.toString()); //System.out.println("labesl:" + labels); //System.out.println("nRows:"+ nRows); for( int j=0; j<nRows; j++){ INDArray example = testData.getRow(j); int digit = (int)labels.getDouble(j); System.out.println("example:"+ example); System.out.println("digit:"+ digit); double score = net.score(new DataSet(example,example)); // Add (score, example) pair to the appropriate list List digitAllPairs = listsByDigit.get(digit); digitAllPairs.add(new ImmutablePair<>(score, example)); System.out.println("score:"+ score); //System.out.println("digitAllPairs:"+ digitAllPairs); } } ps.close(); fos.close(); System.setOut(sysOut); |
ファイルに記述された1万個の評価用データの中から、はじめの1個めのexmple変数の内容、digit、そしてscoreを覗いてみる。
1 2 3 |
example:[[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0235, 0.0745, 0.5216, 0.5216, 0.6118, 0.9961, 0.9961, 0.8392, 0.3255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0392, 0.5255, 0.7725, 0.9961, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.6118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1098, 0.3059, 0.7608, 0.9922, 0.9922, 0.9961, 0.9804, 0.8510, 0.8510, 0.8863, 0.9922, 0.9922, 0.6118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0392, 0.5294, 0.9176, 0.9922, 0.9922, 0.9922, 0.9922, 0.9647, 0.2980, 0, 0.0392, 0.3843, 0.9922, 0.9922, 0.5216, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0353, 0.6078, 0.9922, 0.9922, 0.9922, 0.8784, 0.7765, 0.5255, 0.2706, 0, 0, 0.3059, 0.9922, 0.9922, 0.7529, 0.0392, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0353, 0.5490, 0.7059, 0.3451, 0.2353, 0.1255, 0.0235, 0, 0, 0, 0.2471, 0.9176, 0.9922, 0.6392, 0.1569, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0353, 0.7412, 0.9922, 0.6392, 0.0471, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0431, 0.8667, 0.9922, 0.8275, 0.0392, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2549, 0.9922, 0.9333, 0.2314, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2314, 0.9765, 0.9490, 0.2196, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2510, 0.9765, 0.9765, 0.2039, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0392, 0.9098, 0.9922, 0.6314, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2078, 0.7608, 0.9922, 0.6980, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.6510, 0.9922, 0.9098, 0.1922, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3216, 0.9412, 0.9020, 0.2039, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0706, 0.8275, 0.8980, 0.2118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0667, 0.4941, 0.9922, 0.5020, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1922, 0.9922, 0.5647, 0.0275, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0627, 0.7647, 0.8902, 0.1333, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5961, 0.9020, 0.1686, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] digit:7 score:0.06809146115858107 |
このデータをエクセルに移してみると、数値が7であることがわかる。スコアは、0.06809146115858107.
リストや配列の中にそういうデータが転送されてくるのか、もう一度、整理確認してみよう。