saxutils.py 11.8 KB
Newer Older
1 2
"""\
A library of useful helper classes to the SAX classes, for the
3 4 5
convenience of application and driver writers.
"""

6
import os, urllib.parse, urllib.request
7
import io
8
import codecs
9 10
from . import handler
from . import xmlreader
11

12 13 14 15 16
def __dict_replace(s, d):
    """Replace substrings of a string using a dictionary."""
    for key, value in d.items():
        s = s.replace(key, value)
    return s
Martin v. Löwis's avatar
Martin v. Löwis committed
17

18
def escape(data, entities={}):
19
    """Escape &, <, and > in a string of data.
Tim Peters's avatar
Tim Peters committed
20

21
    You can escape other strings of data by passing a dictionary as
22 23 24
    the optional entities parameter.  The keys and values must all be
    strings; each key will be replaced with its corresponding value.
    """
25 26

    # must do ampersand first
27
    data = data.replace("&", "&amp;")
28 29 30 31 32
    data = data.replace(">", "&gt;")
    data = data.replace("<", "&lt;")
    if entities:
        data = __dict_replace(data, entities)
    return data
33 34 35 36 37 38 39 40

def unescape(data, entities={}):
    """Unescape &amp;, &lt;, and &gt; in a string of data.

    You can unescape other strings of data by passing a dictionary as
    the optional entities parameter.  The keys and values must all be
    strings; each key will be replaced with its corresponding value.
    """
41 42 43 44
    data = data.replace("&lt;", "<")
    data = data.replace("&gt;", ">")
    if entities:
        data = __dict_replace(data, entities)
45
    # must do ampersand last
Fred Drake's avatar
Fred Drake committed
46
    return data.replace("&amp;", "&")
47

48 49 50 51 52 53 54 55 56 57 58
def quoteattr(data, entities={}):
    """Escape and quote an attribute value.

    Escape &, <, and > in a string of data, then quote it for use as
    an attribute value.  The \" character will be escaped as well, if
    necessary.

    You can escape other strings of data by passing a dictionary as
    the optional entities parameter.  The keys and values must all be
    strings; each key will be replaced with its corresponding value.
    """
59 60
    entities = entities.copy()
    entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
61 62 63 64 65 66 67 68 69 70
    data = escape(data, entities)
    if '"' in data:
        if "'" in data:
            data = '"%s"' % data.replace('"', "&quot;")
        else:
            data = "'%s'" % data
    else:
        data = '"%s"' % data
    return data

71

72 73 74 75 76 77 78 79 80
def _gettextwriter(out, encoding):
    if out is None:
        import sys
        return sys.stdout

    if isinstance(out, io.TextIOBase):
        # use a text writer as is
        return out

81 82 83 84
    if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)):
        # use a codecs stream writer as is
        return out

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    # wrap a binary writer with TextIOWrapper
    if isinstance(out, io.RawIOBase):
        # Keep the original file open when the TextIOWrapper is
        # destroyed
        class _wrapper:
            __class__ = out.__class__
            def __getattr__(self, name):
                return getattr(out, name)
        buffer = _wrapper()
        buffer.close = lambda: None
    else:
        # This is to handle passed objects that aren't in the
        # IOBase hierarchy, but just have a write method
        buffer = io.BufferedIOBase()
        buffer.writable = lambda: True
        buffer.write = out.write
        try:
            # TextIOWrapper uses this methods to determine
            # if BOM (for UTF-16, etc) should be added
            buffer.seekable = out.seekable
            buffer.tell = out.tell
        except AttributeError:
            pass
    return io.TextIOWrapper(buffer, encoding=encoding,
                            errors='xmlcharrefreplace',
                            newline='\n',
                            write_through=True)

113 114
class XMLGenerator(handler.ContentHandler):

115
    def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
116
        handler.ContentHandler.__init__(self)
117 118 119
        out = _gettextwriter(out, encoding)
        self._write = out.write
        self._flush = out.flush
120 121
        self._ns_contexts = [{}] # contains uri -> prefix dicts
        self._current_context = self._ns_contexts[-1]
