Kaydet (Commit) 35ac5f82 authored tarafından Serhiy Storchaka's avatar Serhiy Storchaka

Issue #22955: attrgetter, itemgetter and methodcaller objects in the operator

module now support pickling.  Added readable and evaluable repr for these
objects.  Based on patch by Josh Rosenberg.
üst 5418d0bf
...@@ -231,10 +231,13 @@ class attrgetter: ...@@ -231,10 +231,13 @@ class attrgetter:
After h = attrgetter('name.first', 'name.last'), the call h(r) returns After h = attrgetter('name.first', 'name.last'), the call h(r) returns
(r.name.first, r.name.last). (r.name.first, r.name.last).
""" """
__slots__ = ('_attrs', '_call')
def __init__(self, attr, *attrs): def __init__(self, attr, *attrs):
if not attrs: if not attrs:
if not isinstance(attr, str): if not isinstance(attr, str):
raise TypeError('attribute name must be a string') raise TypeError('attribute name must be a string')
self._attrs = (attr,)
names = attr.split('.') names = attr.split('.')
def func(obj): def func(obj):
for name in names: for name in names:
...@@ -242,7 +245,8 @@ class attrgetter: ...@@ -242,7 +245,8 @@ class attrgetter:
return obj return obj
self._call = func self._call = func
else: else:
getters = tuple(map(attrgetter, (attr,) + attrs)) self._attrs = (attr,) + attrs
getters = tuple(map(attrgetter, self._attrs))
def func(obj): def func(obj):
return tuple(getter(obj) for getter in getters) return tuple(getter(obj) for getter in getters)
self._call = func self._call = func
...@@ -250,19 +254,30 @@ class attrgetter: ...@@ -250,19 +254,30 @@ class attrgetter:
def __call__(self, obj): def __call__(self, obj):
return self._call(obj) return self._call(obj)
def __repr__(self):
return '%s.%s(%s)' % (self.__class__.__module__,
self.__class__.__qualname__,
', '.join(map(repr, self._attrs)))
def __reduce__(self):
return self.__class__, self._attrs
class itemgetter: class itemgetter:
""" """
Return a callable object that fetches the given item(s) from its operand. Return a callable object that fetches the given item(s) from its operand.
After f = itemgetter(2), the call f(r) returns r[2]. After f = itemgetter(2), the call f(r) returns r[2].
After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3]) After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3])
""" """
__slots__ = ('_items', '_call')
def __init__(self, item, *items): def __init__(self, item, *items):
if not items: if not items:
self._items = (item,)
def func(obj): def func(obj):
return obj[item] return obj[item]
self._call = func self._call = func
else: else:
items = (item,) + items self._items = items = (item,) + items
def func(obj): def func(obj):
return tuple(obj[i] for i in items) return tuple(obj[i] for i in items)
self._call = func self._call = func
...@@ -270,6 +285,14 @@ class itemgetter: ...@@ -270,6 +285,14 @@ class itemgetter:
def __call__(self, obj): def __call__(self, obj):
return self._call(obj) return self._call(obj)
def __repr__(self):
return '%s.%s(%s)' % (self.__class__.__module__,
self.__class__.__name__,
', '.join(map(repr, self._items)))
def __reduce__(self):
return self.__class__, self._items
class methodcaller: class methodcaller:
""" """
Return a callable object that calls the given method on its operand. Return a callable object that calls the given method on its operand.
...@@ -277,6 +300,7 @@ class methodcaller: ...@@ -277,6 +300,7 @@ class methodcaller:
After g = methodcaller('name', 'date', foo=1), the call g(r) returns After g = methodcaller('name', 'date', foo=1), the call g(r) returns
r.name('date', foo=1). r.name('date', foo=1).
""" """
__slots__ = ('_name', '_args', '_kwargs')
def __init__(*args, **kwargs): def __init__(*args, **kwargs):
if len(args) < 2: if len(args) < 2:
...@@ -284,12 +308,30 @@ class methodcaller: ...@@ -284,12 +308,30 @@ class methodcaller:
raise TypeError(msg) raise TypeError(msg)
self = args[0] self = args[0]
self._name = args[1] self._name = args[1]
if not isinstance(self._name, str):
raise TypeError('method name must be a string')
self._args = args[2:] self._args = args[2:]
self._kwargs = kwargs self._kwargs = kwargs
def __call__(self, obj): def __call__(self, obj):
return getattr(obj, self._name)(*self._args, **self._kwargs) return getattr(obj, self._name)(*self._args, **self._kwargs)
def __repr__(self):
args = [repr(self._name)]
args.extend(map(repr, self._args))
args.extend('%s=%r' % (k, v) for k, v in self._kwargs.items())
return '%s.%s(%s)' % (self.__class__.__module__,
self.__class__.__name__,
', '.join(args))
def __reduce__(self):
if not self._kwargs:
return self.__class__, (self._name,) + self._args
else:
from functools import partial
return partial(self.__class__, self._name, **self._kwargs), self._args
# In-place Operations *********************************************************# # In-place Operations *********************************************************#
def iadd(a, b): def iadd(a, b):
......
import unittest import unittest
import pickle
import sys
from test import support from test import support
...@@ -35,6 +37,9 @@ class Seq2(object): ...@@ -35,6 +37,9 @@ class Seq2(object):
class OperatorTestCase: class OperatorTestCase:
def setUp(self):
sys.modules['operator'] = self.module
def test_lt(self): def test_lt(self):
operator = self.module operator = self.module
self.assertRaises(TypeError, operator.lt) self.assertRaises(TypeError, operator.lt)
...@@ -396,6 +401,7 @@ class OperatorTestCase: ...@@ -396,6 +401,7 @@ class OperatorTestCase:
def test_methodcaller(self): def test_methodcaller(self):
operator = self.module operator = self.module
self.assertRaises(TypeError, operator.methodcaller) self.assertRaises(TypeError, operator.methodcaller)
self.assertRaises(TypeError, operator.methodcaller, 12)
class A: class A:
def foo(self, *args, **kwds): def foo(self, *args, **kwds):
return args[0] + args[1] return args[0] + args[1]
...@@ -491,5 +497,108 @@ class PyOperatorTestCase(OperatorTestCase, unittest.TestCase): ...@@ -491,5 +497,108 @@ class PyOperatorTestCase(OperatorTestCase, unittest.TestCase):
class COperatorTestCase(OperatorTestCase, unittest.TestCase): class COperatorTestCase(OperatorTestCase, unittest.TestCase):
module = c_operator module = c_operator
class OperatorPickleTestCase:
def copy(self, obj, proto):
with support.swap_item(sys.modules, 'operator', self.module):
pickled = pickle.dumps(obj, proto)
with support.swap_item(sys.modules, 'operator', self.module2):
return pickle.loads(pickled)
def test_attrgetter(self):
attrgetter = self.module.attrgetter
attrgetter = self.module.attrgetter
class A:
pass
a = A()
a.x = 'X'
a.y = 'Y'
a.z = 'Z'
a.t = A()
a.t.u = A()
a.t.u.v = 'V'
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
f = attrgetter('x')
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
# multiple gets
f = attrgetter('x', 'y', 'z')
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
# recursive gets
f = attrgetter('t.u.v')
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
def test_itemgetter(self):
itemgetter = self.module.itemgetter
a = 'ABCDE'
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
f = itemgetter(2)
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
# multiple gets
f = itemgetter(2, 0, 4)
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
def test_methodcaller(self):
methodcaller = self.module.methodcaller
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]
def bar(self, f=42):
return f
def baz(*args, **kwds):
return kwds['name'], kwds['self']
a = A()
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
f = methodcaller('bar')
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
# positional args
f = methodcaller('foo', 1, 2)
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
# keyword args
f = methodcaller('bar', f=5)
f2 = self.copy(f, proto)
self.assertEqual(repr(f2), repr(f))
self.assertEqual(f2(a), f(a))
f = methodcaller('baz', self='eggs', name='spam')
f2 = self.copy(f, proto)
# Can't test repr consistently with multiple keyword args
self.assertEqual(f2(a), f(a))
class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
module = py_operator
module2 = py_operator
@unittest.skipUnless(c_operator, 'requires _operator')
class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
module = py_operator
module2 = c_operator
@unittest.skipUnless(c_operator, 'requires _operator')
class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
module = c_operator
module2 = py_operator
@unittest.skipUnless(c_operator, 'requires _operator')
class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
module = c_operator
module2 = c_operator
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -52,6 +52,10 @@ Core and Builtins ...@@ -52,6 +52,10 @@ Core and Builtins
Library Library
------- -------
- Issue #22955: attrgetter, itemgetter and methodcaller objects in the operator
module now support pickling. Added readable and evaluable repr for these
objects. Based on patch by Josh Rosenberg.
- Issue #22107: tempfile.gettempdir() and tempfile.mkdtemp() now try again - Issue #22107: tempfile.gettempdir() and tempfile.mkdtemp() now try again
when a directory with the chosen name already exists on Windows as well as when a directory with the chosen name already exists on Windows as well as
on Unix. tempfile.mkstemp() now fails early if parent directory is not on Unix. tempfile.mkstemp() now fails early if parent directory is not
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment