multibytecodec_support.py 14.4 KB
Newer Older
1
#
2
# multibytecodec_support.py
3 4 5
#   Common Unittest Routines for CJK codecs
#

6 7 8 9 10
import codecs
import os
import re
import sys
import unittest
11
from http.client import HTTPException
12
from test import support
13
from io import BytesIO
14 15 16 17

class TestBase:
    encoding        = ''   # codec name
    codec           = None # codec tuple (with 4 elements)
18
    tstring         = None # must set. 2 strings to test StreamReader
19 20 21 22 23

    codectests      = None # must set. codec test tuple
    roundtriptest   = 1    # set if roundtrip is possible with unicode
    has_iso10646    = 0    # set if this encoding contains whole iso10646 map
    xmlcharnametest = None # string to test xmlcharrefreplace
24
    unmappedunicode = '\udeee' # a unicode code point that is not mapped.
25 26 27 28

    def setUp(self):
        if self.codec is None:
            self.codec = codecs.lookup(self.encoding)
29 30 31 32 33 34
        self.encode = self.codec.encode
        self.decode = self.codec.decode
        self.reader = self.codec.streamreader
        self.writer = self.codec.streamwriter
        self.incrementalencoder = self.codec.incrementalencoder
        self.incrementaldecoder = self.codec.incrementaldecoder
35 36

    def test_chunkcoding(self):
37 38 39 40 41 42 43 44
        tstring_lines = []
        for b in self.tstring:
            lines = b.split(b"\n")
            last = lines.pop()
            assert last == b""
            lines = [line + b"\n" for line in lines]
            tstring_lines.append(lines)
        for native, utf8 in zip(*tstring_lines):
45 46 47 48 49 50 51
            u = self.decode(native)[0]
            self.assertEqual(u, utf8.decode('utf-8'))
            if self.roundtriptest:
                self.assertEqual(native, self.encode(u)[0])

    def test_errorhandle(self):
        for source, scheme, expected in self.codectests:
52
            if isinstance(source, bytes):
53 54 55 56 57
                func = self.decode
            else:
                func = self.encode
            if expected:
                result = func(source, scheme)[0]
58
                if func is self.decode:
59
                    self.assertTrue(type(result) is str, type(result))
60 61 62 63
                    self.assertEqual(result, expected,
                                     '%a.decode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
64
                else:
65
                    self.assertTrue(type(result) is bytes, type(result))
66 67 68 69
                    self.assertEqual(result, expected,
                                     '%a.encode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
70 71 72
            else:
                self.assertRaises(UnicodeError, func, source, scheme)

73 74
    def test_xmlcharrefreplace(self):
        if self.has_iso10646:
75
            self.skipTest('encoding contains full ISO 10646 map')
76

77
        s = "\u0b13\u0b23\u0b60 nd eggs"
78 79
        self.assertEqual(
            self.encode(s, "xmlcharrefreplace")[0],
80
            b"ଓଣୠ nd eggs"
81 82 83 84
        )

    def test_customreplace_encode(self):
        if self.has_iso10646:
85
            self.skipTest('encoding contains full ISO 10646 map')
86

87
        from html.entities import codepoint2name
88 89 90 91 92 93 94

        def xmlcharnamereplace(exc):
            if not isinstance(exc, UnicodeEncodeError):
                raise TypeError("don't know how to handle %r" % exc)
            l = []
            for c in exc.object[exc.start:exc.end]:
                if ord(c) in codepoint2name:
95
                    l.append("&%s;" % codepoint2name[ord(c)])
96
                else:
97 98
                    l.append("&#%d;" % ord(c))
            return ("".join(l), exc.end)
99 100

        codecs.register_error("test.xmlcharnamereplace", xmlcharnamereplace)
101

102 103 104
        if self.xmlcharnametest:
            sin, sout = self.xmlcharnametest
        else:
105
            sin = "\xab\u211c\xbb = \u2329\u1234\u232a"
106
            sout = b"«ℜ» = ⟨ሴ⟩"
107 108 109
        self.assertEqual(self.encode(sin,
                                    "test.xmlcharnamereplace")[0], sout)

110 111 112 113 114 115 116
    def test_callback_returns_bytes(self):
        def myreplace(exc):
            return (b"1234", exc.end)
        codecs.register_error("test.cjktest", myreplace)
        enc = self.encode("abc" + self.unmappedunicode + "def", "test.cjktest")[0]
        self.assertEqual(enc, b"abc1234def")

117 118 119 120 121
    def test_callback_wrong_objects(self):
        def myreplace(exc):
            return (ret, exc.end)
        codecs.register_error("test.cjktest", myreplace)

122
        for ret in ([1, 2, 3], [], None, object()):
123 124 125 126 127
            self.assertRaises(TypeError, self.encode, self.unmappedunicode,
                              'test.cjktest')

    def test_callback_long_index(self):
        def myreplace(exc):
128
            return ('x', int(exc.end))
129
        codecs.register_error("test.cjktest", myreplace)
130
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
131
                                     'test.cjktest'), (b'abcdxefgh', 9))
