Kaydet (Commit) aee52ccc authored tarafından Brett Cannon's avatar Brett Cannon

Merge

...@@ -124,6 +124,27 @@ Functions for sequences: ...@@ -124,6 +124,27 @@ Functions for sequences:
Return a random element from the non-empty sequence *seq*. If *seq* is empty, Return a random element from the non-empty sequence *seq*. If *seq* is empty,
raises :exc:`IndexError`. raises :exc:`IndexError`.
.. function:: weighted_choices(k, population, weights=None, *, cum_weights=None)
Return a *k* sized list of elements chosen from the *population* with replacement.
If the *population* is empty, raises :exc:`IndexError`.
If a *weights* sequence is specified, selections are made according to the
relative weights. Alternatively, if a *cum_weights* sequence is given, the
selections are made according to the cumulative weights. For example, the
relative weights ``[10, 5, 30, 5]`` are equivalent to the cumulative
weights ``[10, 15, 45, 50]``. Internally, the relative weights are
converted to cumulative weights before making selections, so supplying the
cumulative weights saves work.
If neither *weights* nor *cum_weights* are specified, selections are made
with equal probability. If a weights sequence is supplied, it must be
the same length as the *population* sequence. It is a :exc:`TypeError`
to specify both *weights* and *cum_weights*.
The *weights* or *cum_weights* can use any numeric type that interoperates
with the :class:`float` values returned by :func:`random` (that includes
integers, floats, and fractions but excludes decimals).
.. function:: shuffle(x[, random]) .. function:: shuffle(x[, random])
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
--------- ---------
pick random element pick random element
pick random sample pick random sample
pick weighted random sample
generate random permutation generate random permutation
distributions on the real line: distributions on the real line:
...@@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin ...@@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
from os import urandom as _urandom from os import urandom as _urandom
from _collections_abc import Set as _Set, Sequence as _Sequence from _collections_abc import Set as _Set, Sequence as _Sequence
from hashlib import sha512 as _sha512 from hashlib import sha512 as _sha512
import itertools as _itertools
import bisect as _bisect
__all__ = ["Random","seed","random","uniform","randint","choice","sample", __all__ = ["Random","seed","random","uniform","randint","choice","sample",
"randrange","shuffle","normalvariate","lognormvariate", "randrange","shuffle","normalvariate","lognormvariate",
"expovariate","vonmisesvariate","gammavariate","triangular", "expovariate","vonmisesvariate","gammavariate","triangular",
"gauss","betavariate","paretovariate","weibullvariate", "gauss","betavariate","paretovariate","weibullvariate",
"getstate","setstate", "getrandbits", "getstate","setstate", "getrandbits", "weighted_choices",
"SystemRandom"] "SystemRandom"]
NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0) NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
...@@ -334,6 +337,28 @@ class Random(_random.Random): ...@@ -334,6 +337,28 @@ class Random(_random.Random):
result[i] = population[j] result[i] = population[j]
return result return result
def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
"""Return a k sized list of population elements chosen with replacement.
If the relative weights or cumulative weights are not specified,
the selections are made with equal probability.
"""
if cum_weights is None:
if weights is None:
choice = self.choice
return [choice(population) for i in range(k)]
else:
cum_weights = list(_itertools.accumulate(weights))
elif weights is not None:
raise TypeError('Cannot specify both weights and cumulative_weights')
if len(cum_weights) != len(population):
raise ValueError('The number of weights does not match the population')
bisect = _bisect.bisect
random = self.random
total = cum_weights[-1]
return [population[bisect(cum_weights, random() * total)] for i in range(k)]
## -------------------- real-valued distributions ------------------- ## -------------------- real-valued distributions -------------------
## -------------------- uniform distribution ------------------- ## -------------------- uniform distribution -------------------
...@@ -724,6 +749,7 @@ choice = _inst.choice ...@@ -724,6 +749,7 @@ choice = _inst.choice
randrange = _inst.randrange randrange = _inst.randrange
sample = _inst.sample sample = _inst.sample
shuffle = _inst.shuffle shuffle = _inst.shuffle
weighted_choices = _inst.weighted_choices
normalvariate = _inst.normalvariate normalvariate = _inst.normalvariate
lognormvariate = _inst.lognormvariate lognormvariate = _inst.lognormvariate
expovariate = _inst.expovariate expovariate = _inst.expovariate
......
...@@ -7,6 +7,7 @@ import warnings ...@@ -7,6 +7,7 @@ import warnings
from functools import partial from functools import partial
from math import log, exp, pi, fsum, sin from math import log, exp, pi, fsum, sin
from test import support from test import support
from fractions import Fraction
class TestBasicOps: class TestBasicOps:
# Superclass with tests common to all generators. # Superclass with tests common to all generators.
...@@ -141,6 +142,73 @@ class TestBasicOps: ...@@ -141,6 +142,73 @@ class TestBasicOps:
def test_sample_on_dicts(self): def test_sample_on_dicts(self):
self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2) self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
def test_weighted_choices(self):
weighted_choices = self.gen.weighted_choices
data = ['red', 'green', 'blue', 'yellow']
str_data = 'abcd'
range_data = range(4)
set_data = set(range(4))
# basic functionality
for sample in [
weighted_choices(5, data),
weighted_choices(5, data, range(4)),
weighted_choices(k=5, population=data, weights=range(4)),
weighted_choices(k=5, population=data, cum_weights=range(4)),
]:
self.assertEqual(len(sample), 5)
self.assertEqual(type(sample), list)
self.assertTrue(set(sample) <= set(data))
# test argument handling
with self.assertRaises(TypeError): # missing arguments
weighted_choices(2)
self.assertEqual(weighted_choices(0, data), []) # k == 0
self.assertEqual(weighted_choices(-1, data), []) # negative k behaves like ``[0] * -1``
with self.assertRaises(TypeError):
weighted_choices(2.5, data) # k is a float
self.assertTrue(set(weighted_choices(5, str_data)) <= set(str_data)) # population is a string sequence
self.assertTrue(set(weighted_choices(5, range_data)) <= set(range_data)) # population is a range
with self.assertRaises(TypeError):
weighted_choices(2.5, set_data) # population is not a sequence
self.assertTrue(set(weighted_choices(5, data, None)) <= set(data)) # weights is None
self.assertTrue(set(weighted_choices(5, data, weights=None)) <= set(data))
with self.assertRaises(ValueError):
weighted_choices(5, data, [1,2]) # len(weights) != len(population)
with self.assertRaises(IndexError):
weighted_choices(5, data, [0]*4) # weights sum to zero
with self.assertRaises(TypeError):
weighted_choices(5, data, 10) # non-iterable weights
with self.assertRaises(TypeError):
weighted_choices(5, data, [None]*4) # non-numeric weights
for weights in [
[15, 10, 25, 30], # integer weights
[15.1, 10.2, 25.2, 30.3], # float weights
[Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional weights
[True, False, True, False] # booleans (include / exclude)
]:
self.assertTrue(set(weighted_choices(5, data, weights)) <= set(data))
with self.assertRaises(ValueError):
weighted_choices(5, data, cum_weights=[1,2]) # len(weights) != len(population)
with self.assertRaises(IndexError):
weighted_choices(5, data, cum_weights=[0]*4) # cum_weights sum to zero
with self.assertRaises(TypeError):
weighted_choices(5, data, cum_weights=10) # non-iterable cum_weights
with self.assertRaises(TypeError):
weighted_choices(5, data, cum_weights=[None]*4) # non-numeric cum_weights
with self.assertRaises(TypeError):
weighted_choices(5, data, range(4), cum_weights=range(4)) # both weights and cum_weights
for weights in [
[15, 10, 25, 30], # integer cum_weights
[15.1, 10.2, 25.2, 30.3], # float cum_weights
[Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional cum_weights
]:
self.assertTrue(set(weighted_choices(5, data, cum_weights=weights)) <= set(data))
def test_gauss(self): def test_gauss(self):
# Ensure that the seed() method initializes all the hidden state. In # Ensure that the seed() method initializes all the hidden state. In
# particular, through 2.2.1 it failed to reset a piece of state used # particular, through 2.2.1 it failed to reset a piece of state used
......
...@@ -101,6 +101,8 @@ Library ...@@ -101,6 +101,8 @@ Library
- Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name - Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name
fields in X.509 certs. fields in X.509 certs.
- Issue #18844: Add random.weighted_choices().
- Issue #25761: Improved error reporting about truncated pickle data in - Issue #25761: Improved error reporting about truncated pickle data in
C implementation of unpickler. UnpicklingError is now raised instead of C implementation of unpickler. UnpicklingError is now raised instead of
AttributeError and ValueError in some cases. AttributeError and ValueError in some cases.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment