kmizuの日記

プログラミングや形式言語に関係のあることを書いたり書かなかったり。

ScalaでContinuationモナド

Scalaのfor-comprehensionについて、これは単なる拡張for文みたいに機能が限定されたものではなく、モナドを使ったプログラムを簡潔に書ける汎用の構文だよーみたいな説明がよくなされる/したことがある。だが、じゃあ、実際にScalaでforで使えるモナドっぽいデータ型定義するのはどれくらい簡単にできるのか試したことが無かったので、ちょっと試してみた。よくあるMaybeやListモナドは既にOptionやList型として存在しているので、組み込みで存在しないContinuationモナドを作ってみた。

http://www.sampou.org/haskell/a-a-monads/html/contmonad.html

のページの例をできるだけ忠実に再現することを目標にしたのだが、ScalaはHaskellよりも型推論が弱いので、人間型推論器となってあちこちに型注釈をしてまわったり、結構めんどうだった。他にも、HaskellとScalaの違いに色々悩まされりとか、Haskellの構文忘れててリファレンス引きながらだったので結構時間がかかってしまったが、なんとか移植できた。

以下がContinuaionモナドHaskellコードをScalaに移植したもの。Haskell版に比べてあちこちに型注釈が入っているのがわかると思う。

class Cont[R, +A](val runCont: (A => R) => R) {
  def map[B](f: A => B): Cont[R, B] = {
    new Cont[R, B](k => runCont(a => Cont.ret[R, B](f(a)).runCont(k)))
  }
  def flatMap[B](f: A => Cont[R, B]) :Cont[R, B] = {
    new Cont[R, B](k => runCont(a => f(a).runCont(k)))
  }
}
object Cont {
  type ==>[A,B] = A => Cont[B, Any]
  def ret[R, A](a: A) = new Cont[R, A](k => k(a))
  def callCC[A, R](f: (A ==> R) => Cont[R, A]): Cont[R, A] = {
    new Cont[R, A](k => f(a => new Cont((x:Any) => k(a))).runCont(k))
  }
  def runCont[A, B](f: A => B)(c: Cont[B, A]) = c.runCont(f)
  def when[A](cond: Boolean, cont: => Cont[A, Any]): Cont[A, Any] = {
    if(cond) cont else ret[A, Unit](())
  }
  def id[A](a: A) = a
}

それで、こっちがサンプルプログラム。

import Cont._, Character._
val fun : Int => String = {n =>
  runCont(id[String])(for (
    str <- callCC{(exit1: String ==> String) => for (
      _ <- when(n < 10, exit1(n.toString));
      ns = (n / 2).toString.map(digit(_, 10));
      `n'` <- callCC{(exit2: Int ==> String) => for(
        _ <- when(ns.length < 3, exit2(ns.length));
        _ <- when(ns.length < 5, exit2(n));
        _ <- when(ns.length < 7, {
          val `ns'` = ns.reverse.map(forDigit(_, 10))
          exit1(`ns'`.dropWhile(_ == '0').mkString(""))
        })
      ) yield (0/:ns)(_+_)}
    ) yield "(ns = " + ns.toList + ") " + `n'`.toString}
  ) yield "Answer: " + str)
}
println(fun(5))
println(fun(10))
println(fun(100))
println(fun(1000))

実行結果。

Answer: 5
Answer: (ns = List(5)) 1
Answer: (ns = List(5, 0)) 2
Answer: (ns = List(5, 0, 0)) 1000

今回のはとにかくHaskell版のプログラムの構造をできるだけそのまま再現するのを目指したのでこんな冗長になってしまったが、Scalaの機能を生かせばもうちょっと簡潔にできるかも。