diff --git a/direct/src/showbase/PythonUtil.py b/direct/src/showbase/PythonUtil.py index 2370742957..fbef9ecc70 100644 --- a/direct/src/showbase/PythonUtil.py +++ b/direct/src/showbase/PythonUtil.py @@ -1130,6 +1130,10 @@ def weightedChoice(choiceList, rng=random.random, sum=None): """given a list of (weight, item) pairs, chooses an item based on the weights. rng must return 0..1. if you happen to have the sum of the weights, pass it in 'sum'.""" + # Throw an IndexError if we got an empty list. + if not choiceList: + raise IndexError('Cannot choose from an empty sequence') + # TODO: add support for dicts if sum is None: sum = 0. @@ -1138,6 +1142,7 @@ def weightedChoice(choiceList, rng=random.random, sum=None): rand = rng() accum = rand * sum + item = None for weight, item in choiceList: accum -= weight if accum <= 0.: diff --git a/tests/showbase/test_PythonUtil.py b/tests/showbase/test_PythonUtil.py index faf5da269f..c4c1b8076a 100644 --- a/tests/showbase/test_PythonUtil.py +++ b/tests/showbase/test_PythonUtil.py @@ -1,4 +1,5 @@ from direct.showbase import PythonUtil +import pytest def test_queue(): @@ -103,3 +104,64 @@ def test_priority_callbacks(): pc.clear() pc() assert len(l) == 0 + +def test_weighted_choice(): + # Test PythonUtil.weightedChoice() with no valid list. + with pytest.raises(IndexError): + PythonUtil.weightedChoice([]) + + # Create a sample choice list. + # This contains a few tuples containing only a weight + # and an arbitrary item. + choicelist = [(3, 'item1'), (1, 'item2'), (7, 'item3')] + + # These are the items that we expect. + items = ['item1', 'item2', 'item3'] + + # Test PythonUtil.weightedChoice() with our choice list. + item = PythonUtil.weightedChoice(choicelist) + + # Assert that what we got was at least an available item. + assert item in items + + # Create yet another sample choice list, but with a couple more items. + choicelist = [(2, 'item1'), (25, 'item2'), (14, 'item3'), (5, 'item4'), + (7, 'item5'), (3, 'item6'), (6, 'item7'), (50, 'item8')] + + # Set the items that we expect again. + items = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7', 'item8'] + + # The sum of all of the weights is 112. + weightsum = 2 + 25 + 14 + 5 + 7 + 3 + 6 + 50 + + # Test PythonUtil.weightedChoice() with the sum. + item = PythonUtil.weightedChoice(choicelist, sum=weightsum) + + # Assert that we got a valid item (most of the time this should be 'item8'). + assert item in items + + # Test PythonUtil.weightedChoice(), but with an invalid sum. + item = PythonUtil.weightedChoice(choicelist, sum=1) + + # Assert that we got 'item1'. + assert item == items[0] + + # Test PythonUtil.weightedChoice() with an invalid sum. + # This time, we're using 2000 so that regardless of the random + # number, we will still reach the very last item. + item = PythonUtil.weightedChoice(choicelist, sum=100000) + + # Assert that we got 'item8', since we would get the last item. + assert item == items[-1] + + # Create a bogus random function. + rnd = lambda: 0.5 + + # Test PythonUtil.weightedChoice() with the bogus function. + item = PythonUtil.weightedChoice(choicelist, rng=rnd, sum=weightsum) + + # Assert that we got 'item6'. + # We expect 'item6' because 0.5 multiplied by 112 is 56.0. + # When subtracting that number by each weight, it will reach 0 + # by the time it hits 'item6' in the iteration. + assert item == items[5]