Kaydet (Commit) d8ff4658 authored tarafından Raymond Hettinger's avatar Raymond Hettinger

Simplify the signature for itertools.accumulate() to match numpy. Handle one…

Simplify the signature for itertools.accumulate() to match numpy.  Handle one item iterable the same way as min()/max().
üst a7a0e1a0
...@@ -90,13 +90,15 @@ loops that truncate the stream. ...@@ -90,13 +90,15 @@ loops that truncate the stream.
parameter (which defaults to :const:`0`). Elements may be any addable type parameter (which defaults to :const:`0`). Elements may be any addable type
including :class:`Decimal` or :class:`Fraction`. Equivalent to:: including :class:`Decimal` or :class:`Fraction`. Equivalent to::
def accumulate(iterable, start=0): def accumulate(iterable):
'Return running totals' 'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
total = start it = iter(iterable)
for element in iterable: total = next(it)
total += element yield total
yield total for element in it:
total += element
yield total
.. versionadded:: 3.2 .. versionadded:: 3.2
......
...@@ -59,18 +59,18 @@ class TestBasicOps(unittest.TestCase): ...@@ -59,18 +59,18 @@ class TestBasicOps(unittest.TestCase):
def test_accumulate(self): def test_accumulate(self):
self.assertEqual(list(accumulate(range(10))), # one positional arg self.assertEqual(list(accumulate(range(10))), # one positional arg
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
self.assertEqual(list(accumulate(range(10), 100)), # two positional args self.assertEqual(list(accumulate(iterable=range(10))), # kw arg
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145]) [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
for typ in int, complex, Decimal, Fraction: # multiple types for typ in int, complex, Decimal, Fraction: # multiple types
self.assertEqual(list(accumulate(range(10), typ(0))), self.assertEqual(
list(accumulate(map(typ, range(10)))),
list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]))) list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
self.assertEqual(list(accumulate([])), []) # empty iterable self.assertEqual(list(accumulate([])), []) # empty iterable
self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
self.assertRaises(TypeError, accumulate, range(10), 5) # too many args
self.assertRaises(TypeError, accumulate) # too few args self.assertRaises(TypeError, accumulate) # too few args
self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
def test_chain(self): def test_chain(self):
......
...@@ -2597,41 +2597,27 @@ static PyTypeObject accumulate_type; ...@@ -2597,41 +2597,27 @@ static PyTypeObject accumulate_type;
static PyObject * static PyObject *
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{ {
static char *kwargs[] = {"iterable", "start", NULL}; static char *kwargs[] = {"iterable", NULL};
PyObject *iterable; PyObject *iterable;
PyObject *it; PyObject *it;
PyObject *start = NULL;
accumulateobject *lz; accumulateobject *lz;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate", if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
kwargs, &iterable, &start)) return NULL;
return NULL;
/* Get iterator. */ /* Get iterator. */
it = PyObject_GetIter(iterable); it = PyObject_GetIter(iterable);
if (it == NULL) if (it == NULL)
return NULL; return NULL;
/* Default start value */
if (start == NULL) {
start = PyLong_FromLong(0);
if (start == NULL) {
Py_DECREF(it);
return NULL;
}
} else {
Py_INCREF(start);
}
/* create accumulateobject structure */ /* create accumulateobject structure */
lz = (accumulateobject *)type->tp_alloc(type, 0); lz = (accumulateobject *)type->tp_alloc(type, 0);
if (lz == NULL) { if (lz == NULL) {
Py_DECREF(it); Py_DECREF(it);
Py_DECREF(start); return NULL;
return NULL;
} }
lz->total = start; lz->total = NULL;
lz->it = it; lz->it = it;
return (PyObject *)lz; return (PyObject *)lz;
} }
...@@ -2661,11 +2647,17 @@ accumulate_next(accumulateobject *lz) ...@@ -2661,11 +2647,17 @@ accumulate_next(accumulateobject *lz)
val = PyIter_Next(lz->it); val = PyIter_Next(lz->it);
if (val == NULL) if (val == NULL)
return NULL; return NULL;
if (lz->total == NULL) {
Py_INCREF(val);
lz->total = val;
return lz->total;
}
newtotal = PyNumber_Add(lz->total, val); newtotal = PyNumber_Add(lz->total, val);
Py_DECREF(val); Py_DECREF(val);
if (newtotal == NULL) if (newtotal == NULL)
return NULL; return NULL;
oldtotal = lz->total; oldtotal = lz->total;
lz->total = newtotal; lz->total = newtotal;
...@@ -2676,7 +2668,7 @@ accumulate_next(accumulateobject *lz) ...@@ -2676,7 +2668,7 @@ accumulate_next(accumulateobject *lz)
} }
PyDoc_STRVAR(accumulate_doc, PyDoc_STRVAR(accumulate_doc,
"accumulate(iterable, start=0) --> accumulate object\n\ "accumulate(iterable) --> accumulate object\n\
\n\ \n\
Return series of accumulated sums."); Return series of accumulated sums.");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment