kmizuの日記

プログラミングや形式言語に関係のあることを書いたり書かなかったり。

S-99: Ninety-Nine Scala Problems(P11-P20)を解いてみた

S-99: Ninety-Nine Scala Problems(P01-P10)を解いてみたの続き。今回は、表題の通り、P11-P20までを解いてみた。とりあえず最初の回答を作るまでに大体50分くらい。その後、dropとrotateについて題意を勘違いしていたのに気づいてコードを修正。

以下、回答のコード。

object P11toP20 {
  import P01toP10._
  def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] = {
    foldLeft(list, List[B]()){(b, a) => append(b, f(a))}
  }
  def copyN[T](t: T, n: Int): List[T] = n match {
    case 0 => List()
    case n if n > 0 => t::copyN(t, n - 1)
    case _ => error("")
  }
  def encodeModified(list: List[Any]): List[Any] = {
    map(pack(list)){x => 
      length(x) match { case 1 => x.head case n => (n, x.head) }
    }
  }
  def decode[T](encoded: List[(Int, T)]): List[T] = {
    flatMap(encoded){ case (n, e) => copyN(e, n) }
  }
  def encodeDirect[T](list: List[T]): List[(Int, T)] = {
    def encodeDirect1(pre: T, n: Int, rest: List[T]): List[(Int, T)] = {
      rest match {
        case x::xs if pre == x => encodeDirect1(pre, n + 1, xs)
        case x::xs => (n, pre)::encodeDirect1(x, 1, xs)
        case _ => List((n, pre))
      }
    }
    list match { case x::xs => encodeDirect1(x, 1, xs) case _ => List() }
  }
  def duplicate[T](list: List[T]): List[T] = {
    flatMap(list){ case x => List(x, x) }
  }
  def duplicateN[T](n: Int, list: List[T]): List[T] = {
    flatMap(list){ case x => copyN(x, n) }
  }
  def drop[T](nth: Int, list: List[T]): List[T] = {
    def drop1[T](mth: Int, list: List[T]): List[T] = list match {
      case x::xs if mth == 1 => drop1(nth, xs)
      case x::xs if mth > 1 => x::drop1(mth - 1, xs)
      case _ => List[T]()
    }
    if(nth > 0) drop1(nth, list) else error("")
  }
  def split[T](n: Int, list: List[T]): (List[T], List[T]) = list match {
    case xs if n == 0 => (List[T](), xs)
    case x::xs if n > 0 =>
      split(n - 1, xs) match { case (a, b) => (x::a, b) }
    case _ => error("")
  }
  def slice[T](i: Int, k: Int, list: List[T]): List[T] = {
    def slice1[T](i: Int, k: Int, list: List[T]): List[T] = list match {
      case x::xs if i > 0 => slice(i - 1, k - 1, xs)
      case x::xs if k > 0 => x::slice(0, k - 1, xs)
      case _ if i == 0 => List[T]()
      case _ => error("")
    }
    if(0 <= i && i <= k) slice1(i, k, list) else error("")
  }
  /* abs(n) <= length(list)を仮定してしまっていたので、間違い。
  def rotate[T](n: Int, list: List[T]): List[T] = {
    def rotate_+(n: Int, list1: List[T], list2: List[T]): List[T] = list1 match {
      case x::xs if n > 0 => rotate_+(n - 1, xs, x::list2)
      case _ if n == 0 => append(list1, reverse(list2))
      case _ => error("")
    }
    def rotate_-(n: Int, list1: List[T], list2: List[T]): List[T] = list1 match {
      case x::xs if n > 0 => rotate_-(n - 1, xs, x::list2)
      case _ if n == 0 => append(list2, reverse(list1))
      case _ => error("")
    }
    if(n > 0) rotate_+(n, list, List[T]())
    else if(n < 0) rotate_-(-n, reverse(list), List[T]())
    else List[T]()
  }
  */
  def rotate[T](n: Int, list: List[T]): List[T] = {
    def rotate_+-(n: Int, visit: List[T], rest: List[T]): List[T] = n match {
      case n if n > 0 =>
        rest match {
          case Nil => rotate_+-(n, Nil, reverse(visit))
          case x::xs => rotate_+-(n - 1, x::visit, xs)
        }
      case n if n < 0 =>
        visit match {
          case Nil => rotate_+-(n, reverse(rest), Nil)
          case x::xs => rotate_+-(n + 1, xs, x::rest)
        }
      case 0 => append(rest, reverse(visit))
    }
    list match {
      case Nil => error("")
      case _ => rotate_+-(n, Nil, list)
    }
  }

  
  def removeAt[T](nth: Int, list: List[T]): (List[T], T) = list match {
    case x::xs if nth == 0 => (xs, x)
    case x::xs if nth > 0 => 
      removeAt(nth - 1, xs) match { case (ys, y) => (x::ys, y) }
    case _ => error("")
  }
}

実行例。

scala> import P11toP20._
import P11toP20._

scala> encodeModified(List('a, 'a, 'a, 'a, 'b, 'c, 'c, 'a, 'a, 'd, 'e, 'e, 'e, 'e))
res0: List[Any] = List((4,'a), 'b, (2,'c), (2,'a), 'd, (4,'e))

scala> decode(List((4, 'a), (1, 'b), (2, 'c), (2, 'a), (1, 'd), (4, 'e)))
res1: List[Symbol] = List('a, 'a, 'a, 'a, 'b, 'c, 'c, 'a, 'a, 'd, 'e, 'e, 'e, 'e)

scala> encodeDirect(List('a, 'a, 'a, 'a, 'b, 'c, 'c, 'a, 'a, 'd, 'e, 'e, 'e, 'e))
res2: List[(Int, Symbol)] = List((4,'a), (1,'b), (2,'c), (2,'a), (1,'d), (4,'e))

scala> duplicate(List('a, 'b, 'c, 'c, 'd))
res3: List[Symbol] = List('a, 'a, 'b, 'b, 'c, 'c, 'c, 'c, 'd, 'd)

scala> duplicateN(3, List('a, 'b, 'c, 'c, 'd))
res4: List[Symbol] = List('a, 'a, 'a, 'b, 'b, 'b, 'c, 'c, 'c, 'c, 'c, 'c, 'd, 'd, 'd)

scala> drop(3, List('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k))
res5: List[Symbol] = List('a, 'b, 'd, 'e, 'g, 'h, 'j, 'k)

scala> split(3, List('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k))
res6: (List[Symbol], List[Symbol]) = (List('a, 'b, 'c),List('d, 'e, 'f, 'g, 'h, 'i, 'j, 'k))

scala> slice(3, 7, List('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k))
res7: List[Symbol] = List('d, 'e, 'f, 'g)

scala> rotate(3, List('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k))
res8: List[Symbol] = List('d, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'a, 'b, 'c)

scala> rotate(-2, List('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k))
res9: List[Symbol] = List('j, 'k, 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i)

scala> removeAt(1, List('a, 'b, 'c, 'd))
res10: (List[Symbol], Symbol) = (List('a, 'c, 'd),'b)

追記:
nの絶対値がlength(list)より小さいことを仮定してしまっていたため、rotateの定義を修正。これで、以下のように動作するようになった。

scala> rotate(5, List('a, 'b, 'c))
res6: List[Symbol] = List('c, 'a, 'b)