ScalaでStateモナド
Scalaでは普通に副作用が使えるので、Stateモナドみたいなのの出番はまず無いんだけど、最近Haskellの各種モナドをScalaで書き直すのがマイブーム(死語)なのでやってみた。今回も、All About MonadsのState monadの解説ページの実装丸写しで大した工夫も無いけど、まあそれはそれで。しかし、こういう事毎回やるたびに思うんだけど、やはり型推論がHaskellに比べて弱いのは痛いなあと。無駄な型注釈が必要になる場面がたびたびあるし。
class State[S, A](x: S => (A, S)) {
val runState = x
def flatMap[B](f: A => State[S, B]): State[S, B] = {
new State(s => { val (v, s_) = x(s); f(v).runState(s_) })
}
def map[B](f: A => B): State[S, B] = {
new State(s => { val (v, s_) = x(s); State.returns[S, B](f(v)).runState(s_) })
}
}
object State {
def returns[S, A](a: A): State[S, A] = new State(s => (a, s))
}
object MonadState {
def get[S]: State[S, S] = new State(s => (s, s))
def put[S](s: S): State[S, Unit] = new State(_ => ((), s))
}使用例。解説ページのように、乱数ジェネレータ使った例は、そもそも純粋関数型の乱数ジェネレータ作るのがめんどかったので止めて、1つの値だけを入れられる変数をエミュレートするようにしてみた。
class Var[T](c: T) {
def :=(newContent: T): Var[T] = new Var(newContent)
def unary_! : T = c
override def toString(): String = "State(" + c.toString + ")"
}
def current: State[Var[Int], Int] =
for(g <- MonadState.get[Var[Int]]) yield !g
def add(n: Int): State[Var[Int], Int] =
for ( g <- MonadState.get[Var[Int]];
x = !g + n;
g_ = (g := x);
_ <- MonadState.put[Var[Int]](g_)
) yield x
val foo = for ( _ <- add(1);
_ <- add(2);
_ <- add(3);
r <- current
) yield r
println(foo.runState(new Var(0))._1)実行結果。ちゃんと、1,2,3の合計値になっていることがわかる。
6