MNISTデータセット 手書き0-9の機械認識
—————————————-
In[1]
## 1.ライブラリの読み込み ##
# TensorFlowライブラリ
import tensorflow as tf
# tflearnライブラリ
import tflearn
# mnistデータセットを扱うためのライブラリ
import tflearn.datasets.mnist as mnist
# MNIST画像を表示するためのライブラリ
from matplotlib import pyplot as plt
from matplotlib import cm
import numpy as np
—————————————-
In[2]
## 2.データの読み込みと前処理 ##
# MNISTデータを./data/mnistへダウンロードし、解凍して各変数へ格納
trainX, trainY, testX, testY = mnist.load_data(‘./data/mnist/’, one_hot=True)
Extracting ./data/mnist/train-images-idx3-ubyte.gz
Extracting ./data/mnist/train-labels-idx1-ubyte.gz
Extracting ./data/mnist/t10k-images-idx3-ubyte.gz
Extracting ./data/mnist/t10k-labels-idx1-ubyte.gz
—————————————-
In[3]
## データの確認
# 学習用の画像ピクセルデータと正解データのサイズを確認
print(len(trainX),len(trainY))
# テスト用の画像ピクセルデータと正解データのサイズを確認
print(len(testX),len(testY))
# 学習用の画像ピクセルデータを確認
print(trainX)
# 学習用の正解データを確認
print(trainY)
out[3]
55000 55000
10000 10000
[[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]
…,
[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]]
[[ 0. 0. 0. …, 1. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]
…,
[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 0. 0.]
[ 0. 0. 0. …, 0. 1. 0.]]
—————————————-
In[4]
# 学習用の画像ピクセルデータを確認(1枚目)
trainX[0]
out[v4]
array([ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.38039219, 0.37647063, 0.3019608 ,
0.46274513, 0.2392157 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.35294119, 0.5411765 , 0.92156869,
0.92156869, 0.92156869, 0.92156869, 0.92156869, 0.92156869,
0.98431379, 0.98431379, 0.97254908, 0.99607849, 0.96078438,
0.92156869, 0.74509805, 0.08235294, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.54901963,
0.98431379, 0.99607849, 0.99607849, 0.99607849, 0.99607849,
0.99607849, 0.99607849, 0.99607849, 0.99607849, 0.99607849,
0.99607849, 0.99607849, 0.99607849, 0.99607849, 0.99607849,
0.74117649, 0.09019608, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.88627458, 0.99607849, 0.81568635,
0.78039223, 0.78039223, 0.78039223, 0.78039223, 0.54509807,
0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 ,
0.50196081, 0.8705883 , 0.99607849, 0.99607849, 0.74117649,
0.08235294, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.14901961, 0.32156864, 0.0509804 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.13333334,
0.83529419, 0.99607849, 0.99607849, 0.45098042, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.32941177, 0.99607849,
0.99607849, 0.91764712, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.32941177, 0.99607849, 0.99607849, 0.91764712,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.41568631, 0.6156863 ,
0.99607849, 0.99607849, 0.95294124, 0.20000002, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.09803922, 0.45882356, 0.89411771, 0.89411771,
0.89411771, 0.99215692, 0.99607849, 0.99607849, 0.99607849,
0.99607849, 0.94117653, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.26666668, 0.4666667 , 0.86274517,
0.99607849, 0.99607849, 0.99607849, 0.99607849, 0.99607849,
0.99607849, 0.99607849, 0.99607849, 0.99607849, 0.55686277,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.14509805, 0.73333335,
0.99215692, 0.99607849, 0.99607849, 0.99607849, 0.87450987,
0.80784321, 0.80784321, 0.29411766, 0.26666668, 0.84313732,
0.99607849, 0.99607849, 0.45882356, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.44313729, 0.8588236 , 0.99607849, 0.94901967, 0.89019614,
0.45098042, 0.34901962, 0.12156864, 0. , 0. ,
0. , 0. , 0.7843138 , 0.99607849, 0.9450981 ,
0.16078432, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.66274512, 0.99607849,
0.6901961 , 0.24313727, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.18823531,
0.90588242, 0.99607849, 0.91764712, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.07058824, 0.48627454, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.32941177, 0.99607849, 0.99607849,
0.65098041, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.54509807, 0.99607849, 0.9333334 , 0.22352943, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.82352948, 0.98039222, 0.99607849,
0.65882355, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.94901967, 0.99607849, 0.93725497, 0.22352943, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.34901962, 0.98431379, 0.9450981 ,
0.33725491, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.01960784,
0.80784321, 0.96470594, 0.6156863 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01568628, 0.45882356, 0.27058825,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. ], dtype=float32)
—————————————-
In[5]
# 学習用の画像データを確認(1枚目)
plt.imshow(trainX[0].reshape(28, 28), cmap=cm.gray_r, interpolation=’nearest’)
plt.show()
—————————————-
In[6]
# 学習用の正解データを確認(1枚目)
trainY[0]
out[6]
array([ 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
—————————————-
In[7]
## 3.ニューラルネットワークの作成 ##
## 初期化
tf.reset_default_graph()
## 入力層の作成
net = tflearn.input_data(shape=[None, 784])
## 中間層の作成
net = tflearn.fully_connected(net, 128, activation=’relu’)
net = tflearn.dropout(net, 0.5)
## 出力層の作成
net = tflearn.fully_connected(net, 10, activation=’softmax’)
net = tflearn.regression(net, optimizer=’sgd’, learning_rate=0.5, loss=’categorical_crossentropy’)
—————————————-
In[8]
## 4.モデルの作成(学習) ##
# 学習の実行
model = tflearn.DNN(net)
model.fit(trainX, trainY, n_epoch=20, batch_size=100, validation_set=0.1, show_metric=True)
Out[8]
Training Step: 9899 | total loss: 0.10077 | time: 3.300s
| SGD | epoch: 020 | loss: 0.10077 – acc: 0.9680 — iter: 49400/49500
Training Step: 9900 | total loss: 0.09962 | time: 4.334s
| SGD | epoch: 020 | loss: 0.09962 – acc: 0.9682 | val_loss: 0.08137 – val_acc: 0.9767 — iter: 49500/49500
—
—————————————-
In[9]
# 5.モデルの適用(予測) ##
pred = np.array(model.predict(testX)).argmax(axis=1)
print(pred)
label = testY.argmax(axis=1)
print(label)
accuracy = np.mean(pred == label, axis=0)
print(accuracy)
Out[9]
[7 2 1 …, 4 5 6]
[7 2 1 …, 4 5 6]
0.976
—————————————-