multibytecodec_support.py 14.2 KB
Newer Older
1
#!/usr/bin/env python3
2
#
3
# multibytecodec_support.py
4 5 6
#   Common Unittest Routines for CJK codecs
#

7 8 9 10 11
import codecs
import os
import re
import sys
import unittest
12
from http.client import HTTPException
13
from test import support
14
from io import BytesIO
15 16 17 18

class TestBase:
    encoding        = ''   # codec name
    codec           = None # codec tuple (with 4 elements)
19
    tstring         = None # must set. 2 strings to test StreamReader
20 21 22 23 24

    codectests      = None # must set. codec test tuple
    roundtriptest   = 1    # set if roundtrip is possible with unicode
    has_iso10646    = 0    # set if this encoding contains whole iso10646 map
    xmlcharnametest = None # string to test xmlcharrefreplace
25
    unmappedunicode = '\udeee' # a unicode codepoint that is not mapped.
26 27 28 29

    def setUp(self):
        if self.codec is None:
            self.codec = codecs.lookup(self.encoding)
30 31 32 33 34 35
        self.encode = self.codec.encode
        self.decode = self.codec.decode
        self.reader = self.codec.streamreader
        self.writer = self.codec.streamwriter
        self.incrementalencoder = self.codec.incrementalencoder
        self.incrementaldecoder = self.codec.incrementaldecoder
36 37

    def test_chunkcoding(self):
38 39 40 41 42 43 44 45
        tstring_lines = []
        for b in self.tstring:
            lines = b.split(b"\n")
            last = lines.pop()
            assert last == b""
            lines = [line + b"\n" for line in lines]
            tstring_lines.append(lines)
        for native, utf8 in zip(*tstring_lines):
46 47 48 49 50 51 52
            u = self.decode(native)[0]
            self.assertEqual(u, utf8.decode('utf-8'))
            if self.roundtriptest:
                self.assertEqual(native, self.encode(u)[0])

    def test_errorhandle(self):
        for source, scheme, expected in self.codectests:
53
            if isinstance(source, bytes):
54 55 56 57 58
                func = self.decode
            else:
                func = self.encode
            if expected:
                result = func(source, scheme)[0]
59
                if func is self.decode:
60
                    self.assertTrue(type(result) is str, type(result))
61 62 63 64
                    self.assertEqual(result, expected,
                                     '%a.decode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
65
                else:
66
                    self.assertTrue(type(result) is bytes, type(result))
67 68 69 70
                    self.assertEqual(result, expected,
                                     '%a.encode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
71 72 73
            else:
                self.assertRaises(UnicodeError, func, source, scheme)

74 75 76 77
    def test_xmlcharrefreplace(self):
        if self.has_iso10646:
            return

78
        s = "\u0b13\u0b23\u0b60 nd eggs"
79 80
        self.assertEqual(
            self.encode(s, "xmlcharrefreplace")[0],
81
            b"ଓଣୠ nd eggs"
82 83 84 85 86 87
        )

    def test_customreplace_encode(self):
        if self.has_iso10646:
            return

88
        from html.entities import codepoint2name
89 90 91 92 93 94 95

        def xmlcharnamereplace(exc):
            if not isinstance(exc, UnicodeEncodeError):
                raise TypeError("don't know how to handle %r" % exc)
            l = []
            for c in exc.object[exc.start:exc.end]:
                if ord(c) in codepoint2name:
96
                    l.append("&%s;" % codepoint2name[ord(c)])
97
                else:
98 99
                    l.append("&#%d;" % ord(c))
            return ("".join(l), exc.end)
100 101

        codecs.register_error("test.xmlcharnamereplace", xmlcharnamereplace)
102

103 104 105
        if self.xmlcharnametest:
            sin, sout = self.xmlcharnametest
        else:
106
            sin = "\xab\u211c\xbb = \u2329\u1234\u232a"
107
            sout = b"«ℜ» = ⟨ሴ⟩"
108 109 110
        self.assertEqual(self.encode(sin,
                                    "test.xmlcharnamereplace")[0], sout)

111 112 113 114 115 116 117
    def test_callback_returns_bytes(self):
        def myreplace(exc):
            return (b"1234", exc.end)
        codecs.register_error("test.cjktest", myreplace)
        enc = self.encode("abc" + self.unmappedunicode + "def", "test.cjktest")[0]
        self.assertEqual(enc, b"abc1234def")

118 119 120 121 122
    def test_callback_wrong_objects(self):
        def myreplace(exc):
            return (ret, exc.end)
        codecs.register_error("test.cjktest", myreplace)

123
        for ret in ([1, 2, 3], [], None, object()):
124 125 126 127 128
            self.assertRaises(TypeError, self.encode, self.unmappedunicode,
                              'test.cjktest')

    def test_callback_long_index(self):
        def myreplace(exc):
129
            return ('x', int(exc.end))
130
        codecs.register_error("test.cjktest", myreplace)
131
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
132
                                     'test.cjktest'), (b'abcdxefgh', 9))
