test_sax.py 33.5 KB
Newer Older
1
# regression test for SAX 2.0
2 3
# $Id$

4 5
from xml.sax import make_parser, ContentHandler, \
                    SAXException, SAXReaderNotAvailable, SAXParseException
6
import unittest
7 8
try:
    make_parser()
9
except SAXReaderNotAvailable:
10
    # don't try to test this module if we cannot create a parser
11
    raise unittest.SkipTest("no XML parsers available")
12 13 14
from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
                             XMLFilterBase
from xml.sax.expatreader import create_parser
15
from xml.sax.handler import feature_namespaces
16
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
17
from io import BytesIO, StringIO
18
import codecs
19
import os.path
20 21
import shutil
from test import support
22
from test.support import findfile, run_unittest
23 24 25

TEST_XMLFILE = findfile("test.xml", subdir="xmltestdata")
TEST_XMLFILE_OUT = findfile("test.xml.out", subdir="xmltestdata")
26
try:
27 28
    TEST_XMLFILE.encode("utf-8")
    TEST_XMLFILE_OUT.encode("utf-8")
29 30
except UnicodeEncodeError:
    raise unittest.SkipTest("filename is not encodable to utf8")
31

32 33 34 35 36 37 38 39 40 41 42 43
supports_nonascii_filenames = True
if not os.path.supports_unicode_filenames:
    try:
        support.TESTFN_UNICODE.encode(support.TESTFN_ENCODING)
    except (UnicodeError, TypeError):
        # Either the file system encoding is None, or the file name
        # cannot be encoded in the file system encoding.
        supports_nonascii_filenames = False
requires_nonascii_filenames = unittest.skipUnless(
        supports_nonascii_filenames,
        'Requires non-ascii filenames support')

44
ns_uri = "http://www.python.org/xml-ns/saxtest/"
45

46 47 48 49 50 51 52
class XmlTestBase(unittest.TestCase):
    def verify_empty_attrs(self, attrs):
        self.assertRaises(KeyError, attrs.getValue, "attr")
        self.assertRaises(KeyError, attrs.getValueByQName, "attr")
        self.assertRaises(KeyError, attrs.getNameByQName, "attr")
        self.assertRaises(KeyError, attrs.getQNameByName, "attr")
        self.assertRaises(KeyError, attrs.__getitem__, "attr")
53 54 55 56
        self.assertEqual(attrs.getLength(), 0)
        self.assertEqual(attrs.getNames(), [])
        self.assertEqual(attrs.getQNames(), [])
        self.assertEqual(len(attrs), 0)
57
        self.assertNotIn("attr", attrs)
58 59 60 61 62
        self.assertEqual(list(attrs.keys()), [])
        self.assertEqual(attrs.get("attrs"), None)
        self.assertEqual(attrs.get("attrs", 25), 25)
        self.assertEqual(list(attrs.items()), [])
        self.assertEqual(list(attrs.values()), [])
63 64 65 66 67 68 69

    def verify_empty_nsattrs(self, attrs):
        self.assertRaises(KeyError, attrs.getValue, (ns_uri, "attr"))
        self.assertRaises(KeyError, attrs.getValueByQName, "ns:attr")
        self.assertRaises(KeyError, attrs.getNameByQName, "ns:attr")
        self.assertRaises(KeyError, attrs.getQNameByName, (ns_uri, "attr"))
        self.assertRaises(KeyError, attrs.__getitem__, (ns_uri, "attr"))
70 71 72 73
        self.assertEqual(attrs.getLength(), 0)
        self.assertEqual(attrs.getNames(), [])
        self.assertEqual(attrs.getQNames(), [])
        self.assertEqual(len(attrs), 0)
74
        self.assertNotIn((ns_uri, "attr"), attrs)
75 76 77 78 79
        self.assertEqual(list(attrs.keys()), [])
        self.assertEqual(attrs.get((ns_uri, "attr")), None)
        self.assertEqual(attrs.get((ns_uri, "attr"), 25), 25)
        self.assertEqual(list(attrs.items()), [])
        self.assertEqual(list(attrs.values()), [])
80 81

    def verify_attrs_wattr(self, attrs):
82 83 84 85
        self.assertEqual(attrs.getLength(), 1)
        self.assertEqual(attrs.getNames(), ["attr"])
        self.assertEqual(attrs.getQNames(), ["attr"])
        self.assertEqual(len(attrs), 1)
