test_sax.py 28.8 KB
Newer Older
1
# regression test for SAX 2.0
2 3
# $Id$

4 5
from xml.sax import make_parser, ContentHandler, \
                    SAXException, SAXReaderNotAvailable, SAXParseException
6 7
try:
    make_parser()
8
except SAXReaderNotAvailable:
9 10
    # don't try to test this module if we cannot create a parser
    raise ImportError("no XML parsers available")
11 12 13
from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
                             XMLFilterBase
from xml.sax.expatreader import create_parser
14
from xml.sax.handler import feature_namespaces
15
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
16
from io import StringIO
17
from test.support import findfile, run_unittest
18
import unittest
19 20 21

TEST_XMLFILE = findfile("test.xml", subdir="xmltestdata")
TEST_XMLFILE_OUT = findfile("test.xml.out", subdir="xmltestdata")
22 23 24 25 26
try:
    TEST_XMLFILE.encode("utf8")
    TEST_XMLFILE_OUT.encode("utf8")
except UnicodeEncodeError:
    raise unittest.SkipTest("filename is not encodable to utf8")
27

28
ns_uri = "http://www.python.org/xml-ns/saxtest/"
29

30 31 32 33 34 35 36
class XmlTestBase(unittest.TestCase):
    def verify_empty_attrs(self, attrs):
        self.assertRaises(KeyError, attrs.getValue, "attr")
        self.assertRaises(KeyError, attrs.getValueByQName, "attr")
        self.assertRaises(KeyError, attrs.getNameByQName, "attr")
        self.assertRaises(KeyError, attrs.getQNameByName, "attr")
        self.assertRaises(KeyError, attrs.__getitem__, "attr")
37 38 39 40
        self.assertEqual(attrs.getLength(), 0)
        self.assertEqual(attrs.getNames(), [])
        self.assertEqual(attrs.getQNames(), [])
        self.assertEqual(len(attrs), 0)
41
        self.assertNotIn("attr", attrs)
42 43 44 45 46
        self.assertEqual(list(attrs.keys()), [])
        self.assertEqual(attrs.get("attrs"), None)
        self.assertEqual(attrs.get("attrs", 25), 25)
        self.assertEqual(list(attrs.items()), [])
        self.assertEqual(list(attrs.values()), [])
47 48 49 50 51 52 53

    def verify_empty_nsattrs(self, attrs):
        self.assertRaises(KeyError, attrs.getValue, (ns_uri, "attr"))
        self.assertRaises(KeyError, attrs.getValueByQName, "ns:attr")
        self.assertRaises(KeyError, attrs.getNameByQName, "ns:attr")
        self.assertRaises(KeyError, attrs.getQNameByName, (ns_uri, "attr"))
        self.assertRaises(KeyError, attrs.__getitem__, (ns_uri, "attr"))
54 55 56 57
        self.assertEqual(attrs.getLength(), 0)
        self.assertEqual(attrs.getNames(), [])
        self.assertEqual(attrs.getQNames(), [])
        self.assertEqual(len(attrs), 0)
58
        self.assertNotIn((ns_uri, "attr"), attrs)
59 60 61 62 63
        self.assertEqual(list(attrs.keys()), [])
        self.assertEqual(attrs.get((ns_uri, "attr")), None)
        self.assertEqual(attrs.get((ns_uri, "attr"), 25), 25)
        self.assertEqual(list(attrs.items()), [])
        self.assertEqual(list(attrs.values()), [])
64 65

    def verify_attrs_wattr(self, attrs):
66 67 68 69
        self.assertEqual(attrs.getLength(), 1)
        self.assertEqual(attrs.getNames(), ["attr"])
        self.assertEqual(attrs.getQNames(), ["attr"])
        self.assertEqual(len(attrs), 1)
70
        self.assertIn("attr", attrs)
