zipfile.py 52.4 KB
Newer Older
1 2
"""
Read and write ZIP files.
3 4

XXX references to utf-8 need further investigation.
5
"""
6
import io
Barry Warsaw's avatar
Barry Warsaw committed
7
import os
8
import re
Barry Warsaw's avatar
Barry Warsaw committed
9 10 11 12 13 14 15 16
import imp
import sys
import time
import stat
import shutil
import struct
import binascii

17 18

try:
19
    import zlib # We may need its compression method
Christian Heimes's avatar
Christian Heimes committed
20
    crc32 = zlib.crc32
21
except ImportError:
22
    zlib = None
Christian Heimes's avatar
Christian Heimes committed
23
    crc32 = binascii.crc32
24

25
__all__ = ["BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "is_zipfile",
26
           "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile" ]
27

28
class BadZipfile(Exception):
29
    pass
30 31 32 33 34 35 36 37


class LargeZipFile(Exception):
    """
    Raised when writing a zipfile, the zipfile requires ZIP64 extensions
    and those extensions are disabled.
    """

38
error = BadZipfile      # The exception raised by this module
39

40
ZIP64_LIMIT = (1 << 31) - 1
41 42
ZIP_FILECOUNT_LIMIT = 1 << 16
ZIP_MAX_COMMENT = (1 << 16) - 1
43

44 45 46 47 48
# constants for Zip file compression methods
ZIP_STORED = 0
ZIP_DEFLATED = 8
# Other ZIP compression methods not supported

49 50 51 52 53 54 55 56
# Below are some formats and associated data for reading/writing headers using
# the struct module.  The names and structures of headers/records are those used
# in the PKWARE description of the ZIP file format:
#     http://www.pkware.com/documents/casestudies/APPNOTE.TXT
# (URL valid as of January 2008)

# The "end of central directory" structure, magic number, size, and indices
# (section V.I in the format document)
Georg Brandl's avatar
Georg Brandl committed
57 58 59
structEndArchive = b"<4s4H2LH"
stringEndArchive = b"PK\005\006"
sizeEndCentDir = struct.calcsize(structEndArchive)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

_ECD_SIGNATURE = 0
_ECD_DISK_NUMBER = 1
_ECD_DISK_START = 2
_ECD_ENTRIES_THIS_DISK = 3
_ECD_ENTRIES_TOTAL = 4
_ECD_SIZE = 5
_ECD_OFFSET = 6
_ECD_COMMENT_SIZE = 7
# These last two indices are not part of the structure as defined in the
# spec, but they are used internally by this module as a convenience
_ECD_COMMENT = 8
_ECD_LOCATION = 9

# The "central directory" structure, magic number, size, and indices
# of entries in the structure (section V.F in the format document)
structCentralDir = "<4s4B4HL2L5H2L"
Georg Brandl's avatar
Georg Brandl committed
77
stringCentralDir = b"PK\001\002"
78 79
sizeCentralDir = struct.calcsize(structCentralDir)

80 81 82 83 84
# indexes of entries in the central directory structure
_CD_SIGNATURE = 0
_CD_CREATE_VERSION = 1
_CD_CREATE_SYSTEM = 2
_CD_EXTRACT_VERSION = 3
85
_CD_EXTRACT_SYSTEM = 4
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
_CD_FLAG_BITS = 5
_CD_COMPRESS_TYPE = 6
_CD_TIME = 7
_CD_DATE = 8
_CD_CRC = 9
_CD_COMPRESSED_SIZE = 10
_CD_UNCOMPRESSED_SIZE = 11
_CD_FILENAME_LENGTH = 12
_CD_EXTRA_FIELD_LENGTH = 13
_CD_COMMENT_LENGTH = 14
_CD_DISK_NUMBER_START = 15
_CD_INTERNAL_FILE_ATTRIBUTES = 16
_CD_EXTERNAL_FILE_ATTRIBUTES = 17
_CD_LOCAL_HEADER_OFFSET = 18

101 102 103
# The "local file header" structure, magic number, size, and indices
# (section V.A in the format document)
structFileHeader = "<4s2B4HL2L2H"
Georg Brandl's avatar
Georg Brandl committed
104
stringFileHeader = b"PK\003\004"
105 106
sizeFileHeader = struct.calcsize(structFileHeader)

107 108
_FH_SIGNATURE = 0
_FH_EXTRACT_VERSION = 1
109
_FH_EXTRACT_SYSTEM = 2
110 111 112 113 114 115 116 117 118 119
_FH_GENERAL_PURPOSE_FLAG_BITS = 3
_FH_COMPRESSION_METHOD = 4
_FH_LAST_MOD_TIME = 5
_FH_LAST_MOD_DATE = 6
_FH_CRC = 7
_FH_COMPRESSED_SIZE = 8
_FH_UNCOMPRESSED_SIZE = 9
_FH_FILENAME_LENGTH = 10
_FH_EXTRA_FIELD_LENGTH = 11

120
# The "Zip64 end of central directory locator" structure, magic number, and size
Georg Brandl's avatar
Georg Brandl committed
121 122 123
structEndArchive64Locator = "<4sLQL"
stringEndArchive64Locator = b"PK\x06\x07"
sizeEndCentDir64Locator = struct.calcsize(structEndArchive64Locator)
124 125 126

# The "Zip64 end of central directory" record, magic number, size, and indices
# (section V.G in the format document)
Georg Brandl's avatar
Georg Brandl committed
127 128 129
structEndArchive64 = "<4sQ2H2L4Q"
stringEndArchive64 = b"PK\x06\x06"
sizeEndCentDir64 = struct.calcsize(structEndArchive64)
130 131 132 133 134 135 136 137 138 139 140 141

_CD64_SIGNATURE = 0
_CD64_DIRECTORY_RECSIZE = 1
_CD64_CREATE_VERSION = 2
_CD64_EXTRACT_VERSION = 3
_CD64_DISK_NUMBER = 4
_CD64_DISK_NUMBER_START = 5
_CD64_NUMBER_ENTRIES_THIS_DISK = 6
_CD64_NUMBER_ENTRIES_TOTAL = 7
_CD64_DIRECTORY_SIZE = 8
_CD64_OFFSET_START_CENTDIR = 9

142
def _check_zipfile(fp):
143
    try:
144 145
        if _EndRecData(fp):
            return True         # file has correct magic number
146
    except IOError:
147
        pass
148
    return False
149

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
def is_zipfile(filename):
    """Quickly see if a file is a ZIP file by checking the magic number.

    The filename argument may be a file or file-like object too.
    """
    result = False
    try:
        if hasattr(filename, "read"):
            result = _check_zipfile(fp=filename)
        else:
            with open(filename, "rb") as fp:
                result = _check_zipfile(fp)
    except IOError:
        pass
    return result

166 167 168 169
def _EndRecData64(fpin, offset, endrec):
    """
    Read the ZIP64 end-of-archive records and use that to update endrec
    """
170 171
    fpin.seek(offset - sizeEndCentDir64Locator, 2)
    data = fpin.read(sizeEndCentDir64Locator)