122
        self._undeclared_ns_maps = []
123
        self._encoding = encoding
124 125
        self._short_empty_elements = short_empty_elements
        self._pending_start_element = False
126

127 128 129
    def _qname(self, name):
        """Builds a qualified name from a (ns_url, localname) pair"""
        if name[0]:
130 131 132 133 134 135
            # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
            # bound by definition to http://www.w3.org/XML/1998/namespace.  It
            # does not need to be declared and will not usually be found in
            # self._current_context.
            if 'http://www.w3.org/XML/1998/namespace' == name[0]:
                return 'xml:' + name[1]
136 137 138 139 140 141 142 143
            # The name is in a non-empty namespace
            prefix = self._current_context[name[0]]
            if prefix:
                # If it is not the default namespace, prepend the prefix
                return prefix + ":" + name[1]
        # Return the unqualified name
        return name[1]

144 145 146 147 148
    def _finish_pending_start_element(self,endElement=False):
        if self._pending_start_element:
            self._write('>')
            self._pending_start_element = False

149
    # ContentHandler methods
150

151
    def startDocument(self):
152
        self._write('<?xml version="1.0" encoding="%s"?>\n' %
153
                        self._encoding)
154

155 156 157
    def endDocument(self):
        self._flush()

158
    def startPrefixMapping(self, prefix, uri):
159 160
        self._ns_contexts.append(self._current_context.copy())
        self._current_context[uri] = prefix
161
        self._undeclared_ns_maps.append((prefix, uri))
162 163

    def endPrefixMapping(self, prefix):
164 165
        self._current_context = self._ns_contexts[-1]
        del self._ns_contexts[-1]
166 167

    def startElement(self, name, attrs):
168
        self._finish_pending_start_element()
169
        self._write('<' + name)
170
        for (name, value) in attrs.items():
171
            self._write(' %s=%s' % (name, quoteattr(value)))
172 173 174 175
        if self._short_empty_elements:
            self._pending_start_element = True
        else:
            self._write(">")
176

177
    def endElement(self, name):
178 179 180 181 182
        if self._pending_start_element:
            self._write('/>')
            self._pending_start_element = False
        else:
            self._write('</%s>' % name)
183

184
    def startElementNS(self, name, qname, attrs):
185
        self._finish_pending_start_element()
186
        self._write('<' + self._qname(name))
187

188 189
        for prefix, uri in self._undeclared_ns_maps:
            if prefix:
190
                self._write(' xmlns:%s="%s"' % (prefix, uri))
191
            else:
192
                self._write(' xmlns="%s"' % uri)
193
        self._undeclared_ns_maps = []
194

195
        for (name, value) in attrs.items():
196
            self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
197 198 199 200
        if self._short_empty_elements:
            self._pending_start_element = True
        else:
            self._write(">")
201 202

    def endElementNS(self, name, qname):
203 204 205 206 207
        if self._pending_start_element:
            self._write('/>')
            self._pending_start_element = False
        else:
            self._write('</%s>' % self._qname(name))
208

209
    def characters(self, content):
210 211
        if content:
            self._finish_pending_start_element()
212 213
            if not isinstance(content, str):
                content = str(content, self._encoding)
214
            self._write(escape(content))
215 216

    def ignorableWhitespace(self, content):
217 218
        if content:
            self._finish_pending_start_element()
219 220
            if not isinstance(content, str):
                content = str(content, self._encoding)
221
            self._write(content)
222

223
    def processingInstruction(self, target, data):
224
        self._finish_pending_start_element()
225
        self._write('<?%s %s?>' % (target, data))
226

227

228
class XMLFilterBase(xmlreader.XMLReader):
229 230 231 232 233 234 235
    """This class is designed to sit between an XMLReader and the
    client application's event handlers.  By default, it does nothing
    but pass requests up to the reader and events on to the handlers
    unmodified, but subclasses can override specific methods to modify
    the event stream or the configuration requests as they pass
    through."""

236 237 238
    def __init__(self, parent = None):
        xmlreader.XMLReader.__init__(self)
        self._parent = parent
239

