kmizuの日記

プログラミングや形式言語に関係のあることを書いたり書かなかったり。

ScalaでContinuationモナド(改良版)

ScalaでContinuationモナドは、どうにもあまり美しくなかったので、もっとScalaらしく改良してみた。重要な点は、二引数の型コンストラクタCont[R, A]をContinuations[R]の内部クラスContinuation[A]として表現し、flatMap,mapなどはContinuation[A]のメソッドとしたことで、これによって無駄な型注釈をだいぶ省くことができるようになった。
以下、実装コード。

class Continuations[R] {
  type CC[A] = A => Continuation[Any]
  def returns[A](a: A) = new Continuation[A](k => k(a))
  def callCC[A](f: CC[A] => Continuation[A]): Continuation[A] = {
    new Continuation[A](k => f(a => new Continuation(x => k(a))).run(k))
  }
  class Continuation[+A](val run: (A => R) => R) {
    def map[B](f: A => B): Continuation[B] = {
      new Continuation[B](k => run(a => returns[B](f(a)).run(k)))
    }
    def flatMap[B](f: A => Continuation[B]) :Continuation[B] = {
      new Continuation[B](k => run(a => f(a).run(k)))
    }
    def then[B](c: => Continuation[B]): Continuation[B] = {
      this.flatMap(a => c)
    }
  }
  def when[A](cond: Boolean, cont: => Continuation[A]): Continuation[Any] = {
    if(cond) cont else returns[Unit](())
  }
}
object Continuations {
  def id[A](a: A) = a
  def run[A, B](f: A => B)(c: Continuations[B]#Continuation[A]): B = {
    c.run(f)
  }
}

こっちがサンプルプログラム。

import Continuations._, Character._
val fun: Int => String = {n =>
  val c = new Continuations[String]
  import c._
  run(id[String])(for {
    str <- callCC[String]{exit1 => for {
      _ <- when(n < 10, exit1(n toString))
      ns = (n / 2).toString.map(digit(_, 10))
      `n'` <- callCC[Int]{exit2 => for {
        _ <- when(ns.length < 3, exit2(ns length)) then
             when(ns.length < 5, exit2(n)) then
             when(ns.length < 7, {
               val `ns'` = ns.reverse.map(forDigit(_, 10))
               exit1(`ns'`.dropWhile(_ == '0').mkString)
             })
      } yield (0/:ns)(_+_)}
    } yield "(ns = " + ns.toList + ") " + `n'`.toString }
  } yield "Answer: " + str)
}
import Math.pow
for(i <- (1 to 6).map{i => 5 * pow(16, i).toInt}) {
  printf("fun(%8d):%s%n", i, fun(i))
}

実行結果。

fun(      80):Answer: (ns = List(4, 0)) 2
fun(    1280):Answer: (ns = List(6, 4, 0)) 1280
fun(   20480):Answer: 4201
fun(  327680):Answer: 48361
fun( 5242880):Answer: (ns = List(2, 6, 2, 1, 4, 4, 0)) 19
fun(83886080):Answer: (ns = List(4, 1, 9, 4, 3, 0, 4, 0)) 25

2009-09-26:02:18修正:
このサンプルでは動作に影響は無いが、問題のある箇所を修正(thenの引数をby-name parameterに変更)。