132 133

        def myreplace(exc):
134
            return ('x', sys.maxsize + 1)
135 136 137 138 139 140
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_callback_None_index(self):
        def myreplace(exc):
141
            return ('x', None)
142 143 144 145 146 147 148 149
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(TypeError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_callback_backward_index(self):
        def myreplace(exc):
            if myreplace.limit > 0:
                myreplace.limit -= 1
150
                return ('REPLACED', 0)
151
            else:
152
                return ('TERMINAL', exc.end)
153 154
        myreplace.limit = 3
        codecs.register_error("test.cjktest", myreplace)
155
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
156
                                     'test.cjktest'),
157
                (b'abcdREPLACEDabcdREPLACEDabcdREPLACEDabcdTERMINALefgh', 9))
158 159 160

    def test_callback_forward_index(self):
        def myreplace(exc):
161
            return ('REPLACED', exc.end + 2)
162
        codecs.register_error("test.cjktest", myreplace)
163
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
164
                                     'test.cjktest'), (b'abcdREPLACEDgh', 9))
165 166 167

    def test_callback_index_outofbound(self):
        def myreplace(exc):
168
            return ('TERM', 100)
169 170 171 172 173 174
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_incrementalencoder(self):
        UTF8Reader = codecs.getreader('utf-8')
175
        for sizehint in [None] + list(range(1, 33)) + \
176
                        [64, 128, 256, 512, 1024]:
177 178
            istream = UTF8Reader(BytesIO(self.tstring[1]))
            ostream = BytesIO()
179 180 181 182 183 184
            encoder = self.incrementalencoder()
            while 1:
                if sizehint is not None:
                    data = istream.read(sizehint)
                else:
                    data = istream.read()
185

186 187 188 189
                if not data:
                    break
                e = encoder.encode(data)
                ostream.write(e)
190

191
            self.assertEqual(ostream.getvalue(), self.tstring[0])
192

193 194
    def test_incrementaldecoder(self):
        UTF8Writer = codecs.getwriter('utf-8')
195
        for sizehint in [None, -1] + list(range(1, 33)) + \
196
                        [64, 128, 256, 512, 1024]:
197 198
            istream = BytesIO(self.tstring[0])
            ostream = UTF8Writer(BytesIO())
199 200 201 202 203
            decoder = self.incrementaldecoder()
            while 1:
                data = istream.read(sizehint)
                if not data:
                    break
204
                else:
205 206 207 208 209 210 211 212 213 214 215 216
                    u = decoder.decode(data)
                    ostream.write(u)

            self.assertEqual(ostream.getvalue(), self.tstring[1])

    def test_incrementalencoder_error_callback(self):
        inv = self.unmappedunicode

        e = self.incrementalencoder()
        self.assertRaises(UnicodeEncodeError, e.encode, inv, True)

        e.errors = 'ignore'
217
        self.assertEqual(e.encode(inv, True), b'')
218 219 220

        e.reset()
        def tempreplace(exc):
221
            return ('called', exc.end)
222 223
        codecs.register_error('test.incremental_error_callback', tempreplace)
        e.errors = 'test.incremental_error_callback'
224
        self.assertEqual(e.encode(inv, True), b'called')
225 226 227

        # again
        e.errors = 'ignore'
228
        self.assertEqual(e.encode(inv, True), b'')
229 230 231 232

    def test_streamreader(self):
        UTF8Writer = codecs.getwriter('utf-8')
        for name in ["read", "readline", "readlines"]:
233
            for sizehint in [None, -1] + list(range(1, 33)) + \
234
                            [64, 128, 256, 512, 1024]:
235 236
                istream = self.reader(BytesIO(self.tstring[0]))
                ostream = UTF8Writer(BytesIO())
237 238 239 240 241 242 243 244 245 246 247 248 249
                func = getattr(istream, name)
                while 1:
                    data = func(sizehint)
                    if not data:
                        break
                    if name == "readlines":
                        ostream.writelines(data)
                    else:
                        ostream.write(data)

                self.assertEqual(ostream.getvalue(), self.tstring[1])

    def test_streamwriter(self):
250
        readfuncs = ('read', 'readline', 'readlines')
251 252
        UTF8Reader = codecs.getreader('utf-8')
        for name in readfuncs:
253
            for sizehint in [None] + list(range(1, 33)) + \
254
                            [64, 128, 256, 512, 1024]:
255 256
                istream = UTF8Reader(BytesIO(self.tstring[1]))
                ostream = self.writer(BytesIO())
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
                func = getattr(istream, name)
                while 1:
                    if sizehint is not None:
                        data = func(sizehint)
                    else:
                        data = func()

                    if not data:
                        break
                    if name == "readlines":
                        ostream.writelines(data)
                    else:
                        ostream.write(data)

                self.assertEqual(ostream.getvalue(), self.tstring[0])

273 274 275 276 277 278 279
    def test_streamwriter_reset_no_pending(self):
        # Issue #23247: Calling reset() on a fresh StreamWriter instance
        # (without pending data) must not crash
        stream = BytesIO()
        writer = self.writer(stream)
        writer.reset()

280 281 282 283 284

class TestBase_Mapping(unittest.TestCase):
    pass_enctest = []
    pass_dectest = []
    supmaps = []
285
    codectests = []
286

287
    def setUp(self):
288
        try:
289
            self.open_mapping_file().close() # test it to report the error early
290
        except (OSError, HTTPException):
291
            self.skipTest("Could not retrieve "+self.mapfileurl)
292 293

    def open_mapping_file(self):
294
        return support.open_urlresource(self.mapfileurl)
295 296

    def test_mapping_file(self):
297 298 299 300 301 302
        if self.mapfileurl.endswith('.xml'):
            self._test_mapping_file_ucm()
        else:
            self._test_mapping_file_plain()

    def _test_mapping_file_plain(self):
303
        unichrs = lambda s: ''.join(map(chr, map(eval, s.split('+'))))
304 305
        urt_wa = {}

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        with self.open_mapping_file() as f:
            for line in f:
                if not line:
                    break
                data = line.split('#')[0].strip().split()
                if len(data) != 2:
                    continue

                csetval = eval(data[0])
                if csetval <= 0x7F:
                    csetch = bytes([csetval & 0xff])
                elif csetval >= 0x1000000:
                    csetch = bytes([(csetval >> 24), ((csetval >> 16) & 0xff),
                                    ((csetval >> 8) & 0xff), (csetval & 0xff)])
                elif csetval >= 0x10000:
                    csetch = bytes([(csetval >> 16), ((csetval >> 8) & 0xff),
                                    (csetval & 0xff)])
                elif csetval >= 0x100:
                    csetch = bytes([(csetval >> 8), (csetval & 0xff)])
                else:
                    continue
327

328 329 330 331
                unich = unichrs(data[1])
                if ord(unich) == 0xfffd or unich in urt_wa:
                    continue
                urt_wa[unich] = csetch
332

333
                self._testpoint(csetch, unich)
334

335
    def _test_mapping_file_ucm(self):
336 337
        with self.open_mapping_file() as f:
            ucmdata = f.read()
338 339 340 341 342 343
        uc = re.findall('<a u="([A-F0-9]{4})" b="([0-9A-F ]+)"/>', ucmdata)
        for uni, coded in uc:
            unich = chr(int(uni, 16))
            codech = bytes(int(c, 16) for c in coded.split())
            self._testpoint(codech, unich)

344 345 346 347 348 349 350 351
    def test_mapping_supplemental(self):
        for mapping in self.supmaps:
            self._testpoint(*mapping)

    def _testpoint(self, csetch, unich):
        if (csetch, unich) not in self.pass_enctest:
            self.assertEqual(unich.encode(self.encoding), csetch)
        if (csetch, unich) not in self.pass_dectest:
352
            self.assertEqual(str(csetch, self.encoding), unich)
353

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    def test_errorhandle(self):
        for source, scheme, expected in self.codectests:
            if isinstance(source, bytes):
                func = source.decode
            else:
                func = source.encode
            if expected:
                if isinstance(source, bytes):
                    result = func(self.encoding, scheme)
                    self.assertTrue(type(result) is str, type(result))
                    self.assertEqual(result, expected,
                                     '%a.decode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
                else:
                    result = func(self.encoding, scheme)
                    self.assertTrue(type(result) is bytes, type(result))
                    self.assertEqual(result, expected,
                                     '%a.encode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
            else:
                self.assertRaises(UnicodeError, func, self.encoding, scheme)

378 379 380 381 382 383 384
def load_teststring(name):
    dir = os.path.join(os.path.dirname(__file__), 'cjkencodings')
    with open(os.path.join(dir, name + '.txt'), 'rb') as f:
        encoded = f.read()
    with open(os.path.join(dir, name + '-utf8.txt'), 'rb') as f:
        utf8 = f.read()
    return encoded, utf8