chainer MNIST を試す

下準備

git, matplotlib を入れます。

$ sudo apt-get install git
$ sudo pip install matplotlib

データをもらう

$ sudo git clone https://github.com/pfnet/chainer.git

mnist を動かす

$ cd chainer/examples/mnist/
$ ./train_mnist.py -g 0

ここでエラーが出ました。

nvidia-smi でもエラーが出たので、ドライバーを入れ直した所、正常に動きました。

初回の起動でデータセットをダウンロードしますが、http://yann.lecun.com/exdb/mnist/ のサイトがダウンしてたりします。

その際は、/usr/local/lib/python2.7/dist-packages/chainer/datasets/mnist.py のURLを適当なミラーサイトに書き換えれば、とりあえず先に進めます。

次のエラーが出ました。

96     if extensions.PlotReport.available():

この部分で AttributeError: type object 'PlotReport' has no attribute 'available' ということです。

ソースコードをみると、コンストラクタの中で def _check_available(): しています。
バージョンアップに伴って、メソッドの改定が有ったのでしょうか。
正常にコンストラクトできたら、そのチェックは出来ているという雰囲気?!

やむないので、

96     #if extensions.PlotReport.available():
97     if True:

などしちゃいます。

とりあえず実行する。

$ ./train_mnist.py -g 0
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 20

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.193103    0.0990286             0.94295        0.9691                    25.6391       
2           0.0731526   0.071707              0.977382       0.9773                    31.8223       
3           0.0473904   0.086795              0.984782       0.9736                    37.7731       
4           0.0367857   0.081037              0.987615       0.9758                    44.0105       
5           0.0254271   0.0803791             0.991698       0.9775                    49.7157       
6           0.0242218   0.0728918             0.992148       0.9819                    55.351        
7           0.0225581   0.0713523             0.992582       0.9805                    61.1866       
8           0.0151419   0.085011              0.995082       0.9789                    66.9343       
9           0.0154926   0.102077              0.994649       0.9768                    72.5936       
10          0.0187828   0.072758              0.993915       0.9823                    78.331        
11          0.0160584   0.0846834             0.994532       0.9812                    83.9434       
12          0.0101223   0.0924774             0.996732       0.9807                    89.8204       
13          0.0125349   0.0968882             0.996049       0.98                      96.0304       
14          0.0136353   0.079207              0.995899       0.9839                    101.928       
15          0.0109211   0.138                 0.996482       0.9743                    107.863       
16          0.00852963  0.096225              0.997516       0.9812                    113.651       
17          0.00918454  0.0962035             0.997199       0.9819                    119.358       
18          0.00934543  0.0869832             0.997282       0.984                     125.122       
19          0.00991504  0.0877433             0.996999       0.983                     130.87        
20          0.00880426  0.103433              0.997282       0.9804                    136.807

バージョンに有った train_mnist.py を使う

もしくは、ちゃんと github から、自分が使う chainer の version に有ったコードを手に入れます。

Chainer 1.21.0 用

chainer/train_mnist.py at v1.21.0 · pfnet/chainer · GitHub

実行

$ ./train_mnist.py -g 0
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 20

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.19085     0.099535              0.941401       0.9684                    6.44824       
2           0.0744746   0.0768928             0.976717       0.9764                    10.8679       
3           0.0486528   0.0741029             0.984432       0.9772                    15.2855       
4           0.0382096   0.0868596             0.987816       0.9758                    19.7071       
5           0.0295276   0.073381              0.990398       0.9795                    24.1291       
6           0.0233222   0.0931531             0.992381       0.9751                    28.5248       
7           0.0199249   0.0811498             0.993565       0.9801                    32.9042       
8           0.0202854   0.10219               0.993432       0.9779                    37.2762       
9           0.0196941   0.0857413             0.993398       0.9805                    41.6189       
10          0.0114205   0.0849754             0.995932       0.9811                    45.969        
11          0.0153012   0.0857274             0.994949       0.9808                    50.3032       
12          0.0126985   0.0856742             0.996065       0.9805                    54.6374       
13          0.0142371   0.0987941             0.995348       0.9802                    58.9803       
14          0.00942709  0.0975199             0.996999       0.9811                    63.318        
15          0.0112361   0.109751              0.997066       0.9772                    67.661        
16          0.00985846  0.0991543             0.996982       0.9815                    71.9995       
17          0.0107471   0.107608              0.996932       0.9821                    76.3405       
18          0.0105749   0.11046               0.996966       0.98                      80.6792       
19          0.00803951  0.123615              0.997666       0.9814                    85.0176       
20          0.0115847   0.0917939             0.996799       0.9853                    89.358    

速度について

なお、結構遅いです。

nvidia-smi でチェックした所、GPU使用率は20%程度ですが、CPU使用率は100%に成っていましたので、私の環境におけるボトルネックGPUではなく、CPUのようです。

ログのプロット

サンプルに含まれていそうだけど、ちょっと探して見当たらなかったので書きました。

import matplotlib.pyplot as plt
import argparse
import json

parser = argparse.ArgumentParser()
parser.add_argument("result")
args = parser.parse_args()

with open(args.result, 'r') as f:
        jsonData = json.load(f)

xs = []
ys1 = []
ys2 = []
for line in jsonData:
#       print(line["main/loss"])
        xs.append(line["iteration"])
        ys1.append(line["main/loss"])
        ys2.append(line["main/accuracy"])

plt.xlabel("iteration")
plt.ylabel("error")
plt.plot(xs, ys1, label="main/loss")
plt.plot(xs, ys2, label="main/accuracy")
plt.legend(loc='upper left')
plt.show()
$ python plot.py ./result/log

f:id:pongsuke:20170317111522p:plain

学習結果を保存して、判定させる

保存する

スナップショット取るように、改変する。

(しらべちゅう)

...
from chainer import serializers

 ...

def main():
    global model

 ...

serializers.save_hdf5('modelhdf5', model)

読み込んで判定させる