Georg Brandl's avatar
Georg Brandl committed
172 173
    sig, diskno, reloff, disks = struct.unpack(structEndArchive64Locator, data)
    if sig != stringEndArchive64Locator:
174 175 176 177 178 179
        return endrec

    if diskno != 0 or disks != 1:
        raise BadZipfile("zipfiles that span multiple disks are not supported")

    # Assume no 'zip64 extensible data'
180 181
    fpin.seek(offset - sizeEndCentDir64Locator - sizeEndCentDir64, 2)
    data = fpin.read(sizeEndCentDir64)
182 183
    sig, sz, create_version, read_version, disk_num, disk_dir, \
            dircount, dircount2, dirsize, diroffset = \
Georg Brandl's avatar
Georg Brandl committed
184 185
            struct.unpack(structEndArchive64, data)
    if sig != stringEndArchive64:
186 187 188
        return endrec

    # Update the original endrec using data from the ZIP64 record
189
    endrec[_ECD_SIGNATURE] = sig
190 191 192 193 194 195
    endrec[_ECD_DISK_NUMBER] = disk_num
    endrec[_ECD_DISK_START] = disk_dir
    endrec[_ECD_ENTRIES_THIS_DISK] = dircount
    endrec[_ECD_ENTRIES_TOTAL] = dircount2
    endrec[_ECD_SIZE] = dirsize
    endrec[_ECD_OFFSET] = diroffset
196 197 198
    return endrec


199 200 201 202 203
def _EndRecData(fpin):
    """Return data from the "End of Central Directory" record, or None.

    The data is a list of the nine items in the ZIP "End of central dir"
    record followed by a tenth item, the file seek offset of this record."""
204 205 206 207 208 209 210 211

    # Determine file size
    fpin.seek(0, 2)
    filesize = fpin.tell()

    # Check to see if this is ZIP file with no archive comment (the
    # "end of central directory" structure should be the last item in the
    # file if this is the case).
212 213 214 215
    try:
        fpin.seek(-sizeEndCentDir, 2)
    except IOError:
        return None
216
    data = fpin.read()
Georg Brandl's avatar
Georg Brandl committed
217
    if data[0:4] == stringEndArchive and data[-2:] == b"\000\000":
218
        # the signature is correct and there's no comment, unpack structure
Georg Brandl's avatar
Georg Brandl committed
219
        endrec = struct.unpack(structEndArchive, data)
220 221 222 223 224 225
        endrec=list(endrec)

        # Append a blank comment and record start offset
        endrec.append(b"")
        endrec.append(filesize - sizeEndCentDir)

226 227
        # Try to read the "Zip64 end of central directory" structure
        return _EndRecData64(fpin, -sizeEndCentDir, endrec)
228 229 230 231 232 233 234 235

    # Either this is not a ZIP file, or it is a ZIP file with an archive
    # comment.  Search the end of the file for the "end of central directory"
    # record signature. The comment is the last item in the ZIP file and may be
    # up to 64K long.  It is assumed that the "end of central directory" magic
    # number does not appear in the comment.
    maxCommentStart = max(filesize - (1 << 16) - sizeEndCentDir, 0)
    fpin.seek(maxCommentStart, 0)
236
    data = fpin.read()
Georg Brandl's avatar
Georg Brandl committed
237
    start = data.rfind(stringEndArchive)
238 239 240
    if start >= 0:
        # found the magic number; attempt to unpack and interpret
        recData = data[start:start+sizeEndCentDir]
Georg Brandl's avatar
Georg Brandl committed
241
        endrec = list(struct.unpack(structEndArchive, recData))
242 243 244
        comment = data[start+sizeEndCentDir:]
        # check that comment length is correct
        if endrec[_ECD_COMMENT_SIZE] == len(comment):
245 246
            # Append the archive comment and start offset
            endrec.append(comment)
247
            endrec.append(maxCommentStart + start)
248 249 250 251

            # Try to read the "Zip64 end of central directory" structure
            return _EndRecData64(fpin, maxCommentStart + start - filesize,
                                 endrec)
252 253 254

    # Unable to find a valid end of central directory structure
    return
255

256

257
class ZipInfo (object):
258 259
    """Class with attributes describing each file in the ZIP archive."""

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    __slots__ = (
            'orig_filename',
            'filename',
            'date_time',
            'compress_type',
            'comment',
            'extra',
            'create_system',
            'create_version',
            'extract_version',
            'reserved',
            'flag_bits',
            'volume',
            'internal_attr',
            'external_attr',
            'header_offset',
            'CRC',
            'compress_size',
            'file_size',
279
            '_raw_time',
280 281
        )

282
    def __init__(self, filename="NoName", date_time=(1980,1,1,0,0,0)):
283
        self.orig_filename = filename   # Original file name in archive
284 285 286

        # Terminate the file name at the first null byte.  Null bytes in file
        # names are used as tricks by viruses in archives.
287 288 289
        null_byte = filename.find(chr(0))
        if null_byte >= 0:
            filename = filename[0:null_byte]
290 291 292 293
        # This is used to ensure paths in generated ZIP files always use
        # forward slashes as the directory separator, as required by the
        # ZIP format specification.
        if os.sep != "/" and os.sep in filename:
294
            filename = filename.replace(os.sep, "/")
295

296
        self.filename = filename        # Normalized file name
297
        self.date_time = date_time      # year, month, day, hour, min, sec
298
        # Standard values:
299
        self.compress_type = ZIP_STORED # Type of compression for the file
300 301
        self.comment = b""              # Comment for each file
        self.extra = b""                # ZIP extra data
302 303 304 305 306
        if sys.platform == 'win32':
            self.create_system = 0          # System which created ZIP archive
        else:
            # Assume everything else is unix-y
            self.create_system = 3          # System which created ZIP archive
307 308 309 310 311 312 313
        self.create_version = 20        # Version which created ZIP archive
        self.extract_version = 20       # Version needed to extract archive
        self.reserved = 0               # Must be zero
        self.flag_bits = 0              # ZIP flag bits
        self.volume = 0                 # Volume number of file header
        self.internal_attr = 0          # Internal attributes
        self.external_attr = 0          # External file attributes
314
        # Other attributes are set by class ZipFile:
315 316 317 318
        # header_offset         Byte offset to the file header
        # CRC                   CRC-32 of the uncompressed file
        # compress_size         Size of the compressed file
        # file_size             Size of the uncompressed file
319 320

    def FileHeader(self):
321
        """Return the per-file header as a string."""
322 323
        dt = self.date_time
        dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
324
        dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
325
        if self.flag_bits & 0x08:
326 327
            # Set these to zero because we write them after the file data
            CRC = compress_size = file_size = 0
328
        else:
329 330 331
            CRC = self.CRC
            compress_size = self.compress_size
            file_size = self.file_size
332 333 334 335 336 337

        extra = self.extra

        if file_size > ZIP64_LIMIT or compress_size > ZIP64_LIMIT:
            # File is larger than what fits into a 4 byte integer,
            # fall back to the ZIP64 extension
Christian Heimes's avatar
Christian Heimes committed
338
            fmt = '<HHQQ'
339 340
            extra = extra + struct.pack(fmt,
                    1, struct.calcsize(fmt)-4, file_size, compress_size)
