test_sort.py 8.96 KB
Newer Older
1
from test import support
2
import random
3 4
import sys
import unittest
5

6
verbose = support.verbose
7 8
nerrors = 0

9 10 11 12 13 14 15 16 17
def CmpToKey(mycmp):
    'Convert a cmp= function into a key= function'
    class K(object):
        def __init__(self, obj):
            self.obj = obj
        def __lt__(self, other):
            return mycmp(self.obj, other.obj) == -1
    return K

18 19 20 21
def check(tag, expected, raw, compare=None):
    global nerrors

    if verbose:
22
        print("    checking", tag)
23 24 25

    orig = raw[:]   # save input in case of error
    if compare:
26
        raw.sort(key=CmpToKey(compare))
27 28 29 30
    else:
        raw.sort()

    if len(expected) != len(raw):
31 32 33 34 35
        print("error in", tag)
        print("length mismatch;", len(expected), len(raw))
        print(expected)
        print(orig)
        print(raw)
36 37 38 39 40 41
        nerrors += 1
        return

    for i, good in enumerate(expected):
        maybe = raw[i]
        if good is not maybe:
42 43 44 45 46
            print("error in", tag)
            print("out of order at index", i, good, maybe)
            print(expected)
            print(orig)
            print(raw)
47 48 49
            nerrors += 1
            return

50 51 52 53 54 55 56 57
class TestBase(unittest.TestCase):
    def testStressfully(self):
        # Try a variety of sizes at and around powers of 2, and at powers of 10.
        sizes = [0]
        for power in range(1, 10):
            n = 2 ** power
            sizes.extend(range(n-1, n+2))
        sizes.extend([10, 100, 1000])
58

59 60
        class Complains(object):
            maybe_complain = True
61

62 63
            def __init__(self, i):
                self.i = i
64

65 66 67
            def __lt__(self, other):
                if Complains.maybe_complain and random.random() < 0.001:
                    if verbose:
68
                        print("        complaining at", self, other)
69 70
                    raise RuntimeError
                return self.i < other.i
71

72 73
            def __repr__(self):
                return "Complains(%d)" % self.i
74

75 76 77 78
        class Stable(object):
            def __init__(self, key, i):
                self.key = key
                self.index = i
79

80 81
            def __lt__(self, other):
                return self.key < other.key
82

83 84
            def __repr__(self):
                return "Stable(%d, %d)" % (self.key, self.index)
85

86
        for n in sizes:
87
            x = list(range(n))
88
            if verbose:
89
                print("Testing size", n)
90

91 92
            s = x[:]
            check("identity", x, s)
93

94 95 96
            s = x[:]
            s.reverse()
            check("reversed", x, s)
97

98 99 100
            s = x[:]
            random.shuffle(s)
            check("random permutation", x, s)
101

102 103 104
            y = x[:]
            y.reverse()
            s = x[:]
105
            check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
106

107
            if verbose:
108 109
                print("    Checking against an insane comparison function.")
                print("        If the implementation isn't careful, this may segfault.")
110
            s = x[:]
111
            s.sort(key=CmpToKey(lambda a, b:  int(random.random() * 3) - 1))
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            check("an insane function left some permutation", x, s)

            x = [Complains(i) for i in x]
            s = x[:]
            random.shuffle(s)
            Complains.maybe_complain = True
            it_complained = False
            try:
                s.sort()
            except RuntimeError:
                it_complained = True
            if it_complained:
                Complains.maybe_complain = False
                check("exception during sort left some permutation", x, s)

127
            s = [Stable(random.randrange(10), i) for i in range(n)]
128 129 130 131
            augmented = [(e, e.index) for e in s]
            augmented.sort()    # forced stable because ties broken by index
            x = [e for e, i in augmented] # a stable sort of s
            check("stability", x, s)
132

133
#==============================================================================
134

135
class TestBugs(unittest.TestCase):
136

137 138 139 140
    def test_bug453523(self):
        # bug 453523 -- list.sort() crasher.
        # If this fails, the most likely outcome is a core dump.
        # Mutations during a list sort should raise a ValueError.
141

142 143 144 145 146 147 148 149 150 151 152
        class C:
            def __lt__(self, other):
                if L and random.random() < 0.75:
                    L.pop()
                else:
                    L.append(3)
                return random.random() < 0.5

        L = [C() for i in range(50)]
        self.assertRaises(ValueError, L.sort)

153 154 155 156 157 158 159
    def test_undetected_mutation(self):
        # Python 2.4a1 did not always detect mutation
        memorywaster = []
        for i in range(20):
            def mutating_cmp(x, y):
                L.append(3)
                L.pop()
160
                return (x > y) - (x < y)