86
        self.assertIn("attr", attrs)
87 88 89 90 91 92 93 94 95 96
        self.assertEqual(list(attrs.keys()), ["attr"])
        self.assertEqual(attrs.get("attr"), "val")
        self.assertEqual(attrs.get("attr", 25), "val")
        self.assertEqual(list(attrs.items()), [("attr", "val")])
        self.assertEqual(list(attrs.values()), ["val"])
        self.assertEqual(attrs.getValue("attr"), "val")
        self.assertEqual(attrs.getValueByQName("attr"), "val")
        self.assertEqual(attrs.getNameByQName("attr"), "attr")
        self.assertEqual(attrs["attr"], "val")
        self.assertEqual(attrs.getQNameByName("attr"), "attr")
97 98 99

class MakeParserTest(unittest.TestCase):
    def test_make_parser2(self):
100 101 102
        # Creating parsers several times in a row should succeed.
        # Testing this because there have been failures of this kind
        # before.
103
        from xml.sax import make_parser
104
        p = make_parser()
105
        from xml.sax import make_parser
106
        p = make_parser()
107
        from xml.sax import make_parser
108
        p = make_parser()
109
        from xml.sax import make_parser
110
        p = make_parser()
111
        from xml.sax import make_parser
112
        p = make_parser()
113
        from xml.sax import make_parser
114
        p = make_parser()
115 116


117 118 119 120 121 122
# ===========================================================================
#
#   saxutils tests
#
# ===========================================================================

123 124 125
class SaxutilsTest(unittest.TestCase):
    # ===== escape
    def test_escape_basic(self):
126
        self.assertEqual(escape("Donald Duck & Co"), "Donald Duck & Co")
127

128
    def test_escape_all(self):
129 130
        self.assertEqual(escape("<Donald Duck & Co>"),
                         "&lt;Donald Duck &amp; Co&gt;")
131

132
    def test_escape_extra(self):
133 134
        self.assertEqual(escape("Hei på deg", {"å" : "&aring;"}),
                         "Hei p&aring; deg")
135

136 137
    # ===== unescape
    def test_unescape_basic(self):
138
        self.assertEqual(unescape("Donald Duck &amp; Co"), "Donald Duck & Co")
139

140
    def test_unescape_all(self):
141 142
        self.assertEqual(unescape("&lt;Donald Duck &amp; Co&gt;"),
                         "<Donald Duck & Co>")
143

144
    def test_unescape_extra(self):
145 146
        self.assertEqual(unescape("Hei på deg", {"å" : "&aring;"}),
                         "Hei p&aring; deg")
147

148
    def test_unescape_amp_extra(self):
149
        self.assertEqual(unescape("&amp;foo;", {"&foo;": "splat"}), "&foo;")
150

151 152
    # ===== quoteattr
    def test_quoteattr_basic(self):
153 154
        self.assertEqual(quoteattr("Donald Duck & Co"),
                         '"Donald Duck &amp; Co"')
155

156
    def test_single_quoteattr(self):
157 158
        self.assertEqual(quoteattr('Includes "double" quotes'),
                         '\'Includes "double" quotes\'')
159

160
    def test_double_quoteattr(self):
161 162
        self.assertEqual(quoteattr("Includes 'single' quotes"),
                         "\"Includes 'single' quotes\"")
163

164
    def test_single_double_quoteattr(self):
165 166
        self.assertEqual(quoteattr("Includes 'single' and \"double\" quotes"),
                         "\"Includes 'single' and &quot;double&quot; quotes\"")
167

168 169
    # ===== make_parser
    def test_make_parser(self):
170 171
        # Creating a parser should succeed - it should fall back
        # to the expatreader
172
        p = make_parser(['xml.parsers.no_such_parser'])
173 174


175 176
# ===== XMLGenerator

177
class XmlgenTest:
178
    def test_xmlgen_basic(self):
179
        result = self.ioclass()
180 181 182 183 184 185
        gen = XMLGenerator(result)
        gen.startDocument()
        gen.startElement("doc", {})
        gen.endElement("doc")
        gen.endDocument()

186
        self.assertEqual(result.getvalue(), self.xml("<doc></doc>"))
187

188
    def test_xmlgen_basic_empty(self):
