kmizuの日記

プログラミングや形式言語に関係のあることを書いたり書かなかったり。

ScalaでMLスタイルのモジュールを使ったプログラミングをする

何はともあれ以下のコードを見てください(ちなみに複素数クラスの実装は、

d.hatena.ne.jp

を参考にさせていただきました):

trait Complex {
  type T

  def re(a: T): Double

  def im(a: T): Double

  def make(re: Double): T

  def plus(a: T, b: T): T

  def minus(a: T, b: T): T

  def multiply(a: T, b: T): T

  def divide(a: T, b: T): T
}

object Complex extends Complex {
  case class C(re: Double, im: Double)
  type T = C

  def re(a: T): Double = a.re

  def im(a: T): Double = a.im

  def make(re: Double): T = C(re, 0.0)

  def plus(a: T, b: T): T = C(a.re + b.re, a.im + b.im)

  def minus(a: T, b: T): T = C(a.re - b.re, a.im - b.im)

  def multiply(a: T, b: T): T = C(a.re * b.re - a.im * b.im, a.im * b.re + a.re * b.im)

  def divide(a: T, b: T): T = {
    require(b.re != 0.0 || b.im != 0)
    val x = Math.pow(b.re, 2) + Math.pow(b.im, 2)
    C((a.re * b.re + a.im * b.im) / x, (a.im * b.re - a.re * b.im) / x)
  }
}

基本的に、

  • MLのsignature=Scalaのtrait
  • MLのstructure=Scalaのobject
  • MLのfunctor=Scalaのclass

ととらえれば、それなりにMLスタイルのモジュールをまねることができます(というのは、Scala関係の発表資料かどこかで読んだのですが思い出せない)。これは特に、Scalaが抽象型メンバを持っていることに寄っています。なお、「それなりに」と書いた通り、includeとかそのままでは真似られないものが色々あります。

追記:上のコードだと、Complex.Cでcase classの実装が参照できちゃうので、実際には

trait Complex {
  type T

  def re(a: T): Double

  def im(a: T): Double

  def make(re: Double, im: Double): T

  def plus(a: T, b: T): T

  def minus(a: T, b: T): T

  def multiply(a: T, b: T): T

  def divide(a: T, b: T): T
}

object Complex {
  val C: Complex = new Complex {
    case class C(re: Double, im: Double)
    type T = C

    def re(a: T): Double = a.re

    def im(a: T): Double = a.im

    def make(re: Double, im: Double): T = C(re, im)

    def plus(a: T, b: T): T = C(a.re + b.re, a.im + b.im)

    def minus(a: T, b: T): T = C(a.re - b.re, a.im - b.im)

    def multiply(a: T, b: T): T = C(a.re * b.re - a.im * b.im, a.im * b.re + a.re * b.im)

    def divide(a: T, b: T): T = {
      require(b.re != 0.0 || b.im != 0)
      val x = Math.pow(b.re, 2) + Math.pow(b.im, 2)
      C((a.re * b.re + a.im * b.im) / x, (a.im * b.re - a.re * b.im) / x)
    }
  }
}
object Main {
  import Complex.C
  def main(args: Array[String]): Unit = {
    val x = C.make(1.0, 1.0)
    val y = C.make(2.0, 2.0)
    println(C.plus(x, y))
  }
}

のように書くべきですね。こうすれば、抽象型メンバTがcase class Cを完全に隠蔽してくれます。