71 72 73 74 75 76 77 78 79 80
        self.assertEqual(list(attrs.keys()), ["attr"])
        self.assertEqual(attrs.get("attr"), "val")
        self.assertEqual(attrs.get("attr", 25), "val")
        self.assertEqual(list(attrs.items()), [("attr", "val")])
        self.assertEqual(list(attrs.values()), ["val"])
        self.assertEqual(attrs.getValue("attr"), "val")
        self.assertEqual(attrs.getValueByQName("attr"), "val")
        self.assertEqual(attrs.getNameByQName("attr"), "attr")
        self.assertEqual(attrs["attr"], "val")
        self.assertEqual(attrs.getQNameByName("attr"), "attr")
81 82 83

class MakeParserTest(unittest.TestCase):
    def test_make_parser2(self):
84 85 86
        # Creating parsers several times in a row should succeed.
        # Testing this because there have been failures of this kind
        # before.
87
        from xml.sax import make_parser
88
        p = make_parser()
89
        from xml.sax import make_parser
90
        p = make_parser()
91
        from xml.sax import make_parser
92
        p = make_parser()
93
        from xml.sax import make_parser
94
        p = make_parser()
95
        from xml.sax import make_parser
96
        p = make_parser()
97
        from xml.sax import make_parser
98
        p = make_parser()
99 100


101 102 103 104 105 106
# ===========================================================================
#
#   saxutils tests
#
# ===========================================================================

107 108 109
class SaxutilsTest(unittest.TestCase):
    # ===== escape
    def test_escape_basic(self):
110
        self.assertEqual(escape("Donald Duck & Co"), "Donald Duck & Co")
111

112
    def test_escape_all(self):
113 114
        self.assertEqual(escape("<Donald Duck & Co>"),
                         "&lt;Donald Duck &amp; Co&gt;")
115

116
    def test_escape_extra(self):
117 118
        self.assertEqual(escape("Hei på deg", {"å" : "&aring;"}),
                         "Hei p&aring; deg")
119

120 121
    # ===== unescape
    def test_unescape_basic(self):
122
        self.assertEqual(unescape("Donald Duck &amp; Co"), "Donald Duck & Co")
123

124
    def test_unescape_all(self):
125 126
        self.assertEqual(unescape("&lt;Donald Duck &amp; Co&gt;"),
                         "<Donald Duck & Co>")
127

128
    def test_unescape_extra(self):
129 130
        self.assertEqual(unescape("Hei på deg", {"å" : "&aring;"}),
                         "Hei p&aring; deg")
131

132
    def test_unescape_amp_extra(self):
133
        self.assertEqual(unescape("&amp;foo;", {"&foo;": "splat"}), "&foo;")
134

135 136
    # ===== quoteattr
    def test_quoteattr_basic(self):
137 138
        self.assertEqual(quoteattr("Donald Duck & Co"),
                         '"Donald Duck &amp; Co"')
139

140
    def test_single_quoteattr(self):
141 142
        self.assertEqual(quoteattr('Includes "double" quotes'),
                         '\'Includes "double" quotes\'')
143

144
    def test_double_quoteattr(self):
145 146
        self.assertEqual(quoteattr("Includes 'single' quotes"),
                         "\"Includes 'single' quotes\"")
147

148
    def test_single_double_quoteattr(self):
149 150
        self.assertEqual(quoteattr("Includes 'single' and \"double\" quotes"),
                         "\"Includes 'single' and &quot;double&quot; quotes\"")
151

152 153
    # ===== make_parser
    def test_make_parser(self):
154 155
        # Creating a parser should succeed - it should fall back
        # to the expatreader
156
        p = make_parser(['xml.parsers.no_such_parser'])
157 158


159 160 161 162
# ===== XMLGenerator

start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'

163 164 165 166 167 168 169 170 171
class XmlgenTest(unittest.TestCase):
    def test_xmlgen_basic(self):
        result = StringIO()
        gen = XMLGenerator(result)
        gen.startDocument()
        gen.startElement("doc", {})
        gen.endElement("doc")
        gen.endDocument()

172
        self.assertEqual(result.getvalue(), start + "<doc></doc>")
173

174 175 176 177 178 179 180 181
    def test_xmlgen_basic_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)
        gen.startDocument()
        gen.startElement("doc", {})
        gen.endElement("doc")
        gen.endDocument()