189
        result = self.ioclass()
190 191 192 193 194 195
        gen = XMLGenerator(result, short_empty_elements=True)
        gen.startDocument()
        gen.startElement("doc", {})
        gen.endElement("doc")
        gen.endDocument()

196
        self.assertEqual(result.getvalue(), self.xml("<doc/>"))
197

198
    def test_xmlgen_content(self):
199
        result = self.ioclass()
200 201 202 203 204 205 206 207
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.characters("huhei")
        gen.endElement("doc")
        gen.endDocument()

208
        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
209

210
    def test_xmlgen_content_empty(self):
211
        result = self.ioclass()
212 213 214 215 216 217 218 219
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.characters("huhei")
        gen.endElement("doc")
        gen.endDocument()

220
        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
221

222
    def test_xmlgen_pi(self):
223
        result = self.ioclass()
224 225 226 227 228 229 230 231
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.processingInstruction("test", "data")
        gen.startElement("doc", {})
        gen.endElement("doc")
        gen.endDocument()

232 233
        self.assertEqual(result.getvalue(),
            self.xml("<?test data?><doc></doc>"))
234 235

    def test_xmlgen_content_escape(self):
236
        result = self.ioclass()
237 238 239 240 241 242 243 244
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.characters("<huhei&")
        gen.endElement("doc")
        gen.endDocument()

245
        self.assertEqual(result.getvalue(),
246
            self.xml("<doc>&lt;huhei&amp;</doc>"))
247 248

    def test_xmlgen_attr_escape(self):
249
        result = self.ioclass()
250 251 252 253 254 255 256 257 258 259 260 261 262
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {"a": '"'})
        gen.startElement("e", {"a": "'"})
        gen.endElement("e")
        gen.startElement("e", {"a": "'\""})
        gen.endElement("e")
        gen.startElement("e", {"a": "\n\r\t"})
        gen.endElement("e")
        gen.endElement("doc")
        gen.endDocument()

263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
        self.assertEqual(result.getvalue(), self.xml(
            "<doc a='\"'><e a=\"'\"></e>"
            "<e a=\"'&quot;\"></e>"
            "<e a=\"&#10;&#13;&#9;\"></e></doc>"))

    def test_xmlgen_encoding(self):
        encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
                     'utf-16', 'utf-16be', 'utf-16le',
                     'utf-32', 'utf-32be', 'utf-32le')
        for encoding in encodings:
            result = self.ioclass()
            gen = XMLGenerator(result, encoding=encoding)

            gen.startDocument()
            gen.startElement("doc", {"a": '\u20ac'})
            gen.characters("\u20ac")
            gen.endElement("doc")
            gen.endDocument()

            self.assertEqual(result.getvalue(),
                self.xml('<doc a="\u20ac">\u20ac</doc>', encoding=encoding))

    def test_xmlgen_unencodable(self):
        result = self.ioclass()
        gen = XMLGenerator(result, encoding='ascii')

        gen.startDocument()
        gen.startElement("doc", {"a": '\u20ac'})
        gen.characters("\u20ac")
        gen.endElement("doc")
        gen.endDocument()

        self.assertEqual(result.getvalue(),
            self.xml('<doc a="&#8364;">&#8364;</doc>', encoding='ascii'))
297 298

    def test_xmlgen_ignorable(self):
299
        result = self.ioclass()
300 301 302 303 304 305 306 307
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.ignorableWhitespace(" ")
        gen.endElement("doc")
        gen.endDocument()

308
        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
309

310
    def test_xmlgen_ignorable_empty(self):
311
        result = self.ioclass()
312 313 314 315 316 317 318 319
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.ignorableWhitespace(" ")
        gen.endElement("doc")
        gen.endDocument()

320
        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
321

322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
    def test_xmlgen_encoding_bytes(self):
        encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
                     'utf-16', 'utf-16be', 'utf-16le',
                     'utf-32', 'utf-32be', 'utf-32le')
        for encoding in encodings:
            result = self.ioclass()
            gen = XMLGenerator(result, encoding=encoding)

            gen.startDocument()
            gen.startElement("doc", {"a": '\u20ac'})
            gen.characters("\u20ac".encode(encoding))
            gen.ignorableWhitespace(" ".encode(encoding))
            gen.endElement("doc")
            gen.endDocument()

            self.assertEqual(result.getvalue(),
                self.xml('<doc a="\u20ac">\u20ac </doc>', encoding=encoding))

