Haskell で『ゼロから作るDeep Learning』(2)
『ゼロから作るDeep Learning – Pythonで学ぶディープラーニングの理論と実装』 の読書メモです。
今回は 「3.6 手書き数字認識」の手前まで。
3章 ニューラルネットワーク
ニューラルネットワークとは
- パーセプトロンでは人力で重みを決めていくのに対して、ニューラルネットワークは適切な重みを自動で学習できる
- ニューラルネットワークは入力層、中間層(隠れ層)、出力層で構成される
- 前章の
ニューロンは受け取った信号の総和を計算し、それがある値を超えたときのみに 1 を出力する
というような動作は関数で表すことができる。このような関数を「活性化関数」と呼ぶ - 出力層では、重み付き入力信号とバイアスの総和を計算し、それを活性化関数に渡した結果が出力される
活性化関数
- 前章のパーセプトロンで使われていた、閾値を境にして出力が切り替わる関数を「ステップ関数(階段関数)」と呼ばれる
- 活性化関数としてステップ関数以外の関数を使うことでパーセプトロンからニューラルネットワークの世界へ進むことができる
- ニューラルネットワークでよく使われる活性化関数として「シグモイド関数」と呼ばれる関数がある
- ステップ関数とシグモイド関数にはいくつか共通点がある
- 入力信号の重要度に応じて、重要であれば大きな値を、逆に重要でなければ小さな値を出力する
- 出力信号の値は 0 から 1 の間
- 非線形関数である
- 何か入力に対して、出力が入力の定数倍になるような関数を線形関数と呼ぶので、そうではない関数のこと
- 最近は ReLU 関数というのも使われる
- 入力が 0 を超えるとその入力をそのまま出力し、 0 以下なら 0 を出力する関数
出力層の設計
- ニューラルネットワークの出力層の活性化関数は、問題の種類によって変更する必要がある
- 回帰問題なら恒等関数、分類問題ならソフトマックス関数
- ソフトマックス関数の出力の総和は 1 になる(もちろん各出力は 0 から 1.0 の実数の間に収まる)
- この性質のおかげでこの出力を確率として解釈することができる
- 確率として解釈できるから分類問題に適している
- しかし、各要素の大小関係はこの関数を適用しても変わらないので、推論(分類)のフェーズでは省略されるのが一般的
- 出力層のニューロンの数は解く問題に応じて決める必要がある
- たとえば、手描きの画像が数字の 0 から 9 のどれか当てる問題なら、ニューロンの数を 10 個に設定する
実践編
行列の計算をやってみる
hmatrix ライブラリを使うことでベクトルや行列の計算ができるようになります。 以下は書籍で使われている Python の NumPy ライブラリとの比較です。
行列の生成
// Python import numpy as np A = np.array([1,2],[3,4]) B = np.array([5,6],[7,8]) // Haskell // `R` は `Double` の型シノニムです。 import Numeric.LinearAlgebra let a = (2><2) [1,2,3,4] :: Matrix R let b = (2><2) [5,6,7,8] :: Matrix R
行列の次元数
// Python A.shape // Haskell size a
行列の積
// Python np.dot(A,B) // Haskell a <> b
例として、書籍 p.57 のニュートラルネットワークの内積を hmatrix を使って計算してみます。
*Main Numeric.LinearAlgebra> let x = (1><2) [1,2] :: Matrix R *Main Numeric.LinearAlgebra> size x (1,2) *Main Numeric.LinearAlgebra> let w = (2><3) [1,3,5,2,4,6] :: Matrix R *Main Numeric.LinearAlgebra> print w (2><3) [ 1.0, 3.0, 5.0 , 2.0, 4.0, 6.0 ] *Main Numeric.LinearAlgebra> size w (2,3) *Main Numeric.LinearAlgebra> let y = x <> w *Main Numeric.LinearAlgebra> print y (1><3) [ 5.0, 11.0, 17.0 ]
活性化関数のグラフを描画してみる
グラフ描画には今回 Chart というライブラリを使いました。実際にコードを書くときにはパッケージのドキュメントを見ても使い方がよくわからなかったので、こちらの Wiki を見ながらやりました。
そして、書いたコードから出力されたグラフが以下。