Unverified Kaydet (Commit) ea8fc52e authored tarafından Eric V. Smith's avatar Eric V. Smith Kaydeden (comit) GitHub

bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)

Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
üst 2a2247ce
...@@ -18,6 +18,142 @@ __all__ = ['dataclass', ...@@ -18,6 +18,142 @@ __all__ = ['dataclass',
'is_dataclass', 'is_dataclass',
] ]
# Conditions for adding methods. The boxes indicate what action the
# dataclass decorator takes. For all of these tables, when I talk
# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring
# to the arguments to the @dataclass decorator. When checking if a
# dunder method already exists, I mean check for an entry in the
# class's __dict__. I never check to see if an attribute is defined
# in a base class.
# Key:
# +=========+=========================================+
# + Value | Meaning |
# +=========+=========================================+
# | <blank> | No action: no method is added. |
# +---------+-----------------------------------------+
# | add | Generated method is added. |
# +---------+-----------------------------------------+
# | add* | Generated method is added only if the |
# | | existing attribute is None and if the |
# | | user supplied a __eq__ method in the |
# | | class definition. |
# +---------+-----------------------------------------+
# | raise | TypeError is raised. |
# +---------+-----------------------------------------+
# | None | Attribute is set to None. |
# +=========+=========================================+
# __init__
#
# +--- init= parameter
# |
# v | | |
# | no | yes | <--- class has __init__ in __dict__?
# +=======+=======+=======+
# | False | | |
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
# __repr__
#
# +--- repr= parameter
# |
# v | | |
# | no | yes | <--- class has __repr__ in __dict__?
# +=======+=======+=======+
# | False | | |
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
# __setattr__
# __delattr__
#
# +--- frozen= parameter
# |
# v | | |
# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
# +=======+=======+=======+
# | False | | | <- the default
# +-------+-------+-------+
# | True | add | raise |
# +=======+=======+=======+
# Raise because not adding these methods would break the "frozen-ness"
# of the class.
# __eq__
#
# +--- eq= parameter
# |
# v | | |
# | no | yes | <--- class has __eq__ in __dict__?
# +=======+=======+=======+
# | False | | |
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
# __lt__
# __le__
# __gt__
# __ge__
#
# +--- order= parameter
# |
# v | | |
# | no | yes | <--- class has any comparison method in __dict__?
# +=======+=======+=======+
# | False | | | <- the default
# +-------+-------+-------+
# | True | add | raise |
# +=======+=======+=======+
# Raise because to allow this case would interfere with using
# functools.total_ordering.
# __hash__
# +------------------- hash= parameter
# | +----------- eq= parameter
# | | +--- frozen= parameter
# | | |
# v v v | | |
# | no | yes | <--- class has __hash__ in __dict__?
# +=========+=======+=======+========+========+
# | 1 None | False | False | | | No __eq__, use the base class __hash__
# +---------+-------+-------+--------+--------+
# | 2 None | False | True | | | No __eq__, use the base class __hash__
# +---------+-------+-------+--------+--------+
# | 3 None | True | False | None | | <-- the default, not hashable
# +---------+-------+-------+--------+--------+
# | 4 None | True | True | add | add* | Frozen, so hashable
# +---------+-------+-------+--------+--------+
# | 5 False | False | False | | |
# +---------+-------+-------+--------+--------+
# | 6 False | False | True | | |
# +---------+-------+-------+--------+--------+
# | 7 False | True | False | | |
# +---------+-------+-------+--------+--------+
# | 8 False | True | True | | |
# +---------+-------+-------+--------+--------+
# | 9 True | False | False | add | add* | Has no __eq__, but hashable
# +---------+-------+-------+--------+--------+
# |10 True | False | True | add | add* | Has no __eq__, but hashable
# +---------+-------+-------+--------+--------+
# |11 True | True | False | add | add* | Not frozen, but hashable
# +---------+-------+-------+--------+--------+
# |12 True | True | True | add | add* | Frozen, so hashable
# +=========+=======+=======+========+========+
# For boxes that are blank, __hash__ is untouched and therefore
# inherited from the base class. If the base is object, then
# id-based hashing is used.
# Note that a class may have already __hash__=None if it specified an
# __eq__ method in the class body (not one that was created by
# @dataclass).
# Raised when an attempt is made to modify a frozen class. # Raised when an attempt is made to modify a frozen class.
class FrozenInstanceError(AttributeError): pass class FrozenInstanceError(AttributeError): pass
...@@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields): ...@@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields):
# return "(self.x,self.y)". # return "(self.x,self.y)".
# Special case for the 0-tuple. # Special case for the 0-tuple.
if len(fields) == 0: if not fields:
return '()' return '()'
# Note the trailing comma, needed if this turns out to be a 1-tuple. # Note the trailing comma, needed if this turns out to be a 1-tuple.
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
def _create_fn(name, args, body, globals=None, locals=None, def _create_fn(name, args, body, *, globals=None, locals=None,
return_type=MISSING): return_type=MISSING):
# Note that we mutate locals when exec() is called. Caller beware! # Note that we mutate locals when exec() is called. Caller beware!
if locals is None: if locals is None:
...@@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name): ...@@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name):
body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})'] body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
# If no body lines, use 'pass'. # If no body lines, use 'pass'.
if len(body_lines) == 0: if not body_lines:
body_lines = ['pass'] body_lines = ['pass']
locals = {f'_type_{f.name}': f.type for f in fields} locals = {f'_type_{f.name}': f.type for f in fields}
...@@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple): ...@@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
'return NotImplemented']) 'return NotImplemented'])
def _set_eq_fns(cls, fields):
# Create and set the equality comparison methods on cls.
# Pre-compute self_tuple and other_tuple, then re-use them for
# each function.
self_tuple = _tuple_str('self', fields)
other_tuple = _tuple_str('other', fields)
for name, op in [('__eq__', '=='),
('__ne__', '!='),
]:
_set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
def _set_order_fns(cls, fields):
# Create and set the ordering methods on cls.
# Pre-compute self_tuple and other_tuple, then re-use them for
# each function.
self_tuple = _tuple_str('self', fields)
other_tuple = _tuple_str('other', fields)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
]:
_set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
def _hash_fn(fields): def _hash_fn(fields):
self_tuple = _tuple_str('self', fields) self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__', return _create_fn('__hash__',
...@@ -431,20 +541,20 @@ def _find_fields(cls): ...@@ -431,20 +541,20 @@ def _find_fields(cls):
# a Field(), then it contains additional info beyond (and # a Field(), then it contains additional info beyond (and
# possibly including) the actual default value. Pseudo-fields # possibly including) the actual default value. Pseudo-fields
# ClassVars and InitVars are included, despite the fact that # ClassVars and InitVars are included, despite the fact that
# they're not real fields. That's deal with later. # they're not real fields. That's dealt with later.
annotations = getattr(cls, '__annotations__', {}) annotations = getattr(cls, '__annotations__', {})
return [_get_field(cls, a_name, a_type) return [_get_field(cls, a_name, a_type)
for a_name, a_type in annotations.items()] for a_name, a_type in annotations.items()]
def _set_attribute(cls, name, value): def _set_new_attribute(cls, name, value):
# Raise TypeError if an attribute by this name already exists. # Never overwrites an existing attribute. Returns True if the
# attribute already exists.
if name in cls.__dict__: if name in cls.__dict__:
raise TypeError(f'Cannot overwrite attribute {name} ' return True
f'in {cls.__name__}')
setattr(cls, name, value) setattr(cls, name, value)
return False
def _process_class(cls, repr, eq, order, hash, init, frozen): def _process_class(cls, repr, eq, order, hash, init, frozen):
...@@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen): ...@@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
# be inherited down. # be inherited down.
is_frozen = frozen or cls.__setattr__ is _frozen_setattr is_frozen = frozen or cls.__setattr__ is _frozen_setattr
# Was this class defined with an __eq__? Used in __hash__ logic.
auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None
# If we're generating ordering methods, we must be generating # If we're generating ordering methods, we must be generating
# the eq methods. # the eq methods.
if order and not eq: if order and not eq:
...@@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen): ...@@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
has_post_init = hasattr(cls, _POST_INIT_NAME) has_post_init = hasattr(cls, _POST_INIT_NAME)
# Include InitVars and regular fields (so, not ClassVars). # Include InitVars and regular fields (so, not ClassVars).
_set_attribute(cls, '__init__', flds = [f for f in fields.values()
_init_fn(list(filter(lambda f: f._field_type if f._field_type in (_FIELD, _FIELD_INITVAR)]
in (_FIELD, _FIELD_INITVAR), _set_new_attribute(cls, '__init__',
fields.values())), _init_fn(flds,
is_frozen, is_frozen,
has_post_init, has_post_init,
# The name to use for the "self" param # The name to use for the "self" param
# in __init__. Use "self" if possible. # in __init__. Use "self" if possible.
'__dataclass_self__' if 'self' in fields '__dataclass_self__' if 'self' in fields
else 'self', else 'self',
)) ))
# Get the fields as a list, and include only real fields. This is # Get the fields as a list, and include only real fields. This is
# used in all of the following methods. # used in all of the following methods.
field_list = list(filter(lambda f: f._field_type is _FIELD, field_list = [f for f in fields.values() if f._field_type is _FIELD]
fields.values()))
if repr: if repr:
_set_attribute(cls, '__repr__', flds = [f for f in field_list if f.repr]
_repr_fn(list(filter(lambda f: f.repr, field_list)))) _set_new_attribute(cls, '__repr__', _repr_fn(flds))
if is_frozen:
_set_attribute(cls, '__setattr__', _frozen_setattr)
_set_attribute(cls, '__delattr__', _frozen_delattr)
generate_hash = False
if hash is None:
if eq and frozen:
# Generate a hash function.
generate_hash = True
elif eq and not frozen:
# Not hashable.
_set_attribute(cls, '__hash__', None)
elif not eq:
# Otherwise, use the base class definition of hash(). That is,
# don't set anything on this class.
pass
else:
assert "can't get here"
else:
generate_hash = hash
if generate_hash:
_set_attribute(cls, '__hash__',
_hash_fn(list(filter(lambda f: f.compare
if f.hash is None
else f.hash,
field_list))))
if eq: if eq:
# Create and __eq__ and __ne__ methods. # Create _eq__ method. There's no need for a __ne__ method,
_set_eq_fns(cls, list(filter(lambda f: f.compare, field_list))) # since python will call __eq__ and negate it.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
_set_new_attribute(cls, '__eq__',
_cmp_fn('__eq__', '==',
self_tuple, other_tuple))
if order: if order:
# Create and __lt__, __le__, __gt__, and __ge__ methods. # Create and set the ordering methods.
# Create and set the comparison functions. flds = [f for f in field_list if f.compare]
_set_order_fns(cls, list(filter(lambda f: f.compare, field_list))) self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in {cls.__name__}. Consider using '
'functools.total_ordering')
if is_frozen:
for name, fn in [('__setattr__', _frozen_setattr),
('__delattr__', _frozen_delattr)]:
if _set_new_attribute(cls, name, fn):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in {cls.__name__}')
# Decide if/how we're going to create a hash function.
# TODO: Move this table to module scope, so it's not recreated
# all the time.
generate_hash = {(None, False, False): ('', ''),
(None, False, True): ('', ''),
(None, True, False): ('none', ''),
(None, True, True): ('fn', 'fn-x'),
(False, False, False): ('', ''),
(False, False, True): ('', ''),
(False, True, False): ('', ''),
(False, True, True): ('', ''),
(True, False, False): ('fn', 'fn-x'),
(True, False, True): ('fn', 'fn-x'),
(True, True, False): ('fn', 'fn-x'),
(True, True, True): ('fn', 'fn-x'),
}[None if hash is None else bool(hash), # Force bool() if not None.
bool(eq),
bool(frozen)]['__hash__' in cls.__dict__]
# No need to call _set_new_attribute here, since we already know if
# we're overwriting a __hash__ or not.
if generate_hash == '':
# Do nothing.
pass
elif generate_hash == 'none':
cls.__hash__ = None
elif generate_hash in ('fn', 'fn-x'):
if generate_hash == 'fn' or auto_hash_test:
flds = [f for f in field_list
if (f.compare if f.hash is None else f.hash)]
cls.__hash__ = _hash_fn(flds)
else:
assert False, f"can't get here: {generate_hash}"
if not getattr(cls, '__doc__'): if not getattr(cls, '__doc__'):
# Create a class doc-string. # Create a class doc-string.
......
...@@ -9,6 +9,7 @@ import unittest ...@@ -9,6 +9,7 @@ import unittest
from unittest.mock import Mock from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
from collections import deque, OrderedDict, namedtuple from collections import deque, OrderedDict, namedtuple
from functools import total_ordering
# Just any custom exception we can catch. # Just any custom exception we can catch.
class CustomError(Exception): pass class CustomError(Exception): pass
...@@ -82,68 +83,12 @@ class TestCase(unittest.TestCase): ...@@ -82,68 +83,12 @@ class TestCase(unittest.TestCase):
class C(B): class C(B):
x: int = 0 x: int = 0
def test_overwriting_init(self): def test_overwriting_hash(self):
with self.assertRaisesRegex(TypeError, @dataclass(frozen=True)
'Cannot overwrite attribute __init__ '
'in C'):
@dataclass
class C:
x: int
def __init__(self, x):
self.x = 2 * x
@dataclass(init=False)
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(5).x, 10)
def test_overwriting_repr(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __repr__ '
'in C'):
@dataclass
class C:
x: int
def __repr__(self):
pass
@dataclass(repr=False)
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
def test_overwriting_cmp(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __eq__ '
'in C'):
# This will generate the comparison functions, make sure we can't
# overwrite them.
@dataclass(hash=False, frozen=False)
class C:
x: int
def __eq__(self):
pass
@dataclass(order=False, eq=False)
class C: class C:
x: int x: int
def __eq__(self, other): def __hash__(self):
return True pass
self.assertEqual(C(0), 'x')
def test_overwriting_hash(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __hash__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __hash__(self):
pass
@dataclass(frozen=True,hash=False) @dataclass(frozen=True,hash=False)
class C: class C:
...@@ -152,14 +97,11 @@ class TestCase(unittest.TestCase): ...@@ -152,14 +97,11 @@ class TestCase(unittest.TestCase):
return 600 return 600
self.assertEqual(hash(C(0)), 600) self.assertEqual(hash(C(0)), 600)
with self.assertRaisesRegex(TypeError, @dataclass(frozen=True)
'Cannot overwrite attribute __hash__ ' class C:
'in C'): x: int
@dataclass(frozen=True) def __hash__(self):
class C: pass
x: int
def __hash__(self):
pass
@dataclass(frozen=True, hash=False) @dataclass(frozen=True, hash=False)
class C: class C:
...@@ -168,33 +110,6 @@ class TestCase(unittest.TestCase): ...@@ -168,33 +110,6 @@ class TestCase(unittest.TestCase):
return 600 return 600
self.assertEqual(hash(C(0)), 600) self.assertEqual(hash(C(0)), 600)
def test_overwriting_frozen(self):
# frozen uses __setattr__ and __delattr__
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __setattr__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __setattr__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __delattr__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __delattr__(self):
pass
@dataclass(frozen=False)
class C:
x: int
def __setattr__(self, name, value):
self.__dict__['x'] = value * 2
self.assertEqual(C(10).x, 20)
def test_overwrite_fields_in_derived_class(self): def test_overwrite_fields_in_derived_class(self):
# Note that x from C1 replaces x in Base, but the order remains # Note that x from C1 replaces x in Base, but the order remains
# the same as defined in Base. # the same as defined in Base.
...@@ -239,34 +154,6 @@ class TestCase(unittest.TestCase): ...@@ -239,34 +154,6 @@ class TestCase(unittest.TestCase):
first = next(iter(sig.parameters)) first = next(iter(sig.parameters))
self.assertEqual('self', first) self.assertEqual('self', first)
def test_repr(self):
@dataclass
class B:
x: int
@dataclass
class C(B):
y: int = 10
o = C(4)
self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
@dataclass
class D(C):
x: int = 20
self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
@dataclass
class C:
@dataclass
class D:
i: int
@dataclass
class E:
pass
self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
def test_0_field_compare(self): def test_0_field_compare(self):
# Ensure that order=False is the default. # Ensure that order=False is the default.
@dataclass @dataclass
...@@ -420,80 +307,8 @@ class TestCase(unittest.TestCase): ...@@ -420,80 +307,8 @@ class TestCase(unittest.TestCase):
self.assertEqual(hash(C(4)), hash((4,))) self.assertEqual(hash(C(4)), hash((4,)))
self.assertEqual(hash(C(42)), hash((42,))) self.assertEqual(hash(C(42)), hash((42,)))
def test_hash(self):
@dataclass(hash=True)
class C:
x: int
y: str
self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
def test_no_hash(self):
@dataclass(hash=None)
class C:
x: int
with self.assertRaisesRegex(TypeError,
"unhashable type: 'C'"):
hash(C(1))
def test_hash_rules(self):
# There are 24 cases of:
# hash=True/False/None
# eq=True/False
# order=True/False
# frozen=True/False
for (hash, eq, order, frozen, result ) in [
(False, False, False, False, 'absent'),
(False, False, False, True, 'absent'),
(False, False, True, False, 'exception'),
(False, False, True, True, 'exception'),
(False, True, False, False, 'absent'),
(False, True, False, True, 'absent'),
(False, True, True, False, 'absent'),
(False, True, True, True, 'absent'),
(True, False, False, False, 'fn'),
(True, False, False, True, 'fn'),
(True, False, True, False, 'exception'),
(True, False, True, True, 'exception'),
(True, True, False, False, 'fn'),
(True, True, False, True, 'fn'),
(True, True, True, False, 'fn'),
(True, True, True, True, 'fn'),
(None, False, False, False, 'absent'),
(None, False, False, True, 'absent'),
(None, False, True, False, 'exception'),
(None, False, True, True, 'exception'),
(None, True, False, False, 'none'),
(None, True, False, True, 'fn'),
(None, True, True, False, 'none'),
(None, True, True, True, 'fn'),
]:
with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
if result == 'exception':
with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
@dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
class C:
pass
else:
@dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
class C:
pass
# See if the result matches what's expected.
if result == 'fn':
# __hash__ contains the function we generated.
self.assertIn('__hash__', C.__dict__)
self.assertIsNotNone(C.__dict__['__hash__'])
elif result == 'absent':
# __hash__ is not present in our class.
self.assertNotIn('__hash__', C.__dict__)
elif result == 'none':
# __hash__ is set to None.
self.assertIn('__hash__', C.__dict__)
self.assertIsNone(C.__dict__['__hash__'])
else:
assert False, f'unknown result {result!r}'
def test_eq_order(self): def test_eq_order(self):
# Test combining eq and order.
for (eq, order, result ) in [ for (eq, order, result ) in [
(False, False, 'neither'), (False, False, 'neither'),
(False, True, 'exception'), (False, True, 'exception'),
...@@ -513,21 +328,18 @@ class TestCase(unittest.TestCase): ...@@ -513,21 +328,18 @@ class TestCase(unittest.TestCase):
if result == 'neither': if result == 'neither':
self.assertNotIn('__eq__', C.__dict__) self.assertNotIn('__eq__', C.__dict__)
self.assertNotIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__) self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__) self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__) self.assertNotIn('__gt__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__) self.assertNotIn('__ge__', C.__dict__)
elif result == 'both': elif result == 'both':
self.assertIn('__eq__', C.__dict__) self.assertIn('__eq__', C.__dict__)
self.assertIn('__ne__', C.__dict__)
self.assertIn('__lt__', C.__dict__) self.assertIn('__lt__', C.__dict__)
self.assertIn('__le__', C.__dict__) self.assertIn('__le__', C.__dict__)
self.assertIn('__gt__', C.__dict__) self.assertIn('__gt__', C.__dict__)
self.assertIn('__ge__', C.__dict__) self.assertIn('__ge__', C.__dict__)
elif result == 'eq_only': elif result == 'eq_only':
self.assertIn('__eq__', C.__dict__) self.assertIn('__eq__', C.__dict__)
self.assertIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__) self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__) self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__) self.assertNotIn('__gt__', C.__dict__)
...@@ -811,19 +623,6 @@ class TestCase(unittest.TestCase): ...@@ -811,19 +623,6 @@ class TestCase(unittest.TestCase):
y: int y: int
self.assertNotEqual(Point(1, 3), C(1, 3)) self.assertNotEqual(Point(1, 3), C(1, 3))
def test_base_has_init(self):
class B:
def __init__(self):
pass
# Make sure that declaring this class doesn't raise an error.
# The issue is that we can't override __init__ in our class,
# but it should be okay to add __init__ to us if our base has
# an __init__.
@dataclass
class C(B):
x: int = 0
def test_frozen(self): def test_frozen(self):
@dataclass(frozen=True) @dataclass(frozen=True)
class C: class C:
...@@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase): ...@@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase):
'y': int, 'y': int,
'z': 'typing.Any'}) 'z': 'typing.Any'})
class TestDocString(unittest.TestCase): class TestDocString(unittest.TestCase):
def assertDocStrEqual(self, a, b): def assertDocStrEqual(self, a, b):
# Because 3.6 and 3.7 differ in how inspect.signature work # Because 3.6 and 3.7 differ in how inspect.signature work
...@@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase): ...@@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase):
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
class TestInit(unittest.TestCase):
def test_base_has_init(self):
class B:
def __init__(self):
self.z = 100
pass
# Make sure that declaring this class doesn't raise an error.
# The issue is that we can't override __init__ in our class,
# but it should be okay to add __init__ to us if our base has
# an __init__.
@dataclass
class C(B):
x: int = 0
c = C(10)
self.assertEqual(c.x, 10)
self.assertNotIn('z', vars(c))
# Make sure that if we don't add an init, the base __init__
# gets called.
@dataclass(init=False)
class C(B):
x: int = 10
c = C()
self.assertEqual(c.x, 10)
self.assertEqual(c.z, 100)
def test_no_init(self):
dataclass(init=False)
class C:
i: int = 0
self.assertEqual(C().i, 0)
dataclass(init=False)
class C:
i: int = 2
def __init__(self):
self.i = 3
self.assertEqual(C().i, 3)
def test_overwriting_init(self):
# If the class has __init__, use it no matter the value of
# init=.
@dataclass
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(3).x, 6)
@dataclass(init=True)
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(4).x, 8)
@dataclass(init=False)
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(5).x, 10)
class TestRepr(unittest.TestCase):
def test_repr(self):
@dataclass
class B:
x: int
@dataclass
class C(B):
y: int = 10
o = C(4)
self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
@dataclass
class D(C):
x: int = 20
self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
@dataclass
class C:
@dataclass
class D:
i: int
@dataclass
class E:
pass
self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
def test_no_repr(self):
# Test a class with no __repr__ and repr=False.
@dataclass(repr=False)
class C:
x: int
self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
repr(C(3)))
# Test a class with a __repr__ and repr=False.
@dataclass(repr=False)
class C:
x: int
def __repr__(self):
return 'C-class'
self.assertEqual(repr(C(3)), 'C-class')
def test_overwriting_repr(self):
# If the class has __repr__, use it no matter the value of
# repr=.
@dataclass
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
@dataclass(repr=True)
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
@dataclass(repr=False)
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
class TestFrozen(unittest.TestCase):
def test_overwriting_frozen(self):
# frozen uses __setattr__ and __delattr__
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __setattr__'):
@dataclass(frozen=True)
class C:
x: int
def __setattr__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __delattr__'):
@dataclass(frozen=True)
class C:
x: int
def __delattr__(self):
pass
@dataclass(frozen=False)
class C:
x: int
def __setattr__(self, name, value):
self.__dict__['x'] = value * 2
self.assertEqual(C(10).x, 20)
class TestEq(unittest.TestCase):
def test_no_eq(self):
# Test a class with no __eq__ and eq=False.
@dataclass(eq=False)
class C:
x: int
self.assertNotEqual(C(0), C(0))
c = C(3)
self.assertEqual(c, c)
# Test a class with an __eq__ and eq=False.
@dataclass(eq=False)
class C:
x: int
def __eq__(self, other):
return other == 10
self.assertEqual(C(3), 10)
def test_overwriting_eq(self):
# If the class has __eq__, use it no matter the value of
# eq=.
@dataclass
class C:
x: int
def __eq__(self, other):
return other == 3
self.assertEqual(C(1), 3)
self.assertNotEqual(C(1), 1)
@dataclass(eq=True)
class C:
x: int
def __eq__(self, other):
return other == 4
self.assertEqual(C(1), 4)
self.assertNotEqual(C(1), 1)
@dataclass(eq=False)
class C:
x: int
def __eq__(self, other):
return other == 5
self.assertEqual(C(1), 5)
self.assertNotEqual(C(1), 1)
class TestOrdering(unittest.TestCase):
def test_functools_total_ordering(self):
# Test that functools.total_ordering works with this class.
@total_ordering
@dataclass
class C:
x: int
def __lt__(self, other):
# Perform the test "backward", just to make
# sure this is being called.
return self.x >= other
self.assertLess(C(0), -1)
self.assertLessEqual(C(0), -1)
self.assertGreater(C(0), 1)
self.assertGreaterEqual(C(0), 1)
def test_no_order(self):
# Test that no ordering functions are added by default.
@dataclass(order=False)
class C:
x: int
# Make sure no order methods are added.
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
# Test that __lt__ is still called
@dataclass(order=False)
class C:
x: int
def __lt__(self, other):
return False
# Make sure other methods aren't added.
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
def test_overwriting_order(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __lt__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __lt__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __le__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __le__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __gt__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __gt__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __ge__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __ge__(self):
pass
class TestHash(unittest.TestCase):
def test_hash(self):
@dataclass(hash=True)
class C:
x: int
y: str
self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
def test_hash_false(self):
@dataclass(hash=False)
class C:
x: int
y: str
self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
def test_hash_none(self):
@dataclass(hash=None)
class C:
x: int
with self.assertRaisesRegex(TypeError,
"unhashable type: 'C'"):
hash(C(1))
def test_hash_rules(self):
def non_bool(value):
# Map to something else that's True, but not a bool.
if value is None:
return None
if value:
return (3,)
return 0
def test(case, hash, eq, frozen, with_hash, result):
with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
if with_hash:
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
def __hash__(self):
return 0
else:
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
pass
# See if the result matches what's expected.
if result in ('fn', 'fn-x'):
# __hash__ contains the function we generated.
self.assertIn('__hash__', C.__dict__)
self.assertIsNotNone(C.__dict__['__hash__'])
if result == 'fn-x':
# This is the "auto-hash test" case. We
# should overwrite __hash__ iff there's an
# __eq__ and if __hash__=None.
# There are two ways of getting __hash__=None:
# explicitely, and by defining __eq__. If
# __eq__ is defined, python will add __hash__
# when the class is created.
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
def __eq__(self, other): pass
__hash__ = None
# Hash should be overwritten (non-None).
self.assertIsNotNone(C.__dict__['__hash__'])
# Same test as above, but we don't provide
# __hash__, it will implicitely set to None.
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
def __eq__(self, other): pass
# Hash should be overwritten (non-None).
self.assertIsNotNone(C.__dict__['__hash__'])
elif result == '':
# __hash__ is not present in our class.
if not with_hash:
self.assertNotIn('__hash__', C.__dict__)
elif result == 'none':
# __hash__ is set to None.
self.assertIn('__hash__', C.__dict__)
self.assertIsNone(C.__dict__['__hash__'])
else:
assert False, f'unknown result {result!r}'
# There are 12 cases of:
# hash=True/False/None
# eq=True/False
# frozen=True/False
# And for each of these, a different result if
# __hash__ is defined or not.
for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
(None, False, False, '', ''),
(None, False, True, '', ''),
(None, True, False, 'none', ''),
(None, True, True, 'fn', 'fn-x'),
(False, False, False, '', ''),
(False, False, True, '', ''),
(False, True, False, '', ''),
(False, True, True, '', ''),
(True, False, False, 'fn', 'fn-x'),
(True, False, True, 'fn', 'fn-x'),
(True, True, False, 'fn', 'fn-x'),
(True, True, True, 'fn', 'fn-x'),
], 1):
test(case, hash, eq, frozen, False, result_no)
test(case, hash, eq, frozen, True, result_yes)
# Test non-bool truth values, too. This is just to
# make sure the data-driven table in the decorator
# handles non-bool values.
test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
def test_eq_only(self):
# If a class defines __eq__, __hash__ is automatically added
# and set to None. This is normal Python behavior, not
# related to dataclasses. Make sure we don't interfere with
# that (see bpo=32546).
@dataclass
class C:
i: int
def __eq__(self, other):
return self.i == other.i
self.assertEqual(C(1), C(1))
self.assertNotEqual(C(1), C(4))
# And make sure things work in this case if we specify
# hash=True.
@dataclass(hash=True)
class C:
i: int
def __eq__(self, other):
return self.i == other.i
self.assertEqual(C(1), C(1.0))
self.assertEqual(hash(C(1)), hash(C(1.0)))
# And check that the classes __eq__ is being used, despite
# specifying eq=True.
@dataclass(hash=True, eq=True)
class C:
i: int
def __eq__(self, other):
return self.i == 3 and self.i == other.i
self.assertEqual(C(3), C(3))
self.assertNotEqual(C(1), C(1))
self.assertEqual(hash(C(1)), hash(C(1.0)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
In dataclasses, allow easier overriding of dunder methods without specifying
decorator parameters.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment