make fsm's thread safe

This commit is contained in:
David Rose 2008-10-29 22:02:18 +00:00
parent 89d53bbb75
commit 940a46fcfd

View File

@ -10,6 +10,7 @@ previously called FSM.py (now called ClassicFSM.py).
from direct.showbase.DirectObject import DirectObject
from direct.directnotify import DirectNotifyGlobal
from direct.showbase import PythonUtil
from direct.stdpy.threading import RLock
import types
import string
@ -142,6 +143,7 @@ class FSM(DirectObject):
defaultTransitions = None
def __init__(self, name):
self.lock = RLock()
self.name = name
self._serialNum = FSM.SerialNum
FSM.SerialNum += 1
@ -166,9 +168,13 @@ class FSM(DirectObject):
def cleanup(self):
# A convenience function to force the FSM to clean itself up
# by transitioning to the "Off" state.
assert self.state
if self.state != 'Off':
self.__setState('Off')
self.lock.acquire()
try:
assert self.state
if self.state != 'Off':
self.__setState('Off')
finally:
self.lock.release()
def setBroadcastStateChanges(self, doBroadcast):
self._broadcastStateChanges = doBroadcast
@ -183,29 +189,41 @@ class FSM(DirectObject):
# Returns the current state if we are in a state now, or the
# state we are transitioning into if we are currently within
# the enter or exit function for a state.
if self.state:
return self.state
return self.newState
self.lock.acquire()
try:
if self.state:
return self.state
return self.newState
finally:
self.lock.release()
def isInTransition(self):
return self.state == None
self.lock.acquire()
try:
return self.state == None
finally:
self.lock.release()
def forceTransition(self, request, *args):
"""Changes unconditionally to the indicated state. This
bypasses the filterState() function, and just calls
exitState() followed by enterState()."""
assert isinstance(request, types.StringTypes)
self.notify.debug("%s.forceTransition(%s, %s" % (
self.name, request, str(args)[1:]))
self.lock.acquire()
try:
assert isinstance(request, types.StringTypes)
self.notify.debug("%s.forceTransition(%s, %s" % (
self.name, request, str(args)[1:]))
if not self.state:
# Queue up the request.
self.__requestQueue.append(PythonUtil.Functor(
self.forceTransition, request, *args))
return
if not self.state:
# Queue up the request.
self.__requestQueue.append(PythonUtil.Functor(
self.forceTransition, request, *args))
return
self.__setState(request, *args)
self.__setState(request, *args)
finally:
self.lock.release()
def demand(self, request, *args):
"""Requests a state transition, by code that does not expect
@ -219,17 +237,21 @@ class FSM(DirectObject):
sequence.
"""
assert isinstance(request, types.StringTypes)
self.notify.debug("%s.demand(%s, %s" % (
self.name, request, str(args)[1:]))
if not self.state:
# Queue up the request.
self.__requestQueue.append(PythonUtil.Functor(
self.demand, request, *args))
return
self.lock.acquire()
try:
assert isinstance(request, types.StringTypes)
self.notify.debug("%s.demand(%s, %s" % (
self.name, request, str(args)[1:]))
if not self.state:
# Queue up the request.
self.__requestQueue.append(PythonUtil.Functor(
self.demand, request, *args))
return
if not self.request(request, *args):
raise RequestDenied, "%s (from state: %s)" % (request, self.state)
if not self.request(request, *args):
raise RequestDenied, "%s (from state: %s)" % (request, self.state)
finally:
self.lock.release()
def request(self, request, *args):
"""Requests a state transition (or other behavior). The
@ -254,30 +276,34 @@ class FSM(DirectObject):
which will queue these requests up and apply when the
transition is complete)."""
assert isinstance(request, types.StringTypes)
self.notify.debug("%s.request(%s, %s" % (
self.name, request, str(args)[1:]))
self.lock.acquire()
try:
assert isinstance(request, types.StringTypes)
self.notify.debug("%s.request(%s, %s" % (
self.name, request, str(args)[1:]))
if not self.state:
error = "requested %s while FSM is in transition from %s to %s." % (request, self.oldState, self.newState)
raise AlreadyInTransition, error
if not self.state:
error = "requested %s while FSM is in transition from %s to %s." % (request, self.oldState, self.newState)
raise AlreadyInTransition, error
func = getattr(self, "filter" + self.state, None)
if not func:
# If there's no matching filterState() function, call
# defaultFilter() instead.
func = self.defaultFilter
result = func(request, args)
if result:
if isinstance(result, types.StringTypes):
# If the return value is a string, it's just the name
# of the state. Wrap it in a tuple for consistency.
result = (result,) + args
func = getattr(self, "filter" + self.state, None)
if not func:
# If there's no matching filterState() function, call
# defaultFilter() instead.
func = self.defaultFilter
result = func(request, args)
if result:
if isinstance(result, types.StringTypes):
# If the return value is a string, it's just the name
# of the state. Wrap it in a tuple for consistency.
result = (result,) + args
# Otherwise, assume it's a (name, *args) tuple
self.__setState(*result)
# Otherwise, assume it's a (name, *args) tuple
self.__setState(*result)
return result
return result
finally:
self.lock.release()
def defaultEnter(self, *args):
""" This is the default function that is called if there is no
@ -353,25 +379,37 @@ class FSM(DirectObject):
def setStateArray(self, stateArray):
"""array of unique states to iterate through"""
self.stateArray = stateArray
self.lock.acquire()
try:
self.stateArray = stateArray
finally:
self.lock.release()
def requestNext(self, *args):
"""request the 'next' state in the predefined state array"""
assert self.state in self.stateArray
self.lock.acquire()
try:
assert self.state in self.stateArray
curIndex = self.stateArray.index(self.state)
newIndex = (curIndex + 1) % len(self.stateArray)
curIndex = self.stateArray.index(self.state)
newIndex = (curIndex + 1) % len(self.stateArray)
self.request(self.stateArray[newIndex], args)
self.request(self.stateArray[newIndex], args)
finally:
self.lock.release()
def requestPrev(self, *args):
"""request the 'previous' state in the predefined state array"""
assert self.state in self.stateArray
self.lock.acquire()
try:
assert self.state in self.stateArray
curIndex = self.stateArray.index(self.state)
newIndex = (curIndex - 1) % len(self.stateArray)
curIndex = self.stateArray.index(self.state)
newIndex = (curIndex - 1) % len(self.stateArray)
self.request(self.stateArray[newIndex], args)
self.request(self.stateArray[newIndex], args)
finally:
self.lock.release()
def __setState(self, newState, *args):
@ -441,9 +479,13 @@ class FSM(DirectObject):
"""
Print out something useful about the fsm
"""
className = self.__class__.__name__
if self.state:
str = ('%s FSM:%s in state "%s"' % (className, self.name, self.state))
else:
str = ('%s FSM:%s not in any state' % (className, self.name))
return str
self.lock.acquire()
try:
className = self.__class__.__name__
if self.state:
str = ('%s FSM:%s in state "%s"' % (className, self.name, self.state))
else:
str = ('%s FSM:%s not in any state' % (className, self.name))
return str
finally:
self.lock.release()