340
    def test_xmlgen_ns(self):
341
        result = self.ioclass()
342 343 344 345 346 347 348 349 350 351 352 353
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping("ns1", ns_uri)
        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
        # add an unqualified name
        gen.startElementNS((None, "udoc"), None, {})
        gen.endElementNS((None, "udoc"), None)
        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
        gen.endPrefixMapping("ns1")
        gen.endDocument()

354 355
        self.assertEqual(result.getvalue(), self.xml(
           '<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
356
                                         ns_uri))
357

358
    def test_xmlgen_ns_empty(self):
359
        result = self.ioclass()
360 361 362 363 364 365 366 367 368 369 370 371
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startPrefixMapping("ns1", ns_uri)
        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
        # add an unqualified name
        gen.startElementNS((None, "udoc"), None, {})
        gen.endElementNS((None, "udoc"), None)
        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
        gen.endPrefixMapping("ns1")
        gen.endDocument()

372 373
        self.assertEqual(result.getvalue(), self.xml(
           '<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
374 375
                                         ns_uri))

376
    def test_1463026_1(self):
377
        result = self.ioclass()
378
        gen = XMLGenerator(result)
379

380 381 382 383 384
        gen.startDocument()
        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS((None, 'a'), 'a')
        gen.endDocument()

385
        self.assertEqual(result.getvalue(), self.xml('<a b="c"></a>'))
386

387
    def test_1463026_1_empty(self):
388
        result = self.ioclass()
389 390 391 392 393 394 395
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS((None, 'a'), 'a')
        gen.endDocument()

396
        self.assertEqual(result.getvalue(), self.xml('<a b="c"/>'))
397

398
    def test_1463026_2(self):
399
        result = self.ioclass()
