継続 in Java
継続Festに刺激されて、Javaで継続を書いてみようという遊び。
題材は、times関数。とりあえず、リストを定義しておきます。
static IntList il (Integer... ls){ IntList res = null; for (int i=ls.length-1; i>=0; --i){ res = new IntList(ls[i], res); } return res; } static class IntList { int n; IntList tl; IntList(int n, IntList tl) { this.n = n; this.tl = tl;} }
さて、まずは普通にtimes関数を作ります。
static void show (Object s) { System.out.println("end"); System.out.println(s); } static int times (IntList ls) { if (ls == null){ System.out.println("end"); return 1; }else{ System.out.print("" + ls.n + " - "); int s = ls.n * times(ls.tl); System.out.print("" + s + " - "); return s; } }
呼んでみる。
> show(times(il(1,2,3,4,5))); 1 - 2 - 3 - 4 - 5 - end 5 - 20 - 60 - 120 - 120 - end 120 > show(times(il(1,2,0,4,5))); 1 - 2 - 0 - 4 - 5 - end 5 - 20 - 0 - 0 - 0 - end 0
CPS
さて、これを継続を使って無駄な計算を省きたいということですが、まずは、CPS変換を試してみます。最初は補助関数をいくつか。
static void run (closure f) { while (null != (f = f._eval()));} static abstract class closure { abstract closure _eval (); } static abstract class fun1<A>{ abstract closure eval (A a); final closure call (final A a) { return new closure(){ closure _eval () { return eval(a);} }; } } static abstract class fun2<A,B>{ abstract closure eval (A a, B b); final closure call (final A a, final B b){ return new closure(){ closure _eval () { return eval(a,b);} }; } }
CPSはtail callじゃないときついのですが、Javaはjump機能がないので、関数適用(call)と評価(eval)をわけて、runで評価サイクルを回すようになっているところに注意。
では、普通にtimesを作成。
static fun1 <Object> show = new fun1<Object>(){closure eval (Object o){ System.out.println("end"); System.out.println(o); return null; } }; static fun2 <IntList, fun1<? super Integer>> times = new fun2<IntList, fun1<? super Integer>>(){closure eval (final IntList ls, final fun1<? super Integer> c){ if (null == ls){ System.out.println("end"); return c.call(1); }else{ System.out.print("" + ls.n + " - "); return times.call( ls.tl, new fun1 <Integer>(){closure eval (Integer i){ System.out.print("" + i + " - "); return c.call(ls.n * i); } } ); } } };
普通のプログラムにみえるように、見苦しいところを全部右の方に隠しています;p
実行してみると
> run(times.call(il(1,2,3,4,5), show)); 1 - 2 - 3 - 4 - 5 - end 5 - 20 - 60 - 120 - 120 - end 120 > run(times.call(il(1,2,0,4,5), show)); 1 - 2 - 0 - 4 - 5 - end 5 - 20 - 0 - 0 - 0 - end 0
普通ですね。
では、リストの中に0が出現したら、0を返すように変更してみます。Cleanで言うと下のようになります。
times [] = 1 times [0:_] = 0 //この行を追加 times [h:t] = h * times t
static fun2 <IntList, fun1<? super Integer>> times_cancel = new fun2<IntList, fun1<? super Integer>>(){closure eval (final IntList ls, final fun1<? super Integer> c){ if (null == ls){ System.out.println("end"); return c.call(1); }else if (ls.n == 0){ System.out.println("zero"); return c.call(0); }else{ System.out.print("" + ls.n + " - "); return times_cancel.call( ls.tl, new fun1 <Integer>(){closure eval (Integer i){ System.out.print("" + i + " - "); return c.call(ls.n * i); } } ); } } };
実行結果
> run(times_cancel.call(il(1,2,0,4,5), show)); 1 - 2 - zero 0 - 0 - end 0
さらに、大域脱出を実装。
static fun2 <IntList, fun1<? super Integer>> times_escape = new fun2<IntList, fun1<? super Integer>>(){closure eval (final IntList ls, final fun1<? super Integer> c){ if (null == ls){ System.out.println("end"); return c.call(1); }else if (ls.n == 0){ System.out.println("zero"); return show.call(0); // <-- assume that top level continuation is 'show' }else{ System.out.print("" + ls.n + " - "); return times_escape.call( ls.tl, new fun1 <Integer>(){closure eval (Integer i){ System.out.print("" + i + " - "); return c.call(ls.n * i); } } ); } } };
実行結果
> run(times_escape.call(il(1,2,0,4,5), show)); 1 - 2 - zero end 0
さて、CPSを使った継続はfull continuationなわけで、トップレベルの継続がshowだという知識を暗黙で使っているわけですね。大域変数に継続を保存する機構があれば、部分継続にすることができるはずです。
shift/reset
では、shift/resetを使った部分継続を作ってみます。こっちは例外を使って実装することにします。まず、下準備。
static class ResetException extends RuntimeException { Object res; private ResetException (Object a) { this.res = a;} } static class ShiftException extends RuntimeException { Object res; private ShiftException (Object a) { this.res = a;} } static abstract class cont<A> { abstract Object call (A a); } static abstract class reset <A> { final A run () { try{ return call(); }catch(ResetException e){ return (A) e.res; } } abstract A call (); } static abstract class shift<A> { final A run () { try{ Object res = call(new cont<A>(){ Object call (A a) { throw new ShiftException(a); } }); throw new ResetException(res); }catch(ShiftException e){ return (A) e.res; } } abstract Object call (cont<A> c); }
では、timesの実装です。
static int times_escape (IntList ls) { if (ls == null){ System.out.println("end"); return 1; }else if (ls.n == 0){ System.out.println("zero"); return new shift<Integer>() { Object call (cont<Integer> c) { return 0;} }.run(); }else{ System.out.print("" + ls.n + " - "); int s = ls.n * times_escape(ls.tl); System.out.print("" + s + " - "); return s; } }
実行するときは、resetを仕掛けておきます。
> show(new reset<Integer>() { Integer call () { return times_escape(il(1,2,0,4,5));} }.run()); 1 - 2 - zero end 0
さて、このプログラムですが、コンパイルすると
注:Main_Normal.java の操作は、未チェックまたは安全ではありません。
注:詳細については、-Xlint:unchecked オプションを指定して再コンパイルしてください。
という警告が出ます。図らずも、部分継続の型付けの問題を観測することになりました。