diff --git a/src/org/python/modules/_weakref/GlobalRef.java b/src/org/python/modules/_weakref/GlobalRef.java --- a/src/org/python/modules/_weakref/GlobalRef.java +++ b/src/org/python/modules/_weakref/GlobalRef.java @@ -6,11 +6,14 @@ import java.lang.ref.WeakReference; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.locks.ReentrantReadWriteLock; import org.python.core.Py; import org.python.core.PyList; import org.python.core.PyObject; +import org.python.core.PySystemState; import org.python.util.Generic; public class GlobalRef extends WeakReference { @@ -34,14 +37,11 @@ private static ReferenceQueue referenceQueue = new ReferenceQueue(); - private static RefReaperThread reaperThread; + private static Thread reaperThread; + private static ReentrantReadWriteLock reaperLock = new ReentrantReadWriteLock(); private static ConcurrentMap objects = Generic.concurrentMap(); - static { - initReaperThread(); - } - public GlobalRef(PyObject object) { super(object, referenceQueue); hashCode = System.identityHashCode(object); @@ -117,14 +117,36 @@ * @return a new tracked GlobalRef */ public static GlobalRef newInstance(PyObject object) { - GlobalRef ref = objects.get(new GlobalRef(object)); + createReaperThreadIfAbsent(); + + GlobalRef newRef = new GlobalRef(object); + GlobalRef ref = objects.putIfAbsent(newRef, newRef); if (ref == null) { - ref = new GlobalRef(object); - objects.put(ref, ref); + ref = newRef; } return ref; } + private static void createReaperThreadIfAbsent() { + reaperLock.readLock().lock(); + try { + if (reaperThread == null || !reaperThread.isAlive()) { + reaperLock.readLock().unlock(); + reaperLock.writeLock().lock(); + if (reaperThread == null || !reaperThread.isAlive()) { + try { + initReaperThread(); + } finally { + reaperLock.readLock().lock(); + reaperLock.writeLock().unlock(); + } + } + } + } finally { + reaperLock.readLock().unlock(); + } + } + /** * Return the number of references to the specified PyObject. * @@ -197,16 +219,18 @@ } private static void initReaperThread() { - reaperThread = new RefReaperThread(); + RefReaper reaper = new RefReaper(); + PySystemState systemState = Py.getSystemState(); + systemState.registerCloser(reaper); + + reaperThread = new Thread(reaper, "weakref reaper"); reaperThread.setDaemon(true); reaperThread.start(); } - private static class RefReaperThread extends Thread { - - RefReaperThread() { - super("weakref reaper"); - } + private static class RefReaper implements Runnable, Callable { + private volatile boolean exit = false; + private Thread thread; public void collect() throws InterruptedException { GlobalRef gr = (GlobalRef)referenceQueue.remove(); @@ -216,13 +240,32 @@ } public void run() { + // Store the actual reaper thread so that when PySystemState.cleanup() + // is called this thread can be interrupted and die. + this.thread = Thread.currentThread(); + while (true) { try { collect(); } catch (InterruptedException exc) { - // ok + // Is cleanup time so break out and die. + if (exit) { + break; + } } } } + + @Override + public Void call() throws Exception { + this.exit = true; + + if (thread != null && thread.isAlive()) { + this.thread.interrupt(); + this.thread = null; + } + + return null; + } } }