kmizuの日記

プログラミングや形式言語に関係のあることを書いたり書かなかったり。

ScalaでStateモナド

Scalaでは普通に副作用が使えるので、Stateモナドみたいなのの出番はまず無いんだけど、最近Haskellの各種モナドをScalaで書き直すのがマイブーム(死語)なのでやってみた。今回も、All About MonadsのState monadの解説ページの実装丸写しで大した工夫も無いけど、まあそれはそれで。しかし、こういう事毎回やるたびに思うんだけど、やはり型推論がHaskellに比べて弱いのは痛いなあと。無駄な型注釈が必要になる場面がたびたびあるし。

class State[S, A](x: S => (A, S)) {
  val runState = x
  def flatMap[B](f: A => State[S, B]): State[S, B] = {
    new State(s => { val (v, s_) = x(s); f(v).runState(s_) })
  }
  def map[B](f: A => B): State[S, B] = {
    new State(s => { val (v, s_) = x(s); State.returns[S, B](f(v)).runState(s_) })
  }
}
object State {
  def returns[S, A](a: A): State[S, A] = new State(s => (a, s))
}
object MonadState {
  def get[S]: State[S, S] = new State(s => (s, s))
  def put[S](s: S): State[S, Unit] = new State(_ => ((), s))
}

使用例。解説ページのように、乱数ジェネレータ使った例は、そもそも純粋関数型の乱数ジェネレータ作るのがめんどかったので止めて、1つの値だけを入れられる変数をエミュレートするようにしてみた。

class Var[T](c: T) {
  def :=(newContent: T): Var[T] = new Var(newContent)
  def unary_! : T = c
  override def toString(): String = "State(" + c.toString + ")"
}

def current: State[Var[Int], Int] = 
  for(g <- MonadState.get[Var[Int]]) yield !g

def add(n: Int): State[Var[Int], Int] =
  for ( g <- MonadState.get[Var[Int]];
        x = !g + n;
        g_ = (g := x);
        _ <- MonadState.put[Var[Int]](g_)
      ) yield x

val foo = for ( _ <- add(1);
                _ <- add(2);
                _ <- add(3);
                r <- current
              ) yield r
println(foo.runState(new Var(0))._1)

実行結果。ちゃんと、1,2,3の合計値になっていることがわかる。

6