task: Add AsyncFuture::shield() ability, part of #1136

This is modelled after `asyncio.shield()` and can be used to protect an inner future from cancellation when the outer future is cancelled.
This commit is contained in:
rdb 2021-04-09 18:23:44 +02:00
parent e9e1bb1bc6
commit cdf5b16ddd
5 changed files with 101 additions and 4 deletions

View File

@ -91,6 +91,7 @@ pause = AsyncTaskPause
Task.DtoolClassDict['pause'] = staticmethod(pause)
gather = Task.gather
shield = Task.shield
def sequence(*taskList):
seq = AsyncTaskSequence('sequence')

View File

@ -133,6 +133,24 @@ gather(Futures futures) {
}
}
/**
* Creates a new future that shields the given future from cancellation.
* Calling `cancel()` on the returned future will not affect the given future.
*/
INLINE PT(AsyncFuture) AsyncFuture::
shield(PT(AsyncFuture) future) {
if (future->try_lock_pending()) {
PT(AsyncFuture) outer = new AsyncFuture;
outer->_manager = future->_manager;
future->_waiting.push_back((AsyncFuture *)outer);
future->unlock();
return outer;
}
else {
return future;
}
}
/**
* Tries to atomically lock the future, assuming it is pending. Returns false
* if it is not in the pending state, implying it's either done or about to be

View File

@ -148,13 +148,13 @@ notify_done(bool clean_exit) {
// This will only be called by the thread that managed to set the
// _future_state away from the "pending" state, so this is thread safe.
Futures::iterator it;
for (it = _waiting.begin(); it != _waiting.end(); ++it) {
AsyncFuture *fut = *it;
// Go through the futures that are waiting for this to finish.
for (AsyncFuture *fut : _waiting) {
if (fut->is_task()) {
// It's a task. Make it active again.
wake_task((AsyncTask *)fut);
} else {
}
else if (fut->get_type() == AsyncGatheringFuture::get_class_type()) {
// It's a gathering future. Decrease the pending count on it, and if
// we're the last one, call notify_done() on it.
AsyncGatheringFuture *gather = (AsyncGatheringFuture *)fut;
@ -164,6 +164,23 @@ notify_done(bool clean_exit) {
}
}
}
else {
// It's a shielding future. The shielding only protects the inner future
// when the outer is cancelled, not the other way around, so we have to
// propagate any cancellation here as well.
if (clean_exit && _result != nullptr) {
// Propagate the result, if any.
if (fut->try_lock_pending()) {
fut->_result = _result;
fut->_result_ref = _result_ref;
fut->unlock(FS_finished);
fut->notify_done(true);
}
}
else if (fut->set_future_state(clean_exit ? FS_finished : FS_cancelled)) {
fut->notify_done(clean_exit);
}
}
}
_waiting.clear();

View File

@ -78,6 +78,7 @@ PUBLISHED:
EXTENSION(PyObject *add_done_callback(PyObject *self, PyObject *fn));
EXTENSION(static PyObject *gather(PyObject *args));
INLINE static PT(AsyncFuture) shield(PT(AsyncFuture) future);
virtual void output(std::ostream &out) const;

View File

@ -289,6 +289,66 @@ def test_future_gather_cancel_outer():
assert gather.result()
def test_future_shield():
# An already done future is returned as-is (no cancellation can occur)
inner = core.AsyncFuture()
inner.set_result(None)
outer = core.AsyncFuture.shield(inner)
assert inner == outer
# Normally finishing future
inner = core.AsyncFuture()
outer = core.AsyncFuture.shield(inner)
assert not outer.done()
inner.set_result(None)
assert outer.done()
assert not outer.cancelled()
assert inner.result() is None
# Normally finishing future with result
inner = core.AsyncFuture()
outer = core.AsyncFuture.shield(inner)
assert not outer.done()
inner.set_result(123)
assert outer.done()
assert not outer.cancelled()
assert inner.result() == 123
# Cancelled inner future does propagate cancellation outward
inner = core.AsyncFuture()
outer = core.AsyncFuture.shield(inner)
assert not outer.done()
inner.cancel()
assert outer.done()
assert outer.cancelled()
# Finished outer future does nothing to inner
inner = core.AsyncFuture()
outer = core.AsyncFuture.shield(inner)
outer.set_result(None)
assert not inner.done()
inner.cancel()
assert not outer.cancelled()
# Cancelled outer future does nothing to inner
inner = core.AsyncFuture()
outer = core.AsyncFuture.shield(inner)
outer.cancel()
assert not inner.done()
inner.cancel()
# Can be shielded multiple times
inner = core.AsyncFuture()
outer1 = core.AsyncFuture.shield(inner)
outer2 = core.AsyncFuture.shield(inner)
outer1.cancel()
assert not inner.done()
assert not outer2.done()
inner.cancel()
assert outer1.done()
assert outer2.done()
def test_future_done_callback():
fut = core.AsyncFuture()