400 401 402 403 404 405 406 407 408
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping(None, 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping(None)
        gen.endDocument()

409
        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"></a>'))
410

411
    def test_1463026_2_empty(self):
412
        result = self.ioclass()
413 414 415 416 417 418 419 420 421
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startPrefixMapping(None, 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping(None)
        gen.endDocument()

422
        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"/>'))
423

424
    def test_1463026_3(self):
425
        result = self.ioclass()
426 427 428 429 430 431 432 433 434
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping('my', 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping('my')
        gen.endDocument()

435
        self.assertEqual(result.getvalue(),
436
            self.xml('<my:a xmlns:my="qux" b="c"></my:a>'))
437

438
    def test_1463026_3_empty(self):
439
        result = self.ioclass()
440 441 442 443 444 445 446 447 448
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startPrefixMapping('my', 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping('my')
        gen.endDocument()

449
        self.assertEqual(result.getvalue(),
450
            self.xml('<my:a xmlns:my="qux" b="c"/>'))
451

452 453 454
    def test_5027_1(self):
        # The xml prefix (as in xml:lang below) is reserved and bound by
        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
455
        # a bug whereby a KeyError is raised because this namespace is missing
456 457 458 459 460 461 462 463 464 465 466
        # from a dictionary.
        #
        # This test demonstrates the bug by parsing a document.
        test_xml = StringIO(
            '<?xml version="1.0"?>'
            '<a:g1 xmlns:a="http://example.com/ns">'
             '<a:g2 xml:lang="en">Hello</a:g2>'
            '</a:g1>')

        parser = make_parser()
        parser.setFeature(feature_namespaces, True)
467
        result = self.ioclass()
468 469 470 471
        gen = XMLGenerator(result)
        parser.setContentHandler(gen)
        parser.parse(test_xml)

472
        self.assertEqual(result.getvalue(),
473
                         self.xml(
474 475 476
                         '<a:g1 xmlns:a="http://example.com/ns">'
                          '<a:g2 xml:lang="en">Hello</a:g2>'
                         '</a:g1>'))
477 478 479 480

    def test_5027_2(self):
        # The xml prefix (as in xml:lang below) is reserved and bound by
        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
481
        # a bug whereby a KeyError is raised because this namespace is missing
482 483 484 485
        # from a dictionary.
        #
        # This test demonstrates the bug by direct manipulation of the
        # XMLGenerator.
486
        result = self.ioclass()
487 488 489 490 491 492 493 494 495 496 497 498 499
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping('a', 'http://example.com/ns')
        gen.startElementNS(('http://example.com/ns', 'g1'), 'g1', {})
        lang_attr = {('http://www.w3.org/XML/1998/namespace', 'lang'): 'en'}
        gen.startElementNS(('http://example.com/ns', 'g2'), 'g2', lang_attr)
        gen.characters('Hello')
        gen.endElementNS(('http://example.com/ns', 'g2'), 'g2')
        gen.endElementNS(('http://example.com/ns', 'g1'), 'g1')
        gen.endPrefixMapping('a')
        gen.endDocument()

500
        self.assertEqual(result.getvalue(),
501
                         self.xml(
502 503 504
                         '<a:g1 xmlns:a="http://example.com/ns">'
                          '<a:g2 xml:lang="en">Hello</a:g2>'
                         '</a:g1>'))
505

506 507 508 509 510 511 512 513 514
    def test_no_close_file(self):
        result = self.ioclass()
        def func(out):
            gen = XMLGenerator(out)
            gen.startDocument()
            gen.startElement("doc", {})
        func(result)
        self.assertFalse(result.closed)

515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
    def test_xmlgen_fragment(self):
        result = self.ioclass()
        gen = XMLGenerator(result)

        # Don't call gen.startDocument()
        gen.startElement("foo", {"a": "1.0"})
        gen.characters("Hello")
        gen.endElement("foo")
        gen.startElement("bar", {"b": "2.0"})
        gen.endElement("bar")
        # Don't call gen.endDocument()

        self.assertEqual(result.getvalue(),
            self.xml('<foo a="1.0">Hello</foo><bar b="2.0"></bar>')[len(self.xml('')):])

530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
class StringXmlgenTest(XmlgenTest, unittest.TestCase):
    ioclass = StringIO

    def xml(self, doc, encoding='iso-8859-1'):
        return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)

    test_xmlgen_unencodable = None

class BytesXmlgenTest(XmlgenTest, unittest.TestCase):
    ioclass = BytesIO

    def xml(self, doc, encoding='iso-8859-1'):
        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
                (encoding, doc)).encode(encoding, 'xmlcharrefreplace')

class WriterXmlgenTest(BytesXmlgenTest):
    class ioclass(list):
        write = list.append
        closed = False

        def seekable(self):
            return True

        def tell(self):
            # return 0 at start and not 0 after start
            return len(self)

        def getvalue(self):
            return b''.join(self)

560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
class StreamWriterXmlgenTest(XmlgenTest, unittest.TestCase):
    def ioclass(self):
        raw = BytesIO()
        writer = codecs.getwriter('ascii')(raw, 'xmlcharrefreplace')
        writer.getvalue = raw.getvalue
        return writer

    def xml(self, doc, encoding='iso-8859-1'):
        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
                (encoding, doc)).encode('ascii', 'xmlcharrefreplace')

class StreamReaderWriterXmlgenTest(XmlgenTest, unittest.TestCase):
    fname = support.TESTFN + '-codecs'

    def ioclass(self):
        writer = codecs.open(self.fname, 'w', encoding='ascii',
                             errors='xmlcharrefreplace', buffering=0)
577 578 579 580
        def cleanup():
            writer.close()
            support.unlink(self.fname)
        self.addCleanup(cleanup)
581 582 583 584 585 586
        def getvalue():
            # Windows will not let use reopen without first closing
            writer.close()
            with open(writer.name, 'rb') as f:
                return f.read()
        writer.getvalue = getvalue
587 588 589 590 591
        return writer

    def xml(self, doc, encoding='iso-8859-1'):
        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
                (encoding, doc)).encode('ascii', 'xmlcharrefreplace')
592 593 594

start = b'<?xml version="1.0" encoding="iso-8859-1"?>\n'

595 596 597

class XMLFilterBaseTest(unittest.TestCase):
    def test_filter_basic(self):
598
        result = BytesIO()
599 600 601 602 603 604 605 606 607 608 609
        gen = XMLGenerator(result)
        filter = XMLFilterBase()
        filter.setContentHandler(gen)

        filter.startDocument()
        filter.startElement("doc", {})
        filter.characters("content")
        filter.ignorableWhitespace(" ")
        filter.endElement("doc")
        filter.endDocument()

610
        self.assertEqual(result.getvalue(), start + b"<doc>content </doc>")
611 612 613 614 615 616 617

# ===========================================================================
#
#   expatreader tests
#
# ===========================================================================

618
with open(TEST_XMLFILE_OUT, 'rb') as f:
619
    xml_test_out = f.read()
620

621
class ExpatReaderTest(XmlTestBase):
622

623
    # ===== XMLReader support
624

625 626
    def test_expat_file(self):
        parser = create_parser()
627
        result = BytesIO()
628
        xmlgen = XMLGenerator(result)
629

630
        parser.setContentHandler(xmlgen)
631
        with open(TEST_XMLFILE, 'rb') as f:
632
            parser.parse(f)
633

634
        self.assertEqual(result.getvalue(), xml_test_out)
635

636
    @requires_nonascii_filenames
637 638 639 640 641 642
    def test_expat_file_nonascii(self):
        fname = support.TESTFN_UNICODE
        shutil.copyfile(TEST_XMLFILE, fname)
        self.addCleanup(support.unlink, fname)

        parser = create_parser()
643
        result = BytesIO()
644 645 646 647 648 649 650
        xmlgen = XMLGenerator(result)

        parser.setContentHandler(xmlgen)
        parser.parse(open(fname))

        self.assertEqual(result.getvalue(), xml_test_out)

651
    # ===== DTDHandler support
652

653
    class TestDTDHandler:
654

655 656 657
        def __init__(self):
            self._notations = []
            self._entities  = []
658

659 660
        def notationDecl(self, name, publicId, systemId):
            self._notations.append((name, publicId, systemId))
661

662 663
        def unparsedEntityDecl(self, name, publicId, systemId, ndata):
            self._entities.append((name, publicId, systemId, ndata))
664

665 666 667 668
    def test_expat_dtdhandler(self):
        parser = create_parser()
        handler = self.TestDTDHandler()
        parser.setDTDHandler(handler)
669

