Java : Coroutine

さっきのを、もうちょっと汎用化させて、コルーチンを作ってみた。コルーチンにすると、スレッドが競合することがなくなるので、排他制御やらなんやら考えるのはずいぶん楽になる。
これがサンプルコード。sleepしている最中に他のスレッドが割り込んでいないことに注目。

public class CoMain{
  static Coroutine c;
  public static void main (String[] args){
    init();
    while (c.yield());
    System.out.println("main finished");
  }

  static class Test implements Runnable{
    int p;
    Test (int p) { this.p = p;}
    public void run (){
      for (int i=0; i<5; ++i){
        c.yield();
        System.out.print("[" + p + "] - " + i + " => sleep(5) => ");
        try{ Thread.sleep(5);}catch(InterruptedException e){
          System.out.print("interrupted => ");
        }
        System.out.println("[" + p + "] - " + i);
      }
      System.out.println("[" + p + "] finished");
    }
  }

  static void init ()
  {
    c = new Coroutine();
    for (int i=0; i<5; ++i)
      c.add(new Test(i));
  }
}

出力結果はこれ。

$ java CoMain
[3] - 0 => sleep(5) => [3] - 0
[4] - 0 => sleep(5) => [4] - 0
[4] - 1 => sleep(5) => [4] - 1
[3] - 1 => sleep(5) => [3] - 1
[4] - 2 => sleep(5) => [4] - 2
[3] - 2 => sleep(5) => [3] - 2
[3] - 3 => sleep(5) => [3] - 3
[1] - 0 => sleep(5) => [1] - 0
[2] - 0 => sleep(5) => [2] - 0
[3] - 4 => sleep(5) => [3] - 4
[3] finished
[1] - 1 => sleep(5) => [1] - 1
[2] - 1 => sleep(5) => [2] - 1
[0] - 0 => sleep(5) => [0] - 0
[2] - 2 => sleep(5) => [2] - 2
[1] - 2 => sleep(5) => [1] - 2
[0] - 1 => sleep(5) => [0] - 1
[4] - 3 => sleep(5) => [4] - 3
[0] - 2 => sleep(5) => [0] - 2
[0] - 3 => sleep(5) => [0] - 3
[2] - 3 => sleep(5) => [2] - 3
[4] - 4 => sleep(5) => [4] - 4
[4] finished
[1] - 3 => sleep(5) => [1] - 3
[0] - 4 => sleep(5) => [0] - 4
[0] finished
[2] - 4 => sleep(5) => [2] - 4
[2] finished
[1] - 4 => sleep(5) => [1] - 4
[1] finished
main finished

以下、ソースコード。あいかわらず、shutdown()メソッドを用意していないので、必要に応じて拡張が必要。

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;

public class Coroutine{

  private final ArrayList<Thread> pool = new ArrayList<Thread>();
  private final HashMap<Thread, Integer> dict = new HashMap<Thread, Integer>();
  private final Random rand = new Random();
  private final Object mutex = new Object();
  private volatile int currentId;

  public Coroutine (){
    Thread t = Thread.currentThread();
    currentId = 0;
    pool.add(t);
    dict.put(t,currentId);
  }

  public int add (final Runnable c){
    final Coroutine co = this;
    final int myId;
    synchronized (mutex){
      if (!pool.get(currentId).equals(Thread.currentThread()))
        throw new IllegalStateException();
      myId = currentId;
    }
    final int id = pool.size();
    Thread t = new Thread(){
      public void run (){
        yieldTo(myId);
        c.run();
        synchronized (mutex){
          pool.set(id, null);
          int nid = getNextThread(id);
          if (nid >= 0){
            currentId = nid;
            mutex.notifyAll();
          }
        }
      }
    };
    pool.add(t);
    dict.put(t,id);
    synchronized (mutex){
      t.start();
      yieldTo(id);
    }
    return id;
  }

  private int getNextThread (final int myId){
    final int iid;
    int id = rand.nextInt(pool.size());
    if (id == myId){ ++id; if (id >= pool.size()) id = 0;}
    iid = id;
    Thread t;
    while ((id == myId) || ((t = pool.get(id)) == null)){
      ++id; if (id >= pool.size()) id = 0;
      if (id == iid) return -1;
    }
    return id;
  }

  public boolean yield (){
    synchronized (mutex){
      if (!pool.get(currentId).equals(Thread.currentThread()))
        throw new IllegalStateException();
      final int myId = currentId;
      final int id = getNextThread(myId);
      if (id < 0) return false;
      currentId = id;
      mutex.notifyAll();
      while (myId != currentId)
        try{ mutex.wait(100);}
        catch(InterruptedException e){}
    }
    return true;
  }

  public boolean yieldTo (final int id){
    if (id < 0 || pool.size() <= id)
      throw new IndexOutOfBoundsException();
    synchronized (mutex){
      if (!pool.get(currentId).equals(Thread.currentThread()))
        throw new IllegalStateException();
      final int myId = currentId;
      final Thread t = pool.get(id);
      if (null!=t){
        currentId = id;
        mutex.notifyAll();
        while (myId != currentId)
          try{ mutex.wait(100);}
          catch(InterruptedException e){}
        return true;
      }else{
        return false;
      }
    }
  }
}