Chainer Linear を確認する

Chainer Linear を確認する

理解が進む度に、書き加えていく。

サンプルスクリプト

y = 2x + 2 を想定したもの。

#!/usr/bin/env python
# coding:utf-8
import numpy as np
import chainer.functions as F
import chainer.links as L
from chainer import Variable, optimizers

# モデル定義
model = L.Linear(1, 1)

optimizer = optimizers.SGD()
optimizer.setup(model)

# 学習させる回数
times   = 1000

# 入力ベクトル
x   = Variable(np.array( [[1],[3],[5],[7]], dtype=np.float32))
#x  = Variable(np.array( [[1]], dtype=np.float32))

# 正解ベクトル
t   = Variable(np.array( [[2],[6],[10],[14]], dtype=np.float32))
#t  = Variable(np.array( [[2]], dtype=np.float32))

# 学習ループ
for i in range(0, times):
    optimizer.zero_grads()  # 勾配を初期化
    y   = model(x)          # モデルに予測させる
    loss    = F.mean_squared_error(y, t) # 損失を計算
    print("Data: {}, Loss: {}".format(y.data, loss.data) )
    loss.backward()         # 逆伝播する
    optimizer.update()      # optimizer を更新する

print("Weight : {}".format(model.W.data) )
print("Bias   : {}".format(model.b.data) )

print("---TEST---")
x   = Variable(np.array( [[3],[4],[5]], dtype=np.float32) )
y   = model(x)

print("Test data   : {}".format(x.data) )
print("Test result : {}".format(y.data) )

実行

$ python 001.py 

 ・・・

Data: [[  3.98719549]
 [  7.99330425]
 [ 11.99941349]
 [ 16.00552177]], Loss: 5.99056111241e-05
Weight : [[ 2.00304031]]
Bias   : [ 1.98421395]
---TEST---
Test data   : [[ 3.]
 [ 4.]
 [ 5.]]
Test result : [[  7.99333477]
 [  9.99637508]
 [ 11.9994154 ]]

ということで、

y = 2.00304031 x + 1.98421395

と、作成されました!

ソースコード

/usr/local/lib/python2.7/dist-packages/chainer/links/connection/linear.py

/usr/local/lib/python2.7/dist-packages/chainer/functions/connection/linear.py

APIから

Linearモデル

コンスタント

class chainer.links.Linear(in_size, out_size, wscale=1, bias=0, nobias=False, initialW=None, initial_bias=None) とのこと。

今回は model = L.Linear(1, 1) として、 in_size=1, out_size=1 で作っている。

class Linear(link.Link): とのことで、link.Link を継承しているらしい。

call

model(x) すると、 __call__ が呼ばれるわけだけど、そのなかで

from chainer.functions.connection import linear
 ・・・
return linear.linear(x, self.W, self.b)

なので、結局のところ何をするのかは、 functions/connection/linear.py を見る。

def linear(x, W, b=None):
    if b is None:
        return LinearFunction()(x, W)
    else:
        return LinearFunction()(x, W, b)

コメントをカットしたらこうなった。

LinearFunction()(x, W) をみて、ぎょっとしたが、これは、LinearFunction インスタンスを作成して、その ___call___ を読んでいる。

class LinearFunction(function.Function): とのことなので、 Function をみよう。

/usr/local/lib/python2.7/dist-packages/chainer/function.py をみる。

class Function(object):
 ・・・
    def __call__(self, *inputs):
    """Applies forward propagation with chaining backward references.

とのこと。

cuda が使えるかどうかなどで条件分岐されているが、forward するんだろと思う。

forward

    def forward(self, inputs):
        x = _as_mat(inputs[0])
        W = inputs[1]
        y = x.dot(W.T).astype(x.dtype, copy=False)
        if len(inputs) == 3:
            b = inputs[2]
            y += b
        return y,

ドット演算(行列の積)している。

あてがうのは、 W.data.T なので、 transpose されたものだ。

検算

y = 2.00304031 x + 1.98421395

に、x=3, 4, 5 を入れてみると、たしかに、 7.99333477, 9.99637508, 11.9994154 になる。