test_richcmp.py 10.8 KB
Newer Older
Guido van Rossum's avatar
Guido van Rossum committed
1 2
# Tests for rich comparisons

3
import unittest
4
from test import support
5 6

import operator
Guido van Rossum's avatar
Guido van Rossum committed
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

class Number:

    def __init__(self, x):
        self.x = x

    def __lt__(self, other):
        return self.x < other

    def __le__(self, other):
        return self.x <= other

    def __eq__(self, other):
        return self.x == other

    def __ne__(self, other):
        return self.x != other

    def __gt__(self, other):
        return self.x > other

    def __ge__(self, other):
        return self.x >= other

    def __cmp__(self, other):
32
        raise support.TestFailed("Number.__cmp__() should not be called")
Guido van Rossum's avatar
Guido van Rossum committed
33 34

    def __repr__(self):
35
        return "Number(%r)" % (self.x, )
Guido van Rossum's avatar
Guido van Rossum committed
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

class Vector:

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

    def __setitem__(self, i, v):
        self.data[i] = v

51
    __hash__ = None # Vectors cannot be hashed
Guido van Rossum's avatar
Guido van Rossum committed
52

53
    def __bool__(self):
54
        raise TypeError("Vectors cannot be used in Boolean contexts")
Guido van Rossum's avatar
Guido van Rossum committed
55 56

    def __cmp__(self, other):
57
        raise support.TestFailed("Vector.__cmp__() should not be called")
Guido van Rossum's avatar
Guido van Rossum committed
58 59

    def __repr__(self):
60
        return "Vector(%r)" % (self.data, )
Guido van Rossum's avatar
Guido van Rossum committed
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

    def __lt__(self, other):
        return Vector([a < b for a, b in zip(self.data, self.__cast(other))])

    def __le__(self, other):
        return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])

    def __eq__(self, other):
        return Vector([a == b for a, b in zip(self.data, self.__cast(other))])

    def __ne__(self, other):
        return Vector([a != b for a, b in zip(self.data, self.__cast(other))])

    def __gt__(self, other):
        return Vector([a > b for a, b in zip(self.data, self.__cast(other))])

    def __ge__(self, other):
        return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])

    def __cast(self, other):
        if isinstance(other, Vector):
            other = other.data
        if len(self.data) != len(other):
84
            raise ValueError("Cannot compare vectors of different length")
Guido van Rossum's avatar
Guido van Rossum committed
85 86
        return other

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
opmap = {
    "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
    "le": (lambda a,b: a<=b, operator.le, operator.__le__),
    "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
    "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
    "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
    "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
}

class VectorTest(unittest.TestCase):

    def checkfail(self, error, opname, *args):
        for op in opmap[opname]:
            self.assertRaises(error, op, *args)

    def checkequal(self, opname, a, b, expres):
        for op in opmap[opname]:
            realres = op(a, b)
            # can't use assertEqual(realres, expres) here
            self.assertEqual(len(realres), len(expres))
107
            for i in range(len(realres)):
108
                # results are bool, so we can use "is" here
109
                self.assertTrue(realres[i] is expres[i])
110 111 112 113 114 115 116 117 118 119 120

    def test_mixed(self):
        # check that comparisons involving Vector objects
        # which return rich results (i.e. Vectors with itemwise
        # comparison results) work
        a = Vector(range(2))
        b = Vector(range(3))
        # all comparisons should fail for different length
        for opname in opmap:
            self.checkfail(ValueError, opname, a, b)

121
        a = list(range(5))
122 123 124 125 126 127 128 129 130 131 132
        b = 5 * [2]
        # try mixed arguments (but not (a, b) as that won't return a bool vector)
        args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
        for (a, b) in args:
            self.checkequal("lt", a, b, [True,  True,  False, False, False])
            self.checkequal("le", a, b, [True,  True,  True,  False, False])
            self.checkequal("eq", a, b, [False, False, True,  False, False])
            self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
            self.checkequal("gt", a, b, [False, False, False, True,  True ])
            self.checkequal("ge", a, b, [False, False, True,  True,  True ])

133
            for ops in opmap.values():
134
                for op in ops:
135
                    # calls __bool__, which should fail
136 137 138 139 140 141 142 143
                    self.assertRaises(TypeError, bool, op(a, b))

class NumberTest(unittest.TestCase):

    def test_basic(self):
        # Check that comparisons involving Number objects
        # give the same results give as comparing the
        # corresponding ints
144 145
        for a in range(3):
            for b in range(3):
146 147 148 149 150 151
                for typea in (int, Number):
                    for typeb in (int, Number):
                        if typea==typeb==int:
                            continue # the combination int, int is useless
                        ta = typea(a)
                        tb = typeb(b)
152
                        for ops in opmap.values():
153 154 155 156 157 158 159 160 161 162 163 164 165
                            for op in ops:
                                realoutcome = op(a, b)
                                testoutcome = op(ta, tb)
                                self.assertEqual(realoutcome, testoutcome)

    def checkvalue(self, opname, a, b, expres):
        for typea in (int, Number):
            for typeb in (int, Number):
                ta = typea(a)
                tb = typeb(b)
                for op in opmap[opname]:
                    realres = op(ta, tb)
                    realres = getattr(realres, "x", realres)
