Kaydet (Commit) 1be413e3 authored tarafından R David Murray's avatar R David Murray

Don't use metaclasses when class decorators can do the job.

Thanks to Nick Coghlan for pointing out that I'd forgotten about class
decorators.
üst 8e0ed333
...@@ -91,31 +91,25 @@ class _PolicyBase: ...@@ -91,31 +91,25 @@ class _PolicyBase:
return self.clone(**other.__dict__) return self.clone(**other.__dict__)
# Conceptually this isn't a subclass of ABCMeta, but since we want Policy to def _append_doc(doc, added_doc):
# use ABCMeta as a metaclass *and* we want it to use this one as well, we have doc = doc.rsplit('\n', 1)[0]
# to make this one a subclas of ABCMeta. added_doc = added_doc.split('\n', 1)[1]
class _DocstringExtenderMetaclass(abc.ABCMeta): return doc + '\n' + added_doc
def __new__(meta, classname, bases, classdict): def _extend_docstrings(cls):
if classdict.get('__doc__') and classdict['__doc__'].startswith('+'): if cls.__doc__ and cls.__doc__.startswith('+'):
classdict['__doc__'] = meta._append_doc(bases[0].__doc__, cls.__doc__ = _append_doc(cls.__bases__[0].__doc__, cls.__doc__)
classdict['__doc__']) for name, attr in cls.__dict__.items():
for name, attr in classdict.items(): if attr.__doc__ and attr.__doc__.startswith('+'):
if attr.__doc__ and attr.__doc__.startswith('+'): for c in (c for base in cls.__bases__ for c in base.mro()):
for cls in (cls for base in bases for cls in base.mro()): doc = getattr(getattr(c, name), '__doc__')
doc = getattr(getattr(cls, name), '__doc__') if doc:
if doc: attr.__doc__ = _append_doc(doc, attr.__doc__)
attr.__doc__ = meta._append_doc(doc, attr.__doc__) break
break return cls
return super().__new__(meta, classname, bases, classdict)
@staticmethod class Policy(_PolicyBase, metaclass=abc.ABCMeta):
def _append_doc(doc, added_doc):
added_doc = added_doc.split('\n', 1)[1]
return doc + '\n' + added_doc
class Policy(_PolicyBase, metaclass=_DocstringExtenderMetaclass):
r"""Controls for how messages are interpreted and formatted. r"""Controls for how messages are interpreted and formatted.
...@@ -264,6 +258,7 @@ class Policy(_PolicyBase, metaclass=_DocstringExtenderMetaclass): ...@@ -264,6 +258,7 @@ class Policy(_PolicyBase, metaclass=_DocstringExtenderMetaclass):
raise NotImplementedError raise NotImplementedError
@_extend_docstrings
class Compat32(Policy): class Compat32(Policy):
"""+ """+
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
code that adds all the email6 features. code that adds all the email6 features.
""" """
from email._policybase import Policy, Compat32, compat32 from email._policybase import Policy, Compat32, compat32, _extend_docstrings
from email.utils import _has_surrogates from email.utils import _has_surrogates
from email.headerregistry import HeaderRegistry as HeaderRegistry from email.headerregistry import HeaderRegistry as HeaderRegistry
...@@ -17,6 +17,7 @@ __all__ = [ ...@@ -17,6 +17,7 @@ __all__ = [
'HTTP', 'HTTP',
] ]
@_extend_docstrings
class EmailPolicy(Policy): class EmailPolicy(Policy):
"""+ """+
......
...@@ -73,10 +73,8 @@ class TestEmailBase(unittest.TestCase): ...@@ -73,10 +73,8 @@ class TestEmailBase(unittest.TestCase):
'item {}'.format(i)) 'item {}'.format(i))
# Metaclass to allow for parameterized tests def parameterize(cls):
class Parameterized(type): """A test method parameterization class decorator.
"""Provide a test method parameterization facility.
Parameters are specified as the value of a class attribute that ends with Parameters are specified as the value of a class attribute that ends with
the string '_params'. Call the portion before '_params' the prefix. Then the string '_params'. Call the portion before '_params' the prefix. Then
...@@ -92,9 +90,10 @@ class Parameterized(type): ...@@ -92,9 +90,10 @@ class Parameterized(type):
In a _params dictioanry, the keys become part of the name of the generated In a _params dictioanry, the keys become part of the name of the generated
tests. In a _params list, the values in the list are converted into a tests. In a _params list, the values in the list are converted into a
string by joining the string values of the elements of the tuple by '_' and string by joining the string values of the elements of the tuple by '_' and
converting any blanks into '_'s, and this become part of the name. The converting any blanks into '_'s, and this become part of the name.
full name of a generated test is the portion of the _params name before the The full name of a generated test is a 'test_' prefix, the portion of the
'_params' portion, plus an '_', plus the name derived as explained above. test function name after the '_as_' separator, plus an '_', plus the name
derived as explained above.
For example, if we have: For example, if we have:
...@@ -123,30 +122,29 @@ class Parameterized(type): ...@@ -123,30 +122,29 @@ class Parameterized(type):
be used to select the test individually from the unittest command line. be used to select the test individually from the unittest command line.
""" """
paramdicts = {}
def __new__(meta, classname, bases, classdict): for name, attr in cls.__dict__.items():
paramdicts = {} if name.endswith('_params'):
for name, attr in classdict.items(): if not hasattr(attr, 'keys'):
if name.endswith('_params'): d = {}
if not hasattr(attr, 'keys'): for x in attr:
d = {} if not hasattr(x, '__iter__'):
for x in attr: x = (x,)
if not hasattr(x, '__iter__'): n = '_'.join(str(v) for v in x).replace(' ', '_')
x = (x,) d[n] = x
n = '_'.join(str(v) for v in x).replace(' ', '_') attr = d
d[n] = x paramdicts[name[:-7] + '_as_'] = attr
attr = d testfuncs = {}
paramdicts[name[:-7] + '_as_'] = attr for name, attr in cls.__dict__.items():
testfuncs = {} for paramsname, paramsdict in paramdicts.items():
for name, attr in classdict.items(): if name.startswith(paramsname):
for paramsname, paramsdict in paramdicts.items(): testnameroot = 'test_' + name[len(paramsname):]
if name.startswith(paramsname): for paramname, params in paramsdict.items():
testnameroot = 'test_' + name[len(paramsname):] test = (lambda self, name=name, params=params:
for paramname, params in paramsdict.items(): getattr(self, name)(*params))
test = (lambda self, name=name, params=params: testname = testnameroot + '_' + paramname
getattr(self, name)(*params)) test.__name__ = testname
testname = testnameroot + '_' + paramname testfuncs[testname] = test
test.__name__ = testname for key, value in testfuncs.items():
testfuncs[testname] = test setattr(cls, key, value)
classdict.update(testfuncs) return cls
return super().__new__(meta, classname, bases, classdict)
...@@ -4,10 +4,11 @@ import unittest ...@@ -4,10 +4,11 @@ import unittest
from email import message_from_string, message_from_bytes from email import message_from_string, message_from_bytes
from email.generator import Generator, BytesGenerator from email.generator import Generator, BytesGenerator
from email import policy from email import policy
from test.test_email import TestEmailBase, Parameterized from test.test_email import TestEmailBase, parameterize
class TestGeneratorBase(metaclass=Parameterized): @parameterize
class TestGeneratorBase:
policy = policy.default policy = policy.default
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
from email import errors from email import errors
from email import policy from email import policy
from email.message import Message from email.message import Message
from test.test_email import TestEmailBase, Parameterized from test.test_email import TestEmailBase, parameterize
from email import headerregistry from email import headerregistry
from email.headerregistry import Address, Group from email.headerregistry import Address, Group
...@@ -175,7 +175,8 @@ class TestDateHeader(TestHeaderBase): ...@@ -175,7 +175,8 @@ class TestDateHeader(TestHeaderBase):
self.assertEqual(m['Date'].datetime, self.dt) self.assertEqual(m['Date'].datetime, self.dt)
class TestAddressHeader(TestHeaderBase, metaclass=Parameterized): @parameterize
class TestAddressHeader(TestHeaderBase):
example_params = { example_params = {
......
...@@ -6,9 +6,11 @@ import email ...@@ -6,9 +6,11 @@ import email
import email.message import email.message
from email import policy from email import policy
from email.headerregistry import HeaderRegistry from email.headerregistry import HeaderRegistry
from test.test_email import TestEmailBase, Parameterized from test.test_email import TestEmailBase, parameterize
class TestPickleCopyHeader(TestEmailBase, metaclass=Parameterized):
@parameterize
class TestPickleCopyHeader(TestEmailBase):
header_factory = HeaderRegistry() header_factory = HeaderRegistry()
...@@ -33,7 +35,8 @@ class TestPickleCopyHeader(TestEmailBase, metaclass=Parameterized): ...@@ -33,7 +35,8 @@ class TestPickleCopyHeader(TestEmailBase, metaclass=Parameterized):
self.assertEqual(str(h), str(header)) self.assertEqual(str(h), str(header))
class TestPickleCopyMessage(TestEmailBase, metaclass=Parameterized): @parameterize
class TestPickleCopyMessage(TestEmailBase):
# Message objects are a sequence, so we have to make them a one-tuple in # Message objects are a sequence, so we have to make them a one-tuple in
# msg_params so they get passed to the parameterized test method as a # msg_params so they get passed to the parameterized test method as a
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment