Kaydet (Commit) 374274db authored tarafından Éric Araujo's avatar Éric Araujo

Fix the total_ordering decorator to handle cross-type comparisons

that could lead to infinite recursion (closes #10042).
üst 5136867c
...@@ -53,17 +53,17 @@ def wraps(wrapped, ...@@ -53,17 +53,17 @@ def wraps(wrapped,
def total_ordering(cls): def total_ordering(cls):
"""Class decorator that fills in missing ordering methods""" """Class decorator that fills in missing ordering methods"""
convert = { convert = {
'__lt__': [('__gt__', lambda self, other: other < self), '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
('__le__', lambda self, other: not other < self), ('__le__', lambda self, other: self < other or self == other),
('__ge__', lambda self, other: not self < other)], ('__ge__', lambda self, other: not self < other)],
'__le__': [('__ge__', lambda self, other: other <= self), '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
('__lt__', lambda self, other: not other <= self), ('__lt__', lambda self, other: self <= other and not self == other),
('__gt__', lambda self, other: not self <= other)], ('__gt__', lambda self, other: not self <= other)],
'__gt__': [('__lt__', lambda self, other: other > self), '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
('__ge__', lambda self, other: not other > self), ('__ge__', lambda self, other: self > other or self == other),
('__le__', lambda self, other: not self > other)], ('__le__', lambda self, other: not self > other)],
'__ge__': [('__le__', lambda self, other: other >= self), '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
('__gt__', lambda self, other: not other >= self), ('__gt__', lambda self, other: self >= other and not self == other),
('__lt__', lambda self, other: not self >= other)] ('__lt__', lambda self, other: not self >= other)]
} }
roots = set(dir(cls)) & set(convert) roots = set(dir(cls)) & set(convert)
......
...@@ -361,6 +361,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -361,6 +361,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __lt__(self, other): def __lt__(self, other):
return self.value < other.value return self.value < other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -375,6 +377,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -375,6 +377,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __le__(self, other): def __le__(self, other):
return self.value <= other.value return self.value <= other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -389,6 +393,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -389,6 +393,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __gt__(self, other): def __gt__(self, other):
return self.value > other.value return self.value > other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -403,6 +409,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -403,6 +409,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __ge__(self, other): def __ge__(self, other):
return self.value >= other.value return self.value >= other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -428,6 +436,22 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -428,6 +436,22 @@ class TestTotalOrdering(unittest.TestCase):
class A: class A:
pass pass
def test_bug_10042(self):
@functools.total_ordering
class TestTO:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, TestTO):
return self.value == other.value
return False
def __lt__(self, other):
if isinstance(other, TestTO):
return self.value < other.value
raise TypeError
with self.assertRaises(TypeError):
TestTO(8) <= ()
def test_main(verbose=None): def test_main(verbose=None):
test_classes = ( test_classes = (
TestPartial, TestPartial,
......
...@@ -669,6 +669,7 @@ Bernhard Reiter ...@@ -669,6 +669,7 @@ Bernhard Reiter
Steven Reiz Steven Reiz
Roeland Rengelink Roeland Rengelink
Tim Rice Tim Rice
Francesco Ricciardi
Jan Pieter Riegel Jan Pieter Riegel
Armin Rigo Armin Rigo
Nicholas Riley Nicholas Riley
......
...@@ -43,6 +43,9 @@ Core and Builtins ...@@ -43,6 +43,9 @@ Core and Builtins
Library Library
------- -------
- Issue #10042: Fixed the total_ordering decorator to handle cross-type
comparisons that could lead to infinite recursion.
- Issue #10979: unittest stdout buffering now works with class and module - Issue #10979: unittest stdout buffering now works with class and module
setup and teardown. setup and teardown.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment