読者です 読者をやめる 読者になる 読者になる

お布団宇宙ねこ

Haskell ねこ

HaskellでMNISTを使えるようにする

Qiita に 雑な記事 を書いただけになっていたのでちょっとした解説記事を書きました。

今回やったこと

『ゼロから作るDeep Learning – Pythonで学ぶディープラーニングの理論と実装』 のMNISTを扱うための下記サンプルコードを、PythonからHaskellに実装し直しました。

deep-learning-from-scratch/mnist.py at master · oreilly-japan/deep-learning-from-scratch

MNISTについて

MNISTとは手書き数字画像のデータセットのことです。

MNISTには画像データ本体とそれに対応した数字のラベルがあり、それぞれに訓練用とテスト用のものが用意されています。

データの中身は例えば訓練画像は以下のような構成です。 16バイト以降が画像データで、28*28バイトのピクセルデータが60000枚入っています。

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

また、訓練ラベルは以下のような構成です。 8バイト以降がラベルデータで、1つ1バイトで10000枚入っています。

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label

The labels values are 0 to 9.

ref: MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

サンプルコードの処理の流れ

元のサンプルコードでは以下のような処理をやっています。

  • MNISTを準備する
    • MNISTをダウンロード
    • MNISTをNumPyで変換する
    • 変換したものをPickle化して保存する
  • Pickle化されたファイルを読み込んで復元する
  • オプションパラメータに応じて処理を行う
  • 画像とラベルのデータセットを返す

どう実装したのか

上記の処理フローの中でも重要な部分について解説します。

MNISTをNumPyで変換する→hmatrixで変換する

サンプルコードでは NumPy というライブラリを利用することでベクトルや行列を扱っています。

Haskellで実装する際には hmatrix というライブラリを利用しました。

例えば行列はこんな感じで定義できます。

// Python
import numpy as np
A = np.array([1,2],[3,4])

// Haskell
import Numeric.LinearAlgebra
let a = (2><2) [1,2,3,4] :: Matrix R

Haskellでは型を指定する必要があるので、画像データを Matrix R 、ラベルデータを Vector R とし、それをまとめたデータセットをタプルで定義しました。

type Image = Matrix R
type Label = Vector R
type DataSet = (Image, Label)

ダウンロードしたデータセットをhmatrixで変換するコードは下記のようになります。 画像データとラベルデータそれぞれの変換関数を作りました。

loadImg :: String -> IO Image
loadImg fn = do
    c <- fmap GZ.decompress (BL.readFile $ generatePath fn)
    return . matrix imgSize . toDoubleList $ BL.drop 16 c

loadLabel :: String -> IO Label
loadLabel fn = do
    c <- fmap GZ.decompress (BL.readFile $ generatePath fn)
    return . vector . toDoubleList $ BL.drop 8 c

BLData.ByteString.LazyGZCodec.Compression.GZip のことです。

fmap GZ.decompress (BL.readFile $ generatePath fn) の部分では、あるデータセットファイルを ByteString で読み込み decompress という関数を使って解凍しています。(ダウンロードするファイルはgz形式であるため)

はじめの方の項目で書いたようにMNISTのデータセットは、画像データが16バイト以降、ラベルデータが8バイト以降が必要となるデータなので、それより前は drop で切り捨てています。

hmatrixへの変換は、画像データの場合は matrix imgSize . toDoubleList 、ラベルデータの場合は vector . toDoubleList でやっています。共通する関数 toDoubleListVectorMatrix で扱う数値の型を Double にするための変換関数です。

toDoubleList :: BL.ByteString -> [Double]
toDoubleList = fmap (read . show . fromEnum) . BL.unpack

matrix には列数を引数として渡すことで行列を作ることができます。1行に画像データ1つが入ればよいので列数には画像サイズの 28*28=784 を指定します。一方でラベルデータは、1行にすべてのデータを入れるので vector を使っています。

最後に、loadImg , loadLabel を使ってhmatrixで変換した画像データとラベルデータをデータセット DataSet 型としてまとめます。

convertDataset :: IO [DataSet]
convertDataset = do
    tri <- loadImg . snd . head $ keyFiles
    trl <- loadLabel . snd . (!!1) $ keyFiles
    ti <- loadImg . snd . (!!2) $ keyFiles
    tl <- loadLabel . snd . (!!3) $ keyFiles

    return [(tri, trl), (ti, tl)]

