継続 in Java

継続Festに刺激されて、Javaで継続を書いてみようという遊び。
題材は、times関数。とりあえず、リストを定義しておきます。

  static IntList il (Integer... ls){
    IntList res = null;
    for (int i=ls.length-1; i>=0; --i){
      res = new IntList(ls[i], res);
    }
    return res;
  }
  static class IntList {
    int n;
    IntList tl;
    IntList(int n, IntList tl) { this.n = n; this.tl = tl;}
  }

さて、まずは普通にtimes関数を作ります。

  static void show (Object s)
  {
    System.out.println("end");
    System.out.println(s);
  }

  static int times (IntList ls)
  {
    if (ls == null){
      System.out.println("end");
      return 1;
    }else{
      System.out.print("" + ls.n + " - ");
      int s = ls.n * times(ls.tl);
      System.out.print("" + s + " - ");
      return s;
    }
  }

呼んでみる。

> show(times(il(1,2,3,4,5)));
1 - 2 - 3 - 4 - 5 - end
5 - 20 - 60 - 120 - 120 - end
120
> show(times(il(1,2,0,4,5)));
1 - 2 - 0 - 4 - 5 - end
5 - 20 - 0 - 0 - 0 - end
0

CPS

さて、これを継続を使って無駄な計算を省きたいということですが、まずは、CPS変換を試してみます。最初は補助関数をいくつか。

  static void run (closure f) { while (null != (f = f._eval()));}
  static abstract class closure {
    abstract closure _eval ();
  }
  static abstract class fun1<A>{
    abstract closure eval (A a);
    final closure call (final A a) {
      return new closure(){
        closure _eval () { return eval(a);}
      };
    }
  }
  static abstract class fun2<A,B>{
    abstract closure eval (A a, B b);
    final closure call (final A a, final B b){
      return new closure(){
        closure _eval () { return eval(a,b);}
      };
    }
  }

CPSはtail callじゃないときついのですが、Javaはjump機能がないので、関数適用(call)と評価(eval)をわけて、runで評価サイクルを回すようになっているところに注意。
では、普通にtimesを作成。

  static fun1                                             <Object>
    show                                                  = new fun1<Object>(){closure eval
      (Object o){
        System.out.println("end");
        System.out.println(o);
        return null;
      }                                                   };

  static fun2                                             <IntList, fun1<? super Integer>>
    times                                                 = new fun2<IntList, fun1<? super Integer>>(){closure eval
      (final IntList ls, final fun1<? super Integer> c){
        if (null == ls){
          System.out.println("end");
          return c.call(1);
        }else{
          System.out.print("" + ls.n + " - ");
          return times.call(
            ls.tl,
            new fun1                                      <Integer>(){closure eval
              (Integer i){
                System.out.print("" + i + " - ");
                return c.call(ls.n * i);
              }                                           }
            );
          }
        }                                                 };

普通のプログラムにみえるように、見苦しいところを全部右の方に隠しています;p
実行してみると

> run(times.call(il(1,2,3,4,5), show));
1 - 2 - 3 - 4 - 5 - end
5 - 20 - 60 - 120 - 120 - end
120
> run(times.call(il(1,2,0,4,5), show));
1 - 2 - 0 - 4 - 5 - end
5 - 20 - 0 - 0 - 0 - end
0

普通ですね。
では、リストの中に0が出現したら、0を返すように変更してみます。Cleanで言うと下のようになります。

times [] = 1
times [0:_] = 0 //この行を追加
times [h:t] = h * times t
  static fun2                                             <IntList, fun1<? super Integer>>
    times_cancel                                          = new fun2<IntList, fun1<? super Integer>>(){closure eval
      (final IntList ls, final fun1<? super Integer> c){
        if (null == ls){
          System.out.println("end");
          return c.call(1);
        }else if (ls.n == 0){
          System.out.println("zero");
          return c.call(0);
        }else{
          System.out.print("" + ls.n + " - ");
          return times_cancel.call(
            ls.tl,
            new fun1                                      <Integer>(){closure eval
              (Integer i){
                System.out.print("" + i + " - ");
                return c.call(ls.n * i);
              }                                           }
            );
          }
        }                                                 };

実行結果

> run(times_cancel.call(il(1,2,0,4,5), show));
1 - 2 - zero
0 - 0 - end
0

さらに、大域脱出を実装。

  static fun2                                             <IntList, fun1<? super Integer>>
    times_escape                                          = new fun2<IntList, fun1<? super Integer>>(){closure eval
      (final IntList ls, final fun1<? super Integer> c){
        if (null == ls){
          System.out.println("end");
          return c.call(1);
        }else if (ls.n == 0){
          System.out.println("zero");
          return show.call(0); // <-- assume that top level continuation is 'show'
        }else{
          System.out.print("" + ls.n + " - ");
          return times_escape.call(
            ls.tl,
            new fun1                                      <Integer>(){closure eval
              (Integer i){
                System.out.print("" + i + " - ");
                return c.call(ls.n * i);
              }                                           }
            );
          }
        }                                                 };

実行結果

> run(times_escape.call(il(1,2,0,4,5), show));
1 - 2 - zero
end
0

さて、CPSを使った継続はfull continuationなわけで、トップレベルの継続がshowだという知識を暗黙で使っているわけですね。大域変数に継続を保存する機構があれば、部分継続にすることができるはずです。

shift/reset

では、shift/resetを使った部分継続を作ってみます。こっちは例外を使って実装することにします。まず、下準備。

  static class ResetException extends RuntimeException {
    Object res;
    private ResetException (Object a) { this.res = a;}
  }
  static class ShiftException extends RuntimeException {
    Object res;
    private ShiftException (Object a) { this.res = a;}
  }
  static abstract class cont<A> {
    abstract Object call (A a);
  }
  static abstract class reset <A> {
    final A run () {
      try{
        return call();
      }catch(ResetException e){
        return (A) e.res;
      }
    }
    abstract A call ();
  }
  static abstract class shift<A> {
    final A run () {
      try{
        Object res = call(new cont<A>(){
          Object call (A a) {
            throw new ShiftException(a);
          }
        });
        throw new ResetException(res);
      }catch(ShiftException e){
        return (A) e.res;
      }
    }
    abstract Object call (cont<A> c);
  }

では、timesの実装です。

  static int times_escape (IntList ls)
  {
    if (ls == null){
      System.out.println("end");
      return 1;
    }else if (ls.n == 0){
      System.out.println("zero");
      return new shift<Integer>() {
        Object call (cont<Integer> c) { return 0;}
      }.run();
    }else{
      System.out.print("" + ls.n + " - ");
      int s = ls.n * times_escape(ls.tl);
      System.out.print("" + s + " - ");
      return s;
    }
  }

実行するときは、resetを仕掛けておきます。

> show(new reset<Integer>() {
    Integer call () { return times_escape(il(1,2,0,4,5));}
  }.run());
1 - 2 - zero
end
0

さて、このプログラムですが、コンパイルすると

注:Main_Normal.java の操作は、未チェックまたは安全ではありません。
注:詳細については、-Xlint:unchecked オプションを指定して再コンパイルしてください。

という警告が出ます。図らずも、部分継続の型付けの問題を観測することになりました。