166
                    self.assertTrue(realres is expres)
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

    def test_values(self):
        # check all operators and all comparison results
        self.checkvalue("lt", 0, 0, False)
        self.checkvalue("le", 0, 0, True )
        self.checkvalue("eq", 0, 0, True )
        self.checkvalue("ne", 0, 0, False)
        self.checkvalue("gt", 0, 0, False)
        self.checkvalue("ge", 0, 0, True )

        self.checkvalue("lt", 0, 1, True )
        self.checkvalue("le", 0, 1, True )
        self.checkvalue("eq", 0, 1, False)
        self.checkvalue("ne", 0, 1, True )
        self.checkvalue("gt", 0, 1, False)
        self.checkvalue("ge", 0, 1, False)

        self.checkvalue("lt", 1, 0, False)
        self.checkvalue("le", 1, 0, False)
        self.checkvalue("eq", 1, 0, False)
        self.checkvalue("ne", 1, 0, True )
        self.checkvalue("gt", 1, 0, True )
        self.checkvalue("ge", 1, 0, True )

class MiscTest(unittest.TestCase):

    def test_misbehavin(self):
        class Misb:
Georg Brandl's avatar
Georg Brandl committed
195 196 197 198 199 200
            def __lt__(self_, other): return 0
            def __gt__(self_, other): return 0
            def __eq__(self_, other): return 0
            def __le__(self_, other): self.fail("This shouldn't happen")
            def __ge__(self_, other): self.fail("This shouldn't happen")
            def __ne__(self_, other): self.fail("This shouldn't happen")
201 202 203 204 205 206 207
        a = Misb()
        b = Misb()
        self.assertEqual(a<b, 0)
        self.assertEqual(a==b, 0)
        self.assertEqual(a>b, 0)

    def test_not(self):
208
        # Check that exceptions in __bool__ are properly
209 210
        # propagated by the not operator
        import operator
211
        class Exc(Exception):
Guido van Rossum's avatar
Guido van Rossum committed
212
            pass
213
        class Bad:
214
            def __bool__(self):
215 216 217 218 219 220 221 222
                raise Exc

        def do(bad):
            not bad

        for func in (do, operator.not_):
            self.assertRaises(Exc, func, Bad())

223
    @support.no_tracing
224
    def test_recursion(self):
225
        # Check that comparison for recursive objects fails gracefully
226
        from collections import UserList
227 228 229 230
        a = UserList()
        b = UserList()
        a.append(b)
        b.append(a)
231 232 233 234 235 236
        self.assertRaises(RuntimeError, operator.eq, a, b)
        self.assertRaises(RuntimeError, operator.ne, a, b)
        self.assertRaises(RuntimeError, operator.lt, a, b)
        self.assertRaises(RuntimeError, operator.le, a, b)
        self.assertRaises(RuntimeError, operator.gt, a, b)
        self.assertRaises(RuntimeError, operator.ge, a, b)
237 238

        b.append(17)
239 240
        # Even recursive lists of different lengths are different,
        # but they cannot be ordered
241 242
        self.assertTrue(not (a == b))
        self.assertTrue(a != b)
243 244 245 246
        self.assertRaises(RuntimeError, operator.lt, a, b)
        self.assertRaises(RuntimeError, operator.le, a, b)
        self.assertRaises(RuntimeError, operator.gt, a, b)
        self.assertRaises(RuntimeError, operator.ge, a, b)
247
        a.append(17)
248 249 250 251
        self.assertRaises(RuntimeError, operator.eq, a, b)
        self.assertRaises(RuntimeError, operator.ne, a, b)
        a.insert(0, 11)
        b.insert(0, 12)
252 253 254
        self.assertTrue(not (a == b))
        self.assertTrue(a != b)
        self.assertTrue(a < b)
255 256 257 258 259

class DictTest(unittest.TestCase):

    def test_dicts(self):
        # Verify that __eq__ and __ne__ work for dicts even if the keys and
260 261
        # values don't support anything other than __eq__ and __ne__ (and
        # __hash__).  Complex numbers are a fine example of that.
262 263 264 265
        import random
        imag1a = {}
        for i in range(50):
            imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
266
        items = list(imag1a.items())
267 268 269 270 271 272
        random.shuffle(items)
        imag1b = {}
        for k, v in items:
            imag1b[k] = v
        imag2 = imag1b.copy()
        imag2[k] = v + 1.0
273 274 275
        self.assertEqual(imag1a, imag1a)
        self.assertEqual(imag1a, imag1b)
        self.assertEqual(imag2, imag2)
276
        self.assertTrue(imag1a != imag2)
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
        for opname in ("lt", "le", "gt", "ge"):
            for op in opmap[opname]:
                self.assertRaises(TypeError, op, imag1a, imag2)

class ListTest(unittest.TestCase):

    def test_coverage(self):
        # exercise all comparisons for lists
        x = [42]
        self.assertIs(x<x, False)
        self.assertIs(x<=x, True)
        self.assertIs(x==x, True)
        self.assertIs(x!=x, False)
        self.assertIs(x>x, False)
        self.assertIs(x>=x, True)
        y = [42, 42]
        self.assertIs(x<y, True)
        self.assertIs(x<=y, True)
        self.assertIs(x==y, False)
        self.assertIs(x!=y, True)
        self.assertIs(x>y, False)
        self.assertIs(x>=y, False)

    def test_badentry(self):
        # make sure that exceptions for item comparison are properly
        # propagated in list comparisons
303
        class Exc(Exception):
Guido van Rossum's avatar
Guido van Rossum committed
304
            pass
305 306 307
        class Bad:
            def __eq__(self, other):
                raise Exc
Guido van Rossum's avatar
Guido van Rossum committed
308

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
        x = [Bad()]
        y = [Bad()]

        for op in opmap["eq"]:
            self.assertRaises(Exc, op, x, y)

    def test_goodentry(self):
        # This test exercises the final call to PyObject_RichCompare()
        # in Objects/listobject.c::list_richcompare()
        class Good:
            def __lt__(self, other):
                return True

        x = [Good()]
        y = [Good()]

        for op in opmap["lt"]:
            self.assertIs(op(x, y), True)

328

329
def test_main():
330
    support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest)
331 332 333

if __name__ == "__main__":
    test_main()