diff --git a/direct/src/stdpy/pickle.py b/direct/src/stdpy/pickle.py index fefdeae9ff..60d2b7d92f 100644 --- a/direct/src/stdpy/pickle.py +++ b/direct/src/stdpy/pickle.py @@ -33,19 +33,30 @@ else: # with the local pickle.py. pickle = __import__('pickle') -class Pickler(pickle.Pickler): +if sys.version_info >= (3, 0): + BasePickler = pickle._Pickler + BaseUnpickler = pickle._Unpickler +else: + BasePickler = pickle.Pickler + BaseUnpickler = pickle.Unpickler + + +class _Pickler(BasePickler): def __init__(self, *args, **kw): self.bamWriter = BamWriter() - pickle.Pickler.__init__(self, *args, **kw) + BasePickler.__init__(self, *args, **kw) # We have to duplicate most of the save() method, so we can add # support for __reduce_persist__(). - def save(self, obj): + def save(self, obj, save_persistent_id=True): + if self.proto >= 4: + self.framer.commit_frame() + # Check for persistent id (defined by a subclass) pid = self.persistent_id(obj) - if pid: + if pid is not None and save_persistent_id: self.save_pers(pid) return @@ -112,11 +123,12 @@ class Pickler(pickle.Pickler): # Save the reduce() output and finally memoize the object self.save_reduce(obj=obj, *rv) -class Unpickler(pickle.Unpickler): + +class Unpickler(BaseUnpickler): def __init__(self, *args, **kw): self.bamReader = BamReader() - pickle.Unpickler.__init__(self, *args, **kw) + BaseUnpickler.__init__(self, *args, **kw) # Duplicate the load_reduce() function, to provide a special case # for the reduction function. @@ -126,9 +138,10 @@ class Unpickler(pickle.Unpickler): args = stack.pop() func = stack[-1] - # If the function name ends with "Persist", then assume the + # If the function name ends with "_persist", then assume the # function wants the Unpickler as the first parameter. - if func.__name__.endswith('Persist'): + func_name = func.__name__ + if func_name.endswith('_persist') or func_name.endswith('Persist'): value = func(self, *args) else: # Otherwise, use the existing pickle convention. @@ -136,9 +149,32 @@ class Unpickler(pickle.Unpickler): stack[-1] = value - #FIXME: how to replace in Python 3? + if sys.version_info >= (3, 0): + BaseUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce + else: + BaseUnpickler.dispatch[pickle.REDUCE] = load_reduce + + +if sys.version_info >= (3, 8): + # In Python 3.8 and up, we can use the C implementation of Pickler, which + # supports a reducer_override method. + class Pickler(pickle.Pickler): + def __init__(self, *args, **kw): + self.bamWriter = BamWriter() + pickle.Pickler.__init__(self, *args, **kw) + + def reducer_override(self, obj): + reduce = getattr(obj, "__reduce_persist__", None) + if reduce: + return reduce(self) + + return NotImplemented +else: + # Otherwise, we have to use our custom version that overrides save(). + Pickler = _Pickler + if sys.version_info < (3, 0): - pickle.Unpickler.dispatch[pickle.REDUCE] = load_reduce + del _Pickler # Shorthands diff --git a/tests/stdpy/test_pickle.py b/tests/stdpy/test_pickle.py new file mode 100644 index 0000000000..303c5559c1 --- /dev/null +++ b/tests/stdpy/test_pickle.py @@ -0,0 +1,11 @@ +from direct.stdpy.pickle import dumps, loads + + +def test_reduce_persist(): + from panda3d.core import NodePath + + parent = NodePath("parent") + child = parent.attach_new_node("child") + + parent2, child2 = loads(dumps([parent, child])) + assert tuple(parent2.children) == (child2,)