tickの例でStateモナドの動きを覗いてみる

参考: http://www.haskell.org/hawiki/MonadState

newtype State s a = State { runState ::  (s -> (a, s)) }

instance Monad (State s) where
    return a = State $ \s -> (a, s)
    x >>= f  = State $ \s -> let (v, s') = runState x s
                             in runState (f v) s'

class MonadState m s | m -> s where
    get :: m s
    put :: s -> m ()

instance MonadState (State s) s where
    get   = State $ \s -> (s, s)
    put s = State $ \_ -> ((), s)

Stateモナドは、値のほかに現在の状態を表す関数を持ったペアで表される。…こんな理解でいいのかな。こんな感じで使うっぽい。

runState/execState/evalState <monad> <initial-state>

で、浅ーい理解だった自分にはget, putがどこから状態を取ってくるのか?とか何とか、こんがらがってしまいとりあえずtickを分解することにした。

import Monad
import Control.Monad.State

tick :: State Int Int
tick =  do { n <- get; put (n+1); return n }

tickは、Stateモナドを使って数字を1刻みでカウントアップしていく関数。値の部分にはもちろん現在の数字、状態の部分には次の数字を置いたペアを持ち回る。*1

*Main> evalState (tick >> tick >> tick) 0
2
*Main> runState (tick >> tick >> tick) 0
(2,3)
tick' = get >>= \n -> put (n+1) >> return n

tick'' = State $ \s0 -> let (n, s) = runState get s0
                        in runState (put (n+1) >> return n) s

tick''' = State $ \s0 -> let (n, s) = runState get s0
                         in runState (put (n+1) >>= (\_ -> return n)) s

tick'''' = State $ \s0 -> let (n, s) = runState get s0
                          in runState (State $ \s -> let (n', s') = runState (put (n+1)) s
                                                      in runState ((\_ -> return n) n') s') s

tick''''' = State $ \s0 -> let (n, s) = runState get s0
                           in runState (State $ \s -> let (n', s') = runState (put (n+1)) s
                                                       in runState (return n) s') s

tick'''''' = State $ \s0 -> let (n, s) = runState get s0
                            in let (n', s') = runState (put (n+1)) s
                               in runState (return n) s'

tick''''''' = State $ \s0 -> let (n, s) = runState (State $ \s -> (s, s)) s0        -- get
                             in let (n', s') = runState (State $ \_ -> ((), n+1)) s -- put
                                in runState (State $ \s -> (n, s)) s'

tick'''''''' = State $ \s0 -> let (n, s)   = (s0, s0)  -- get
                                  (n', s') = ((), n+1) -- put
                              in (n, s')

何だか分かった…のか?結局get/putが上手くいくのはbind(>>=)がうまくやってくれているから(そりゃそうだ)。bindはモナドをたらい回しにしてるイメージがあったんだけど、結局関数の結合なわけだった。

結論: do記法、>>=の見た目に惑わされてはいけない。

*1:この言い方であってるんだろうか