PythonUtil: weightedChoice should throw IndexError on empty list

Also includes a unit test.

Closes #682
This commit is contained in:
pythonengineer 2019-07-10 21:09:41 -04:00 committed by rdb
parent ce6d02b8d7
commit 46a3a72029
2 changed files with 67 additions and 0 deletions

View File

@ -1130,6 +1130,10 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
"""given a list of (weight, item) pairs, chooses an item based on the """given a list of (weight, item) pairs, chooses an item based on the
weights. rng must return 0..1. if you happen to have the sum of the weights. rng must return 0..1. if you happen to have the sum of the
weights, pass it in 'sum'.""" weights, pass it in 'sum'."""
# Throw an IndexError if we got an empty list.
if not choiceList:
raise IndexError('Cannot choose from an empty sequence')
# TODO: add support for dicts # TODO: add support for dicts
if sum is None: if sum is None:
sum = 0. sum = 0.
@ -1138,6 +1142,7 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
rand = rng() rand = rng()
accum = rand * sum accum = rand * sum
item = None
for weight, item in choiceList: for weight, item in choiceList:
accum -= weight accum -= weight
if accum <= 0.: if accum <= 0.:

View File

@ -1,4 +1,5 @@
from direct.showbase import PythonUtil from direct.showbase import PythonUtil
import pytest
def test_queue(): def test_queue():
@ -103,3 +104,64 @@ def test_priority_callbacks():
pc.clear() pc.clear()
pc() pc()
assert len(l) == 0 assert len(l) == 0
def test_weighted_choice():
# Test PythonUtil.weightedChoice() with no valid list.
with pytest.raises(IndexError):
PythonUtil.weightedChoice([])
# Create a sample choice list.
# This contains a few tuples containing only a weight
# and an arbitrary item.
choicelist = [(3, 'item1'), (1, 'item2'), (7, 'item3')]
# These are the items that we expect.
items = ['item1', 'item2', 'item3']
# Test PythonUtil.weightedChoice() with our choice list.
item = PythonUtil.weightedChoice(choicelist)
# Assert that what we got was at least an available item.
assert item in items
# Create yet another sample choice list, but with a couple more items.
choicelist = [(2, 'item1'), (25, 'item2'), (14, 'item3'), (5, 'item4'),
(7, 'item5'), (3, 'item6'), (6, 'item7'), (50, 'item8')]
# Set the items that we expect again.
items = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7', 'item8']
# The sum of all of the weights is 112.
weightsum = 2 + 25 + 14 + 5 + 7 + 3 + 6 + 50
# Test PythonUtil.weightedChoice() with the sum.
item = PythonUtil.weightedChoice(choicelist, sum=weightsum)
# Assert that we got a valid item (most of the time this should be 'item8').
assert item in items
# Test PythonUtil.weightedChoice(), but with an invalid sum.
item = PythonUtil.weightedChoice(choicelist, sum=1)
# Assert that we got 'item1'.
assert item == items[0]
# Test PythonUtil.weightedChoice() with an invalid sum.
# This time, we're using 2000 so that regardless of the random
# number, we will still reach the very last item.
item = PythonUtil.weightedChoice(choicelist, sum=100000)
# Assert that we got 'item8', since we would get the last item.
assert item == items[-1]
# Create a bogus random function.
rnd = lambda: 0.5
# Test PythonUtil.weightedChoice() with the bogus function.
item = PythonUtil.weightedChoice(choicelist, rng=rnd, sum=weightsum)
# Assert that we got 'item6'.
# We expect 'item6' because 0.5 multiplied by 112 is 56.0.
# When subtracting that number by each weight, it will reach 0
# by the time it hits 'item6' in the iteration.
assert item == items[5]