133 134

        def myreplace(exc):
135
            return ('x', sys.maxsize + 1)
136 137 138 139 140 141
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_callback_None_index(self):
        def myreplace(exc):
142
            return ('x', None)
143 144 145 146 147 148 149 150
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(TypeError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_callback_backward_index(self):
        def myreplace(exc):
            if myreplace.limit > 0:
                myreplace.limit -= 1
151
                return ('REPLACED', 0)
152
            else:
153
                return ('TERMINAL', exc.end)
154 155
        myreplace.limit = 3
        codecs.register_error("test.cjktest", myreplace)
156
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
157
                                     'test.cjktest'),
158
                (b'abcdREPLACEDabcdREPLACEDabcdREPLACEDabcdTERMINALefgh', 9))
159 160 161

    def test_callback_forward_index(self):
        def myreplace(exc):
162
            return ('REPLACED', exc.end + 2)
163
        codecs.register_error("test.cjktest", myreplace)
164
        self.assertEqual(self.encode('abcd' + self.unmappedunicode + 'efgh',
165
                                     'test.cjktest'), (b'abcdREPLACEDgh', 9))
166 167 168

    def test_callback_index_outofbound(self):
        def myreplace(exc):
169
            return ('TERM', 100)
170 171 172 173 174 175
        codecs.register_error("test.cjktest", myreplace)
        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
                          'test.cjktest')

    def test_incrementalencoder(self):
        UTF8Reader = codecs.getreader('utf-8')
176
        for sizehint in [None] + list(range(1, 33)) + \
177
                        [64, 128, 256, 512, 1024]:
178 179
            istream = UTF8Reader(BytesIO(self.tstring[1]))
            ostream = BytesIO()
180 181 182 183 184 185
            encoder = self.incrementalencoder()
            while 1:
                if sizehint is not None:
                    data = istream.read(sizehint)
                else:
                    data = istream.read()
186

187 188 189 190
                if not data:
                    break
                e = encoder.encode(data)
                ostream.write(e)
191

192
            self.assertEqual(ostream.getvalue(), self.tstring[0])
193

194 195
    def test_incrementaldecoder(self):
        UTF8Writer = codecs.getwriter('utf-8')
196
        for sizehint in [None, -1] + list(range(1, 33)) + \
197
                        [64, 128, 256, 512, 1024]:
198 199
            istream = BytesIO(self.tstring[0])
            ostream = UTF8Writer(BytesIO())
200 201 202 203 204
            decoder = self.incrementaldecoder()
            while 1:
                data = istream.read(sizehint)
                if not data:
                    break
205
                else:
206 207 208 209 210 211 212 213 214 215 216 217
                    u = decoder.decode(data)
                    ostream.write(u)

            self.assertEqual(ostream.getvalue(), self.tstring[1])

    def test_incrementalencoder_error_callback(self):
        inv = self.unmappedunicode

        e = self.incrementalencoder()
        self.assertRaises(UnicodeEncodeError, e.encode, inv, True)

        e.errors = 'ignore'
218
        self.assertEqual(e.encode(inv, True), b'')
219 220 221

        e.reset()
        def tempreplace(exc):
222
            return ('called', exc.end)
223 224
        codecs.register_error('test.incremental_error_callback', tempreplace)
        e.errors = 'test.incremental_error_callback'
225
        self.assertEqual(e.encode(inv, True), b'called')
226 227 228

        # again
        e.errors = 'ignore'
229
        self.assertEqual(e.encode(inv, True), b'')
230 231 232 233

    def test_streamreader(self):
        UTF8Writer = codecs.getwriter('utf-8')
        for name in ["read", "readline", "readlines"]:
234
            for sizehint in [None, -1] + list(range(1, 33)) + \
235
                            [64, 128, 256, 512, 1024]:
236 237
                istream = self.reader(BytesIO(self.tstring[0]))
                ostream = UTF8Writer(BytesIO())
238 239 240 241 242 243 244 245 246 247 248 249 250
                func = getattr(istream, name)
                while 1:
                    data = func(sizehint)
                    if not data:
                        break
                    if name == "readlines":
                        ostream.writelines(data)
                    else:
                        ostream.write(data)

                self.assertEqual(ostream.getvalue(), self.tstring[1])

    def test_streamwriter(self):
251
        readfuncs = ('read', 'readline', 'readlines')
