test_memoryview.py 10.7 KB
Newer Older
1 2 3 4 5 6
"""Unit tests for the memoryview

XXX We need more tests! Some tests are in test_bytes
"""

import unittest
7
import test.support
8
import sys
9 10
import gc
import weakref
11
import array
12

13

14 15
class AbstractMemoryTests:
    source_bytes = b"abcdef"
16

17 18 19 20 21 22 23
    @property
    def _source(self):
        return self.source_bytes

    @property
    def _types(self):
        return filter(None, [self.ro_type, self.rw_type])
24 25

    def check_getitem_with_type(self, tp):
26 27
        item = self.getitem_type
        b = tp(self._source)
28 29
        oldrefcount = sys.getrefcount(b)
        m = self._view(b)
30
        self.assertEquals(m[0], item(b"a"))
31
        self.assertIsInstance(m[0], bytes)
32 33 34
        self.assertEquals(m[5], item(b"f"))
        self.assertEquals(m[-1], item(b"f"))
        self.assertEquals(m[-6], item(b"a"))
35 36 37 38 39 40 41 42 43 44 45 46
        # Bounds checking
        self.assertRaises(IndexError, lambda: m[6])
        self.assertRaises(IndexError, lambda: m[-7])
        self.assertRaises(IndexError, lambda: m[sys.maxsize])
        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
        # Type checking
        self.assertRaises(TypeError, lambda: m[None])
        self.assertRaises(TypeError, lambda: m[0.0])
        self.assertRaises(TypeError, lambda: m["a"])
        m = None
        self.assertEquals(sys.getrefcount(b), oldrefcount)

47 48 49
    def test_getitem(self):
        for tp in self._types:
            self.check_getitem_with_type(tp)
50

51 52 53 54 55 56
    def test_iter(self):
        for tp in self._types:
            b = tp(self._source)
            m = self._view(b)
            self.assertEqual(list(m), [m[i] for i in range(len(m))])

57
    def test_setitem_readonly(self):
58 59 60
        if not self.ro_type:
            return
        b = self.ro_type(self._source)
61 62 63 64 65 66 67 68 69 70 71
        oldrefcount = sys.getrefcount(b)
        m = self._view(b)
        def setitem(value):
            m[0] = value
        self.assertRaises(TypeError, setitem, b"a")
        self.assertRaises(TypeError, setitem, 65)
        self.assertRaises(TypeError, setitem, memoryview(b"a"))
        m = None
        self.assertEquals(sys.getrefcount(b), oldrefcount)

    def test_setitem_writable(self):
72 73 74 75
        if not self.rw_type:
            return
        tp = self.rw_type
        b = self.rw_type(self._source)
76 77
        oldrefcount = sys.getrefcount(b)
        m = self._view(b)
78 79 80 81 82 83 84 85
        m[0] = tp(b"0")
        self._check_contents(tp, b, b"0bcdef")
        m[1:3] = tp(b"12")
        self._check_contents(tp, b, b"012def")
        m[1:1] = tp(b"")
        self._check_contents(tp, b, b"012def")
        m[:] = tp(b"abcdef")
        self._check_contents(tp, b, b"abcdef")
86 87 88

        # Overlapping copies of a view into itself
        m[0:3] = m[2:5]
89 90
        self._check_contents(tp, b, b"cdedef")
        m[:] = tp(b"abcdef")
91
        m[2:5] = m[0:3]
92
        self._check_contents(tp, b, b"ababcf")
93 94

        def setitem(key, value):
95
            m[key] = tp(value)
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        # Bounds checking
        self.assertRaises(IndexError, setitem, 6, b"a")
        self.assertRaises(IndexError, setitem, -7, b"a")
        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
        # Wrong index/slice types
        self.assertRaises(TypeError, setitem, 0.0, b"a")
        self.assertRaises(TypeError, setitem, (0,), b"a")
        self.assertRaises(TypeError, setitem, "a", b"a")
        # Trying to resize the memory object
        self.assertRaises(ValueError, setitem, 0, b"")
        self.assertRaises(ValueError, setitem, 0, b"ab")
        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
        self.assertRaises(ValueError, setitem, slice(0,2), b"a")

        m = None
        self.assertEquals(sys.getrefcount(b), oldrefcount)

    def test_tobytes(self):
