PythonUtil: weightedChoice should throw IndexError on empty list

Also includes a unit test.

Closes #682
This commit is contained in:
pythonengineer 2019-07-10 21:09:41 -04:00 committed by rdb
parent ce6d02b8d7
commit 46a3a72029
2 changed files with 67 additions and 0 deletions

View File

@ -1130,6 +1130,10 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
"""given a list of (weight, item) pairs, chooses an item based on the
weights. rng must return 0..1. if you happen to have the sum of the
weights, pass it in 'sum'."""
# Throw an IndexError if we got an empty list.
if not choiceList:
raise IndexError('Cannot choose from an empty sequence')
# TODO: add support for dicts
if sum is None:
sum = 0.
@ -1138,6 +1142,7 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
rand = rng()
accum = rand * sum
item = None
for weight, item in choiceList:
accum -= weight
if accum <= 0.:

View File

@ -1,4 +1,5 @@
from direct.showbase import PythonUtil
import pytest
def test_queue():
@ -103,3 +104,64 @@ def test_priority_callbacks():
pc.clear()
pc()
assert len(l) == 0
def test_weighted_choice():
# Test PythonUtil.weightedChoice() with no valid list.
with pytest.raises(IndexError):
PythonUtil.weightedChoice([])
# Create a sample choice list.
# This contains a few tuples containing only a weight
# and an arbitrary item.
choicelist = [(3, 'item1'), (1, 'item2'), (7, 'item3')]
# These are the items that we expect.
items = ['item1', 'item2', 'item3']
# Test PythonUtil.weightedChoice() with our choice list.
item = PythonUtil.weightedChoice(choicelist)
# Assert that what we got was at least an available item.
assert item in items
# Create yet another sample choice list, but with a couple more items.
choicelist = [(2, 'item1'), (25, 'item2'), (14, 'item3'), (5, 'item4'),
(7, 'item5'), (3, 'item6'), (6, 'item7'), (50, 'item8')]
# Set the items that we expect again.
items = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7', 'item8']
# The sum of all of the weights is 112.
weightsum = 2 + 25 + 14 + 5 + 7 + 3 + 6 + 50
# Test PythonUtil.weightedChoice() with the sum.
item = PythonUtil.weightedChoice(choicelist, sum=weightsum)
# Assert that we got a valid item (most of the time this should be 'item8').
assert item in items
# Test PythonUtil.weightedChoice(), but with an invalid sum.
item = PythonUtil.weightedChoice(choicelist, sum=1)
# Assert that we got 'item1'.
assert item == items[0]
# Test PythonUtil.weightedChoice() with an invalid sum.
# This time, we're using 2000 so that regardless of the random
# number, we will still reach the very last item.
item = PythonUtil.weightedChoice(choicelist, sum=100000)
# Assert that we got 'item8', since we would get the last item.
assert item == items[-1]
# Create a bogus random function.
rnd = lambda: 0.5
# Test PythonUtil.weightedChoice() with the bogus function.
item = PythonUtil.weightedChoice(choicelist, rng=rnd, sum=weightsum)
# Assert that we got 'item6'.
# We expect 'item6' because 0.5 multiplied by 112 is 56.0.
# When subtracting that number by each weight, it will reach 0
# by the time it hits 'item6' in the iteration.
assert item == items[5]