変換したものをPickle化して保存する→バイナリで保存する

Pythonにはオブジェクトをバイナリファイルで保存しそれを復元するための機能が pickleモジュール で提供されています。

Haskellにはpickleのような機能は見当たらなかったのでpickleでやっていることを愚直にやることにしました。幸いなことにhmatrixの VectorMatrix という型は Data.Binaryインスタンスとなっているため、 Data.Binary の関数がそのまま使えます。

https://github.com/albertoruiz/hmatrix/blob/0.18.0.0/packages/base/src/Internal/Vector.hs#L417 https://github.com/albertoruiz/hmatrix/blob/0.18.0.0/packages/base/src/Internal/Element.hs#L40

実際にバイナリで保存するコードは下記のようになります。

createPickle :: String -> [DataSet] -> IO ()
createPickle p ds = BL.writeFile p $ (GZ.compress . encode) ds

データセットをまずバイナリに変換して圧縮します。 Data.Binaryインスタンスとなっているので encode という関数を適用するだけでバイナリ ( ByteString )に変換することができます。あとはこのBytestringを圧縮してバイナリファイルとして保存するだけです。圧縮には compress 、保存には writeFile を使います。

Pickle化されたファイルを読み込んで復元する→バイナリを読み込む

バイナリで保存したときと逆のことをすれば元のデータセットを復元できます。

loadPickle :: String -> IO [DataSet]
loadPickle p = do
    eds <- BL.readFile p
    return $ (decode . GZ.decompress) eds

利用例

MNISTを扱えるようになっただけでは少し物足りないので、これらのコードを使ってニューラルネットワークの推論処理をやって締めようと思います。

下記コードは こちらのサンプルコードHaskellで実装し直したものです。

入力層を784個、出力層を10個のニューロンで構成しています。

import Numeric.LinearAlgebra
import ActivationFunction
import Mnist
import SampleWeight

batchSize = 100

predict :: SampleWeight -> Vector R -> Vector R
predict ([w1,w2,w3],[b1,b2,b3]) x =
    softMax' . (\x'' -> sumInput x'' w3 b3) . sigmoid . (\x' -> sumInput x' w2 b2) . sigmoid $ sumInput x w1 b1

sumInput :: Vector R -> Weight -> Bias -> Vector R
sumInput x w b = (x <# w) + b

maxIndexPredict :: SampleWeight -> Vector R -> Double
maxIndexPredict sw x = fromIntegral . maxIndex $ predict sw x

take' :: Indexable c t => Int -> Int -> c -> [t]
take' n1 n2 x
    | n1 >= n2  = []
    | otherwise = (x ! n1) : take' (n1+1) n2 x

increment :: [Double] -> [Double] -> Double
increment ps l = fromIntegral . length . filter id $ zipWith (==) ps l

countAccuracy' :: Double -> Int -> SampleWeight -> DataSet -> Double
countAccuracy' a n sw ds@(i,l)
    | n <= 0    = a
    | otherwise = countAccuracy' (a+cnt) (n-batchSize) sw ds
        where ps = maxIndexPredict sw <$> take' (n-batchSize) n i
              ls = take' (n-batchSize) n l
              cnt = increment ps ls

main = do
    [_, ds] <- loadMnist True
    sw <- loadSW
    let r = rows $ fst ds
        cnt = countAccuracy' 0 r sw ds

    putStrLn $ "Accuracy: " ++ show (cnt / fromIntegral r)
$ stack runghc src/NeuralnetMnist.hs
Accuracy: 0.9352

サンプルコードと同じ値が出力されたので正しく実装できていそうです。

まとめ

MNISTを扱うコードは100行ほどで実装できましたが、テストを書かなくてもそれなりに動くものが作れるのはやはり型を定義しているおかげなのでしょう(コンパイルが通れば大体意図した通りに動く)。コードを読むときも型が書いてあることで一目で関数の入出力がわかるため全体の処理の流れが追いやすいです。しかし、今回のようにファイル操作など IO モナドを多用しているために若干読みづらいコードになっている気がします…。

今回紹介したコードは一部なので全貌が気になる方は こちらのリポジトリソースコードを置いてあります。

参考文献

MNIST 手書き数字データを画像ファイルに変換する - y_uti のブログ
HaskellでParsecを使ってCSVをパースする - Qiita
Haskellから簡単にWeb APIを叩く方法 - Qiita