670 671 672 673 674 675
        parser.feed('<!DOCTYPE doc [\n')
        parser.feed('  <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
        parser.feed('  <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
        parser.feed(']>\n')
        parser.feed('<doc></doc>')
        parser.close()
676

677
        self.assertEqual(handler._notations,
678
            [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)])
679
        self.assertEqual(handler._entities, [("img", None, "expat.gif", "GIF")])
680

681
    # ===== EntityResolver support
682

683
    class TestEntityResolver:
684

685 686
        def resolveEntity(self, publicId, systemId):
            inpsrc = InputSource()
687
            inpsrc.setByteStream(BytesIO(b"<entity/>"))
688
            return inpsrc
689

690 691 692
    def test_expat_entityresolver(self):
        parser = create_parser()
        parser.setEntityResolver(self.TestEntityResolver())
693
        result = BytesIO()
694
        parser.setContentHandler(XMLGenerator(result))
695

696 697 698 699 700
        parser.feed('<!DOCTYPE doc [\n')
        parser.feed('  <!ENTITY test SYSTEM "whatever">\n')
        parser.feed(']>\n')
        parser.feed('<doc>&test;</doc>')
        parser.close()
701

702
        self.assertEqual(result.getvalue(), start +
703
                         b"<doc><entity></entity></doc>")
704

705
    # ===== Attributes support
706

707
    class AttrGatherer(ContentHandler):
708

709 710
        def startElement(self, name, attrs):
            self._attrs = attrs
711

712 713
        def startElementNS(self, name, qname, attrs):
            self._attrs = attrs
714

715 716 717 718
    def test_expat_attrs_empty(self):
        parser = create_parser()
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
719

720 721
        parser.feed("<doc/>")
        parser.close()
722

723
        self.verify_empty_attrs(gather._attrs)
724

725 726 727 728
    def test_expat_attrs_wattr(self):
        parser = create_parser()
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
729

730 731
        parser.feed("<doc attr='val'/>")
        parser.close()
732

733
        self.verify_attrs_wattr(gather._attrs)
734

735 736 737 738
    def test_expat_nsattrs_empty(self):
        parser = create_parser(1)
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
739

740 741
        parser.feed("<doc/>")
        parser.close()
742

743
        self.verify_empty_nsattrs(gather._attrs)
744

745 746 747 748
    def test_expat_nsattrs_wattr(self):
        parser = create_parser(1)
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
749

750 751
        parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
        parser.close()
752

753
        attrs = gather._attrs
754

755 756
        self.assertEqual(attrs.getLength(), 1)
        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
757 758
        self.assertTrue((attrs.getQNames() == [] or
                         attrs.getQNames() == ["ns:attr"]))
759
        self.assertEqual(len(attrs), 1)
760
        self.assertIn((ns_uri, "attr"), attrs)
761 762 763 764 765 766
        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
        self.assertEqual(list(attrs.items()), [((ns_uri, "attr"), "val")])
        self.assertEqual(list(attrs.values()), ["val"])
        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
        self.assertEqual(attrs[(ns_uri, "attr")], "val")
767

768
    # ===== InputSource support
769

770
    def test_expat_inpsource_filename(self):
771
        parser = create_parser()
772
        result = BytesIO()
773
        xmlgen = XMLGenerator(result)
774

775
        parser.setContentHandler(xmlgen)
776
        parser.parse(TEST_XMLFILE)
777

778
        self.assertEqual(result.getvalue(), xml_test_out)
779

780
    def test_expat_inpsource_sysid(self):
781
        parser = create_parser()
782
        result = BytesIO()
783
        xmlgen = XMLGenerator(result)
784

785
        parser.setContentHandler(xmlgen)
786
        parser.parse(InputSource(TEST_XMLFILE))
787

788
        self.assertEqual(result.getvalue(), xml_test_out)
789

790
    @requires_nonascii_filenames
791 792 793 794 795 796
    def test_expat_inpsource_sysid_nonascii(self):
        fname = support.TESTFN_UNICODE
        shutil.copyfile(TEST_XMLFILE, fname)
        self.addCleanup(support.unlink, fname)

        parser = create_parser()
797
        result = BytesIO()
798 799 800 801 802 803 804
        xmlgen = XMLGenerator(result)

        parser.setContentHandler(xmlgen)
        parser.parse(InputSource(fname))

        self.assertEqual(result.getvalue(), xml_test_out)

805 806
    def test_expat_inpsource_stream(self):
        parser = create_parser()
807
        result = BytesIO()
808
        xmlgen = XMLGenerator(result)
809

810 811
        parser.setContentHandler(xmlgen)
        inpsrc = InputSource()
812
        with open(TEST_XMLFILE, 'rb') as f:
813 814
            inpsrc.setByteStream(f)
            parser.parse(inpsrc)
815

816
        self.assertEqual(result.getvalue(), xml_test_out)
817

818
    # ===== IncrementalParser support
819

820
    def test_expat_incremental(self):
821
        result = BytesIO()
822 823 824
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
825

826 827 828
        parser.feed("<doc>")
        parser.feed("</doc>")
        parser.close()
829

830
        self.assertEqual(result.getvalue(), start + b"<doc></doc>")
831

832
    def test_expat_incremental_reset(self):
833
        result = BytesIO()
834 835 836
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
837

838 839
        parser.feed("<doc>")
        parser.feed("text")
840

841
        result = BytesIO()
842 843 844
        xmlgen = XMLGenerator(result)
        parser.setContentHandler(xmlgen)
        parser.reset()
845

846 847 848 849
        parser.feed("<doc>")
        parser.feed("text")
        parser.feed("</doc>")
        parser.close()
850

851
        self.assertEqual(result.getvalue(), start + b"<doc>text</doc>")
852

853
    # ===== Locator support
854

855
    def test_expat_locator_noinfo(self):
856
        result = BytesIO()
857 858 859
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
860

861 862 863
        parser.feed("<doc>")
        parser.feed("</doc>")
        parser.close()
864

865 866 867
        self.assertEqual(parser.getSystemId(), None)
        self.assertEqual(parser.getPublicId(), None)
        self.assertEqual(parser.getLineNumber(), 1)
868

869
    def test_expat_locator_withinfo(self):
870
        result = BytesIO()
871 872 873
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
874
        parser.parse(TEST_XMLFILE)
875

876 877
        self.assertEqual(parser.getSystemId(), TEST_XMLFILE)
        self.assertEqual(parser.getPublicId(), None)
878

879
    @requires_nonascii_filenames
880 881 882 883 884
    def test_expat_locator_withinfo_nonascii(self):
        fname = support.TESTFN_UNICODE
        shutil.copyfile(TEST_XMLFILE, fname)
        self.addCleanup(support.unlink, fname)

885
        result = BytesIO()
886 887 888 889 890 891 892 893
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
        parser.parse(fname)

        self.assertEqual(parser.getSystemId(), fname)
        self.assertEqual(parser.getPublicId(), None)

894 895 896 897 898 899 900

# ===========================================================================
#
#   error reporting
#
# ===========================================================================

901 902 903 904 905
class ErrorReportingTest(unittest.TestCase):
    def test_expat_inpsource_location(self):
        parser = create_parser()
        parser.setContentHandler(ContentHandler()) # do nothing
        source = InputSource()
906
        source.setByteStream(BytesIO(b"<foo bar foobar>"))   #ill-formed
907 908 909 910 911 912
        name = "a file name"
        source.setSystemId(name)
        try:
            parser.parse(source)
            self.fail()
        except SAXException as e:
913
            self.assertEqual(e.getSystemId(), name)
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953

    def test_expat_incomplete(self):
        parser = create_parser()
        parser.setContentHandler(ContentHandler()) # do nothing
        self.assertRaises(SAXParseException, parser.parse, StringIO("<foo>"))

    def test_sax_parse_exception_str(self):
        # pass various values from a locator to the SAXParseException to
        # make sure that the __str__() doesn't fall apart when None is
        # passed instead of an integer line and column number
        #
        # use "normal" values for the locator:
        str(SAXParseException("message", None,
                              self.DummyLocator(1, 1)))
        # use None for the line number:
        str(SAXParseException("message", None,
                              self.DummyLocator(None, 1)))
        # use None for the column number:
        str(SAXParseException("message", None,
                              self.DummyLocator(1, None)))
        # use None for both:
        str(SAXParseException("message", None,
                              self.DummyLocator(None, None)))

    class DummyLocator:
        def __init__(self, lineno, colno):
            self._lineno = lineno
            self._colno = colno

        def getPublicId(self):
            return "pubid"

        def getSystemId(self):
            return "sysid"

        def getLineNumber(self):
            return self._lineno

        def getColumnNumber(self):
            return self._colno
954

955 956 957 958 959 960
# ===========================================================================
#
#   xmlreader tests
#
# ===========================================================================

961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
class XmlReaderTest(XmlTestBase):

    # ===== AttributesImpl
    def test_attrs_empty(self):
        self.verify_empty_attrs(AttributesImpl({}))

    def test_attrs_wattr(self):
        self.verify_attrs_wattr(AttributesImpl({"attr" : "val"}))

    def test_nsattrs_empty(self):
        self.verify_empty_nsattrs(AttributesNSImpl({}, {}))

    def test_nsattrs_wattr(self):
        attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
                                 {(ns_uri, "attr") : "ns:attr"})