252 253
        UTF8Reader = codecs.getreader('utf-8')
        for name in readfuncs:
254
            for sizehint in [None] + list(range(1, 33)) + \
255
                            [64, 128, 256, 512, 1024]:
256 257
                istream = UTF8Reader(BytesIO(self.tstring[1]))
                ostream = self.writer(BytesIO())
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
                func = getattr(istream, name)
                while 1:
                    if sizehint is not None:
                        data = func(sizehint)
                    else:
                        data = func()

                    if not data:
                        break
                    if name == "readlines":
                        ostream.writelines(data)
                    else:
                        ostream.write(data)

                self.assertEqual(ostream.getvalue(), self.tstring[0])


class TestBase_Mapping(unittest.TestCase):
    pass_enctest = []
    pass_dectest = []
    supmaps = []
279
    codectests = []
280 281 282

    def __init__(self, *args, **kw):
        unittest.TestCase.__init__(self, *args, **kw)
283
        try:
284
            self.open_mapping_file().close() # test it to report the error early
285
        except (OSError, HTTPException):
286
            self.skipTest("Could not retrieve "+self.mapfileurl)
287 288

    def open_mapping_file(self):
289
        return support.open_urlresource(self.mapfileurl)
290 291

    def test_mapping_file(self):
292 293 294 295 296 297
        if self.mapfileurl.endswith('.xml'):
            self._test_mapping_file_ucm()
        else:
            self._test_mapping_file_plain()

    def _test_mapping_file_plain(self):
298
        unichrs = lambda s: ''.join(map(chr, map(eval, s.split('+'))))
299 300
        urt_wa = {}

301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
        with self.open_mapping_file() as f:
            for line in f:
                if not line:
                    break
                data = line.split('#')[0].strip().split()
                if len(data) != 2:
                    continue

                csetval = eval(data[0])
                if csetval <= 0x7F:
                    csetch = bytes([csetval & 0xff])
                elif csetval >= 0x1000000:
                    csetch = bytes([(csetval >> 24), ((csetval >> 16) & 0xff),
                                    ((csetval >> 8) & 0xff), (csetval & 0xff)])
                elif csetval >= 0x10000:
                    csetch = bytes([(csetval >> 16), ((csetval >> 8) & 0xff),
                                    (csetval & 0xff)])
                elif csetval >= 0x100:
                    csetch = bytes([(csetval >> 8), (csetval & 0xff)])
                else:
                    continue
322

323 324 325 326
                unich = unichrs(data[1])
                if ord(unich) == 0xfffd or unich in urt_wa:
                    continue
                urt_wa[unich] = csetch
327

328
                self._testpoint(csetch, unich)
329

330
    def _test_mapping_file_ucm(self):
331 332
        with self.open_mapping_file() as f:
            ucmdata = f.read()
333 334 335 336 337 338
        uc = re.findall('<a u="([A-F0-9]{4})" b="([0-9A-F ]+)"/>', ucmdata)
        for uni, coded in uc:
            unich = chr(int(uni, 16))
            codech = bytes(int(c, 16) for c in coded.split())
            self._testpoint(codech, unich)

339 340 341 342 343 344 345 346
    def test_mapping_supplemental(self):
        for mapping in self.supmaps:
            self._testpoint(*mapping)

    def _testpoint(self, csetch, unich):
        if (csetch, unich) not in self.pass_enctest:
            self.assertEqual(unich.encode(self.encoding), csetch)
        if (csetch, unich) not in self.pass_dectest:
347
            self.assertEqual(str(csetch, self.encoding), unich)
348

349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    def test_errorhandle(self):
        for source, scheme, expected in self.codectests:
            if isinstance(source, bytes):
                func = source.decode
            else:
                func = source.encode
            if expected:
                if isinstance(source, bytes):
                    result = func(self.encoding, scheme)
                    self.assertTrue(type(result) is str, type(result))
                    self.assertEqual(result, expected,
                                     '%a.decode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
                else:
                    result = func(self.encoding, scheme)
                    self.assertTrue(type(result) is bytes, type(result))
                    self.assertEqual(result, expected,
                                     '%a.encode(%r, %r)=%a != %a'
                                     % (source, self.encoding, scheme, result,
                                        expected))
            else:
                self.assertRaises(UnicodeError, func, self.encoding, scheme)

373 374 375 376 377 378 379
def load_teststring(name):
    dir = os.path.join(os.path.dirname(__file__), 'cjkencodings')
    with open(os.path.join(dir, name + '.txt'), 'rb') as f:
        encoded = f.read()
    with open(os.path.join(dir, name + '-utf8.txt'), 'rb') as f:
        utf8 = f.read()
    return encoded, utf8