fsm: Support asynchronous transitions via coroutine enter/exit funcs

Fixes #1037
This commit is contained in:
rdb 2021-02-08 16:07:11 +01:00
parent cdad2c6e58
commit 9cb3c7726f

View File

@ -13,6 +13,8 @@ from direct.showbase.MessengerGlobal import messenger
from direct.showbase import PythonUtil
from direct.directnotify import DirectNotifyGlobal
from direct.stdpy.threading import RLock
from panda3d.core import AsyncTaskManager, AsyncFuture, PythonTask
import types
class FSMException(Exception):
@ -27,6 +29,19 @@ class RequestDenied(FSMException):
pass
class Transition(tuple):
"""Used for the return value of fsm.request(). Behaves like a tuple, for
historical reasons."""
_future = None
def __await__(self):
if self._future:
yield self._future
return tuple(self)
class FSM(DirectObject):
"""
A Finite State Machine. This is intended to be the base class
@ -154,6 +169,9 @@ class FSM(DirectObject):
# must be approved by some filter function.
defaultTransitions = None
__doneFuture = AsyncFuture()
__doneFuture.set_result(None)
# An enum class for special states like the DEFAULT or ANY state,
# that should be treatened by the FSM in a special way
class EnumStates():
@ -247,7 +265,13 @@ class FSM(DirectObject):
def forceTransition(self, request, *args):
"""Changes unconditionally to the indicated state. This
bypasses the filterState() function, and just calls
exitState() followed by enterState()."""
exitState() followed by enterState().
If the FSM is currently undergoing a transition, this will
queue up the new transition.
Returns a future, which can be used to await the transition.
"""
self.fsmLock.acquire()
try:
@ -257,11 +281,13 @@ class FSM(DirectObject):
if not self.state:
# Queue up the request.
self.__requestQueue.append(PythonUtil.Functor(
self.forceTransition, request, *args))
return
fut = AsyncFuture()
self.__requestQueue.append((PythonUtil.Functor(
self.forceTransition, request, *args), fut))
return fut
self.__setState(request, *args)
result = self.__setState(request, *args)
return result._future or self.__doneFuture
finally:
self.fsmLock.release()
@ -275,6 +301,10 @@ class FSM(DirectObject):
request is queued up and will be executed when the current
transition finishes. Multiple requests will queue up in
sequence.
The return value of this function can be used in an `await`
expression to suspend the current coroutine until the
transition is done.
"""
self.fsmLock.acquire()
@ -284,12 +314,15 @@ class FSM(DirectObject):
self._name, request, str(args)[1:]))
if not self.state:
# Queue up the request.
self.__requestQueue.append(PythonUtil.Functor(
self.demand, request, *args))
return
fut = AsyncFuture()
self.__requestQueue.append((PythonUtil.Functor(
self.demand, request, *args), fut))
return fut
if not self.request(request, *args):
result = self.request(request, *args)
if not result:
raise RequestDenied("%s (from state: %s)" % (request, self.state))
return result._future or self.__doneFuture
finally:
self.fsmLock.release()
@ -314,7 +347,12 @@ class FSM(DirectObject):
executing an enterState or exitState function), an
`AlreadyInTransition` exception is raised (but see `demand()`,
which will queue these requests up and apply when the
transition is complete)."""
transition is complete).
If the previous state's exitFunc or the new state's enterFunc
is a coroutine, the state change may not have been applied by
the time request() returns, but you can use `await` on the
return value to await the transition."""
self.fsmLock.acquire()
try:
@ -331,7 +369,7 @@ class FSM(DirectObject):
result = (result,) + args
# Otherwise, assume it's a (name, *args) tuple
self.__setState(*result)
return self.__setState(*result)
return result
finally:
@ -441,11 +479,11 @@ class FSM(DirectObject):
try:
if self.stateArray:
if not self.state in self.stateArray:
self.request(self.stateArray[0])
return self.request(self.stateArray[0])
else:
cur_index = self.stateArray.index(self.state)
new_index = (cur_index + 1) % len(self.stateArray)
self.request(self.stateArray[new_index], args)
return self.request(self.stateArray[new_index], args)
else:
assert self.notifier.debug(
"stateArray empty. Can't switch to next.")
@ -459,11 +497,11 @@ class FSM(DirectObject):
try:
if self.stateArray:
if not self.state in self.stateArray:
self.request(self.stateArray[0])
return self.request(self.stateArray[0])
else:
cur_index = self.stateArray.index(self.state)
new_index = (cur_index - 1) % len(self.stateArray)
self.request(self.stateArray[new_index], args)
return self.request(self.stateArray[new_index], args)
else:
assert self.notifier.debug(
"stateArray empty. Can't switch to next.")
@ -471,8 +509,26 @@ class FSM(DirectObject):
self.fsmLock.release()
def __setState(self, newState, *args):
# Internal function to change unconditionally to the indicated
# state.
# Internal function to change unconditionally to the indicated state.
transition = Transition((newState,) + args)
# See if we can transition immediately by polling the coroutine.
coro = self.__transition(newState, *args)
try:
coro.send(None)
except StopIteration:
# We managed to apply this straight away.
return transition
# Continue the state transition in a task.
task = PythonTask(coro)
mgr = AsyncTaskManager.get_global_ptr()
mgr.add(task)
transition._future = task
return transition
async def __transition(self, newState, *args):
assert self.state
assert self.notify.debug("%s to state %s." % (self._name, newState))
@ -482,8 +538,13 @@ class FSM(DirectObject):
try:
if not self.__callFromToFunc(self.oldState, self.newState, *args):
self.__callExitFunc(self.oldState)
self.__callEnterFunc(self.newState, *args)
result = self.__callExitFunc(self.oldState)
if isinstance(result, types.CoroutineType):
await result
result = self.__callEnterFunc(self.newState, *args)
if isinstance(result, types.CoroutineType):
await result
except:
# If we got an exception during the enter or exit methods,
# go directly to state "InternalError" and raise up the
@ -503,9 +564,10 @@ class FSM(DirectObject):
del self.newState
if self.__requestQueue:
request = self.__requestQueue.pop(0)
request, fut = self.__requestQueue.pop(0)
assert self.notify.debug("%s continued queued request." % (self._name))
request()
await request()
fut.set_result(None)
def __callEnterFunc(self, name, *args):
# Calls the appropriate enter function when transitioning into
@ -517,7 +579,7 @@ class FSM(DirectObject):
# If there's no matching enterFoo() function, call
# defaultEnter() instead.
func = self.defaultEnter
func(*args)
return func(*args)
def __callFromToFunc(self, oldState, newState, *args):
# Calls the appropriate fromTo function when transitioning into
@ -540,7 +602,7 @@ class FSM(DirectObject):
# If there's no matching exitFoo() function, call
# defaultExit() instead.
func = self.defaultExit
func()
return func()
def __repr__(self):
return self.__str__()