Kaydet (Commit) 4cda01e2 authored tarafından Raymond Hettinger's avatar Raymond Hettinger

* Increase test coverage.

* Have groupby() be careful about decreffing structure members.
üst aec3c9b5
...@@ -53,6 +53,10 @@ class TestBasicOps(unittest.TestCase): ...@@ -53,6 +53,10 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, count, 'a') self.assertRaises(TypeError, count, 'a')
c = count(sys.maxint-2) # verify that rollover doesn't crash c = count(sys.maxint-2) # verify that rollover doesn't crash
c.next(); c.next(); c.next(); c.next(); c.next() c.next(); c.next(); c.next(); c.next(); c.next()
c = count(3)
self.assertEqual(repr(c), 'count(3)')
c.next()
self.assertEqual(repr(c), 'count(4)')
def test_cycle(self): def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
...@@ -67,6 +71,7 @@ class TestBasicOps(unittest.TestCase): ...@@ -67,6 +71,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual([], list(groupby([], key=id))) self.assertEqual([], list(groupby([], key=id)))
self.assertRaises(TypeError, list, groupby('abc', [])) self.assertRaises(TypeError, list, groupby('abc', []))
self.assertRaises(TypeError, groupby, None) self.assertRaises(TypeError, groupby, None)
self.assertRaises(TypeError, groupby, 'abc', lambda x:x, 10)
# Check normal input # Check normal input
s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22), s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
...@@ -199,6 +204,12 @@ class TestBasicOps(unittest.TestCase): ...@@ -199,6 +204,12 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, repeat) self.assertRaises(TypeError, repeat)
self.assertRaises(TypeError, repeat, None, 3, 4) self.assertRaises(TypeError, repeat, None, 3, 4)
self.assertRaises(TypeError, repeat, None, 'a') self.assertRaises(TypeError, repeat, None, 'a')
r = repeat(1+0j)
self.assertEqual(repr(r), 'repeat((1+0j))')
r = repeat(1+0j, 5)
self.assertEqual(repr(r), 'repeat((1+0j), 5)')
list(r)
self.assertEqual(repr(r), 'repeat((1+0j), 0)')
def test_imap(self): def test_imap(self):
self.assertEqual(list(imap(operator.pow, range(3), range(1,7))), self.assertEqual(list(imap(operator.pow, range(3), range(1,7))),
...@@ -275,6 +286,9 @@ class TestBasicOps(unittest.TestCase): ...@@ -275,6 +286,9 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra') self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra')
self.assertRaises(TypeError, takewhile(10, [(4,5)]).next) self.assertRaises(TypeError, takewhile(10, [(4,5)]).next)
self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next) self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next)
t = takewhile(bool, [1, 1, 1, 0, 0, 0])
self.assertEqual(list(t), [1, 1, 1])
self.assertRaises(StopIteration, t.next)
def test_dropwhile(self): def test_dropwhile(self):
data = [1, 3, 5, 20, 2, 4, 6, 8] data = [1, 3, 5, 20, 2, 4, 6, 8]
...@@ -347,11 +361,26 @@ class TestBasicOps(unittest.TestCase): ...@@ -347,11 +361,26 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(a), range(100,2000)) self.assertEqual(list(a), range(100,2000))
self.assertEqual(list(c), range(2,2000)) self.assertEqual(list(c), range(2,2000))
# test values of n
self.assertRaises(TypeError, tee, 'abc', 'invalid')
for n in xrange(5):
result = tee('abc', n)
self.assertEqual(type(result), tuple)
self.assertEqual(len(result), n)
self.assertEqual(map(list, result), [list('abc')]*n)
# tee pass-through to copyable iterator # tee pass-through to copyable iterator
a, b = tee('abc') a, b = tee('abc')
c, d = tee(a) c, d = tee(a)
self.assert_(a is c) self.assert_(a is c)
# test tee_new
t1, t2 = tee('abc')
tnew = type(t1)
self.assertRaises(TypeError, tnew)
self.assertRaises(TypeError, tnew, 10)
t3 = tnew(t1)
self.assert_(list(t1) == list(t2) == list(t3) == list('abc'))
def test_StopIteration(self): def test_StopIteration(self):
self.assertRaises(StopIteration, izip().next) self.assertRaises(StopIteration, izip().next)
......
...@@ -75,7 +75,7 @@ groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg) ...@@ -75,7 +75,7 @@ groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg)
static PyObject * static PyObject *
groupby_next(groupbyobject *gbo) groupby_next(groupbyobject *gbo)
{ {
PyObject *newvalue, *newkey, *r, *grouper; PyObject *newvalue, *newkey, *r, *grouper, *tmp;
/* skip to next iteration group */ /* skip to next iteration group */
for (;;) { for (;;) {
...@@ -110,15 +110,19 @@ groupby_next(groupbyobject *gbo) ...@@ -110,15 +110,19 @@ groupby_next(groupbyobject *gbo)
} }
} }
Py_XDECREF(gbo->currkey); tmp = gbo->currkey;
gbo->currkey = newkey; gbo->currkey = newkey;
Py_XDECREF(gbo->currvalue); Py_XDECREF(tmp);
tmp = gbo->currvalue;
gbo->currvalue = newvalue; gbo->currvalue = newvalue;
Py_XDECREF(tmp);
} }
Py_XDECREF(gbo->tgtkey);
gbo->tgtkey = gbo->currkey;
Py_INCREF(gbo->currkey); Py_INCREF(gbo->currkey);
tmp = gbo->tgtkey;
gbo->tgtkey = gbo->currkey;
Py_XDECREF(tmp);
grouper = _grouper_create(gbo, gbo->tgtkey); grouper = _grouper_create(gbo, gbo->tgtkey);
if (grouper == NULL) if (grouper == NULL)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment