kmizuの日記

プログラミングや形式言語に関係のあることを書いたり書かなかったり。

JavaでContinuationモナド

.@kmizu のScalaのサンプルコードが、まるでPerlのコードのようにイミフだ... http://d.hatena.ne.jp/kmizushima/20090925/1253890980 *P3

http://twitter.com/t_yano/statuses/4370238778

よーし。じゃあ、Javaなら大丈夫ですよね?*1というわけでJavaで書いてみた。…………非常にしんどかった。なんか本当に人間型推論器になった気分。せっかく作ったので一応貼っておくけど、読む気がしないことうけあいのコードになったと思う(ライブラリ側はさほどではないけど、利用側が…)。

import java.util.*;
import static java.lang.Integer.*;
import static java.lang.Character.*;
public class Continuations<R> {
  public static <A, B> B run(
    Function<A, B> f, Continuations<B>.Continuation<A> c
  ) {
    return c.run.call(f);
  }
  public static <A> Function<A, A> id() {
    return new Function<A, A>() { public A call(A a) { return a; } };
  }
  public interface Thunk<A> { A thaw(); }
  public interface Function<A, B> { B call(A a); }
  public abstract class CC<A> {
    abstract <B> Continuation<B> call(A a);
  }
  public <A> Continuation<A> returns(final A a) {
    return new Continuation<A>(new Function<Function<A, R>, R>() {
      public R call(Function<A, R> k) { return k.call(a); }
    });
  }
  public <A> Continuation<A> callCC(final Function<CC<A>, Continuation<A>> f) {
    return new Continuation<A>(new Function<Function<A, R>, R>() {
      public R call(final Function<A, R> k) {
        return f.call(new CC<A>() {
          public <B> Continuation<B> call(final A a) {
            return new Continuation<B>(new Function<Function<B, R>, R>() {
              public R call(Function<B, R> x) {
                return k.call(a);
              }
            });
          }
        }).run.call(k);
      }
    });
  }
  public <A> Continuation<A> when(boolean cond, Thunk<Continuation<A>> cont) {
    return cond ? cont.thaw() : this.<A>returns(null);
  }
  public class Continuation<A> {
    final Function<Function<A, R>, R> run;
    Continuation(Function<Function<A, R>, R> run) {
      this.run = run;
    }
    public <B> Continuation<B> bind(final Function<A, Continuation<B>> f) {
      return new Continuation<B>(new Function<Function<B, R>, R>() {
        public R call(final Function<B, R> k) {
          return run.call(new Function<A, R>() {
            public R call(final A a) {
              return f.call(a).run.call(k);
            }
          });
        }
      });
    }
    public <B> Continuation<B> then(final Thunk<Continuation<B>> c) {
      return bind(new Function<A, Continuation<B>>() {
        public Continuation<B> call(A a) { return c.thaw(); }
      });
    }
  }
  private static int sum(List<Integer> list) {
    int r = 0;
    for(int i:list) r += i;
    return r;
  }
  private static String concat(List<Character> list) {
    StringBuilder s = new StringBuilder();
    for(char c:list) s.append(c);
    return new String(s);
  }
  public static String fun(final int n) {
    final Continuations<String> c = new  Continuations<String>();
    return run(
      Continuations.<String>id(),
      c.callCC(
        new Function<
          Continuations<String>.CC<String>, Continuations<String>.Continuation<String>
        >() {
          public Continuations<String>.Continuation<String> call(
            final Continuations<String>.CC<String> exit1
          ) {
            return c.when(n < 10, 
              new Thunk<Continuations<String>.Continuation<String>>() {
                public Continuations<String>.Continuation<String> thaw() { 
                  return exit1.call(Integer.toString(n));
                }
              }
            ).bind(new Function<String, Continuations<String>.Continuation<String>>() {
              public Continuations<String>.Continuation<String> call(String unused) {
                final List<Integer> ns = new ArrayList<Integer>();
                for(char c:Integer.toString(n / 2).toCharArray()) {
                  ns.add(digit(c, 10));
                }
                return c.callCC(
                  new Function<
                    Continuations<String>.CC<Integer>, Continuations<String>.Continuation<Integer>
                  >() {
                    public Continuations<String>.Continuation<Integer> call(
                      final Continuations<String>.CC<Integer> exit2
                    ) {
                      return c.when(ns.size() < 3,
                        new Thunk<Continuations<String>.Continuation<Integer>>() {
                          public Continuations<String>.Continuation<Integer> thaw() {
                            return exit2.call(ns.size());
                          }
                        }
                      ).then(
                        new Thunk<Continuations<String>.Continuation<Integer>>() {
                          public Continuations<String>.Continuation<Integer> thaw() {
                            return c.when(ns.size() < 5,
                              new Thunk<Continuations<String>.Continuation<Integer>>() {
                                public Continuations<String>.Continuation<Integer> thaw() {
                                  return exit2.call(n);
                                }
                              }
                            );
                          }
                        }
                      ).then(
                        new Thunk<Continuations<String>.Continuation<Integer>>() {
                          public Continuations<String>.Continuation<Integer> thaw() {
                            return c.when(ns.size() < 7,
                              new Thunk<Continuations<String>.Continuation<Integer>>() {
                                public Continuations<String>.Continuation<Integer> thaw() {
                                  Collections.reverse(ns);
                                  LinkedList<Character> ns2 = new LinkedList<Character>();
                                  for(int c:ns) {
                                    ns2.add(forDigit(c, 10));
                                  }
                                  while(ns2.get(0) == '0') ns2.remove(0);
                                  return exit1.call(concat(ns2));
                                }
                              }
                            );
                          }
                        }
                      ).then(
                        new Thunk<Continuations<String>.Continuation<Integer>>() {
                          public Continuations<String>.Continuation<Integer> thaw() {
                            return c.returns(sum(ns));
                          }
                        }
                      );
                    }
                  }
                ).bind(new Function<Integer, Continuations<String>.Continuation<String>>() {
                  public Continuations<String>.Continuation<String> call(Integer n2) {
                    return c.returns(String.format("(ns = %s) %d", ns.toString(), n2));
                  }
                });
              }
            });
          }
        }
      ).bind(new Function<String, Continuations<String>.Continuation<String>>() {
        public Continuations<String>.Continuation<String> call(String str) {
          return c.returns("Answer: " + str);
        }
      })
    );
  }
  public static void main(String[] args) {
    for(int i = 1; i <= 6; i++) {
      int x = (int)(5 * Math.pow(16, i));
      System.out.printf("fun(%8d):%s%n", x, fun(x));
    }
  }
}

実行結果。途中、遅延評価すべきところを遅延評価してなくて間違った結果が出たり、型エラーが出まくったりしたけど、なんとか元のHaskellコードと同等の結果が出るようになった。

fun(      80):Answer: (ns = [4, 0]) 2
fun(    1280):Answer: (ns = [6, 4, 0]) 1280
fun(   20480):Answer: 4201
fun(  327680):Answer: 48361
fun( 5242880):Answer: (ns = [2, 6, 2, 1, 4, 4, 0]) 19
fun(83886080):Answer: (ns = [4, 1, 9, 4, 3, 0, 4, 0]) 25

*1:もちろん冗談です