diff --git a/direct/src/fsm/FSM.py b/direct/src/fsm/FSM.py index 93d11df21e..f625b4a046 100644 --- a/direct/src/fsm/FSM.py +++ b/direct/src/fsm/FSM.py @@ -13,6 +13,8 @@ from direct.showbase.MessengerGlobal import messenger from direct.showbase import PythonUtil from direct.directnotify import DirectNotifyGlobal from direct.stdpy.threading import RLock +from panda3d.core import AsyncTaskManager, AsyncFuture, PythonTask +import types class FSMException(Exception): @@ -27,6 +29,19 @@ class RequestDenied(FSMException): pass +class Transition(tuple): + """Used for the return value of fsm.request(). Behaves like a tuple, for + historical reasons.""" + + _future = None + + def __await__(self): + if self._future: + yield self._future + + return tuple(self) + + class FSM(DirectObject): """ A Finite State Machine. This is intended to be the base class @@ -154,6 +169,9 @@ class FSM(DirectObject): # must be approved by some filter function. defaultTransitions = None + __doneFuture = AsyncFuture() + __doneFuture.set_result(None) + # An enum class for special states like the DEFAULT or ANY state, # that should be treatened by the FSM in a special way class EnumStates(): @@ -247,7 +265,13 @@ class FSM(DirectObject): def forceTransition(self, request, *args): """Changes unconditionally to the indicated state. This bypasses the filterState() function, and just calls - exitState() followed by enterState().""" + exitState() followed by enterState(). + + If the FSM is currently undergoing a transition, this will + queue up the new transition. + + Returns a future, which can be used to await the transition. + """ self.fsmLock.acquire() try: @@ -257,11 +281,13 @@ class FSM(DirectObject): if not self.state: # Queue up the request. - self.__requestQueue.append(PythonUtil.Functor( - self.forceTransition, request, *args)) - return + fut = AsyncFuture() + self.__requestQueue.append((PythonUtil.Functor( + self.forceTransition, request, *args), fut)) + return fut - self.__setState(request, *args) + result = self.__setState(request, *args) + return result._future or self.__doneFuture finally: self.fsmLock.release() @@ -275,6 +301,10 @@ class FSM(DirectObject): request is queued up and will be executed when the current transition finishes. Multiple requests will queue up in sequence. + + The return value of this function can be used in an `await` + expression to suspend the current coroutine until the + transition is done. """ self.fsmLock.acquire() @@ -284,12 +314,15 @@ class FSM(DirectObject): self._name, request, str(args)[1:])) if not self.state: # Queue up the request. - self.__requestQueue.append(PythonUtil.Functor( - self.demand, request, *args)) - return + fut = AsyncFuture() + self.__requestQueue.append((PythonUtil.Functor( + self.demand, request, *args), fut)) + return fut - if not self.request(request, *args): + result = self.request(request, *args) + if not result: raise RequestDenied("%s (from state: %s)" % (request, self.state)) + return result._future or self.__doneFuture finally: self.fsmLock.release() @@ -314,7 +347,12 @@ class FSM(DirectObject): executing an enterState or exitState function), an `AlreadyInTransition` exception is raised (but see `demand()`, which will queue these requests up and apply when the - transition is complete).""" + transition is complete). + + If the previous state's exitFunc or the new state's enterFunc + is a coroutine, the state change may not have been applied by + the time request() returns, but you can use `await` on the + return value to await the transition.""" self.fsmLock.acquire() try: @@ -331,7 +369,7 @@ class FSM(DirectObject): result = (result,) + args # Otherwise, assume it's a (name, *args) tuple - self.__setState(*result) + return self.__setState(*result) return result finally: @@ -441,11 +479,11 @@ class FSM(DirectObject): try: if self.stateArray: if not self.state in self.stateArray: - self.request(self.stateArray[0]) + return self.request(self.stateArray[0]) else: cur_index = self.stateArray.index(self.state) new_index = (cur_index + 1) % len(self.stateArray) - self.request(self.stateArray[new_index], args) + return self.request(self.stateArray[new_index], args) else: assert self.notifier.debug( "stateArray empty. Can't switch to next.") @@ -459,11 +497,11 @@ class FSM(DirectObject): try: if self.stateArray: if not self.state in self.stateArray: - self.request(self.stateArray[0]) + return self.request(self.stateArray[0]) else: cur_index = self.stateArray.index(self.state) new_index = (cur_index - 1) % len(self.stateArray) - self.request(self.stateArray[new_index], args) + return self.request(self.stateArray[new_index], args) else: assert self.notifier.debug( "stateArray empty. Can't switch to next.") @@ -471,8 +509,26 @@ class FSM(DirectObject): self.fsmLock.release() def __setState(self, newState, *args): - # Internal function to change unconditionally to the indicated - # state. + # Internal function to change unconditionally to the indicated state. + + transition = Transition((newState,) + args) + + # See if we can transition immediately by polling the coroutine. + coro = self.__transition(newState, *args) + try: + coro.send(None) + except StopIteration: + # We managed to apply this straight away. + return transition + + # Continue the state transition in a task. + task = PythonTask(coro) + mgr = AsyncTaskManager.get_global_ptr() + mgr.add(task) + transition._future = task + return transition + + async def __transition(self, newState, *args): assert self.state assert self.notify.debug("%s to state %s." % (self._name, newState)) @@ -482,8 +538,13 @@ class FSM(DirectObject): try: if not self.__callFromToFunc(self.oldState, self.newState, *args): - self.__callExitFunc(self.oldState) - self.__callEnterFunc(self.newState, *args) + result = self.__callExitFunc(self.oldState) + if isinstance(result, types.CoroutineType): + await result + + result = self.__callEnterFunc(self.newState, *args) + if isinstance(result, types.CoroutineType): + await result except: # If we got an exception during the enter or exit methods, # go directly to state "InternalError" and raise up the @@ -503,9 +564,10 @@ class FSM(DirectObject): del self.newState if self.__requestQueue: - request = self.__requestQueue.pop(0) + request, fut = self.__requestQueue.pop(0) assert self.notify.debug("%s continued queued request." % (self._name)) - request() + await request() + fut.set_result(None) def __callEnterFunc(self, name, *args): # Calls the appropriate enter function when transitioning into @@ -517,7 +579,7 @@ class FSM(DirectObject): # If there's no matching enterFoo() function, call # defaultEnter() instead. func = self.defaultEnter - func(*args) + return func(*args) def __callFromToFunc(self, oldState, newState, *args): # Calls the appropriate fromTo function when transitioning into @@ -540,7 +602,7 @@ class FSM(DirectObject): # If there's no matching exitFoo() function, call # defaultExit() instead. func = self.defaultExit - func() + return func() def __repr__(self): return self.__str__()