fsm: Support asynchronous transitions via coroutine enter/exit funcs

Fixes #1037
This commit is contained in:
rdb 2021-02-08 16:07:11 +01:00
parent cdad2c6e58
commit 9cb3c7726f

View File

@ -13,6 +13,8 @@ from direct.showbase.MessengerGlobal import messenger
from direct.showbase import PythonUtil from direct.showbase import PythonUtil
from direct.directnotify import DirectNotifyGlobal from direct.directnotify import DirectNotifyGlobal
from direct.stdpy.threading import RLock from direct.stdpy.threading import RLock
from panda3d.core import AsyncTaskManager, AsyncFuture, PythonTask
import types
class FSMException(Exception): class FSMException(Exception):
@ -27,6 +29,19 @@ class RequestDenied(FSMException):
pass pass
class Transition(tuple):
"""Used for the return value of fsm.request(). Behaves like a tuple, for
historical reasons."""
_future = None
def __await__(self):
if self._future:
yield self._future
return tuple(self)
class FSM(DirectObject): class FSM(DirectObject):
""" """
A Finite State Machine. This is intended to be the base class A Finite State Machine. This is intended to be the base class
@ -154,6 +169,9 @@ class FSM(DirectObject):
# must be approved by some filter function. # must be approved by some filter function.
defaultTransitions = None defaultTransitions = None
__doneFuture = AsyncFuture()
__doneFuture.set_result(None)
# An enum class for special states like the DEFAULT or ANY state, # An enum class for special states like the DEFAULT or ANY state,
# that should be treatened by the FSM in a special way # that should be treatened by the FSM in a special way
class EnumStates(): class EnumStates():
@ -247,7 +265,13 @@ class FSM(DirectObject):
def forceTransition(self, request, *args): def forceTransition(self, request, *args):
"""Changes unconditionally to the indicated state. This """Changes unconditionally to the indicated state. This
bypasses the filterState() function, and just calls bypasses the filterState() function, and just calls
exitState() followed by enterState().""" exitState() followed by enterState().
If the FSM is currently undergoing a transition, this will
queue up the new transition.
Returns a future, which can be used to await the transition.
"""
self.fsmLock.acquire() self.fsmLock.acquire()
try: try:
@ -257,11 +281,13 @@ class FSM(DirectObject):
if not self.state: if not self.state:
# Queue up the request. # Queue up the request.
self.__requestQueue.append(PythonUtil.Functor( fut = AsyncFuture()
self.forceTransition, request, *args)) self.__requestQueue.append((PythonUtil.Functor(
return self.forceTransition, request, *args), fut))
return fut
self.__setState(request, *args) result = self.__setState(request, *args)
return result._future or self.__doneFuture
finally: finally:
self.fsmLock.release() self.fsmLock.release()
@ -275,6 +301,10 @@ class FSM(DirectObject):
request is queued up and will be executed when the current request is queued up and will be executed when the current
transition finishes. Multiple requests will queue up in transition finishes. Multiple requests will queue up in
sequence. sequence.
The return value of this function can be used in an `await`
expression to suspend the current coroutine until the
transition is done.
""" """
self.fsmLock.acquire() self.fsmLock.acquire()
@ -284,12 +314,15 @@ class FSM(DirectObject):
self._name, request, str(args)[1:])) self._name, request, str(args)[1:]))
if not self.state: if not self.state:
# Queue up the request. # Queue up the request.
self.__requestQueue.append(PythonUtil.Functor( fut = AsyncFuture()
self.demand, request, *args)) self.__requestQueue.append((PythonUtil.Functor(
return self.demand, request, *args), fut))
return fut
if not self.request(request, *args): result = self.request(request, *args)
if not result:
raise RequestDenied("%s (from state: %s)" % (request, self.state)) raise RequestDenied("%s (from state: %s)" % (request, self.state))
return result._future or self.__doneFuture
finally: finally:
self.fsmLock.release() self.fsmLock.release()
@ -314,7 +347,12 @@ class FSM(DirectObject):
executing an enterState or exitState function), an executing an enterState or exitState function), an
`AlreadyInTransition` exception is raised (but see `demand()`, `AlreadyInTransition` exception is raised (but see `demand()`,
which will queue these requests up and apply when the which will queue these requests up and apply when the
transition is complete).""" transition is complete).
If the previous state's exitFunc or the new state's enterFunc
is a coroutine, the state change may not have been applied by
the time request() returns, but you can use `await` on the
return value to await the transition."""
self.fsmLock.acquire() self.fsmLock.acquire()
try: try:
@ -331,7 +369,7 @@ class FSM(DirectObject):
result = (result,) + args result = (result,) + args
# Otherwise, assume it's a (name, *args) tuple # Otherwise, assume it's a (name, *args) tuple
self.__setState(*result) return self.__setState(*result)
return result return result
finally: finally:
@ -441,11 +479,11 @@ class FSM(DirectObject):
try: try:
if self.stateArray: if self.stateArray:
if not self.state in self.stateArray: if not self.state in self.stateArray:
self.request(self.stateArray[0]) return self.request(self.stateArray[0])
else: else:
cur_index = self.stateArray.index(self.state) cur_index = self.stateArray.index(self.state)
new_index = (cur_index + 1) % len(self.stateArray) new_index = (cur_index + 1) % len(self.stateArray)
self.request(self.stateArray[new_index], args) return self.request(self.stateArray[new_index], args)
else: else:
assert self.notifier.debug( assert self.notifier.debug(
"stateArray empty. Can't switch to next.") "stateArray empty. Can't switch to next.")
@ -459,11 +497,11 @@ class FSM(DirectObject):
try: try:
if self.stateArray: if self.stateArray:
if not self.state in self.stateArray: if not self.state in self.stateArray:
self.request(self.stateArray[0]) return self.request(self.stateArray[0])
else: else:
cur_index = self.stateArray.index(self.state) cur_index = self.stateArray.index(self.state)
new_index = (cur_index - 1) % len(self.stateArray) new_index = (cur_index - 1) % len(self.stateArray)
self.request(self.stateArray[new_index], args) return self.request(self.stateArray[new_index], args)
else: else:
assert self.notifier.debug( assert self.notifier.debug(
"stateArray empty. Can't switch to next.") "stateArray empty. Can't switch to next.")
@ -471,8 +509,26 @@ class FSM(DirectObject):
self.fsmLock.release() self.fsmLock.release()
def __setState(self, newState, *args): def __setState(self, newState, *args):
# Internal function to change unconditionally to the indicated # Internal function to change unconditionally to the indicated state.
# state.
transition = Transition((newState,) + args)
# See if we can transition immediately by polling the coroutine.
coro = self.__transition(newState, *args)
try:
coro.send(None)
except StopIteration:
# We managed to apply this straight away.
return transition
# Continue the state transition in a task.
task = PythonTask(coro)
mgr = AsyncTaskManager.get_global_ptr()
mgr.add(task)
transition._future = task
return transition
async def __transition(self, newState, *args):
assert self.state assert self.state
assert self.notify.debug("%s to state %s." % (self._name, newState)) assert self.notify.debug("%s to state %s." % (self._name, newState))
@ -482,8 +538,13 @@ class FSM(DirectObject):
try: try:
if not self.__callFromToFunc(self.oldState, self.newState, *args): if not self.__callFromToFunc(self.oldState, self.newState, *args):
self.__callExitFunc(self.oldState) result = self.__callExitFunc(self.oldState)
self.__callEnterFunc(self.newState, *args) if isinstance(result, types.CoroutineType):
await result
result = self.__callEnterFunc(self.newState, *args)
if isinstance(result, types.CoroutineType):
await result
except: except:
# If we got an exception during the enter or exit methods, # If we got an exception during the enter or exit methods,
# go directly to state "InternalError" and raise up the # go directly to state "InternalError" and raise up the
@ -503,9 +564,10 @@ class FSM(DirectObject):
del self.newState del self.newState
if self.__requestQueue: if self.__requestQueue:
request = self.__requestQueue.pop(0) request, fut = self.__requestQueue.pop(0)
assert self.notify.debug("%s continued queued request." % (self._name)) assert self.notify.debug("%s continued queued request." % (self._name))
request() await request()
fut.set_result(None)
def __callEnterFunc(self, name, *args): def __callEnterFunc(self, name, *args):
# Calls the appropriate enter function when transitioning into # Calls the appropriate enter function when transitioning into
@ -517,7 +579,7 @@ class FSM(DirectObject):
# If there's no matching enterFoo() function, call # If there's no matching enterFoo() function, call
# defaultEnter() instead. # defaultEnter() instead.
func = self.defaultEnter func = self.defaultEnter
func(*args) return func(*args)
def __callFromToFunc(self, oldState, newState, *args): def __callFromToFunc(self, oldState, newState, *args):
# Calls the appropriate fromTo function when transitioning into # Calls the appropriate fromTo function when transitioning into
@ -540,7 +602,7 @@ class FSM(DirectObject):
# If there's no matching exitFoo() function, call # If there's no matching exitFoo() function, call
# defaultExit() instead. # defaultExit() instead.
func = self.defaultExit func = self.defaultExit
func() return func()
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()