977 978 979 980
        self.assertEqual(attrs.getLength(), 1)
        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
        self.assertEqual(attrs.getQNames(), ["ns:attr"])
        self.assertEqual(len(attrs), 1)
981
        self.assertIn((ns_uri, "attr"), attrs)
982 983 984 985 986 987 988 989 990 991
        self.assertEqual(list(attrs.keys()), [(ns_uri, "attr")])
        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
        self.assertEqual(list(attrs.items()), [((ns_uri, "attr"), "val")])
        self.assertEqual(list(attrs.values()), ["val"])
        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
        self.assertEqual(attrs.getValueByQName("ns:attr"), "val")
        self.assertEqual(attrs.getNameByQName("ns:attr"), (ns_uri, "attr"))
        self.assertEqual(attrs[(ns_uri, "attr")], "val")
        self.assertEqual(attrs.getQNameByName((ns_uri, "attr")), "ns:attr")
992 993


994
def test_main():
995 996
    run_unittest(MakeParserTest,
                 SaxutilsTest,
997 998 999
                 StringXmlgenTest,
                 BytesXmlgenTest,
                 WriterXmlgenTest,
1000 1001
                 StreamWriterXmlgenTest,
                 StreamReaderWriterXmlgenTest,
1002 1003 1004 1005 1006
                 ExpatReaderTest,
                 ErrorReportingTest,
                 XmlReaderTest)

if __name__ == "__main__":
1007
    test_main()