341 342
            file_size = 0xffffffff
            compress_size = 0xffffffff
343 344 345
            self.extract_version = max(45, self.extract_version)
            self.create_version = max(45, self.extract_version)

346
        filename, flag_bits = self._encodeFilenameFlags()
Georg Brandl's avatar
Georg Brandl committed
347
        header = struct.pack(structFileHeader, stringFileHeader,
348
                 self.extract_version, self.reserved, flag_bits,
349 350
                 self.compress_type, dostime, dosdate, CRC,
                 compress_size, file_size,
351 352 353 354 355 356 357 358
                 len(filename), len(extra))
        return header + filename + extra

    def _encodeFilenameFlags(self):
        try:
            return self.filename.encode('ascii'), self.flag_bits
        except UnicodeEncodeError:
            return self.filename.encode('utf-8'), self.flag_bits | 0x800
359 360 361 362 363 364

    def _decodeExtra(self):
        # Try to decode the extra field.
        extra = self.extra
        unpack = struct.unpack
        while extra:
Christian Heimes's avatar
Christian Heimes committed
365
            tp, ln = unpack('<HH', extra[:4])
366 367
            if tp == 1:
                if ln >= 24:
Christian Heimes's avatar
Christian Heimes committed
368
                    counts = unpack('<QQQ', extra[4:28])
369
                elif ln == 16:
Christian Heimes's avatar
Christian Heimes committed
370
                    counts = unpack('<QQ', extra[4:20])
371
                elif ln == 8:
Christian Heimes's avatar
Christian Heimes committed
372
                    counts = unpack('<Q', extra[4:12])
373 374 375
                elif ln == 0:
                    counts = ()
                else:
376
                    raise RuntimeError("Corrupt extra field %s"%(ln,))
377 378 379 380

                idx = 0

                # ZIP64 extension (large files and/or large archives)
Christian Heimes's avatar
Christian Heimes committed
381
                if self.file_size in (0xffffffffffffffff, 0xffffffff):
382 383 384
                    self.file_size = counts[idx]
                    idx += 1

385
                if self.compress_size == 0xFFFFFFFF:
386 387 388
                    self.compress_size = counts[idx]
                    idx += 1

389
                if self.header_offset == 0xffffffff:
390 391 392 393 394
                    old = self.header_offset
                    self.header_offset = counts[idx]
                    idx+=1

            extra = extra[ln+4:]
395 396


397 398 399 400 401
class _ZipDecrypter:
    """Class to handle decryption of files stored within a ZIP archive.

    ZIP supports a password-based form of encryption. Even though known
    plaintext attacks have been found against it, it is still useful
402
    to be able to get data out of such a file.
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431

    Usage:
        zd = _ZipDecrypter(mypwd)
        plain_char = zd(cypher_char)
        plain_text = map(zd, cypher_text)
    """

    def _GenerateCRCTable():
        """Generate a CRC-32 table.

        ZIP encryption uses the CRC32 one-byte primitive for scrambling some
        internal keys. We noticed that a direct implementation is faster than
        relying on binascii.crc32().
        """
        poly = 0xedb88320
        table = [0] * 256
        for i in range(256):
            crc = i
            for j in range(8):
                if crc & 1:
                    crc = ((crc >> 1) & 0x7FFFFFFF) ^ poly
                else:
                    crc = ((crc >> 1) & 0x7FFFFFFF)
            table[i] = crc
        return table
    crctable = _GenerateCRCTable()

    def _crc32(self, ch, crc):
        """Compute the CRC32 primitive on one byte."""
432
        return ((crc >> 8) & 0xffffff) ^ self.crctable[(crc ^ ch) & 0xff]
433 434 435 436 437 438 439 440 441 442 443 444

    def __init__(self, pwd):
        self.key0 = 305419896
        self.key1 = 591751049
        self.key2 = 878082192
        for p in pwd:
            self._UpdateKeys(p)

    def _UpdateKeys(self, c):
        self.key0 = self._crc32(c, self.key0)
        self.key1 = (self.key1 + (self.key0 & 255)) & 4294967295
        self.key1 = (self.key1 * 134775813 + 1) & 4294967295
445
        self.key2 = self._crc32((self.key1 >> 24) & 255, self.key2)
446 447 448

    def __call__(self, c):
        """Decrypt a single character."""
449
        assert isinstance(c, int)
450 451 452 453 454
        k = self.key2 | 2
        c = c ^ (((k * (k^1)) >> 8) & 255)
        self._UpdateKeys(c)
        return c

455
class ZipExtFile(io.BufferedIOBase):
456 457 458 459
    """File-like object for reading an archive member.
       Is returned by ZipFile.open().
    """

460 461
    # Max size supported by decompressor.
    MAX_N = 1 << 31 - 1
462

463 464
    # Read from compressed files in 4k blocks.
    MIN_READ_SIZE = 4096
465

466 467
    # Search for universal newlines or line chunks.
    PATTERN = re.compile(br'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)')
468

469 470 471
    def __init__(self, fileobj, mode, zipinfo, decrypter=None):
        self._fileobj = fileobj
        self._decrypter = decrypter
472

473 474 475 476 477 478
        self._compress_type = zipinfo.compress_type
        self._compress_size = zipinfo.compress_size
        self._compress_left = zipinfo.compress_size

        if self._compress_type == ZIP_DEFLATED:
            self._decompressor = zlib.decompressobj(-15)
479
        self._unconsumed = b''
480

481 482
        self._readbuffer = b''
        self._offset = 0
483

484 485
        self._universal = 'U' in mode
        self.newlines = None
486

487 488 489 490
        # Adjust read size for encrypted files since the first 12 bytes
        # are for the encryption/password information.
        if self._decrypter is not None:
            self._compress_left -= 12
491

492 493
        self.mode = mode
        self.name = zipinfo.filename
494

495 496
    def readline(self, limit=-1):
        """Read and return a line from the stream.
497

498
        If limit is specified, at most limit bytes will be read.
499 500
        """

501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
        if not self._universal and limit < 0:
            # Shortcut common case - newline found in buffer.
            i = self._readbuffer.find(b'\n', self._offset) + 1
            if i > 0:
                line = self._readbuffer[self._offset: i]
                self._offset = i
                return line

        if not self._universal:
            return io.BufferedIOBase.readline(self, limit)

        line = b''
        while limit < 0 or len(line) < limit:
            readahead = self.peek(2)
            if readahead == b'':
                return line

            #
            # Search for universal newlines or line chunks.
            #
            # The pattern returns either a line chunk or a newline, but not
            # both. Combined with peek(2), we are assured that the sequence
            # '\r\n' is always retrieved completely and never split into
            # separate newlines - '\r', '\n' due to coincidental readaheads.
            #
            match = self.PATTERN.search(readahead)
            newline = match.group('newline')
            if newline is not None:
                if self.newlines is None:
                    self.newlines = []
                if newline not in self.newlines:
                    self.newlines.append(newline)
                self._offset += len(newline)
                return line + b'\n'

            chunk = match.group('chunk')
            if limit >= 0:
                chunk = chunk[: limit - len(line)]

            self._offset += len(chunk)
            line += chunk

        return line

    def peek(self, n=1):
        """Returns buffered bytes without advancing the position."""
        if n > len(self._readbuffer) - self._offset:
            chunk = self.read(n)
            self._offset -= len(chunk)

        # Return up to 512 bytes to reduce allocation overhead for tight loops.
        return self._readbuffer[self._offset: self._offset + 512]

    def readable(self):
        return True

    def read(self, n=-1):
        """Read and return up to n bytes.
        If the argument is omitted, None, or negative, data is read and returned until EOF is reached..
560 561
        """

