chainer MNIST を試す
下準備
git, matplotlib を入れます。
$ sudo apt-get install git $ sudo pip install matplotlib
データをもらう
$ sudo git clone https://github.com/pfnet/chainer.git
mnist を動かす
$ cd chainer/examples/mnist/ $ ./train_mnist.py -g 0
ここでエラーが出ました。
nvidia-smi でもエラーが出たので、ドライバーを入れ直した所、正常に動きました。
初回の起動でデータセットをダウンロードしますが、http://yann.lecun.com/exdb/mnist/
のサイトがダウンしてたりします。
その際は、/usr/local/lib/python2.7/dist-packages/chainer/datasets/mnist.py
のURLを適当なミラーサイトに書き換えれば、とりあえず先に進めます。
次のエラーが出ました。
96 if extensions.PlotReport.available():
この部分で AttributeError: type object 'PlotReport' has no attribute 'available'
ということです。
ソースコードをみると、コンストラクタの中で def _check_available():
しています。
バージョンアップに伴って、メソッドの改定が有ったのでしょうか。
正常にコンストラクトできたら、そのチェックは出来ているという雰囲気?!
やむないので、
96 #if extensions.PlotReport.available(): 97 if True:
などしちゃいます。
とりあえず実行する。
$ ./train_mnist.py -g 0 GPU: 0 # unit: 1000 # Minibatch-size: 100 # epoch: 20 epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time 1 0.193103 0.0990286 0.94295 0.9691 25.6391 2 0.0731526 0.071707 0.977382 0.9773 31.8223 3 0.0473904 0.086795 0.984782 0.9736 37.7731 4 0.0367857 0.081037 0.987615 0.9758 44.0105 5 0.0254271 0.0803791 0.991698 0.9775 49.7157 6 0.0242218 0.0728918 0.992148 0.9819 55.351 7 0.0225581 0.0713523 0.992582 0.9805 61.1866 8 0.0151419 0.085011 0.995082 0.9789 66.9343 9 0.0154926 0.102077 0.994649 0.9768 72.5936 10 0.0187828 0.072758 0.993915 0.9823 78.331 11 0.0160584 0.0846834 0.994532 0.9812 83.9434 12 0.0101223 0.0924774 0.996732 0.9807 89.8204 13 0.0125349 0.0968882 0.996049 0.98 96.0304 14 0.0136353 0.079207 0.995899 0.9839 101.928 15 0.0109211 0.138 0.996482 0.9743 107.863 16 0.00852963 0.096225 0.997516 0.9812 113.651 17 0.00918454 0.0962035 0.997199 0.9819 119.358 18 0.00934543 0.0869832 0.997282 0.984 125.122 19 0.00991504 0.0877433 0.996999 0.983 130.87 20 0.00880426 0.103433 0.997282 0.9804 136.807
バージョンに有った train_mnist.py を使う
もしくは、ちゃんと github から、自分が使う chainer の version に有ったコードを手に入れます。
Chainer 1.21.0 用
chainer/train_mnist.py at v1.21.0 · pfnet/chainer · GitHub
実行
$ ./train_mnist.py -g 0 GPU: 0 # unit: 1000 # Minibatch-size: 100 # epoch: 20 epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time 1 0.19085 0.099535 0.941401 0.9684 6.44824 2 0.0744746 0.0768928 0.976717 0.9764 10.8679 3 0.0486528 0.0741029 0.984432 0.9772 15.2855 4 0.0382096 0.0868596 0.987816 0.9758 19.7071 5 0.0295276 0.073381 0.990398 0.9795 24.1291 6 0.0233222 0.0931531 0.992381 0.9751 28.5248 7 0.0199249 0.0811498 0.993565 0.9801 32.9042 8 0.0202854 0.10219 0.993432 0.9779 37.2762 9 0.0196941 0.0857413 0.993398 0.9805 41.6189 10 0.0114205 0.0849754 0.995932 0.9811 45.969 11 0.0153012 0.0857274 0.994949 0.9808 50.3032 12 0.0126985 0.0856742 0.996065 0.9805 54.6374 13 0.0142371 0.0987941 0.995348 0.9802 58.9803 14 0.00942709 0.0975199 0.996999 0.9811 63.318 15 0.0112361 0.109751 0.997066 0.9772 67.661 16 0.00985846 0.0991543 0.996982 0.9815 71.9995 17 0.0107471 0.107608 0.996932 0.9821 76.3405 18 0.0105749 0.11046 0.996966 0.98 80.6792 19 0.00803951 0.123615 0.997666 0.9814 85.0176 20 0.0115847 0.0917939 0.996799 0.9853 89.358
速度について
なお、結構遅いです。
nvidia-smi でチェックした所、GPU使用率は20%程度ですが、CPU使用率は100%に成っていましたので、私の環境におけるボトルネックは GPUではなく、CPUのようです。
ログのプロット
サンプルに含まれていそうだけど、ちょっと探して見当たらなかったので書きました。
import matplotlib.pyplot as plt import argparse import json parser = argparse.ArgumentParser() parser.add_argument("result") args = parser.parse_args() with open(args.result, 'r') as f: jsonData = json.load(f) xs = [] ys1 = [] ys2 = [] for line in jsonData: # print(line["main/loss"]) xs.append(line["iteration"]) ys1.append(line["main/loss"]) ys2.append(line["main/accuracy"]) plt.xlabel("iteration") plt.ylabel("error") plt.plot(xs, ys1, label="main/loss") plt.plot(xs, ys2, label="main/accuracy") plt.legend(loc='upper left') plt.show()
$ python plot.py ./result/log
学習結果を保存して、判定させる
保存する
スナップショット取るように、改変する。
(しらべちゅう)
... from chainer import serializers ... def main(): global model ... serializers.save_hdf5('modelhdf5', model)