Chainer Linear を確認する
Chainer Linear を確認する
理解が進む度に、書き加えていく。
サンプルスクリプト
y = 2x + 2
を想定したもの。
#!/usr/bin/env python # coding:utf-8 import numpy as np import chainer.functions as F import chainer.links as L from chainer import Variable, optimizers # モデル定義 model = L.Linear(1, 1) optimizer = optimizers.SGD() optimizer.setup(model) # 学習させる回数 times = 1000 # 入力ベクトル x = Variable(np.array( [[1],[3],[5],[7]], dtype=np.float32)) #x = Variable(np.array( [[1]], dtype=np.float32)) # 正解ベクトル t = Variable(np.array( [[2],[6],[10],[14]], dtype=np.float32)) #t = Variable(np.array( [[2]], dtype=np.float32)) # 学習ループ for i in range(0, times): optimizer.zero_grads() # 勾配を初期化 y = model(x) # モデルに予測させる loss = F.mean_squared_error(y, t) # 損失を計算 print("Data: {}, Loss: {}".format(y.data, loss.data) ) loss.backward() # 逆伝播する optimizer.update() # optimizer を更新する print("Weight : {}".format(model.W.data) ) print("Bias : {}".format(model.b.data) ) print("---TEST---") x = Variable(np.array( [[3],[4],[5]], dtype=np.float32) ) y = model(x) print("Test data : {}".format(x.data) ) print("Test result : {}".format(y.data) )
実行
$ python 001.py ・・・ Data: [[ 3.98719549] [ 7.99330425] [ 11.99941349] [ 16.00552177]], Loss: 5.99056111241e-05 Weight : [[ 2.00304031]] Bias : [ 1.98421395] ---TEST--- Test data : [[ 3.] [ 4.] [ 5.]] Test result : [[ 7.99333477] [ 9.99637508] [ 11.9994154 ]]
ということで、
y = 2.00304031 x + 1.98421395
と、作成されました!
ソースコード
/usr/local/lib/python2.7/dist-packages/chainer/links/connection/linear.py
/usr/local/lib/python2.7/dist-packages/chainer/functions/connection/linear.py
APIから
Linearモデル
コンスタント
class chainer.links.Linear(in_size, out_size, wscale=1, bias=0, nobias=False, initialW=None, initial_bias=None)
とのこと。
今回は model = L.Linear(1, 1)
として、 in_size=1, out_size=1 で作っている。
class Linear(link.Link):
とのことで、link.Link
を継承しているらしい。
call
model(x)
すると、 __call__
が呼ばれるわけだけど、そのなかで
from chainer.functions.connection import linear ・・・ return linear.linear(x, self.W, self.b)
なので、結局のところ何をするのかは、 functions/connection/linear.py
を見る。
def linear(x, W, b=None): if b is None: return LinearFunction()(x, W) else: return LinearFunction()(x, W, b)
コメントをカットしたらこうなった。
LinearFunction()(x, W)
をみて、ぎょっとしたが、これは、LinearFunction インスタンスを作成して、その ___call___
を読んでいる。
class LinearFunction(function.Function):
とのことなので、 Function をみよう。
/usr/local/lib/python2.7/dist-packages/chainer/function.py
をみる。
class Function(object): ・・・ def __call__(self, *inputs): """Applies forward propagation with chaining backward references.
とのこと。
cuda が使えるかどうかなどで条件分岐されているが、forward するんだろと思う。
forward
def forward(self, inputs): x = _as_mat(inputs[0]) W = inputs[1] y = x.dot(W.T).astype(x.dtype, copy=False) if len(inputs) == 3: b = inputs[2] y += b return y,
ドット演算(行列の積)している。
あてがうのは、 W.data.T なので、 transpose されたものだ。
検算
y = 2.00304031 x + 1.98421395
に、x=3, 4, 5 を入れてみると、たしかに、 7.99333477, 9.99637508, 11.9994154 になる。