562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
        buf = b''
        while n < 0 or n is None or n > len(buf):
            data = self.read1(n)
            if len(data) == 0:
                return buf

            buf += data

        return buf

    def read1(self, n):
        """Read up to n bytes with at most one read() system call."""

        # Simplify algorithm (branching) by transforming negative n to large n.
        if n < 0 or n is None:
            n = self.MAX_N

        # Bytes available in read buffer.
        len_readbuffer = len(self._readbuffer) - self._offset

        # Read from file.
        if self._compress_left > 0 and n > len_readbuffer + len(self._unconsumed):
            nbytes = n - len_readbuffer - len(self._unconsumed)
            nbytes = max(nbytes, self.MIN_READ_SIZE)
            nbytes = min(nbytes, self._compress_left)

            data = self._fileobj.read(nbytes)
            self._compress_left -= len(data)

            if data and self._decrypter is not None:
                data = bytes(map(self._decrypter, data))

            if self._compress_type == ZIP_STORED:
                self._readbuffer = self._readbuffer[self._offset:] + data
                self._offset = 0
            else:
                # Prepare deflated bytes for decompression.
                self._unconsumed += data

        # Handle unconsumed data.
602 603
        if (len(self._unconsumed) > 0 and n > len_readbuffer and
            self._compress_type == ZIP_DEFLATED):
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
            data = self._decompressor.decompress(
                self._unconsumed,
                max(n - len_readbuffer, self.MIN_READ_SIZE)
            )

            self._unconsumed = self._decompressor.unconsumed_tail
            if len(self._unconsumed) == 0 and self._compress_left == 0:
                data += self._decompressor.flush()

            self._readbuffer = self._readbuffer[self._offset:] + data
            self._offset = 0

        # Read from buffer.
        data = self._readbuffer[self._offset: self._offset + n]
        self._offset += len(data)
619
        return data
620 621


622

623
class ZipFile:
Tim Peters's avatar
Tim Peters committed
624 625
    """ Class with methods to open, read, write, close, list zip files.

626
    z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=False)
Tim Peters's avatar
Tim Peters committed
627

628 629 630 631
    file: Either the path to the file, or a file-like object.
          If it is a path, the file will be opened and closed by ZipFile.
    mode: The mode can be either read "r", write "w" or append "a".
    compression: ZIP_STORED (no compression) or ZIP_DEFLATED (requires zlib).
632 633 634 635
    allowZip64: if True ZipFile will create files with ZIP64 extensions when
                needed, otherwise it will raise an exception when this would
                be necessary.

636
    """
637

638 639
    fp = None                   # Set here since __del__ checks it

640
    def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False):
641
        """Open the ZIP file with mode read "r", write "w" or append "a"."""
642 643 644
        if mode not in ("r", "w", "a"):
            raise RuntimeError('ZipFile() requires mode "r", "w", or "a"')

645 646 647 648
        if compression == ZIP_STORED:
            pass
        elif compression == ZIP_DEFLATED:
            if not zlib:
649 650
                raise RuntimeError(
                      "Compression requires the (missing) zlib module")
651
        else:
652
            raise RuntimeError("That compression method is not supported")
653 654 655

        self._allowZip64 = allowZip64
        self._didModify = False
656 657 658 659
        self.debug = 0  # Level of printing: 0 through 3
        self.NameToInfo = {}    # Find file info given name
        self.filelist = []      # List of ZipInfo instances for archive
        self.compression = compression  # Method of compression
660
        self.mode = key = mode.replace('b', '')[0]
661
        self.pwd = None
662
        self.comment = b''
Tim Peters's avatar
Tim Peters committed
663

664
        # Check if we were passed a file-like object
665
        if isinstance(file, str):
666
            # No, it's a filename
667 668 669
            self._filePassed = 0
            self.filename = file
            modeDict = {'r' : 'rb', 'w': 'wb', 'a' : 'r+b'}
670
            try:
671
                self.fp = io.open(file, modeDict[mode])
672 673 674
            except IOError:
                if mode == 'a':
                    mode = key = 'w'
675
                    self.fp = io.open(file, modeDict[mode])
676 677
                else:
                    raise
678 679 680 681
        else:
            self._filePassed = 1
            self.fp = file
            self.filename = getattr(file, 'name', None)
Tim Peters's avatar
Tim Peters committed
682

683 684 685
        if key == 'r':
            self._GetContents()
        elif key == 'w':
686
            pass
687
        elif key == 'a':
688 689
            try:                        # See if file is a zip file
                self._RealGetContents()
690
                # seek to start of directory and overwrite
691 692 693
                self.fp.seek(self.start_dir, 0)
            except BadZipfile:          # file is not a zip file, just append
                self.fp.seek(0, 2)
694
        else:
695 696 697
            if not self._filePassed:
                self.fp.close()
                self.fp = None
698
            raise RuntimeError('Mode must be "r", "w" or "a"')
699

700 701 702 703 704 705
    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

706
    def _GetContents(self):
707 708 709 710 711 712 713 714 715 716 717
        """Read the directory, making sure we close the file if the format
        is bad."""
        try:
            self._RealGetContents()
        except BadZipfile:
            if not self._filePassed:
                self.fp.close()
                self.fp = None
            raise

    def _RealGetContents(self):
718
        """Read in the table of contents for the ZIP file."""
719
        fp = self.fp
720 721
        endrec = _EndRecData(fp)
        if not endrec:
722
            raise BadZipfile("File is not a zip file")
723
        if self.debug > 1:
724
            print(endrec)
725 726 727 728
        size_cd = endrec[_ECD_SIZE]             # bytes in central directory
        offset_cd = endrec[_ECD_OFFSET]         # offset of central directory
        self.comment = endrec[_ECD_COMMENT]     # archive comment

729
        # "concat" is zero, unless zip was concatenated to another file
730
        concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
731 732
        if endrec[_ECD_SIGNATURE] == stringEndArchive64:
            # If Zip64 extension structures are present, account for them
733 734
            concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)

735
        if self.debug > 2:
736 737
            inferred = concat + offset_cd
            print("given, inferred, offset", offset_cd, inferred, concat)
738 739 740
        # self.start_dir:  Position of start of central directory
        self.start_dir = offset_cd + concat
        fp.seek(self.start_dir, 0)
741
        data = fp.read(size_cd)
742
        fp = io.BytesIO(data)
743 744
        total = 0
        while total < size_cd:
745
            centdir = fp.read(sizeCentralDir)
Georg Brandl's avatar
Georg Brandl committed
746
            if centdir[0:4] != stringCentralDir:
747
                raise BadZipfile("Bad magic number for central directory")
