Kaydet (Commit) cb1d96f7 authored tarafından Raymond Hettinger's avatar Raymond Hettinger

Issue #18594: Make the C code more closely match the pure python code.

üst 5b22dd87
...@@ -818,6 +818,24 @@ class TestCollectionABCs(ABCTestCase): ...@@ -818,6 +818,24 @@ class TestCollectionABCs(ABCTestCase):
### Counter ### Counter
################################################################################ ################################################################################
class CounterSubclassWithSetItem(Counter):
# Test a counter subclass that overrides __setitem__
def __init__(self, *args, **kwds):
self.called = False
Counter.__init__(self, *args, **kwds)
def __setitem__(self, key, value):
self.called = True
Counter.__setitem__(self, key, value)
class CounterSubclassWithGet(Counter):
# Test a counter subclass that overrides get()
def __init__(self, *args, **kwds):
self.called = False
Counter.__init__(self, *args, **kwds)
def get(self, key, default):
self.called = True
return Counter.get(self, key, default)
class TestCounter(unittest.TestCase): class TestCounter(unittest.TestCase):
def test_basics(self): def test_basics(self):
...@@ -1022,6 +1040,12 @@ class TestCounter(unittest.TestCase): ...@@ -1022,6 +1040,12 @@ class TestCounter(unittest.TestCase):
self.assertEqual(m, self.assertEqual(m,
OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)])) OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
# test fidelity to the pure python version
c = CounterSubclassWithSetItem('abracadabra')
self.assertTrue(c.called)
c = CounterSubclassWithGet('abracadabra')
self.assertTrue(c.called)
################################################################################ ################################################################################
### OrderedDict ### OrderedDict
......
...@@ -1689,17 +1689,17 @@ Count elements in the iterable, updating the mappping"); ...@@ -1689,17 +1689,17 @@ Count elements in the iterable, updating the mappping");
static PyObject * static PyObject *
_count_elements(PyObject *self, PyObject *args) _count_elements(PyObject *self, PyObject *args)
{ {
_Py_IDENTIFIER(__getitem__); _Py_IDENTIFIER(get);
_Py_IDENTIFIER(__setitem__); _Py_IDENTIFIER(__setitem__);
PyObject *it, *iterable, *mapping, *oldval; PyObject *it, *iterable, *mapping, *oldval;
PyObject *newval = NULL; PyObject *newval = NULL;
PyObject *key = NULL; PyObject *key = NULL;
PyObject *zero = NULL; PyObject *zero = NULL;
PyObject *one = NULL; PyObject *one = NULL;
PyObject *mapping_get = NULL; PyObject *bound_get = NULL;
PyObject *mapping_getitem; PyObject *mapping_get;
PyObject *dict_get;
PyObject *mapping_setitem; PyObject *mapping_setitem;
PyObject *dict_getitem;
PyObject *dict_setitem; PyObject *dict_setitem;
if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
...@@ -1713,15 +1713,16 @@ _count_elements(PyObject *self, PyObject *args) ...@@ -1713,15 +1713,16 @@ _count_elements(PyObject *self, PyObject *args)
if (one == NULL) if (one == NULL)
goto done; goto done;
mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__); /* Only take the fast path when get() and __setitem__()
dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__); * have not been overridden.
*/
mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get);
dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get);
mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__); mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__);
dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__); dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__);
if (mapping_getitem != NULL && if (mapping_get != NULL && mapping_get == dict_get &&
mapping_getitem == dict_getitem && mapping_setitem != NULL && mapping_setitem == dict_setitem) {
mapping_setitem != NULL &&
mapping_setitem == dict_setitem) {
while (1) { while (1) {
key = PyIter_Next(it); key = PyIter_Next(it);
if (key == NULL) if (key == NULL)
...@@ -1741,8 +1742,8 @@ _count_elements(PyObject *self, PyObject *args) ...@@ -1741,8 +1742,8 @@ _count_elements(PyObject *self, PyObject *args)
Py_DECREF(key); Py_DECREF(key);
} }
} else { } else {
mapping_get = PyObject_GetAttrString(mapping, "get"); bound_get = PyObject_GetAttrString(mapping, "get");
if (mapping_get == NULL) if (bound_get == NULL)
goto done; goto done;
zero = PyLong_FromLong(0); zero = PyLong_FromLong(0);
...@@ -1753,7 +1754,7 @@ _count_elements(PyObject *self, PyObject *args) ...@@ -1753,7 +1754,7 @@ _count_elements(PyObject *self, PyObject *args)
key = PyIter_Next(it); key = PyIter_Next(it);
if (key == NULL) if (key == NULL)
break; break;
oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL); oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL);
if (oldval == NULL) if (oldval == NULL)
break; break;
newval = PyNumber_Add(oldval, one); newval = PyNumber_Add(oldval, one);
...@@ -1771,7 +1772,7 @@ done: ...@@ -1771,7 +1772,7 @@ done:
Py_DECREF(it); Py_DECREF(it);
Py_XDECREF(key); Py_XDECREF(key);
Py_XDECREF(newval); Py_XDECREF(newval);
Py_XDECREF(mapping_get); Py_XDECREF(bound_get);
Py_XDECREF(zero); Py_XDECREF(zero);
Py_XDECREF(one); Py_XDECREF(one);
if (PyErr_Occurred()) if (PyErr_Occurred())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment