test_multibytecodec_support.py 14.4 KB
Newer Older
1
#!/usr/bin/env python3
2 3 4 5 6
#
# test_multibytecodec_support.py
#   Common Unittest Routines for CJK codecs
#

7 8 9 10 11
import codecs
import os
import re
import sys
import unittest
12
from http.client import HTTPException
13
from test import support
14
from io import BytesIO
15 16 17 18

class TestBase:
    encoding        = ''   # codec name
    codec           = None # codec tuple (with 4 elements)
19
    tstring         = None # must set. 2 strings to test StreamReader
20 21 22 23 24

    codectests      = None # must set. codec test tuple
    roundtriptest   = 1    # set if roundtrip is possible with unicode
    has_iso10646    = 0    # set if this encoding contains whole iso10646 map
    xmlcharnametest = None # string to test xmlcharrefreplace
25
    unmappedunicode = '\udeee' # a unicode codepoint that is not mapped.
26 27 28 29

    def setUp(self):
        if self.codec is None:
            self.codec = codecs.lookup(self.encoding)
30 31 32 33 34 35
        self.encode = self.codec.encode
        self.decode = self.codec.decode
        self.reader = self.codec.streamreader
        self.writer = self.codec.streamwriter
        self.incrementalencoder = self.codec.incrementalencoder
        self.incrementaldecoder = self.codec.incrementaldecoder
36 37

    def test_chunkcoding(self):
38 39 40 41 42 43 44 45
        tstring_lines = []
        for b in self.tstring:
            lines = b.split(b"\n")
            last = lines.pop()
            assert last == b""
            lines = [line + b"\n" for line in lines]
            tstring_lines.append(lines)
        for native, utf8 in zip(*tstring_lines):
46 47 48 49 50 51 52
            u = self.decode(native)[0]
            self.assertEqual(u, utf8.decode('utf-8'))
            if self.roundtriptest:
                self.assertEqual(native, self.encode(u)[0])

    def test_errorhandle(self):
        for source, scheme, expected in self.codectests:
53
            if isinstance(source, bytes):
54 55 56 57 58
                func = self.decode
            else:
                func = self.encode
            if expected:
                result = func(source, scheme)[0]
59
                if func is self.decode:
60
                    self.assertTrue(type(result) is str, type(result))
61 62 63 64
                    self.assertEqual(result, expected,
                                     '%a.decode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
65
                else:
66
                    self.assertTrue(type(result) is bytes, type(result))
67 68 69 70
                    self.assertEqual(result, expected,
                                     '%a.encode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
71 72 73
            else:
                self.assertRaises(UnicodeError, func, source, scheme)

74 75 76 77
    def test_xmlcharrefreplace(self):
        if self.has_iso10646:
            return

78
        s = "\u0b13\u0b23\u0b60 nd eggs"
79 80
        self.assertEqual(
            self.encode(s, "xmlcharrefreplace")[0],
81
            b"ଓଣୠ nd eggs"
82 83 84 85 86 87
        )

    def test_customreplace_encode(self):
        if self.has_iso10646:
            return

88
        from html.entities import codepoint2name
89 90 91 92 93 94 95

        def xmlcharnamereplace(exc):
            if not isinstance(exc, UnicodeEncodeError):
                raise TypeError("don't know how to handle %r" % exc)
            l = []
            for c in exc.object[exc.start:exc.end]:
                if ord(c) in codepoint2name:
96
                    l.append("&%s;" % codepoint2name[ord(c)])
97
                else:
98 99
                    l.append("&#%d;" % ord(c))
            return ("".join(l), exc.end)
100 101

        codecs.register_error("test.xmlcharnamereplace", xmlcharnamereplace)
102

103 104 105
        if self.xmlcharnametest:
            sin, sout = self.xmlcharnametest
        else:
106
            sin = "\xab\u211c\xbb = \u2329\u1234\u232a"
107
            sout = b"«ℜ» = ⟨ሴ⟩"
108 109 110 111 112 113 114 115
        self.assertEqual(self.encode(sin,
                                    "test.xmlcharnamereplace")[0], sout)

    def test_callback_wrong_objects(self):
        def myreplace(exc):
            return (ret, exc.end)
        codecs.register_error("test.cjktest", myreplace)

116
        for ret in ([1, 2, 3], [], None, object(), b'string', b''):
117 118 119 120 121
            self.assertRaises(TypeError, self.encode, self.unmappedunicode,
                              'test.cjktest')

    def test_callback_long_index(self):
        def myreplace(exc):
122
            return ('x', int(exc.end))
123
        codecs.register_error("test.cjktest", myreplace)
124
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
125
                                     'test.cjktest'), (b'abcdxefgh', 9))
126 127

        def myreplace(exc):
128
            return ('x', sys.maxsize + 1)
129 130 131 132 133 134
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_callback_None_index(self):
        def myreplace(exc):
135
            return ('x', None)
136 137 138 139 140 141 142 143
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(TypeError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_callback_backward_index(self):
        def myreplace(exc):
            if myreplace.limit > 0:
                myreplace.limit -= 1
144
                return ('REPLACED', 0)
145
            else:
146
                return ('TERMINAL', exc.end)
147 148
        myreplace.limit = 3
        codecs.register_error("test.cjktest", myreplace)
149
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
150
                                     'test.cjktest'),
151
                (b'abcdREPLACEDabcdREPLACEDabcdREPLACEDabcdTERMINALefgh', 9))
152 153 154

    def test_callback_forward_index(self):
        def myreplace(exc):
155
            return ('REPLACED', exc.end + 2)
156
        codecs.register_error("test.cjktest", myreplace)
157
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
158
                                     'test.cjktest'), (b'abcdREPLACEDgh', 9))
159 160 161

    def test_callback_index_outofbound(self):
        def myreplace(exc):
162
            return ('TERM', 100)
163 164 165 166 167 168
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_incrementalencoder(self):
        UTF8Reader = codecs.getreader('utf-8')
169
        for sizehint in [None] + list(range(1, 33)) + \
170
                        [64, 128, 256, 512, 1024]:
171 172
            istream = UTF8Reader(BytesIO(self.tstring[1]))
            ostream = BytesIO()
173 174 175 176 177 178
            encoder = self.incrementalencoder()
            while 1:
                if sizehint is not None:
                    data = istream.read(sizehint)
                else:
                    data = istream.read()
179

180 181 182 183
                if not data:
                    break
                e = encoder.encode(data)
                ostream.write(e)
184

185
            self.assertEqual(ostream.getvalue(), self.tstring[0])
186

187 188
    def test_incrementaldecoder(self):
        UTF8Writer = codecs.getwriter('utf-8')
189
        for sizehint in [None, -1] + list(range(1, 33)) + \
190
                        [64, 128, 256, 512, 1024]:
191 192
            istream = BytesIO(self.tstring[0])
            ostream = UTF8Writer(BytesIO())
193 194 195 196 197
            decoder = self.incrementaldecoder()
            while 1:
                data = istream.read(sizehint)
                if not data:
                    break
198
                else:
199 200 201 202 203 204 205 206 207 208 209 210
                    u = decoder.decode(data)
                    ostream.write(u)

            self.assertEqual(ostream.getvalue(), self.tstring[1])

    def test_incrementalencoder_error_callback(self):
        inv = self.unmappedunicode

        e = self.incrementalencoder()
        self.assertRaises(UnicodeEncodeError, e.encode, inv, True)

        e.errors = 'ignore'
211
        self.assertEqual(e.encode(inv, True), b'')
212 213 214

        e.reset()
        def tempreplace(exc):
215
            return ('called', exc.end)
216 217
        codecs.register_error('test.incremental_error_callback', tempreplace)
        e.errors = 'test.incremental_error_callback'
218
        self.assertEqual(e.encode(inv, True), b'called')
219 220 221

        # again
        e.errors = 'ignore'
222
        self.assertEqual(e.encode(inv, True), b'')
223 224 225 226

    def test_streamreader(self):
        UTF8Writer = codecs.getwriter('utf-8')
        for name in ["read", "readline", "readlines"]:
227
            for sizehint in [None, -1] + list(range(1, 33)) + \
228
                            [64, 128, 256, 512, 1024]:
229 230
                istream = self.reader(BytesIO(self.tstring[0]))
                ostream = UTF8Writer(BytesIO())
231 232 233 234 235 236 237 238 239 240 241 242 243
                func = getattr(istream, name)
                while 1:
                    data = func(sizehint)
                    if not data:
                        break
                    if name == "readlines":
                        ostream.writelines(data)
                    else:
                        ostream.write(data)

                self.assertEqual(ostream.getvalue(), self.tstring[1])

    def test_streamwriter(self):
244
        readfuncs = ('read', 'readline', 'readlines')
245 246
        UTF8Reader = codecs.getreader('utf-8')
        for name in readfuncs:
247
            for sizehint in [None] + list(range(1, 33)) + \
248
                            [64, 128, 256, 512, 1024]:
249 250
                istream = UTF8Reader(BytesIO(self.tstring[1]))
                ostream = self.writer(BytesIO())
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
                func = getattr(istream, name)
                while 1:
                    if sizehint is not None:
                        data = func(sizehint)
                    else:
                        data = func()

                    if not data:
                        break
                    if name == "readlines":
                        ostream.writelines(data)
                    else:
                        ostream.write(data)

                self.assertEqual(ostream.getvalue(), self.tstring[0])