748 749
            centdir = struct.unpack(structCentralDir, centdir)
            if self.debug > 2:
750
                print(centdir)
751
            filename = fp.read(centdir[_CD_FILENAME_LENGTH])
752 753 754 755 756 757 758
            flags = centdir[5]
            if flags & 0x800:
                # UTF-8 file names extension
                filename = filename.decode('utf-8')
            else:
                # Historical ZIP filename encoding
                filename = filename.decode('cp437')
759
            # Create ZipInfo instance to store file information
760
            x = ZipInfo(filename)
761 762
            x.extra = fp.read(centdir[_CD_EXTRA_FIELD_LENGTH])
            x.comment = fp.read(centdir[_CD_COMMENT_LENGTH])
763
            x.header_offset = centdir[_CD_LOCAL_HEADER_OFFSET]
764 765 766 767 768
            (x.create_version, x.create_system, x.extract_version, x.reserved,
                x.flag_bits, x.compress_type, t, d,
                x.CRC, x.compress_size, x.file_size) = centdir[1:12]
            x.volume, x.internal_attr, x.external_attr = centdir[15:18]
            # Convert date/time code to (year, month, day, hour, min, sec)
769
            x._raw_time = t
770
            x.date_time = ( (d>>9)+1980, (d>>5)&0xF, d&0x1F,
771
                                     t>>11, (t>>5)&0x3F, (t&0x1F) * 2 )
772 773 774

            x._decodeExtra()
            x.header_offset = x.header_offset + concat
775 776
            self.filelist.append(x)
            self.NameToInfo[x.filename] = x
777 778 779 780 781 782

            # update total bytes read from central directory
            total = (total + sizeCentralDir + centdir[_CD_FILENAME_LENGTH]
                     + centdir[_CD_EXTRA_FIELD_LENGTH]
                     + centdir[_CD_COMMENT_LENGTH])

783
            if self.debug > 2:
784
                print("total", total)
785

786 787

    def namelist(self):
788
        """Return a list of file names in the archive."""
789 790 791 792 793 794
        l = []
        for data in self.filelist:
            l.append(data.filename)
        return l

    def infolist(self):
795 796
        """Return a list of class ZipInfo instances for files in the
        archive."""
797 798
        return self.filelist

799
    def printdir(self, file=None):
800
        """Print a table of contents for the zip file."""
801 802
        print("%-46s %19s %12s" % ("File Name", "Modified    ", "Size"),
              file=file)
803
        for zinfo in self.filelist:
804
            date = "%d-%02d-%02d %02d:%02d:%02d" % zinfo.date_time[:6]
805 806
            print("%-46s %s %12d" % (zinfo.filename, date, zinfo.file_size),
                  file=file)
807 808

    def testzip(self):
809
        """Read all the files and check the CRC."""
810
        chunk_size = 2 ** 20
811 812
        for zinfo in self.filelist:
            try:
813 814 815 816 817
                # Read by chunks, to avoid an OverflowError or a
                # MemoryError with very large embedded files.
                f = self.open(zinfo.filename, "r")
                while f.read(chunk_size):     # Check CRC-32
                    pass
818
            except BadZipfile:
819 820 821
                return zinfo.filename

    def getinfo(self, name):
822
        """Return the instance of ZipInfo given 'name'."""
823 824 825 826 827 828
        info = self.NameToInfo.get(name)
        if info is None:
            raise KeyError(
                'There is no item named %r in the archive' % name)

        return info
829

830 831
    def setpassword(self, pwd):
        """Set default password for encrypted files."""
832
        assert isinstance(pwd, bytes)
833 834 835
        self.pwd = pwd

    def read(self, name, pwd=None):
836
        """Return file bytes (as a string) for name."""
837 838 839 840 841
        return self.open(name, "r", pwd).read()

    def open(self, name, mode="r", pwd=None):
        """Return file-like object for 'name'."""
        if mode not in ("r", "U", "rU"):
842
            raise RuntimeError('open() requires mode "r", "U", or "rU"')
843
        if not self.fp:
844 845
            raise RuntimeError(
                  "Attempt to read ZIP archive that was already closed")
846 847 848 849 850 851

        # Only open a new file for instances where we were not
        # given a file object in the constructor
        if self._filePassed:
            zef_file = self.fp
        else:
852
            zef_file = io.open(self.filename, 'rb')
853

Georg Brandl's avatar
Georg Brandl committed
854 855 856 857 858 859 860
        # Make sure we have an info object
        if isinstance(name, ZipInfo):
            # 'name' is already an info object
            zinfo = name
        else:
            # Get info object for name
            zinfo = self.getinfo(name)
861 862

        zef_file.seek(zinfo.header_offset, 0)
863 864

        # Skip the file header:
865
        fheader = zef_file.read(sizeFileHeader)
Georg Brandl's avatar
Georg Brandl committed
866
        if fheader[0:4] != stringFileHeader:
867
            raise BadZipfile("Bad magic number for file header")
868 869

        fheader = struct.unpack(structFileHeader, fheader)
870
        fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
871
        if fheader[_FH_EXTRA_FIELD_LENGTH]:
872
            zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
873

874
        if fname != zinfo.orig_filename.encode("utf-8"):
875 876 877
            raise BadZipfile(
                  'File name in directory %r and header %r differ.'
                  % (zinfo.orig_filename, fname))
878

879 880 881
        # check for encrypted flag & handle password
        is_encrypted = zinfo.flag_bits & 0x1
        zd = None
882
        if is_encrypted:
883 884 885
            if not pwd:
                pwd = self.pwd
            if not pwd:
886 887
                raise RuntimeError("File %s is encrypted, "
                                   "password required for extraction" % name)
888

889 890 891 892
            zd = _ZipDecrypter(pwd)
            # The first 12 bytes in the cypher stream is an encryption header
            #  used to strengthen the algorithm. The first 11 bytes are
            #  completely random, while the 12th contains the MSB of the CRC,
893
            #  or the MSB of the file time depending on the header type
894
            #  and is used to check the correctness of the password.
895
            bytes = zef_file.read(12)
896
            h = list(map(zd, bytes[0:12]))
897 898 899 900 901 902 903 904
            if zinfo.flag_bits & 0x8:
                # compare against the file type from extended local headers
                check_byte = (zinfo._raw_time >> 8) & 0xff
            else:
                # compare against the CRC otherwise
                check_byte = (zinfo.CRC >> 24) & 0xff
            if h[11] != check_byte:
                raise RuntimeError("Bad password for file", name)
905

906
        return  ZipExtFile(zef_file, mode, zinfo, zd)
907

908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939
    def extract(self, member, path=None, pwd=None):
        """Extract a member from the archive to the current working directory,
           using its full name. Its file information is extracted as accurately
           as possible. `member' may be a filename or a ZipInfo object. You can
           specify a different directory using `path'.
        """
        if not isinstance(member, ZipInfo):
            member = self.getinfo(member)

        if path is None:
            path = os.getcwd()

        return self._extract_member(member, path, pwd)

    def extractall(self, path=None, members=None, pwd=None):
        """Extract all members from the archive to the current working
           directory. `path' specifies a different directory to extract to.
           `members' is optional and must be a subset of the list returned
           by namelist().
        """
        if members is None:
            members = self.namelist()

        for zipinfo in members:
            self.extract(zipinfo, path, pwd)

    def _extract_member(self, member, targetpath, pwd):
        """Extract the ZipInfo object 'member' to a physical
           file on the path targetpath.
        """
        # build the destination pathname, replacing
        # forward slashes to platform specific separators.
