From 5ef1b44455b05db11e1a89673fdccdce4830f9a2 Mon Sep 17 00:00:00 2001 From: rdb Date: Sun, 18 Apr 2021 17:05:00 +0200 Subject: [PATCH] event: cancel future being awaited when cancelling coroutine task Fixes #1136 --- panda/src/event/asyncFuture.h | 2 +- panda/src/event/asyncFuture_ext.cxx | 37 +++-- panda/src/event/asyncFuture_ext.h | 2 +- panda/src/event/pythonTask.cxx | 95 +++++++++--- panda/src/event/pythonTask.h | 2 +- tests/event/test_futures.py | 221 ++++++++++++++++++++++++++++ 6 files changed, 324 insertions(+), 35 deletions(-) diff --git a/panda/src/event/asyncFuture.h b/panda/src/event/asyncFuture.h index e80918c149..73abe9a3a3 100644 --- a/panda/src/event/asyncFuture.h +++ b/panda/src/event/asyncFuture.h @@ -67,7 +67,7 @@ PUBLISHED: INLINE bool done() const; INLINE bool cancelled() const; - EXTENSION(PyObject *result(PyObject *timeout = Py_None) const); + EXTENSION(PyObject *result(PyObject *self, PyObject *timeout = Py_None) const); virtual bool cancel(); diff --git a/panda/src/event/asyncFuture_ext.cxx b/panda/src/event/asyncFuture_ext.cxx index fcd49598fa..7bc2376913 100644 --- a/panda/src/event/asyncFuture_ext.cxx +++ b/panda/src/event/asyncFuture_ext.cxx @@ -222,16 +222,35 @@ set_result(PyObject *result) { * raises TimeoutError. */ PyObject *Extension:: -result(PyObject *timeout) const { +result(PyObject *self, PyObject *timeout) const { + double timeout_val; + if (timeout != Py_None) { + timeout_val = PyFloat_AsDouble(timeout); + if (timeout_val == -1.0 && _PyErr_OCCURRED()) { + return nullptr; + } + } + if (!_this->done()) { // Not yet done? Wait until it is done, or until a timeout occurs. But // first check to make sure we're not trying to deadlock the thread. Thread *current_thread = Thread::get_current_thread(); - if (_this == (const AsyncFuture *)current_thread->get_current_task()) { + AsyncTask *current_task = (AsyncTask *)current_thread->get_current_task(); + if (_this == current_task) { PyErr_SetString(PyExc_RuntimeError, "cannot call task.result() from within the task"); return nullptr; } + PythonTask *python_task = nullptr; + if (current_task != nullptr && + current_task->is_of_type(PythonTask::get_class_type())) { + // If we are calling result() inside a coroutine, mark it as awaiting this + // future. That makes it possible to cancel() us from another thread. + python_task = (PythonTask *)current_task; + nassertr(python_task->_fut_waiter == nullptr, nullptr); + python_task->_fut_waiter = self; + } + // Release the GIL for the duration. #if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS) PyThreadState *_save; @@ -239,18 +258,18 @@ result(PyObject *timeout) const { #endif if (timeout == Py_None) { _this->wait(); - } else { - PyObject *num = PyNumber_Float(timeout); - if (num != nullptr) { - _this->wait(PyFloat_AS_DOUBLE(num)); - } else { - return Dtool_Raise_ArgTypeError(timeout, 0, "result", "float"); - } + } + else { + _this->wait(timeout_val); } #if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS) Py_BLOCK_THREADS #endif + if (python_task != nullptr) { + python_task->_fut_waiter = nullptr; + } + if (!_this->done()) { // It timed out. Raise an exception. static PyObject *exc_type = nullptr; diff --git a/panda/src/event/asyncFuture_ext.h b/panda/src/event/asyncFuture_ext.h index fbf834331f..e4f8ba9bcc 100644 --- a/panda/src/event/asyncFuture_ext.h +++ b/panda/src/event/asyncFuture_ext.h @@ -30,7 +30,7 @@ public: static PyObject *__iter__(PyObject *self) { return __await__(self); } void set_result(PyObject *result); - PyObject *result(PyObject *timeout = Py_None) const; + PyObject *result(PyObject *self, PyObject *timeout = Py_None) const; PyObject *add_done_callback(PyObject *self, PyObject *fn); diff --git a/panda/src/event/pythonTask.cxx b/panda/src/event/pythonTask.cxx index 2091fa9214..d9c28b4611 100644 --- a/panda/src/event/pythonTask.cxx +++ b/panda/src/event/pythonTask.cxx @@ -45,7 +45,7 @@ PythonTask(PyObject *func_or_coro, const std::string &name) : _exc_value(nullptr), _exc_traceback(nullptr), _generator(nullptr), - _future_done(nullptr), + _fut_waiter(nullptr), _ignore_return(false), _retrieved_exception(false) { @@ -404,20 +404,58 @@ cancel() { << "Cancelling " << *this << "\n"; } + bool must_cancel = true; + if (_fut_waiter != nullptr) { + // Cancel the future that this task is waiting on. Note that we do this + // before grabbing the lock, since this operation may also grab it. This + // means that _fut_waiter is only protected by the GIL. +#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS) + // Use PyGILState to protect this asynchronous call. + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); +#endif + + // Shortcut for unextended AsyncFuture. + if (Py_TYPE(_fut_waiter) == (PyTypeObject *)&Dtool_AsyncFuture) { + AsyncFuture *fut = (AsyncFuture *)DtoolInstance_VOID_PTR(_fut_waiter); + if (!fut->done()) { + fut->cancel(); + } + if (fut->done()) { + // We don't need this anymore. + Py_DECREF(_fut_waiter); + _fut_waiter = nullptr; + } + } + else { + PyObject *result = PyObject_CallMethod(_fut_waiter, "cancel", nullptr); + Py_XDECREF(result); + } + +#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS) + PyGILState_Release(gstate); +#endif + // Keep _fut_waiter in any case, because we may need to cancel it again + // later if it ignores the cancellation. + } + MutexHolder holder(manager->_lock); if (_state == S_awaiting) { // Reactivate it so that it can receive a CancelledException. - _must_cancel = true; + if (must_cancel) { + _must_cancel = true; + } _state = AsyncTask::S_active; _chain->_active.push_back(this); --_chain->_num_awaiting_tasks; return true; } - else if (_future_done != nullptr) { - // We are polling, waiting for a non-Panda future to be done. - Py_DECREF(_future_done); - _future_done = nullptr; - _must_cancel = true; + else if (must_cancel || _fut_waiter != nullptr) { + // We may be polling an external future, so we still need to throw a + // CancelledException and allow it to be caught. + if (must_cancel) { + _must_cancel = true; + } return true; } else if (_chain->do_remove(this, true)) { @@ -477,17 +515,24 @@ AsyncTask::DoneStatus PythonTask:: do_python_task() { PyObject *result = nullptr; - // Are we waiting for a future to finish? - if (_future_done != nullptr) { - PyObject *is_done = PyObject_CallNoArgs(_future_done); - if (!PyObject_IsTrue(is_done)) { - // Nope, ask again next frame. + // Are we waiting for a future to finish? Short-circuit all the logic below + // by simply calling done(). + { + PyObject *fut_waiter = _fut_waiter; + if (fut_waiter != nullptr) { + PyObject *is_done = PyObject_CallMethod(fut_waiter, "done", nullptr); + if (is_done == nullptr) { + return DS_interrupt; + } + if (!PyObject_IsTrue(is_done)) { + // Nope, ask again next frame. + Py_DECREF(is_done); + return DS_cont; + } Py_DECREF(is_done); - return DS_cont; + Py_DECREF(fut_waiter); + _fut_waiter = nullptr; } - Py_DECREF(is_done); - Py_DECREF(_future_done); - _future_done = nullptr; } if (_generator == nullptr) { @@ -664,7 +709,9 @@ do_python_task() { task_cat.error() << *this << " cannot await itself\n"; } - Py_DECREF(result); + // Store the Python object in case we need to cancel it (it may be a + // subclass of AsyncFuture that overrides cancel() from Python) + _fut_waiter = result; return DS_await; } } else { @@ -674,8 +721,9 @@ do_python_task() { if (check != nullptr && check != Py_None) { Py_DECREF(check); // Next frame, check whether this future is done. - _future_done = PyObject_GetAttrString(result, "done"); - if (_future_done == nullptr || !PyCallable_Check(_future_done)) { + PyObject *fut_done = PyObject_GetAttrString(result, "done"); + if (fut_done == nullptr || !PyCallable_Check(fut_done)) { + Py_XDECREF(fut_done); task_cat.error() << "future.done is not callable\n"; return DS_interrupt; @@ -686,7 +734,7 @@ do_python_task() { << *this << " is now polling " << PyUnicode_AsUTF8(str) << ".done()\n"; Py_DECREF(str); } - Py_DECREF(result); + _fut_waiter = result; return DS_cont; } PyErr_Clear(); @@ -802,9 +850,10 @@ upon_death(AsyncTaskManager *manager, bool clean_exit) { AsyncTask::upon_death(manager, clean_exit); // If we were polling something when we were removed, get rid of it. - if (_future_done != nullptr) { - Py_DECREF(_future_done); - _future_done = nullptr; + //TODO: should we call cancel() on it? + if (_fut_waiter != nullptr) { + Py_DECREF(_fut_waiter); + _fut_waiter = nullptr; } if (_upon_death != Py_None) { diff --git a/panda/src/event/pythonTask.h b/panda/src/event/pythonTask.h index 7af6598ab1..88bee56bd9 100644 --- a/panda/src/event/pythonTask.h +++ b/panda/src/event/pythonTask.h @@ -115,7 +115,7 @@ private: PyObject *_exc_traceback; PyObject *_generator; - PyObject *_future_done; + PyObject *_fut_waiter; bool _append_task; bool _ignore_return; diff --git a/tests/event/test_futures.py b/tests/event/test_futures.py index d18d52c3f7..db09d8c677 100644 --- a/tests/event/test_futures.py +++ b/tests/event/test_futures.py @@ -9,6 +9,33 @@ else: from concurrent.futures._base import TimeoutError, CancelledError +class MockFuture: + _asyncio_future_blocking = False + _state = 'PENDING' + _cancel_return = False + _result = None + + def __await__(self): + while self._state == 'PENDING': + yield self + return self.result() + + def done(self): + return self._state != 'PENDING' + + def cancelled(self): + return self._state == 'CANCELLED' + + def cancel(self): + return self._cancel_return + + def result(self): + if self._state == 'CANCELLED': + raise CancelledError + + return self._result + + def test_future_cancelled(): fut = core.AsyncFuture() @@ -123,6 +150,66 @@ def test_task_cancel_during_run(): task.result() +def test_task_cancel_waiting(): + # Calling result() in a threaded task chain should cancel the future being + # waited on if the surrounding task is cancelled. + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_chain = task_mgr.make_task_chain("test_task_cancel_waiting") + task_chain.set_num_threads(1) + + fut = core.AsyncFuture() + + async def task_main(task): + # This will block the thread this task is in until the future is done, + # or until the task is cancelled (which implicitly cancels the future). + fut.result() + return task.done + + task = core.PythonTask(task_main, 'task_main') + task.set_task_chain(task_chain.name) + task_mgr.add(task) + + task_chain.start_threads() + try: + assert not task.done() + fut.cancel() + task.wait() + + assert task.cancelled() + assert fut.cancelled() + + finally: + task_chain.stop_threads() + + +def test_task_cancel_awaiting(): + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_chain = task_mgr.make_task_chain("test_task_cancel_awaiting") + + fut = core.AsyncFuture() + + async def task_main(task): + await fut + return task.done + + task = core.PythonTask(task_main, 'task_main') + task.set_task_chain(task_chain.name) + task_mgr.add(task) + + task_chain.poll() + assert not task.done() + + task_chain.poll() + assert not task.done() + + task.cancel() + task_chain.poll() + assert task.done() + assert task.cancelled() + assert fut.done() + assert fut.cancelled() + + def test_task_result(): task_mgr = core.AsyncTaskManager.get_global_ptr() task_chain = task_mgr.make_task_chain("test_task_result") @@ -144,6 +231,140 @@ def test_task_result(): assert task.result() == 42 +def test_coro_await_coro(): + # Await another coro in a coro. + fut = core.AsyncFuture() + async def coro2(): + await fut + + async def coro_main(): + await coro2() + + task = core.PythonTask(coro_main()) + + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_mgr.add(task) + for i in range(5): + task_mgr.poll() + + assert not task.done() + fut.set_result(None) + task_mgr.poll() + assert task.done() + assert not task.cancelled() + + +def test_coro_await_cancel_resistant_coro(): + # Await another coro in a coro, but cancel the outer. + fut = core.AsyncFuture() + cancelled_caught = [0] + keep_going = [False] + + async def cancel_resistant_coro(): + while not fut.done(): + try: + await core.AsyncFuture.shield(fut) + except CancelledError as ex: + cancelled_caught[0] += 1 + + async def coro_main(): + await cancel_resistant_coro() + + task = core.PythonTask(coro_main(), 'coro_main') + + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_mgr.add(task) + assert not task.done() + + task_mgr.poll() + assert not task.done() + + # No cancelling it once it started... + for i in range(3): + assert task.cancel() + assert not task.done() + + for j in range(3): + task_mgr.poll() + assert not task.done() + + assert cancelled_caught[0] == 3 + + fut.set_result(None) + task_mgr.poll() + assert task.done() + assert not task.cancelled() + + +def test_coro_await_external(): + # Await an external future in a coro. + fut = MockFuture() + fut._result = 12345 + res = [] + + async def coro_main(): + res.append(await fut) + + task = core.PythonTask(coro_main(), 'coro_main') + + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_mgr.add(task) + for i in range(5): + task_mgr.poll() + + assert not task.done() + fut._state = 'FINISHED' + task_mgr.poll() + assert task.done() + assert not task.cancelled() + assert res == [12345] + + +def test_coro_await_external_cancel_inner(): + # Cancel external future being awaited by a coro. + fut = MockFuture() + + async def coro_main(): + await fut + + task = core.PythonTask(coro_main(), 'coro_main') + + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_mgr.add(task) + for i in range(5): + task_mgr.poll() + + assert not task.done() + fut._state = 'CANCELLED' + assert not task.done() + task_mgr.poll() + assert task.done() + assert task.cancelled() + + +def test_coro_await_external_cancel_outer(): + # Cancel task that is awaiting external future. + fut = MockFuture() + result = [] + + async def coro_main(): + result.append(await fut) + + task = core.PythonTask(coro_main(), 'coro_main') + + task_mgr = core.AsyncTaskManager.get_global_ptr() + task_mgr.add(task) + for i in range(5): + task_mgr.poll() + + assert not task.done() + fut._state = 'CANCELLED' + assert not task.done() + task_mgr.poll() + assert task.done() + assert task.cancelled() + + def test_coro_exception(): task_mgr = core.AsyncTaskManager.get_global_ptr() task_chain = task_mgr.make_task_chain("test_coro_exception")