267
if len('\U00012345') == 2: # ucs2 build
268 269
    _unichr = chr
    def chr(v):
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        if v >= 0x10000:
            return _unichr(0xd800 + ((v - 0x10000) >> 10)) + \
                   _unichr(0xdc00 + ((v - 0x10000) & 0x3ff))
        else:
            return _unichr(v)
    _ord = ord
    def ord(c):
        if len(c) == 2:
            return 0x10000 + ((_ord(c[0]) - 0xd800) << 10) + \
                          (ord(c[1]) - 0xdc00)
        else:
            return _ord(c)

class TestBase_Mapping(unittest.TestCase):
    pass_enctest = []
    pass_dectest = []
    supmaps = []
287
    codectests = []
288 289 290

    def __init__(self, *args, **kw):
        unittest.TestCase.__init__(self, *args, **kw)
291
        try:
292
            self.open_mapping_file().close() # test it to report the error early
293
        except (IOError, HTTPException):
294
            self.skipTest("Could not retrieve "+self.mapfileurl)
295 296

    def open_mapping_file(self):
297
        return support.open_urlresource(self.mapfileurl)
298 299

    def test_mapping_file(self):
300 301 302 303 304 305
        if self.mapfileurl.endswith('.xml'):
            self._test_mapping_file_ucm()
        else:
            self._test_mapping_file_plain()

    def _test_mapping_file_plain(self):
306
        unichrs = lambda s: ''.join(map(chr, map(eval, s.split('+'))))
307 308
        urt_wa = {}

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
        with self.open_mapping_file() as f:
            for line in f:
                if not line:
                    break
                data = line.split('#')[0].strip().split()
                if len(data) != 2:
                    continue

                csetval = eval(data[0])
                if csetval <= 0x7F:
                    csetch = bytes([csetval & 0xff])
                elif csetval >= 0x1000000:
                    csetch = bytes([(csetval >> 24), ((csetval >> 16) & 0xff),
                                    ((csetval >> 8) & 0xff), (csetval & 0xff)])
                elif csetval >= 0x10000:
                    csetch = bytes([(csetval >> 16), ((csetval >> 8) & 0xff),
                                    (csetval & 0xff)])
                elif csetval >= 0x100:
                    csetch = bytes([(csetval >> 8), (csetval & 0xff)])
                else:
                    continue
330

331 332 333 334
                unich = unichrs(data[1])
                if ord(unich) == 0xfffd or unich in urt_wa:
                    continue
                urt_wa[unich] = csetch
335

336
                self._testpoint(csetch, unich)
337

338
    def _test_mapping_file_ucm(self):
339 340
        with self.open_mapping_file() as f:
            ucmdata = f.read()
341 342 343 344 345 346
        uc = re.findall('<a u="([A-F0-9]{4})" b="([0-9A-F ]+)"/>', ucmdata)
        for uni, coded in uc:
            unich = chr(int(uni, 16))
            codech = bytes(int(c, 16) for c in coded.split())
            self._testpoint(codech, unich)

347 348 349 350 351 352 353 354
    def test_mapping_supplemental(self):
        for mapping in self.supmaps:
            self._testpoint(*mapping)

    def _testpoint(self, csetch, unich):
        if (csetch, unich) not in self.pass_enctest:
            self.assertEqual(unich.encode(self.encoding), csetch)
        if (csetch, unich) not in self.pass_dectest:
355
            self.assertEqual(str(csetch, self.encoding), unich)
356

357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
    def test_errorhandle(self):
        for source, scheme, expected in self.codectests:
            if isinstance(source, bytes):
                func = source.decode
            else:
                func = source.encode
            if expected:
                if isinstance(source, bytes):
                    result = func(self.encoding, scheme)
                    self.assertTrue(type(result) is str, type(result))
                    self.assertEqual(result, expected,
                                     '%a.decode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
                else:
                    result = func(self.encoding, scheme)
                    self.assertTrue(type(result) is bytes, type(result))
                    self.assertEqual(result, expected,
                                     '%a.encode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
            else:
                self.assertRaises(UnicodeError, func, self.encoding, scheme)

381 382 383 384 385 386 387
def load_teststring(name):
    dir = os.path.join(os.path.dirname(__file__), 'cjkencodings')
    with open(os.path.join(dir, name + '.txt'), 'rb') as f:
        encoded = f.read()
    with open(os.path.join(dir, name + '-utf8.txt'), 'rb') as f:
        utf8 = f.read()
    return encoded, utf8