940 941 942
        # Strip trailing path separator, unless it represents the root.
        if (targetpath[-1:] in (os.path.sep, os.path.altsep)
            and len(os.path.splitdrive(targetpath)[1]) > 1):
943 944 945
            targetpath = targetpath[:-1]

        # don't include leading "/" from file name if present
946
        if member.filename[0] == '/':
947 948 949 950 951 952 953 954 955 956 957
            targetpath = os.path.join(targetpath, member.filename[1:])
        else:
            targetpath = os.path.join(targetpath, member.filename)

        targetpath = os.path.normpath(targetpath)

        # Create all upper directories if necessary.
        upperdirs = os.path.dirname(targetpath)
        if upperdirs and not os.path.exists(upperdirs):
            os.makedirs(upperdirs)

958
        if member.filename[-1] == '/':
959 960
            if not os.path.isdir(targetpath):
                os.mkdir(targetpath)
961 962
            return targetpath

Georg Brandl's avatar
Georg Brandl committed
963
        source = self.open(member, pwd=pwd)
964 965 966 967 968 969 970
        target = open(targetpath, "wb")
        shutil.copyfileobj(source, target)
        source.close()
        target.close()

        return targetpath

971
    def _writecheck(self, zinfo):
972
        """Check for errors before writing a file to the archive."""
973
        if zinfo.filename in self.NameToInfo:
974
            if self.debug:      # Warning for duplicate names
975
                print("Duplicate name:", zinfo.filename)
976
        if self.mode not in ("w", "a"):
977
            raise RuntimeError('write() requires mode "w" or "a"')
978
        if not self.fp:
979 980
            raise RuntimeError(
                  "Attempt to write ZIP archive that was already closed")
981
        if zinfo.compress_type == ZIP_DEFLATED and not zlib:
982 983
            raise RuntimeError(
                  "Compression requires the (missing) zlib module")
984
        if zinfo.compress_type not in (ZIP_STORED, ZIP_DEFLATED):
985
            raise RuntimeError("That compression method is not supported")
986 987 988 989 990
        if zinfo.file_size > ZIP64_LIMIT:
            if not self._allowZip64:
                raise LargeZipFile("Filesize would require ZIP64 extensions")
        if zinfo.header_offset > ZIP64_LIMIT:
            if not self._allowZip64:
991 992
                raise LargeZipFile(
                      "Zipfile size would require ZIP64 extensions")
993 994

    def write(self, filename, arcname=None, compress_type=None):
995 996
        """Put the bytes from filename into the archive under the name
        arcname."""
997 998 999 1000
        if not self.fp:
            raise RuntimeError(
                  "Attempt to write to ZIP archive that was already closed")

1001
        st = os.stat(filename)
1002
        isdir = stat.S_ISDIR(st.st_mode)
1003
        mtime = time.localtime(st.st_mtime)
1004 1005 1006
        date_time = mtime[0:6]
        # Create ZipInfo instance to store file information
        if arcname is None:
1007 1008 1009 1010
            arcname = filename
        arcname = os.path.normpath(os.path.splitdrive(arcname)[1])
        while arcname[0] in (os.sep, os.altsep):
            arcname = arcname[1:]
1011 1012
        if isdir:
            arcname += '/'
1013
        zinfo = ZipInfo(arcname, date_time)
1014
        zinfo.external_attr = (st[0] & 0xFFFF) << 16      # Unix attributes
1015
        if compress_type is None:
1016
            zinfo.compress_type = self.compression
1017
        else:
1018
            zinfo.compress_type = compress_type
1019 1020

        zinfo.file_size = st.st_size
1021
        zinfo.flag_bits = 0x00
1022
        zinfo.header_offset = self.fp.tell()    # Start of header bytes
1023 1024 1025

        self._writecheck(zinfo)
        self._didModify = True
1026 1027 1028 1029 1030 1031 1032 1033 1034 1035

        if isdir:
            zinfo.file_size = 0
            zinfo.compress_size = 0
            zinfo.CRC = 0
            self.filelist.append(zinfo)
            self.NameToInfo[zinfo.filename] = zinfo
            self.fp.write(zinfo.FileHeader())
            return

Benjamin Peterson's avatar
Benjamin Peterson committed
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
        with open(filename, "rb") as fp:
            # Must overwrite CRC and sizes with correct data later
            zinfo.CRC = CRC = 0
            zinfo.compress_size = compress_size = 0
            zinfo.file_size = file_size = 0
            self.fp.write(zinfo.FileHeader())
            if zinfo.compress_type == ZIP_DEFLATED:
                cmpr = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
                     zlib.DEFLATED, -15)
            else:
                cmpr = None
            while 1:
                buf = fp.read(1024 * 8)
                if not buf:
                    break
                file_size = file_size + len(buf)
                CRC = crc32(buf, CRC) & 0xffffffff
                if cmpr:
                    buf = cmpr.compress(buf)
                    compress_size = compress_size + len(buf)
                self.fp.write(buf)
1057 1058 1059 1060 1061 1062 1063 1064 1065
        if cmpr:
            buf = cmpr.flush()
            compress_size = compress_size + len(buf)
            self.fp.write(buf)
            zinfo.compress_size = compress_size
        else:
            zinfo.compress_size = file_size
        zinfo.CRC = CRC
        zinfo.file_size = file_size
1066
        # Seek backwards and write CRC and file sizes
Tim Peters's avatar
Tim Peters committed
1067
        position = self.fp.tell()       # Preserve current position in file
1068
        self.fp.seek(zinfo.header_offset + 14, 0)
Christian Heimes's avatar
Christian Heimes committed
1069
        self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size,
1070
              zinfo.file_size))
1071
        self.fp.seek(position, 0)
1072 1073 1074
        self.filelist.append(zinfo)
        self.NameToInfo[zinfo.filename] = zinfo

1075
    def writestr(self, zinfo_or_arcname, data, compress_type=None):
1076 1077 1078 1079
        """Write a file into the archive.  The contents is 'data', which
        may be either a 'str' or a 'bytes' instance; if it is a 'str',
        it is encoded as UTF-8 first.
        'zinfo_or_arcname' is either a ZipInfo instance or
1080
        the name of the file in the archive."""
1081 1082
        if isinstance(data, str):
            data = data.encode("utf-8")
1083 1084
        if not isinstance(zinfo_or_arcname, ZipInfo):
            zinfo = ZipInfo(filename=zinfo_or_arcname,
1085
                            date_time=time.localtime(time.time())[:6])
1086
            zinfo.compress_type = self.compression
1087
            zinfo.external_attr = 0o600 << 16
1088 1089
        else:
            zinfo = zinfo_or_arcname
1090 1091 1092 1093 1094

        if not self.fp:
            raise RuntimeError(
                  "Attempt to write to ZIP archive that was already closed")