161
            L = [1,2]
162
            self.assertRaises(ValueError, L.sort, key=CmpToKey(mutating_cmp))
163 164 165
            def mutating_cmp(x, y):
                L.append(3)
                del L[:]
166
                return (x > y) - (x < y)
167
            self.assertRaises(ValueError, L.sort, key=CmpToKey(mutating_cmp))
168 169
            memorywaster = [memorywaster]

170 171 172 173 174 175 176 177 178
#==============================================================================

class TestDecorateSortUndecorate(unittest.TestCase):

    def test_decorated(self):
        data = 'The quick Brown fox Jumped over The lazy Dog'.split()
        copy = data[:]
        random.shuffle(data)
        data.sort(key=str.lower)
179 180 181 182
        def my_cmp(x, y):
            xlower, ylower = x.lower(), y.lower()
            return (xlower > ylower) - (xlower < ylower)
        copy.sort(key=CmpToKey(my_cmp))
183 184 185

    def test_baddecorator(self):
        data = 'The quick Brown fox Jumped over The lazy Dog'.split()
186
        self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
187 188

    def test_stability(self):
189
        data = [(random.randrange(100), i) for i in range(200)]
190
        copy = data[:]
191
        data.sort(key=lambda t: t[0])   # sort on the random first field
192 193 194
        copy.sort()                     # sort using both fields
        self.assertEqual(data, copy)    # should get the same result

195 196
    def test_key_with_exception(self):
        # Verify that the wrapper has been removed
197
        data = list(range(-2, 2))
198
        dup = data[:]
199
        self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
200 201
        self.assertEqual(data, dup)

202
    def test_key_with_mutation(self):
203
        data = list(range(10))
204 205 206 207 208 209 210
        def k(x):
            del data[:]
            data[:] = range(20)
            return x
        self.assertRaises(ValueError, data.sort, key=k)

    def test_key_with_mutating_del(self):
211
        data = list(range(10))
212 213 214 215 216 217
        class SortKiller(object):
            def __init__(self, x):
                pass
            def __del__(self):
                del data[:]
                data[:] = range(20)
218 219
            def __lt__(self, other):
                return id(self) < id(other)
220 221 222
        self.assertRaises(ValueError, data.sort, key=SortKiller)

    def test_key_with_mutating_del_and_exception(self):
223
        data = list(range(10))
224 225 226 227 228 229 230
        ## dup = data[:]
        class SortKiller(object):
            def __init__(self, x):
                if x > 2:
                    raise RuntimeError
            def __del__(self):
                del data[:]
231
                data[:] = list(range(20))
232 233 234 235 236 237 238 239 240 241
        self.assertRaises(RuntimeError, data.sort, key=SortKiller)
        ## major honking subtlety: we *can't* do:
        ##
        ## self.assertEqual(data, dup)
        ##
        ## because there is a reference to a SortKiller in the
        ## traceback and by the time it dies we're outside the call to
        ## .sort() and so the list protection gimmicks are out of
        ## date (this cost some brain cells to figure out...).

242
    def test_reverse(self):
243
        data = list(range(100))
244 245
        random.shuffle(data)
        data.sort(reverse=True)
246
        self.assertEqual(data, list(range(99,-1,-1)))
247 248

    def test_reverse_stability(self):
249
        data = [(random.randrange(100), i) for i in range(200)]
250 251
        copy1 = data[:]
        copy2 = data[:]
252 253 254 255 256 257 258 259
        def my_cmp(x, y):
            x0, y0 = x[0], y[0]
            return (x0 > y0) - (x0 < y0)
        def my_cmp_reversed(x, y):
            x0, y0 = x[0], y[0]
            return (y0 > x0) - (y0 < x0)
        data.sort(key=CmpToKey(my_cmp), reverse=True)
        copy1.sort(key=CmpToKey(my_cmp_reversed))
260 261 262 263 264 265 266 267
        self.assertEqual(data, copy1)
        copy2.sort(key=lambda x: x[0], reverse=True)
        self.assertEqual(data, copy2)

#==============================================================================

def test_main(verbose=None):
    test_classes = (
268
        TestBase,
269 270 271 272
        TestDecorateSortUndecorate,
        TestBugs,
    )

273
    support.run_unittest(*test_classes)
274 275 276 277 278

    # verify reference counting
    if verbose and hasattr(sys, "gettotalrefcount"):
        import gc
        counts = [None] * 5
279
        for i in range(len(counts)):
280
            support.run_unittest(*test_classes)
281 282
            gc.collect()
            counts[i] = sys.gettotalrefcount()
283
        print(counts)
284 285 286

if __name__ == "__main__":
    test_main(verbose=True)