115 116 117 118 119 120 121
        for tp in self._types:
            m = self._view(tp(self._source))
            b = m.tobytes()
            # This calls self.getitem_type() on each separate byte of b"abcdef"
            expected = b"".join(
                self.getitem_type(bytes([c])) for c in b"abcdef")
            self.assertEquals(b, expected)
122
            self.assertIsInstance(b, bytes)
123 124

    def test_tolist(self):
125 126 127 128
        for tp in self._types:
            m = self._view(tp(self._source))
            l = m.tolist()
            self.assertEquals(l, list(b"abcdef"))
129 130 131 132

    def test_compare(self):
        # memoryviews can compare for equality with other objects
        # having the buffer interface.
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
        for tp in self._types:
            m = self._view(tp(self._source))
            for tp_comp in self._types:
                self.assertTrue(m == tp_comp(b"abcdef"))
                self.assertFalse(m != tp_comp(b"abcdef"))
                self.assertFalse(m == tp_comp(b"abcde"))
                self.assertTrue(m != tp_comp(b"abcde"))
                self.assertFalse(m == tp_comp(b"abcde1"))
                self.assertTrue(m != tp_comp(b"abcde1"))
            self.assertTrue(m == m)
            self.assertTrue(m == m[:])
            self.assertTrue(m[0:6] == m[:])
            self.assertFalse(m[0:5] == m)

            # Comparison with objects which don't support the buffer API
            self.assertFalse(m == "abcdef")
            self.assertTrue(m != "abcdef")
            self.assertFalse("abcdef" == m)
            self.assertTrue("abcdef" != m)

            # Unordered comparisons
            for c in (m, b"abcdef"):
                self.assertRaises(TypeError, lambda: m < c)
                self.assertRaises(TypeError, lambda: c <= m)
                self.assertRaises(TypeError, lambda: m >= c)
                self.assertRaises(TypeError, lambda: c > m)
159 160

    def check_attributes_with_type(self, tp):
161 162 163
        m = self._view(tp(self._source))
        self.assertEquals(m.format, self.format)
        self.assertEquals(m.itemsize, self.itemsize)
164 165
        self.assertEquals(m.ndim, 1)
        self.assertEquals(m.shape, (6,))
166
        self.assertEquals(len(m), 6)
167
        self.assertEquals(m.strides, (self.itemsize,))
168 169 170 171
        self.assertEquals(m.suboffsets, None)
        return m

    def test_attributes_readonly(self):
172 173 174
        if not self.ro_type:
            return
        m = self.check_attributes_with_type(self.ro_type)
175 176 177
        self.assertEquals(m.readonly, True)

    def test_attributes_writable(self):
178 179 180
        if not self.rw_type:
            return
        m = self.check_attributes_with_type(self.rw_type)
181 182
        self.assertEquals(m.readonly, False)

183 184
    def test_getbuffer(self):
        # Test PyObject_GetBuffer() on a memoryview object.
185 186 187 188 189 190 191 192 193 194
        for tp in self._types:
            b = tp(self._source)
            oldrefcount = sys.getrefcount(b)
            m = self._view(b)
            oldviewrefcount = sys.getrefcount(m)
            s = str(m, "utf-8")
            self._check_contents(tp, b, s.encode("utf-8"))
            self.assertEquals(sys.getrefcount(m), oldviewrefcount)
            m = None
            self.assertEquals(sys.getrefcount(b), oldrefcount)
195 196

    def test_gc(self):
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        for tp in self._types:
            if not isinstance(tp, type):
                # If tp is a factory rather than a plain type, skip
                continue

            class MySource(tp):
                pass
            class MyObject:
                pass

            # Create a reference cycle through a memoryview object
            b = MySource(tp(b'abc'))
            m = self._view(b)
            o = MyObject()
            b.m = m
            b.o = o
            wr = weakref.ref(o)
            b = m = o = None
            # The cycle must be broken
            gc.collect()