1095 1096
        zinfo.file_size = len(data)            # Uncompressed size
        zinfo.header_offset = self.fp.tell()    # Start of header data
1097 1098 1099
        if compress_type is not None:
            zinfo.compress_type = compress_type

1100 1101
        self._writecheck(zinfo)
        self._didModify = True
Christian Heimes's avatar
Christian Heimes committed
1102
        zinfo.CRC = crc32(data) & 0xffffffff       # CRC-32 checksum
1103 1104 1105
        if zinfo.compress_type == ZIP_DEFLATED:
            co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
                 zlib.DEFLATED, -15)
1106 1107
            data = co.compress(data) + co.flush()
            zinfo.compress_size = len(data)    # Compressed size
1108 1109
        else:
            zinfo.compress_size = zinfo.file_size
1110
        zinfo.header_offset = self.fp.tell()    # Start of header data
1111
        self.fp.write(zinfo.FileHeader())
1112
        self.fp.write(data)
1113
        self.fp.flush()
1114
        if zinfo.flag_bits & 0x08:
1115
            # Write CRC and file sizes after the file data
1116
            self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size,
1117
                  zinfo.file_size))
1118 1119 1120 1121
        self.filelist.append(zinfo)
        self.NameToInfo[zinfo.filename] = zinfo

    def __del__(self):
1122
        """Call the "close()" method in case the user forgot."""
1123
        self.close()
1124 1125

    def close(self):
1126 1127
        """Close the file, and for mode "w" and "a" write the ending
        records."""
1128 1129
        if self.fp is None:
            return
1130 1131

        if self.mode in ("w", "a") and self._didModify: # write ending records
1132 1133
            count = 0
            pos1 = self.fp.tell()
1134
            for zinfo in self.filelist:         # write central directory
1135 1136 1137
                count = count + 1
                dt = zinfo.date_time
                dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
1138
                dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
1139 1140 1141 1142 1143
                extra = []
                if zinfo.file_size > ZIP64_LIMIT \
                        or zinfo.compress_size > ZIP64_LIMIT:
                    extra.append(zinfo.file_size)
                    extra.append(zinfo.compress_size)
1144 1145
                    file_size = 0xffffffff
                    compress_size = 0xffffffff
1146 1147 1148 1149 1150 1151
                else:
                    file_size = zinfo.file_size
                    compress_size = zinfo.compress_size

                if zinfo.header_offset > ZIP64_LIMIT:
                    extra.append(zinfo.header_offset)
1152
                    header_offset = 0xffffffff
1153 1154 1155 1156 1157 1158 1159
                else:
                    header_offset = zinfo.header_offset

                extra_data = zinfo.extra
                if extra:
                    # Append a ZIP64 field to the extra's
                    extra_data = struct.pack(
Christian Heimes's avatar
Christian Heimes committed
1160
                            '<HH' + 'Q'*len(extra),
1161 1162 1163 1164 1165 1166 1167 1168
                            1, 8*len(extra), *extra) + extra_data

                    extract_version = max(45, zinfo.extract_version)
                    create_version = max(45, zinfo.create_version)
                else:
                    extract_version = zinfo.extract_version
                    create_version = zinfo.create_version

1169 1170 1171
                try:
                    filename, flag_bits = zinfo._encodeFilenameFlags()
                    centdir = struct.pack(structCentralDir,
Ezio Melotti's avatar
Ezio Melotti committed
1172 1173 1174 1175 1176 1177 1178
                        stringCentralDir, create_version,
                        zinfo.create_system, extract_version, zinfo.reserved,
                        flag_bits, zinfo.compress_type, dostime, dosdate,
                        zinfo.CRC, compress_size, file_size,
                        len(filename), len(extra_data), len(zinfo.comment),
                        0, zinfo.internal_attr, zinfo.external_attr,
                        header_offset)
1179
                except DeprecationWarning:
Ezio Melotti's avatar
Ezio Melotti committed
1180 1181 1182 1183 1184 1185 1186
                    print((structCentralDir, stringCentralDir, create_version,
                        zinfo.create_system, extract_version, zinfo.reserved,
                        zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
                        zinfo.CRC, compress_size, file_size,
                        len(zinfo.filename), len(extra_data), len(zinfo.comment),
                        0, zinfo.internal_attr, zinfo.external_attr,
                        header_offset), file=sys.stderr)
1187
                    raise
1188
                self.fp.write(centdir)
1189
                self.fp.write(filename)
1190
                self.fp.write(extra_data)
1191
                self.fp.write(zinfo.comment)
1192

1193 1194
            pos2 = self.fp.tell()
            # Write end-of-zip-archive record
1195 1196
            centDirCount = count
            centDirSize = pos2 - pos1
1197
            centDirOffset = pos1
1198 1199 1200
            if (centDirCount >= ZIP_FILECOUNT_LIMIT or
                centDirOffset > ZIP64_LIMIT or
                centDirSize > ZIP64_LIMIT):
1201 1202
                # Need to write the ZIP64 end-of-archive records
                zip64endrec = struct.pack(
Georg Brandl's avatar
Georg Brandl committed
1203
                        structEndArchive64, stringEndArchive64,
1204 1205
                        44, 45, 45, 0, 0, centDirCount, centDirCount,
                        centDirSize, centDirOffset)
1206 1207 1208
                self.fp.write(zip64endrec)

                zip64locrec = struct.pack(
Georg Brandl's avatar
Georg Brandl committed
1209 1210
                        structEndArchive64Locator,
                        stringEndArchive64Locator, 0, pos2, 1)
1211
                self.fp.write(zip64locrec)
1212 1213 1214
                centDirCount = min(centDirCount, 0xFFFF)
                centDirSize = min(centDirSize, 0xFFFFFFFF)
                centDirOffset = min(centDirOffset, 0xFFFFFFFF)
1215 1216 1217 1218 1219 1220 1221 1222

            # check for valid comment length
            if len(self.comment) >= ZIP_MAX_COMMENT:
                if self.debug > 0:
                    msg = 'Archive comment is too long; truncating to %d bytes' \
                          % ZIP_MAX_COMMENT
                self.comment = self.comment[:ZIP_MAX_COMMENT]

Georg Brandl's avatar
Georg Brandl committed
1223
            endrec = struct.pack(structEndArchive, stringEndArchive,
1224 1225
                                 0, 0, centDirCount, centDirCount,
                                 centDirSize, centDirOffset, len(self.comment))
1226 1227
            self.fp.write(endrec)
            self.fp.write(self.comment)
1228
            self.fp.flush()
1229

1230 1231
        if not self._filePassed:
            self.fp.close()
1232 1233 1234 1235
        self.fp = None


class PyZipFile(ZipFile):
1236 1237
    """Class to create ZIP archives with Python library files and packages."""

1238
    def writepy(self, pathname, basename=""):