240 241 242 243 244 245 246 247 248 249 250 251
    # ErrorHandler methods

    def error(self, exception):
        self._err_handler.error(exception)

    def fatalError(self, exception):
        self._err_handler.fatalError(exception)

    def warning(self, exception):
        self._err_handler.warning(exception)

    # ContentHandler methods
252

253 254
    def setDocumentLocator(self, locator):
        self._cont_handler.setDocumentLocator(locator)
255

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    def startDocument(self):
        self._cont_handler.startDocument()

    def endDocument(self):
        self._cont_handler.endDocument()

    def startPrefixMapping(self, prefix, uri):
        self._cont_handler.startPrefixMapping(prefix, uri)

    def endPrefixMapping(self, prefix):
        self._cont_handler.endPrefixMapping(prefix)

    def startElement(self, name, attrs):
        self._cont_handler.startElement(name, attrs)

271 272 273 274
    def endElement(self, name):
        self._cont_handler.endElement(name)

    def startElementNS(self, name, qname, attrs):
275
        self._cont_handler.startElementNS(name, qname, attrs)
276 277 278

    def endElementNS(self, name, qname):
        self._cont_handler.endElementNS(name, qname)
279 280 281 282

    def characters(self, content):
        self._cont_handler.characters(content)

283 284
    def ignorableWhitespace(self, chars):
        self._cont_handler.ignorableWhitespace(chars)
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302

    def processingInstruction(self, target, data):
        self._cont_handler.processingInstruction(target, data)

    def skippedEntity(self, name):
        self._cont_handler.skippedEntity(name)

    # DTDHandler methods

    def notationDecl(self, name, publicId, systemId):
        self._dtd_handler.notationDecl(name, publicId, systemId)

    def unparsedEntityDecl(self, name, publicId, systemId, ndata):
        self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)

    # EntityResolver methods

    def resolveEntity(self, publicId, systemId):
303
        return self._ent_handler.resolveEntity(publicId, systemId)
304 305 306 307 308 309 310 311 312 313 314 315

    # XMLReader methods

    def parse(self, source):
        self._parent.setContentHandler(self)
        self._parent.setErrorHandler(self)
        self._parent.setEntityResolver(self)
        self._parent.setDTDHandler(self)
        self._parent.parse(source)

    def setLocale(self, locale):
        self._parent.setLocale(locale)
316

317 318 319 320 321 322 323 324 325 326 327
    def getFeature(self, name):
        return self._parent.getFeature(name)

    def setFeature(self, name, state):
        self._parent.setFeature(name, state)

    def getProperty(self, name):
        return self._parent.getProperty(name)

    def setProperty(self, name, value):
        self._parent.setProperty(name, value)
328

329 330 331 332 333 334 335 336
    # XMLFilter methods

    def getParent(self):
        return self._parent

    def setParent(self, parent):
        self._parent = parent

337 338
# --- Utility functions

339
def prepare_input_source(source, base=""):
340 341
    """This function takes an InputSource and an optional base URL and
    returns a fully resolved InputSource object ready for reading."""
342

343
    if isinstance(source, str):
Martin v. Löwis's avatar
Martin v. Löwis committed
344 345 346
        source = xmlreader.InputSource(source)
    elif hasattr(source, "read"):
        f = source
347
        source = xmlreader.InputSource()
Martin v. Löwis's avatar
Martin v. Löwis committed
348
        source.setByteStream(f)
349
        if hasattr(f, "name") and isinstance(f.name, str):
350
            source.setSystemId(f.name)
351

352
    if source.getByteStream() is None:
353
        sysid = source.getSystemId()
354
        basehead = os.path.dirname(os.path.normpath(base))
355 356 357 358
        sysidfilename = os.path.join(basehead, sysid)
        if os.path.isfile(sysidfilename):
            source.setSystemId(sysidfilename)
            f = open(sysidfilename, "rb")
359
        else:
360 361
            source.setSystemId(urllib.parse.urljoin(base, sysid))
            f = urllib.request.urlopen(source.getSystemId())
362

363
        source.setByteStream(f)
364

365
    return source