217
            self.assertTrue(wr() is None, wr())
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236


# Variations on source objects for the buffer: bytes-like objects, then arrays
# with itemsize > 1.
# NOTE: support for multi-dimensional objects is unimplemented.

class BaseBytesMemoryTests(AbstractMemoryTests):
    ro_type = bytes
    rw_type = bytearray
    getitem_type = bytes
    itemsize = 1
    format = 'B'

class BaseArrayMemoryTests(AbstractMemoryTests):
    ro_type = None
    rw_type = lambda self, b: array.array('i', list(b))
    getitem_type = lambda self, b: array.array('i', list(b)).tostring()
    itemsize = array.array('i').itemsize
    format = 'i'
237

238 239 240 241 242 243 244
    def test_getbuffer(self):
        # XXX Test should be adapted for non-byte buffers
        pass

    def test_tolist(self):
        # XXX NotImplementedError: tolist() only supports byte views
        pass
245

246

247 248 249
# Variations on indirection levels: memoryview, slice of memoryview,
# slice of slice of memoryview.
# This is important to test allocation subtleties.
250

251
class BaseMemoryviewTests:
252 253 254
    def _view(self, obj):
        return memoryview(obj)

255 256
    def _check_contents(self, tp, obj, contents):
        self.assertEquals(obj, tp(contents))
257

258 259
class BaseMemorySliceTests:
    source_bytes = b"XabcdefY"
260 261 262 263 264

    def _view(self, obj):
        m = memoryview(obj)
        return m[1:7]

265 266
    def _check_contents(self, tp, obj, contents):
        self.assertEquals(obj[1:7], tp(contents))
267 268

    def test_refs(self):
269 270 271 272 273
        for tp in self._types:
            m = memoryview(tp(self._source))
            oldrefcount = sys.getrefcount(m)
            m[1:2]
            self.assertEquals(sys.getrefcount(m), oldrefcount)
274

275 276
class BaseMemorySliceSliceTests:
    source_bytes = b"XabcdefY"
277 278 279 280 281

    def _view(self, obj):
        m = memoryview(obj)
        return m[:7][1:]

282 283
    def _check_contents(self, tp, obj, contents):
        self.assertEquals(obj[1:7], tp(contents))
284 285


286 287 288 289 290 291 292 293
# Concrete test classes

class BytesMemoryviewTest(unittest.TestCase,
    BaseMemoryviewTests, BaseBytesMemoryTests):

    def test_constructor(self):
        for tp in self._types:
            ob = tp(self._source)
294 295
            self.assertTrue(memoryview(ob))
            self.assertTrue(memoryview(object=ob))
296 297 298 299 300 301 302
            self.assertRaises(TypeError, memoryview)
            self.assertRaises(TypeError, memoryview, ob, ob)
            self.assertRaises(TypeError, memoryview, argument=ob)
            self.assertRaises(TypeError, memoryview, ob, argument=True)

class ArrayMemoryviewTest(unittest.TestCase,
    BaseMemoryviewTests, BaseArrayMemoryTests):
303

304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
    def test_array_assign(self):
        # Issue #4569: segfault when mutating a memoryview with itemsize != 1
        a = array.array('i', range(10))
        m = memoryview(a)
        new_a = array.array('i', range(9, -1, -1))
        m[:] = new_a
        self.assertEquals(a, new_a)


class BytesMemorySliceTest(unittest.TestCase,
    BaseMemorySliceTests, BaseBytesMemoryTests):
    pass

class ArrayMemorySliceTest(unittest.TestCase,
    BaseMemorySliceTests, BaseArrayMemoryTests):
    pass

class BytesMemorySliceSliceTest(unittest.TestCase,
    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
    pass

class ArrayMemorySliceSliceTest(unittest.TestCase,
    BaseMemorySliceSliceTests, BaseArrayMemoryTests):
    pass


def test_main():
    test.support.run_unittest(__name__)
332 333 334

if __name__ == "__main__":
    test_main()