1239 1240
        """Add all files from "pathname" to the ZIP archive.

1241 1242 1243 1244 1245 1246 1247 1248 1249
        If pathname is a package directory, search the directory and
        all package subdirectories recursively for all *.py and enter
        the modules into the archive.  If pathname is a plain
        directory, listdir *.py and enter all modules.  Else, pathname
        must be a Python *.py file and the module will be put into the
        archive.  Added modules are always module.pyo or module.pyc.
        This method will compile the module.py into module.pyc if
        necessary.
        """
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259
        dir, name = os.path.split(pathname)
        if os.path.isdir(pathname):
            initname = os.path.join(pathname, "__init__.py")
            if os.path.isfile(initname):
                # This is a package directory, add it
                if basename:
                    basename = "%s/%s" % (basename, name)
                else:
                    basename = name
                if self.debug:
1260
                    print("Adding package in", pathname, "as", basename)
1261 1262
                fname, arcname = self._get_codename(initname[0:-3], basename)
                if self.debug:
1263
                    print("Adding", arcname)
1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
                self.write(fname, arcname)
                dirlist = os.listdir(pathname)
                dirlist.remove("__init__.py")
                # Add all *.py files and package subdirectories
                for filename in dirlist:
                    path = os.path.join(pathname, filename)
                    root, ext = os.path.splitext(filename)
                    if os.path.isdir(path):
                        if os.path.isfile(os.path.join(path, "__init__.py")):
                            # This is a package directory, add it
                            self.writepy(path, basename)  # Recursive call
                    elif ext == ".py":
                        fname, arcname = self._get_codename(path[0:-3],
                                         basename)
                        if self.debug:
1279
                            print("Adding", arcname)
1280 1281 1282 1283
                        self.write(fname, arcname)
            else:
                # This is NOT a package directory, add its files at top level
                if self.debug:
1284
                    print("Adding files from directory", pathname)
1285 1286 1287 1288 1289 1290 1291
                for filename in os.listdir(pathname):
                    path = os.path.join(pathname, filename)
                    root, ext = os.path.splitext(filename)
                    if ext == ".py":
                        fname, arcname = self._get_codename(path[0:-3],
                                         basename)
                        if self.debug:
1292
                            print("Adding", arcname)
1293 1294 1295
                        self.write(fname, arcname)
        else:
            if pathname[-3:] != ".py":
1296 1297
                raise RuntimeError(
                      'Files added with writepy() must end with ".py"')
1298 1299
            fname, arcname = self._get_codename(pathname[0:-3], basename)
            if self.debug:
1300
                print("Adding file", arcname)
1301 1302 1303 1304 1305
            self.write(fname, arcname)

    def _get_codename(self, pathname, basename):
        """Return (filename, archivename) for the path.

1306 1307 1308 1309
        Given a module name path, return the correct file path and
        archive name, compiling if necessary.  For example, given
        /python/lib/string, return (/python/lib/string.pyc, string).
        """
1310 1311 1312
        file_py  = pathname + ".py"
        file_pyc = pathname + ".pyc"
        file_pyo = pathname + ".pyo"
Barry Warsaw's avatar
Barry Warsaw committed
1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336
        pycache_pyc = imp.cache_from_source(file_py, True)
        pycache_pyo = imp.cache_from_source(file_py, False)
        if (os.path.isfile(file_pyo) and
            os.stat(file_pyo).st_mtime >= os.stat(file_py).st_mtime):
            # Use .pyo file.
            arcname = fname = file_pyo
        elif (os.path.isfile(file_pyc) and
              os.stat(file_pyc).st_mtime >= os.stat(file_py).st_mtime):
            # Use .pyc file.
            arcname = fname = file_pyc
        elif (os.path.isfile(pycache_pyc) and
              os.stat(pycache_pyc).st_mtime >= os.stat(file_py).st_mtime):
            # Use the __pycache__/*.pyc file, but write it to the legacy pyc
            # file name in the archive.
            fname = pycache_pyc
            arcname = file_pyc
        elif (os.path.isfile(pycache_pyo) and
              os.stat(pycache_pyo).st_mtime >= os.stat(file_py).st_mtime):
            # Use the __pycache__/*.pyo file, but write it to the legacy pyo
            # file name in the archive.
            fname = pycache_pyo
            arcname = file_pyo
        else:
            # Compile py into PEP 3147 pyc file.
1337
            import py_compile
1338
            if self.debug:
1339
                print("Compiling", file_py)
1340
            try:
Barry Warsaw's avatar
Barry Warsaw committed
1341 1342
                py_compile.compile(file_py, doraise=True)
            except py_compile.PyCompileError as error:
1343
                print(err.msg)
Barry Warsaw's avatar
Barry Warsaw committed
1344 1345 1346 1347 1348
                fname = file_py
            else:
                fname = (pycache_pyc if __debug__ else pycache_pyo)
                arcname = (file_pyc if __debug__ else file_pyo)
        archivename = os.path.split(arcname)[1]
1349 1350 1351
        if basename:
            archivename = "%s/%s" % (basename, archivename)
        return (fname, archivename)
1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366


def main(args = None):
    import textwrap
    USAGE=textwrap.dedent("""\
        Usage:
            zipfile.py -l zipfile.zip        # Show listing of a zipfile
            zipfile.py -t zipfile.zip        # Test if a zipfile is valid
            zipfile.py -e zipfile.zip target # Extract zipfile into target dir
            zipfile.py -c zipfile.zip src ... # Create zipfile from sources
        """)
    if args is None:
        args = sys.argv[1:]

    if not args or args[0] not in ('-l', '-c', '-e', '-t'):
1367
        print(USAGE)
1368 1369 1370 1371
        sys.exit(1)

    if args[0] == '-l':
        if len(args) != 2:
1372
            print(USAGE)
1373 1374 1375 1376 1377 1378 1379
            sys.exit(1)
        zf = ZipFile(args[1], 'r')
        zf.printdir()
        zf.close()

    elif args[0] == '-t':
        if len(args) != 2:
1380
            print(USAGE)
1381 1382 1383
            sys.exit(1)
        zf = ZipFile(args[1], 'r')
        zf.testzip()
1384
        print("Done testing")
1385 1386 1387

    elif args[0] == '-e':
        if len(args) != 3:
1388
            print(USAGE)
1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401
            sys.exit(1)

        zf = ZipFile(args[1], 'r')
        out = args[2]
        for path in zf.namelist():
            if path.startswith('./'):
                tgt = os.path.join(out, path[2:])
            else:
                tgt = os.path.join(out, path)

            tgtdir = os.path.dirname(tgt)
            if not os.path.exists(tgtdir):
                os.makedirs(tgtdir)
Benjamin Peterson's avatar
Benjamin Peterson committed
1402 1403
            with open(tgt, 'wb') as fp:
                fp.write(zf.read(path))
1404 1405 1406 1407
        zf.close()

    elif args[0] == '-c':
        if len(args) < 3:
1408
            print(USAGE)
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427
            sys.exit(1)

        def addToZip(zf, path, zippath):
            if os.path.isfile(path):
                zf.write(path, zippath, ZIP_DEFLATED)
            elif os.path.isdir(path):
                for nm in os.listdir(path):
                    addToZip(zf,
                            os.path.join(path, nm), os.path.join(zippath, nm))
            # else: ignore

        zf = ZipFile(args[1], 'w', allowZip64=True)
        for src in args[2:]:
            addToZip(zf, src, os.path.basename(src))

        zf.close()

if __name__ == "__main__":
    main()