diff --git a/direct/src/task/Task.py b/direct/src/task/Task.py index 9002d5bd63..5ac2e95448 100644 --- a/direct/src/task/Task.py +++ b/direct/src/task/Task.py @@ -91,6 +91,7 @@ pause = AsyncTaskPause Task.DtoolClassDict['pause'] = staticmethod(pause) gather = Task.gather +shield = Task.shield def sequence(*taskList): seq = AsyncTaskSequence('sequence') diff --git a/panda/src/event/asyncFuture.I b/panda/src/event/asyncFuture.I index 9ab8a0086e..c38b04d880 100644 --- a/panda/src/event/asyncFuture.I +++ b/panda/src/event/asyncFuture.I @@ -133,6 +133,24 @@ gather(Futures futures) { } } +/** + * Creates a new future that shields the given future from cancellation. + * Calling `cancel()` on the returned future will not affect the given future. + */ +INLINE PT(AsyncFuture) AsyncFuture:: +shield(PT(AsyncFuture) future) { + if (future->try_lock_pending()) { + PT(AsyncFuture) outer = new AsyncFuture; + outer->_manager = future->_manager; + future->_waiting.push_back((AsyncFuture *)outer); + future->unlock(); + return outer; + } + else { + return future; + } +} + /** * Tries to atomically lock the future, assuming it is pending. Returns false * if it is not in the pending state, implying it's either done or about to be diff --git a/panda/src/event/asyncFuture.cxx b/panda/src/event/asyncFuture.cxx index fd2fe9c028..6828213942 100644 --- a/panda/src/event/asyncFuture.cxx +++ b/panda/src/event/asyncFuture.cxx @@ -148,13 +148,13 @@ notify_done(bool clean_exit) { // This will only be called by the thread that managed to set the // _future_state away from the "pending" state, so this is thread safe. - Futures::iterator it; - for (it = _waiting.begin(); it != _waiting.end(); ++it) { - AsyncFuture *fut = *it; + // Go through the futures that are waiting for this to finish. + for (AsyncFuture *fut : _waiting) { if (fut->is_task()) { // It's a task. Make it active again. wake_task((AsyncTask *)fut); - } else { + } + else if (fut->get_type() == AsyncGatheringFuture::get_class_type()) { // It's a gathering future. Decrease the pending count on it, and if // we're the last one, call notify_done() on it. AsyncGatheringFuture *gather = (AsyncGatheringFuture *)fut; @@ -164,6 +164,23 @@ notify_done(bool clean_exit) { } } } + else { + // It's a shielding future. The shielding only protects the inner future + // when the outer is cancelled, not the other way around, so we have to + // propagate any cancellation here as well. + if (clean_exit && _result != nullptr) { + // Propagate the result, if any. + if (fut->try_lock_pending()) { + fut->_result = _result; + fut->_result_ref = _result_ref; + fut->unlock(FS_finished); + fut->notify_done(true); + } + } + else if (fut->set_future_state(clean_exit ? FS_finished : FS_cancelled)) { + fut->notify_done(clean_exit); + } + } } _waiting.clear(); diff --git a/panda/src/event/asyncFuture.h b/panda/src/event/asyncFuture.h index 2e62c07d1a..e80918c149 100644 --- a/panda/src/event/asyncFuture.h +++ b/panda/src/event/asyncFuture.h @@ -78,6 +78,7 @@ PUBLISHED: EXTENSION(PyObject *add_done_callback(PyObject *self, PyObject *fn)); EXTENSION(static PyObject *gather(PyObject *args)); + INLINE static PT(AsyncFuture) shield(PT(AsyncFuture) future); virtual void output(std::ostream &out) const; diff --git a/tests/event/test_futures.py b/tests/event/test_futures.py index 56ab3073e0..d18d52c3f7 100644 --- a/tests/event/test_futures.py +++ b/tests/event/test_futures.py @@ -289,6 +289,66 @@ def test_future_gather_cancel_outer(): assert gather.result() +def test_future_shield(): + # An already done future is returned as-is (no cancellation can occur) + inner = core.AsyncFuture() + inner.set_result(None) + outer = core.AsyncFuture.shield(inner) + assert inner == outer + + # Normally finishing future + inner = core.AsyncFuture() + outer = core.AsyncFuture.shield(inner) + assert not outer.done() + inner.set_result(None) + assert outer.done() + assert not outer.cancelled() + assert inner.result() is None + + # Normally finishing future with result + inner = core.AsyncFuture() + outer = core.AsyncFuture.shield(inner) + assert not outer.done() + inner.set_result(123) + assert outer.done() + assert not outer.cancelled() + assert inner.result() == 123 + + # Cancelled inner future does propagate cancellation outward + inner = core.AsyncFuture() + outer = core.AsyncFuture.shield(inner) + assert not outer.done() + inner.cancel() + assert outer.done() + assert outer.cancelled() + + # Finished outer future does nothing to inner + inner = core.AsyncFuture() + outer = core.AsyncFuture.shield(inner) + outer.set_result(None) + assert not inner.done() + inner.cancel() + assert not outer.cancelled() + + # Cancelled outer future does nothing to inner + inner = core.AsyncFuture() + outer = core.AsyncFuture.shield(inner) + outer.cancel() + assert not inner.done() + inner.cancel() + + # Can be shielded multiple times + inner = core.AsyncFuture() + outer1 = core.AsyncFuture.shield(inner) + outer2 = core.AsyncFuture.shield(inner) + outer1.cancel() + assert not inner.done() + assert not outer2.done() + inner.cancel() + assert outer1.done() + assert outer2.done() + + def test_future_done_callback(): fut = core.AsyncFuture()