182
        self.assertEqual(result.getvalue(), start + "<doc/>")
183

184 185 186 187 188 189 190 191 192 193
    def test_xmlgen_content(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.characters("huhei")
        gen.endElement("doc")
        gen.endDocument()

194
        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
195

196 197 198 199 200 201 202 203 204 205
    def test_xmlgen_content_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.characters("huhei")
        gen.endElement("doc")
        gen.endDocument()

206
        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
207

208 209 210 211 212 213 214 215 216 217
    def test_xmlgen_pi(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.processingInstruction("test", "data")
        gen.startElement("doc", {})
        gen.endElement("doc")
        gen.endDocument()

218
        self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
219 220 221 222 223 224 225 226 227 228 229

    def test_xmlgen_content_escape(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.characters("<huhei&")
        gen.endElement("doc")
        gen.endDocument()

230
        self.assertEqual(result.getvalue(),
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
            start + "<doc>&lt;huhei&amp;</doc>")

    def test_xmlgen_attr_escape(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {"a": '"'})
        gen.startElement("e", {"a": "'"})
        gen.endElement("e")
        gen.startElement("e", {"a": "'\""})
        gen.endElement("e")
        gen.startElement("e", {"a": "\n\r\t"})
        gen.endElement("e")
        gen.endElement("doc")
        gen.endDocument()

248
        self.assertEqual(result.getvalue(), start +
249 250 251 252 253 254 255 256 257 258 259 260 261 262
            ("<doc a='\"'><e a=\"'\"></e>"
             "<e a=\"'&quot;\"></e>"
             "<e a=\"&#10;&#13;&#9;\"></e></doc>"))

    def test_xmlgen_ignorable(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.ignorableWhitespace(" ")
        gen.endElement("doc")
        gen.endDocument()

263
        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
264

265 266 267 268 269 270 271 272 273 274
    def test_xmlgen_ignorable_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startElement("doc", {})
        gen.ignorableWhitespace(" ")
        gen.endElement("doc")
        gen.endDocument()

275
        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
276

277 278 279 280 281 282 283 284 285 286 287 288 289 290
    def test_xmlgen_ns(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping("ns1", ns_uri)
        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
        # add an unqualified name
        gen.startElementNS((None, "udoc"), None, {})
        gen.endElementNS((None, "udoc"), None)
        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
        gen.endPrefixMapping("ns1")
        gen.endDocument()

291
        self.assertEqual(result.getvalue(), start + \
292 293
           ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
                                         ns_uri))
294

295 296 297 298 299 300 301 302 303 304 305 306 307 308
    def test_xmlgen_ns_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startPrefixMapping("ns1", ns_uri)
        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
        # add an unqualified name
        gen.startElementNS((None, "udoc"), None, {})
        gen.endElementNS((None, "udoc"), None)
        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
        gen.endPrefixMapping("ns1")
        gen.endDocument()

309
        self.assertEqual(result.getvalue(), start + \
310 311 312
           ('<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
                                         ns_uri))

313 314 315
    def test_1463026_1(self):
        result = StringIO()
        gen = XMLGenerator(result)
316

317 318 319 320 321
        gen.startDocument()
        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS((None, 'a'), 'a')
        gen.endDocument()

322
        self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
323

324 325 326 327 328 329 330 331 332
    def test_1463026_1_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS((None, 'a'), 'a')
        gen.endDocument()

333
        self.assertEqual(result.getvalue(), start+'<a b="c"/>')
334

335 336 337 338 339 340 341 342 343 344 345
    def test_1463026_2(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping(None, 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping(None)
        gen.endDocument()

346
        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
347

348 349 350 351 352 353 354 355 356 357 358
    def test_1463026_2_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startPrefixMapping(None, 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping(None)
        gen.endDocument()

359
        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"/>')
360

361 362 363 364 365 366 367 368 369 370 371
    def test_1463026_3(self):
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping('my', 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping('my')
        gen.endDocument()

372
        self.assertEqual(result.getvalue(),
373 374
            start+'<my:a xmlns:my="qux" b="c"></my:a>')

375 376 377 378 379 380 381 382 383 384 385
    def test_1463026_3_empty(self):
        result = StringIO()
        gen = XMLGenerator(result, short_empty_elements=True)

        gen.startDocument()
        gen.startPrefixMapping('my', 'qux')
        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
        gen.endElementNS(('qux', 'a'), 'a')
        gen.endPrefixMapping('my')
        gen.endDocument()

386
        self.assertEqual(result.getvalue(),
387 388
            start+'<my:a xmlns:my="qux" b="c"/>')

389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
    def test_5027_1(self):
        # The xml prefix (as in xml:lang below) is reserved and bound by
        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
        # a bug whereby a KeyError is thrown because this namespace is missing
        # from a dictionary.
        #
        # This test demonstrates the bug by parsing a document.
        test_xml = StringIO(
            '<?xml version="1.0"?>'
            '<a:g1 xmlns:a="http://example.com/ns">'
             '<a:g2 xml:lang="en">Hello</a:g2>'
            '</a:g1>')

        parser = make_parser()
        parser.setFeature(feature_namespaces, True)
        result = StringIO()
        gen = XMLGenerator(result)
        parser.setContentHandler(gen)
        parser.parse(test_xml)

409 410 411 412 413
        self.assertEqual(result.getvalue(),
                         start + (
                         '<a:g1 xmlns:a="http://example.com/ns">'
                          '<a:g2 xml:lang="en">Hello</a:g2>'
                         '</a:g1>'))
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436

    def test_5027_2(self):
        # The xml prefix (as in xml:lang below) is reserved and bound by
        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
        # a bug whereby a KeyError is thrown because this namespace is missing
        # from a dictionary.
        #
        # This test demonstrates the bug by direct manipulation of the
        # XMLGenerator.
        result = StringIO()
        gen = XMLGenerator(result)

        gen.startDocument()
        gen.startPrefixMapping('a', 'http://example.com/ns')
        gen.startElementNS(('http://example.com/ns', 'g1'), 'g1', {})
        lang_attr = {('http://www.w3.org/XML/1998/namespace', 'lang'): 'en'}
        gen.startElementNS(('http://example.com/ns', 'g2'), 'g2', lang_attr)
        gen.characters('Hello')
        gen.endElementNS(('http://example.com/ns', 'g2'), 'g2')
        gen.endElementNS(('http://example.com/ns', 'g1'), 'g1')
        gen.endPrefixMapping('a')
        gen.endDocument()

437 438 439 440 441
        self.assertEqual(result.getvalue(),
                         start + (
                         '<a:g1 xmlns:a="http://example.com/ns">'
                          '<a:g2 xml:lang="en">Hello</a:g2>'
                         '</a:g1>'))
442

443 444 445 446 447 448 449 450 451 452 453 454 455 456 457

class XMLFilterBaseTest(unittest.TestCase):
    def test_filter_basic(self):
        result = StringIO()
        gen = XMLGenerator(result)
        filter = XMLFilterBase()
        filter.setContentHandler(gen)

        filter.startDocument()
        filter.startElement("doc", {})
        filter.characters("content")
        filter.ignorableWhitespace(" ")
        filter.endElement("doc")
        filter.endDocument()

458
        self.assertEqual(result.getvalue(), start + "<doc>content </doc>")
459 460 461 462 463 464 465

# ===========================================================================
#
#   expatreader tests
#
# ===========================================================================

466 467
with open(TEST_XMLFILE_OUT) as f:
    xml_test_out = f.read()
468

469
class ExpatReaderTest(XmlTestBase):
470

471
    # ===== XMLReader support
472

473 474 475 476
    def test_expat_file(self):
        parser = create_parser()
        result = StringIO()
        xmlgen = XMLGenerator(result)
477

478
        parser.setContentHandler(xmlgen)
479 480
        with open(TEST_XMLFILE) as f:
            parser.parse(f)
481

482
        self.assertEqual(result.getvalue(), xml_test_out)
483

484
    # ===== DTDHandler support
485

486
    class TestDTDHandler:
487

488 489 490
        def __init__(self):
            self._notations = []
            self._entities  = []
491

492 493
        def notationDecl(self, name, publicId, systemId):
            self._notations.append((name, publicId, systemId))
494

495 496
        def unparsedEntityDecl(self, name, publicId, systemId, ndata):
            self._entities.append((name, publicId, systemId, ndata))
497

498 499 500 501
    def test_expat_dtdhandler(self):
        parser = create_parser()
        handler = self.TestDTDHandler()
        parser.setDTDHandler(handler)
502

503 504 505 506 507 508
        parser.feed('<!DOCTYPE doc [\n')
        parser.feed('  <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
        parser.feed('  <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
        parser.feed(']>\n')
        parser.feed('<doc></doc>')
        parser.close()
509

510
        self.assertEqual(handler._notations,
511
            [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)])
512
        self.assertEqual(handler._entities, [("img", None, "expat.gif", "GIF")])
513

514
    # ===== EntityResolver support
515

516
    class TestEntityResolver:
517

518 519 520 521
        def resolveEntity(self, publicId, systemId):
            inpsrc = InputSource()
            inpsrc.setByteStream(StringIO("<entity/>"))
            return inpsrc
522

523 524 525 526 527
    def test_expat_entityresolver(self):
        parser = create_parser()
        parser.setEntityResolver(self.TestEntityResolver())
        result = StringIO()
        parser.setContentHandler(XMLGenerator(result))
528

529 530 531 532 533
        parser.feed('<!DOCTYPE doc [\n')
        parser.feed('  <!ENTITY test SYSTEM "whatever">\n')
        parser.feed(']>\n')
        parser.feed('<doc>&test;</doc>')
        parser.close()
534

535 536
        self.assertEqual(result.getvalue(), start +
                         "<doc><entity></entity></doc>")
537

538
    # ===== Attributes support
539

540
    class AttrGatherer(ContentHandler):
541

542 543
        def startElement(self, name, attrs):
            self._attrs = attrs
544

545 546
        def startElementNS(self, name, qname, attrs):
            self._attrs = attrs
547

548 549 550 551
    def test_expat_attrs_empty(self):
        parser = create_parser()
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
552

553 554
        parser.feed("<doc/>")
        parser.close()
555

556
        self.verify_empty_attrs(gather._attrs)
557

558 559 560 561
    def test_expat_attrs_wattr(self):
        parser = create_parser()
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
562

563 564
        parser.feed("<doc attr='val'/>")
        parser.close()
565

566
        self.verify_attrs_wattr(gather._attrs)
567

568 569 570 571
    def test_expat_nsattrs_empty(self):
        parser = create_parser(1)
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
572

573 574
        parser.feed("<doc/>")
        parser.close()
575

576
        self.verify_empty_nsattrs(gather._attrs)
577

578 579 580 581
    def test_expat_nsattrs_wattr(self):
        parser = create_parser(1)
        gather = self.AttrGatherer()
        parser.setContentHandler(gather)
582

583 584
        parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
        parser.close()
585

586
        attrs = gather._attrs
587

588 589
        self.assertEqual(attrs.getLength(), 1)
        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
590 591
        self.assertTrue((attrs.getQNames() == [] or
                         attrs.getQNames() == ["ns:attr"]))
592
        self.assertEqual(len(attrs), 1)
593
        self.assertIn((ns_uri, "attr"), attrs)
594 595 596 597 598 599
        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
        self.assertEqual(list(attrs.items()), [((ns_uri, "attr"), "val")])
        self.assertEqual(list(attrs.values()), ["val"])
        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
        self.assertEqual(attrs[(ns_uri, "attr")], "val")
600

601
    # ===== InputSource support
602

603
    def test_expat_inpsource_filename(self):
604 605 606
        parser = create_parser()
        result = StringIO()
        xmlgen = XMLGenerator(result)
607

608
        parser.setContentHandler(xmlgen)
609
        parser.parse(TEST_XMLFILE)
610

611
        self.assertEqual(result.getvalue(), xml_test_out)
612

613
    def test_expat_inpsource_sysid(self):
614 615 616
        parser = create_parser()
        result = StringIO()
        xmlgen = XMLGenerator(result)
617

618
        parser.setContentHandler(xmlgen)
619
        parser.parse(InputSource(TEST_XMLFILE))
620

621
        self.assertEqual(result.getvalue(), xml_test_out)
622

623 624 625 626
    def test_expat_inpsource_stream(self):
        parser = create_parser()
        result = StringIO()
        xmlgen = XMLGenerator(result)
627

628 629
        parser.setContentHandler(xmlgen)
        inpsrc = InputSource()
630 631 632
        with open(TEST_XMLFILE) as f:
            inpsrc.setByteStream(f)
            parser.parse(inpsrc)
633

634
        self.assertEqual(result.getvalue(), xml_test_out)
635

636
    # ===== IncrementalParser support
637

638 639 640 641 642
    def test_expat_incremental(self):
        result = StringIO()
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
643

644 645 646
        parser.feed("<doc>")
        parser.feed("</doc>")
        parser.close()
647

648
        self.assertEqual(result.getvalue(), start + "<doc></doc>")
649

650 651 652 653 654
    def test_expat_incremental_reset(self):
        result = StringIO()
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
655

656 657
        parser.feed("<doc>")
        parser.feed("text")
658

659 660 661 662
        result = StringIO()
        xmlgen = XMLGenerator(result)
        parser.setContentHandler(xmlgen)
        parser.reset()
663

664 665 666 667
        parser.feed("<doc>")
        parser.feed("text")
        parser.feed("</doc>")
        parser.close()
668

669
        self.assertEqual(result.getvalue(), start + "<doc>text</doc>")
670

671
    # ===== Locator support
672

673 674 675 676 677
    def test_expat_locator_noinfo(self):
        result = StringIO()
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
678

679 680 681
        parser.feed("<doc>")
        parser.feed("</doc>")
        parser.close()
682

683 684 685
        self.assertEqual(parser.getSystemId(), None)
        self.assertEqual(parser.getPublicId(), None)
        self.assertEqual(parser.getLineNumber(), 1)
686

687
    def test_expat_locator_withinfo(self):
688 689 690 691
        result = StringIO()
        xmlgen = XMLGenerator(result)
        parser = create_parser()
        parser.setContentHandler(xmlgen)
692
        parser.parse(TEST_XMLFILE)
693

694 695
        self.assertEqual(parser.getSystemId(), TEST_XMLFILE)
        self.assertEqual(parser.getPublicId(), None)
696

697 698 699 700 701 702 703

# ===========================================================================
#
#   error reporting
#
# ===========================================================================

704 705 706 707 708 709 710 711 712 713 714 715
class ErrorReportingTest(unittest.TestCase):
    def test_expat_inpsource_location(self):
        parser = create_parser()
        parser.setContentHandler(ContentHandler()) # do nothing
        source = InputSource()
        source.setByteStream(StringIO("<foo bar foobar>"))   #ill-formed
        name = "a file name"
        source.setSystemId(name)
        try:
            parser.parse(source)
            self.fail()
        except SAXException as e:
716
            self.assertEqual(e.getSystemId(), name)
717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756

    def test_expat_incomplete(self):
        parser = create_parser()
        parser.setContentHandler(ContentHandler()) # do nothing
        self.assertRaises(SAXParseException, parser.parse, StringIO("<foo>"))

    def test_sax_parse_exception_str(self):
        # pass various values from a locator to the SAXParseException to
        # make sure that the __str__() doesn't fall apart when None is
        # passed instead of an integer line and column number
        #
        # use "normal" values for the locator:
        str(SAXParseException("message", None,
                              self.DummyLocator(1, 1)))
        # use None for the line number:
        str(SAXParseException("message", None,
                              self.DummyLocator(None, 1)))
        # use None for the column number:
        str(SAXParseException("message", None,
                              self.DummyLocator(1, None)))
        # use None for both:
        str(SAXParseException("message", None,
                              self.DummyLocator(None, None)))

    class DummyLocator:
        def __init__(self, lineno, colno):
            self._lineno = lineno
            self._colno = colno

        def getPublicId(self):
            return "pubid"

        def getSystemId(self):
            return "sysid"

        def getLineNumber(self):
            return self._lineno

        def getColumnNumber(self):
            return self._colno
757

758 759 760 761 762 763
# ===========================================================================
#
#   xmlreader tests
#
# ===========================================================================

764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
class XmlReaderTest(XmlTestBase):

    # ===== AttributesImpl
    def test_attrs_empty(self):
        self.verify_empty_attrs(AttributesImpl({}))

    def test_attrs_wattr(self):
        self.verify_attrs_wattr(AttributesImpl({"attr" : "val"}))

    def test_nsattrs_empty(self):
        self.verify_empty_nsattrs(AttributesNSImpl({}, {}))

    def test_nsattrs_wattr(self):
        attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
                                 {(ns_uri, "attr") : "ns:attr"})

780 781 782 783
        self.assertEqual(attrs.getLength(), 1)
        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
        self.assertEqual(attrs.getQNames(), ["ns:attr"])
        self.assertEqual(len(attrs), 1)
784
        self.assertIn((ns_uri, "attr"), attrs)
785 786 787 788 789 790 791 792 793 794
        self.assertEqual(list(attrs.keys()), [(ns_uri, "attr")])
        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
        self.assertEqual(list(attrs.items()), [((ns_uri, "attr"), "val")])
        self.assertEqual(list(attrs.values()), ["val"])
        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
        self.assertEqual(attrs.getValueByQName("ns:attr"), "val")
        self.assertEqual(attrs.getNameByQName("ns:attr"), (ns_uri, "attr"))
        self.assertEqual(attrs[(ns_uri, "attr")], "val")
        self.assertEqual(attrs.getQNameByName((ns_uri, "attr")), "ns:attr")
795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829


    # During the development of Python 2.5, an attempt to move the "xml"
    # package implementation to a new package ("xmlcore") proved painful.
    # The goal of this change was to allow applications to be able to
    # obtain and rely on behavior in the standard library implementation
    # of the XML support without needing to be concerned about the
    # availability of the PyXML implementation.
    #
    # While the existing import hackery in Lib/xml/__init__.py can cause
    # PyXML's _xmlpus package to supplant the "xml" package, that only
    # works because either implementation uses the "xml" package name for
    # imports.
    #
    # The move resulted in a number of problems related to the fact that
    # the import machinery's "package context" is based on the name that's
    # being imported rather than the __name__ of the actual package
    # containment; it wasn't possible for the "xml" package to be replaced
    # by a simple module that indirected imports to the "xmlcore" package.
    #
    # The following two tests exercised bugs that were introduced in that
    # attempt.  Keeping these tests around will help detect problems with
    # other attempts to provide reliable access to the standard library's
    # implementation of the XML support.

    def test_sf_1511497(self):
        # Bug report: http://www.python.org/sf/1511497
        import sys
        old_modules = sys.modules.copy()
        for modname in list(sys.modules.keys()):
            if modname.startswith("xml."):
                del sys.modules[modname]
        try:
            import xml.sax.expatreader
            module = xml.sax.expatreader
830
            self.assertEqual(module.__name__, "xml.sax.expatreader")
831 832 833 834 835 836 837 838 839 840 841
        finally:
            sys.modules.update(old_modules)

    def test_sf_1513611(self):
        # Bug report: http://www.python.org/sf/1513611
        sio = StringIO("invalid")
        parser = make_parser()
        from xml.sax import SAXParseException
        self.assertRaises(SAXParseException, parser.parse, sio)


842
def test_main():
843 844 845 846 847 848 849 850
    run_unittest(MakeParserTest,
                 SaxutilsTest,
                 XmlgenTest,
                 ExpatReaderTest,
                 ErrorReportingTest,
                 XmlReaderTest)

if __name__ == "__main__":
851
    test_main()