event: cancel future being awaited when cancelling coroutine task

Fixes #1136
This commit is contained in:
rdb 2021-04-18 17:05:00 +02:00
parent cdf5b16ddd
commit 5ef1b44455
6 changed files with 324 additions and 35 deletions

View File

@ -67,7 +67,7 @@ PUBLISHED:
INLINE bool done() const;
INLINE bool cancelled() const;
EXTENSION(PyObject *result(PyObject *timeout = Py_None) const);
EXTENSION(PyObject *result(PyObject *self, PyObject *timeout = Py_None) const);
virtual bool cancel();

View File

@ -222,16 +222,35 @@ set_result(PyObject *result) {
* raises TimeoutError.
*/
PyObject *Extension<AsyncFuture>::
result(PyObject *timeout) const {
result(PyObject *self, PyObject *timeout) const {
double timeout_val;
if (timeout != Py_None) {
timeout_val = PyFloat_AsDouble(timeout);
if (timeout_val == -1.0 && _PyErr_OCCURRED()) {
return nullptr;
}
}
if (!_this->done()) {
// Not yet done? Wait until it is done, or until a timeout occurs. But
// first check to make sure we're not trying to deadlock the thread.
Thread *current_thread = Thread::get_current_thread();
if (_this == (const AsyncFuture *)current_thread->get_current_task()) {
AsyncTask *current_task = (AsyncTask *)current_thread->get_current_task();
if (_this == current_task) {
PyErr_SetString(PyExc_RuntimeError, "cannot call task.result() from within the task");
return nullptr;
}
PythonTask *python_task = nullptr;
if (current_task != nullptr &&
current_task->is_of_type(PythonTask::get_class_type())) {
// If we are calling result() inside a coroutine, mark it as awaiting this
// future. That makes it possible to cancel() us from another thread.
python_task = (PythonTask *)current_task;
nassertr(python_task->_fut_waiter == nullptr, nullptr);
python_task->_fut_waiter = self;
}
// Release the GIL for the duration.
#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
PyThreadState *_save;
@ -239,18 +258,18 @@ result(PyObject *timeout) const {
#endif
if (timeout == Py_None) {
_this->wait();
} else {
PyObject *num = PyNumber_Float(timeout);
if (num != nullptr) {
_this->wait(PyFloat_AS_DOUBLE(num));
} else {
return Dtool_Raise_ArgTypeError(timeout, 0, "result", "float");
}
}
else {
_this->wait(timeout_val);
}
#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
Py_BLOCK_THREADS
#endif
if (python_task != nullptr) {
python_task->_fut_waiter = nullptr;
}
if (!_this->done()) {
// It timed out. Raise an exception.
static PyObject *exc_type = nullptr;

View File

@ -30,7 +30,7 @@ public:
static PyObject *__iter__(PyObject *self) { return __await__(self); }
void set_result(PyObject *result);
PyObject *result(PyObject *timeout = Py_None) const;
PyObject *result(PyObject *self, PyObject *timeout = Py_None) const;
PyObject *add_done_callback(PyObject *self, PyObject *fn);

View File

@ -45,7 +45,7 @@ PythonTask(PyObject *func_or_coro, const std::string &name) :
_exc_value(nullptr),
_exc_traceback(nullptr),
_generator(nullptr),
_future_done(nullptr),
_fut_waiter(nullptr),
_ignore_return(false),
_retrieved_exception(false) {
@ -404,20 +404,58 @@ cancel() {
<< "Cancelling " << *this << "\n";
}
bool must_cancel = true;
if (_fut_waiter != nullptr) {
// Cancel the future that this task is waiting on. Note that we do this
// before grabbing the lock, since this operation may also grab it. This
// means that _fut_waiter is only protected by the GIL.
#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
// Use PyGILState to protect this asynchronous call.
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
#endif
// Shortcut for unextended AsyncFuture.
if (Py_TYPE(_fut_waiter) == (PyTypeObject *)&Dtool_AsyncFuture) {
AsyncFuture *fut = (AsyncFuture *)DtoolInstance_VOID_PTR(_fut_waiter);
if (!fut->done()) {
fut->cancel();
}
if (fut->done()) {
// We don't need this anymore.
Py_DECREF(_fut_waiter);
_fut_waiter = nullptr;
}
}
else {
PyObject *result = PyObject_CallMethod(_fut_waiter, "cancel", nullptr);
Py_XDECREF(result);
}
#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
PyGILState_Release(gstate);
#endif
// Keep _fut_waiter in any case, because we may need to cancel it again
// later if it ignores the cancellation.
}
MutexHolder holder(manager->_lock);
if (_state == S_awaiting) {
// Reactivate it so that it can receive a CancelledException.
_must_cancel = true;
if (must_cancel) {
_must_cancel = true;
}
_state = AsyncTask::S_active;
_chain->_active.push_back(this);
--_chain->_num_awaiting_tasks;
return true;
}
else if (_future_done != nullptr) {
// We are polling, waiting for a non-Panda future to be done.
Py_DECREF(_future_done);
_future_done = nullptr;
_must_cancel = true;
else if (must_cancel || _fut_waiter != nullptr) {
// We may be polling an external future, so we still need to throw a
// CancelledException and allow it to be caught.
if (must_cancel) {
_must_cancel = true;
}
return true;
}
else if (_chain->do_remove(this, true)) {
@ -477,17 +515,24 @@ AsyncTask::DoneStatus PythonTask::
do_python_task() {
PyObject *result = nullptr;
// Are we waiting for a future to finish?
if (_future_done != nullptr) {
PyObject *is_done = PyObject_CallNoArgs(_future_done);
if (!PyObject_IsTrue(is_done)) {
// Nope, ask again next frame.
// Are we waiting for a future to finish? Short-circuit all the logic below
// by simply calling done().
{
PyObject *fut_waiter = _fut_waiter;
if (fut_waiter != nullptr) {
PyObject *is_done = PyObject_CallMethod(fut_waiter, "done", nullptr);
if (is_done == nullptr) {
return DS_interrupt;
}
if (!PyObject_IsTrue(is_done)) {
// Nope, ask again next frame.
Py_DECREF(is_done);
return DS_cont;
}
Py_DECREF(is_done);
return DS_cont;
Py_DECREF(fut_waiter);
_fut_waiter = nullptr;
}
Py_DECREF(is_done);
Py_DECREF(_future_done);
_future_done = nullptr;
}
if (_generator == nullptr) {
@ -664,7 +709,9 @@ do_python_task() {
task_cat.error()
<< *this << " cannot await itself\n";
}
Py_DECREF(result);
// Store the Python object in case we need to cancel it (it may be a
// subclass of AsyncFuture that overrides cancel() from Python)
_fut_waiter = result;
return DS_await;
}
} else {
@ -674,8 +721,9 @@ do_python_task() {
if (check != nullptr && check != Py_None) {
Py_DECREF(check);
// Next frame, check whether this future is done.
_future_done = PyObject_GetAttrString(result, "done");
if (_future_done == nullptr || !PyCallable_Check(_future_done)) {
PyObject *fut_done = PyObject_GetAttrString(result, "done");
if (fut_done == nullptr || !PyCallable_Check(fut_done)) {
Py_XDECREF(fut_done);
task_cat.error()
<< "future.done is not callable\n";
return DS_interrupt;
@ -686,7 +734,7 @@ do_python_task() {
<< *this << " is now polling " << PyUnicode_AsUTF8(str) << ".done()\n";
Py_DECREF(str);
}
Py_DECREF(result);
_fut_waiter = result;
return DS_cont;
}
PyErr_Clear();
@ -802,9 +850,10 @@ upon_death(AsyncTaskManager *manager, bool clean_exit) {
AsyncTask::upon_death(manager, clean_exit);
// If we were polling something when we were removed, get rid of it.
if (_future_done != nullptr) {
Py_DECREF(_future_done);
_future_done = nullptr;
//TODO: should we call cancel() on it?
if (_fut_waiter != nullptr) {
Py_DECREF(_fut_waiter);
_fut_waiter = nullptr;
}
if (_upon_death != Py_None) {

View File

@ -115,7 +115,7 @@ private:
PyObject *_exc_traceback;
PyObject *_generator;
PyObject *_future_done;
PyObject *_fut_waiter;
bool _append_task;
bool _ignore_return;

View File

@ -9,6 +9,33 @@ else:
from concurrent.futures._base import TimeoutError, CancelledError
class MockFuture:
_asyncio_future_blocking = False
_state = 'PENDING'
_cancel_return = False
_result = None
def __await__(self):
while self._state == 'PENDING':
yield self
return self.result()
def done(self):
return self._state != 'PENDING'
def cancelled(self):
return self._state == 'CANCELLED'
def cancel(self):
return self._cancel_return
def result(self):
if self._state == 'CANCELLED':
raise CancelledError
return self._result
def test_future_cancelled():
fut = core.AsyncFuture()
@ -123,6 +150,66 @@ def test_task_cancel_during_run():
task.result()
def test_task_cancel_waiting():
# Calling result() in a threaded task chain should cancel the future being
# waited on if the surrounding task is cancelled.
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_chain = task_mgr.make_task_chain("test_task_cancel_waiting")
task_chain.set_num_threads(1)
fut = core.AsyncFuture()
async def task_main(task):
# This will block the thread this task is in until the future is done,
# or until the task is cancelled (which implicitly cancels the future).
fut.result()
return task.done
task = core.PythonTask(task_main, 'task_main')
task.set_task_chain(task_chain.name)
task_mgr.add(task)
task_chain.start_threads()
try:
assert not task.done()
fut.cancel()
task.wait()
assert task.cancelled()
assert fut.cancelled()
finally:
task_chain.stop_threads()
def test_task_cancel_awaiting():
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_chain = task_mgr.make_task_chain("test_task_cancel_awaiting")
fut = core.AsyncFuture()
async def task_main(task):
await fut
return task.done
task = core.PythonTask(task_main, 'task_main')
task.set_task_chain(task_chain.name)
task_mgr.add(task)
task_chain.poll()
assert not task.done()
task_chain.poll()
assert not task.done()
task.cancel()
task_chain.poll()
assert task.done()
assert task.cancelled()
assert fut.done()
assert fut.cancelled()
def test_task_result():
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_chain = task_mgr.make_task_chain("test_task_result")
@ -144,6 +231,140 @@ def test_task_result():
assert task.result() == 42
def test_coro_await_coro():
# Await another coro in a coro.
fut = core.AsyncFuture()
async def coro2():
await fut
async def coro_main():
await coro2()
task = core.PythonTask(coro_main())
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_mgr.add(task)
for i in range(5):
task_mgr.poll()
assert not task.done()
fut.set_result(None)
task_mgr.poll()
assert task.done()
assert not task.cancelled()
def test_coro_await_cancel_resistant_coro():
# Await another coro in a coro, but cancel the outer.
fut = core.AsyncFuture()
cancelled_caught = [0]
keep_going = [False]
async def cancel_resistant_coro():
while not fut.done():
try:
await core.AsyncFuture.shield(fut)
except CancelledError as ex:
cancelled_caught[0] += 1
async def coro_main():
await cancel_resistant_coro()
task = core.PythonTask(coro_main(), 'coro_main')
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_mgr.add(task)
assert not task.done()
task_mgr.poll()
assert not task.done()
# No cancelling it once it started...
for i in range(3):
assert task.cancel()
assert not task.done()
for j in range(3):
task_mgr.poll()
assert not task.done()
assert cancelled_caught[0] == 3
fut.set_result(None)
task_mgr.poll()
assert task.done()
assert not task.cancelled()
def test_coro_await_external():
# Await an external future in a coro.
fut = MockFuture()
fut._result = 12345
res = []
async def coro_main():
res.append(await fut)
task = core.PythonTask(coro_main(), 'coro_main')
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_mgr.add(task)
for i in range(5):
task_mgr.poll()
assert not task.done()
fut._state = 'FINISHED'
task_mgr.poll()
assert task.done()
assert not task.cancelled()
assert res == [12345]
def test_coro_await_external_cancel_inner():
# Cancel external future being awaited by a coro.
fut = MockFuture()
async def coro_main():
await fut
task = core.PythonTask(coro_main(), 'coro_main')
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_mgr.add(task)
for i in range(5):
task_mgr.poll()
assert not task.done()
fut._state = 'CANCELLED'
assert not task.done()
task_mgr.poll()
assert task.done()
assert task.cancelled()
def test_coro_await_external_cancel_outer():
# Cancel task that is awaiting external future.
fut = MockFuture()
result = []
async def coro_main():
result.append(await fut)
task = core.PythonTask(coro_main(), 'coro_main')
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_mgr.add(task)
for i in range(5):
task_mgr.poll()
assert not task.done()
fut._state = 'CANCELLED'
assert not task.done()
task_mgr.poll()
assert task.done()
assert task.cancelled()
def test_coro_exception():
task_mgr = core.AsyncTaskManager.get_global_ptr()
task_chain = task_mgr.make_task_chain("test_coro_exception")