Kaydet (Commit) 5b913e31 authored tarafından Raymond Hettinger's avatar Raymond Hettinger

Forward port r68394 for issue 4816.

üst 1a67f589
...@@ -108,6 +108,8 @@ loops that truncate the stream. ...@@ -108,6 +108,8 @@ loops that truncate the stream.
# combinations(range(4), 3) --> 012 013 023 123 # combinations(range(4), 3) --> 012 013 023 123
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
if r > n:
return
indices = range(r) indices = range(r)
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
while 1: while 1:
...@@ -132,6 +134,9 @@ loops that truncate the stream. ...@@ -132,6 +134,9 @@ loops that truncate the stream.
if sorted(indices) == list(indices): if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
or zero when ``r > n``.
.. versionadded:: 2.6 .. versionadded:: 2.6
.. function:: count([n]) .. function:: count([n])
...@@ -399,6 +404,8 @@ loops that truncate the stream. ...@@ -399,6 +404,8 @@ loops that truncate the stream.
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
r = n if r is None else r r = n if r is None else r
if r > n:
return
indices = range(n) indices = range(n)
cycles = range(n, n-r, -1) cycles = range(n, n-r, -1)
yield tuple(pool[i] for i in indices[:r]) yield tuple(pool[i] for i in indices[:r])
...@@ -428,6 +435,9 @@ loops that truncate the stream. ...@@ -428,6 +435,9 @@ loops that truncate the stream.
if len(set(indices)) == r: if len(set(indices)) == r:
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n``
or zero when ``r > n``.
.. versionadded:: 2.6 .. versionadded:: 2.6
.. function:: product(*iterables[, repeat]) .. function:: product(*iterables[, repeat])
...@@ -674,7 +684,8 @@ which incur interpreter overhead. ...@@ -674,7 +684,8 @@ which incur interpreter overhead.
return (d for d, s in izip(data, selectors) if s) return (d for d, s in izip(data, selectors) if s)
def combinations_with_replacement(iterable, r): def combinations_with_replacement(iterable, r):
"combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
# number items returned: (n+r-1)! / r! / (n-1)!
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
indices = [0] * r indices = [0] * r
......
...@@ -71,11 +71,11 @@ class TestBasicOps(unittest.TestCase): ...@@ -71,11 +71,11 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
def test_combinations(self): def test_combinations(self):
self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc') # missing r argument
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(TypeError, combinations, None) # pool is not iterable
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big self.assertEqual(list(combinations('abc', 32)), []) # r > n
self.assertEqual(list(combinations(range(4), 3)), self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
...@@ -83,6 +83,8 @@ class TestBasicOps(unittest.TestCase): ...@@ -83,6 +83,8 @@ class TestBasicOps(unittest.TestCase):
'Pure python version shown in the docs' 'Pure python version shown in the docs'
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
if r > n:
return
indices = range(r) indices = range(r)
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
while 1: while 1:
...@@ -106,9 +108,9 @@ class TestBasicOps(unittest.TestCase): ...@@ -106,9 +108,9 @@ class TestBasicOps(unittest.TestCase):
for n in range(7): for n in range(7):
values = [5*x-12 for x in range(n)] values = [5*x-12 for x in range(n)]
for r in range(n+1): for r in range(n+2):
result = list(combinations(values, r)) result = list(combinations(values, r))
self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs
self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(len(result), len(set(result))) # no repeats
self.assertEqual(result, sorted(result)) # lexicographic order self.assertEqual(result, sorted(result)) # lexicographic order
for c in result: for c in result:
...@@ -119,7 +121,7 @@ class TestBasicOps(unittest.TestCase): ...@@ -119,7 +121,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(c), self.assertEqual(list(c),
[e for e in values if e in c]) # comb is a subsequence of the input iterable [e for e in values if e in c]) # comb is a subsequence of the input iterable
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
# Test implementation detail: tuple re-use # Test implementation detail: tuple re-use
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
...@@ -130,7 +132,7 @@ class TestBasicOps(unittest.TestCase): ...@@ -130,7 +132,7 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(TypeError, permutations, None) # pool is not iterable
self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big self.assertEqual(list(permutations('abc', 32)), []) # r > n
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)), self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
...@@ -140,6 +142,8 @@ class TestBasicOps(unittest.TestCase): ...@@ -140,6 +142,8 @@ class TestBasicOps(unittest.TestCase):
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
r = n if r is None else r r = n if r is None else r
if r > n:
return
indices = range(n) indices = range(n)
cycles = range(n, n-r, -1) cycles = range(n, n-r, -1)
yield tuple(pool[i] for i in indices[:r]) yield tuple(pool[i] for i in indices[:r])
...@@ -168,9 +172,9 @@ class TestBasicOps(unittest.TestCase): ...@@ -168,9 +172,9 @@ class TestBasicOps(unittest.TestCase):
for n in range(7): for n in range(7):
values = [5*x-12 for x in range(n)] values = [5*x-12 for x in range(n)]
for r in range(n+1): for r in range(n+2):
result = list(permutations(values, r)) result = list(permutations(values, r))
self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms
self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(len(result), len(set(result))) # no repeats
self.assertEqual(result, sorted(result)) # lexicographic order self.assertEqual(result, sorted(result)) # lexicographic order
for p in result: for p in result:
...@@ -178,7 +182,7 @@ class TestBasicOps(unittest.TestCase): ...@@ -178,7 +182,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(p)), r) # no duplicate elements self.assertEqual(len(set(p)), r) # no duplicate elements
self.assert_(all(e in values for e in p)) # elements taken from input iterable self.assert_(all(e in values for e in p)) # elements taken from input iterable
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version
if r == n: if r == n:
self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r self.assertEqual(result, list(permutations(values))) # test default r
...@@ -1363,6 +1367,26 @@ perform as purported. ...@@ -1363,6 +1367,26 @@ perform as purported.
>>> list(combinations_with_replacement('abc', 2)) >>> list(combinations_with_replacement('abc', 2))
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
>>> list(combinations_with_replacement('01', 3))
[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
>>> def combinations_with_replacement2(iterable, r):
... 'Alternate version that filters from product()'
... pool = tuple(iterable)
... n = len(pool)
... for indices in product(range(n), repeat=r):
... if sorted(indices) == list(indices):
... yield tuple(pool[i] for i in indices)
>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
True
>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
True
>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
True
>>> list(unique_everseen('AAAABBBCCDAABBB')) >>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D'] ['A', 'B', 'C', 'D']
......
...@@ -2059,10 +2059,6 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) ...@@ -2059,10 +2059,6 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyErr_SetString(PyExc_ValueError, "r must be non-negative"); PyErr_SetString(PyExc_ValueError, "r must be non-negative");
goto error; goto error;
} }
if (r > n) {
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
goto error;
}
indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
if (indices == NULL) { if (indices == NULL) {
...@@ -2082,7 +2078,7 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) ...@@ -2082,7 +2078,7 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
co->indices = indices; co->indices = indices;
co->result = NULL; co->result = NULL;
co->r = r; co->r = r;
co->stopped = 0; co->stopped = r > n ? 1 : 0;
return (PyObject *)co; return (PyObject *)co;
...@@ -2318,10 +2314,6 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) ...@@ -2318,10 +2314,6 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyErr_SetString(PyExc_ValueError, "r must be non-negative"); PyErr_SetString(PyExc_ValueError, "r must be non-negative");
goto error; goto error;
} }
if (r > n) {
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
goto error;
}
indices = PyMem_Malloc(n * sizeof(Py_ssize_t)); indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
cycles = PyMem_Malloc(r * sizeof(Py_ssize_t)); cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
...@@ -2345,7 +2337,7 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) ...@@ -2345,7 +2337,7 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
po->cycles = cycles; po->cycles = cycles;
po->result = NULL; po->result = NULL;
po->r = r; po->r = r;
po->stopped = 0; po->stopped = r > n ? 1 : 0;
return (PyObject *)po; return (PyObject *)po;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment