wave.py 17.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
"""Stuff to parse WAVE files.

Usage.

Reading WAVE files:
      f = wave.open(file, 'r')
where file is either the name of a file or an open file pointer.
The open file pointer must have methods read(), seek(), and close().
When the setpos() and rewind() methods are not used, the seek()
method is not  necessary.

This returns an instance of a class with the following public methods:
      getnchannels()  -- returns number of audio channels (1 for
                         mono, 2 for stereo)
      getsampwidth()  -- returns sample width in bytes
      getframerate()  -- returns sampling frequency
      getnframes()    -- returns number of audio frames
      getcomptype()   -- returns compression type ('NONE' for linear samples)
      getcompname()   -- returns human-readable version of
                         compression type ('not compressed' linear samples)
      getparams()     -- returns a tuple consisting of all of the
                         above in the above order
      getmarkers()    -- returns None (for compatibility with the
                         aifc module)
      getmark(id)     -- raises an error since the mark does not
                         exist (for compatibility with the aifc module)
      readframes(n)   -- returns at most n frames of audio
      rewind()        -- rewind to the beginning of the audio stream
      setpos(pos)     -- seek to the specified position
      tell()          -- return the current position
      close()         -- close the instance (make it unusable)
The position returned by tell() and the position given to setpos()
33
are compatible and have nothing to do with the actual position in the
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
file.
The close() method is called automatically when the class instance
is destroyed.

Writing WAVE files:
      f = wave.open(file, 'w')
where file is either the name of a file or an open file pointer.
The open file pointer must have methods write(), tell(), seek(), and
close().

This returns an instance of a class with the following public methods:
      setnchannels(n) -- set the number of channels
      setsampwidth(n) -- set the sample width
      setframerate(n) -- set the frame rate
      setnframes(n)   -- set the number of frames
      setcomptype(type, name)
                      -- set the compression type and the
                         human-readable compression type
      setparams(tuple)
                      -- set all parameters at once
      tell()          -- return current position in output file
      writeframesraw(data)
                      -- write audio frames without pathing up the
                         file header
      writeframes(data)
                      -- write audio frames and patch up the file header
      close()         -- patch up the file header and close the
                         output file
You should set the parameters before the first writeframesraw or
writeframes.  The total number of frames does not need to be set,
but when it is set to the correct value, the header does not have to
be patched up.
It is best to first set all parameters, perhaps possibly the
compression type, and then write audio frames using writeframesraw.
When all frames have been written, either call writeframes('') or
close() to patch up the sizes in the header.
The close() method is called automatically when the class instance
is destroyed.
"""
73

74
import builtins
75

76 77
__all__ = ["open", "openfp", "Error"]

78 79
class Error(Exception):
    pass
80 81 82 83 84

WAVE_FORMAT_PCM = 0x0001

_array_fmts = None, 'b', 'h', None, 'l'

Guido van Rossum's avatar
Guido van Rossum committed
85 86
# Determine endian-ness
import struct
87
if struct.pack("h", 1) == b"\000\001":
88
    big_endian = 1
Guido van Rossum's avatar
Guido van Rossum committed
89
else:
90
    big_endian = 0
Guido van Rossum's avatar
Guido van Rossum committed
91

92
from chunk import Chunk
93 94

class Wave_read:
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    """Variables used in this class:

    These variables are available to the user though appropriate
    methods of this class:
    _file -- the open file with methods read(), close(), and seek()
              set through the __init__() method
    _nchannels -- the number of audio channels
              available through the getnchannels() method
    _nframes -- the number of audio frames
              available through the getnframes() method
    _sampwidth -- the number of bytes per audio sample
              available through the getsampwidth() method
    _framerate -- the sampling frequency
              available through the getframerate() method
    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
              available through the getcomptype() method
    _compname -- the human-readable AIFF-C compression type
              available through the getcomptype() method
    _soundpos -- the position in the audio stream
              available through the tell() method, set through the
              setpos() method

    These variables are used internally only:
    _fmt_chunk_read -- 1 iff the FMT chunk has been read
    _data_seek_needed -- 1 iff positioned correctly in audio
              file for readframes()
    _data_chunk -- instantiation of a chunk class for the DATA chunk
    _framesize -- size of one frame in the file
    """

    def initfp(self, file):
        self._convert = None
        self._soundpos = 0
        self._file = Chunk(file, bigendian = 0)
129
        if self._file.getname() != b'RIFF':
130
            raise Error('file does not start with RIFF id')
131
        if self._file.read(4) != b'WAVE':
132
            raise Error('not a WAVE file')
133 134 135 136 137 138 139 140 141
        self._fmt_chunk_read = 0
        self._data_chunk = None
        while 1:
            self._data_seek_needed = 1
            try:
                chunk = Chunk(self._file, bigendian = 0)
            except EOFError:
                break
            chunkname = chunk.getname()
142
            if chunkname == b'fmt ':
143 144
                self._read_fmt_chunk(chunk)
                self._fmt_chunk_read = 1
145
            elif chunkname == b'data':
146
                if not self._fmt_chunk_read:
147
                    raise Error('data chunk before fmt chunk')
148
                self._data_chunk = chunk
149
                self._nframes = chunk.chunksize // self._framesize
150 151 152 153
                self._data_seek_needed = 0
                break
            chunk.skip()
        if not self._fmt_chunk_read or not self._data_chunk:
154
            raise Error('fmt chunk and/or data chunk missing')
155 156

    def __init__(self, f):
157
        self._i_opened_the_file = None
158
        if isinstance(f, str):
159
            f = builtins.open(f, 'rb')
160
            self._i_opened_the_file = f
161
        # else, assume it is an open file object already
162 163 164 165 166 167
        try:
            self.initfp(f)
        except:
            if self._i_opened_the_file:
                f.close()
            raise
168

169 170
    def __del__(self):
        self.close()
171 172 173 174 175 176 177 178 179 180 181
    #
    # User visible methods.
    #
    def getfp(self):
        return self._file

    def rewind(self):
        self._data_seek_needed = 1
        self._soundpos = 0

    def close(self):
182 183 184
        if self._i_opened_the_file:
            self._i_opened_the_file.close()
            self._i_opened_the_file = None
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        self._file = None

    def tell(self):
        return self._soundpos

    def getnchannels(self):
        return self._nchannels

    def getnframes(self):
        return self._nframes

    def getsampwidth(self):
        return self._sampwidth

    def getframerate(self):
        return self._framerate

    def getcomptype(self):
        return self._comptype

    def getcompname(self):
        return self._compname

    def getparams(self):
        return self.getnchannels(), self.getsampwidth(), \
               self.getframerate(), self.getnframes(), \
               self.getcomptype(), self.getcompname()

    def getmarkers(self):
        return None

    def getmark(self, id):
217
        raise Error('no marks')
218 219 220

    def setpos(self, pos):
        if pos < 0 or pos > self._nframes:
221
            raise Error('position not in range')
222 223 224 225 226 227 228 229 230 231 232
        self._soundpos = pos
        self._data_seek_needed = 1

    def readframes(self, nframes):
        if self._data_seek_needed:
            self._data_chunk.seek(0, 0)
            pos = self._soundpos * self._framesize
            if pos:
                self._data_chunk.seek(pos, 0)
            self._data_seek_needed = 0
        if nframes == 0:
233
            return b''
234 235 236 237 238 239 240 241 242
        if self._sampwidth > 1 and big_endian:
            # unfortunately the fromfile() method does not take
            # something that only looks like a file object, so
            # we have to reach into the innards of the chunk object
            import array
            chunk = self._data_chunk
            data = array.array(_array_fmts[self._sampwidth])
            nitems = nframes * self._nchannels
            if nitems * self._sampwidth > chunk.chunksize - chunk.size_read:
243
                nitems = (chunk.chunksize - chunk.size_read) // self._sampwidth
244 245 246 247 248 249 250
            data.fromfile(chunk.file.file, nitems)
            # "tell" data chunk how much was read
            chunk.size_read = chunk.size_read + nitems * self._sampwidth
            # do the same for the outermost chunk
            chunk = chunk.file
            chunk.size_read = chunk.size_read + nitems * self._sampwidth
            data.byteswap()
251
            data = data.tobytes()
252 253 254 255
        else:
            data = self._data_chunk.read(nframes * self._framesize)
        if self._convert and data:
            data = self._convert(data)
256
        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
257 258 259 260 261 262 263
        return data

    #
    # Internal methods.
    #

    def _read_fmt_chunk(self, chunk):
264
        wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
265
        if wFormatTag == WAVE_FORMAT_PCM:
266
            sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
267
            self._sampwidth = (sampwidth + 7) // 8
268
        else:
269
            raise Error('unknown format: %r' % (wFormatTag,))
270 271 272
        self._framesize = self._nchannels * self._sampwidth
        self._comptype = 'NONE'
        self._compname = 'not compressed'
273 274

class Wave_write:
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    """Variables used in this class:

    These variables are user settable through appropriate methods
    of this class:
    _file -- the open file with methods write(), close(), tell(), seek()
              set through the __init__() method
    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
              set through the setcomptype() or setparams() method
    _compname -- the human-readable AIFF-C compression type
              set through the setcomptype() or setparams() method
    _nchannels -- the number of audio channels
              set through the setnchannels() or setparams() method
    _sampwidth -- the number of bytes per audio sample
              set through the setsampwidth() or setparams() method
    _framerate -- the sampling frequency
              set through the setframerate() or setparams() method
    _nframes -- the number of audio frames written to the header
              set through the setnframes() or setparams() method

    These variables are used internally only:
    _datalength -- the size of the audio samples written to the header
    _nframeswritten -- the number of frames actually written
    _datawritten -- the size of the audio samples actually written
    """

    def __init__(self, f):
301
        self._i_opened_the_file = None
302
        if isinstance(f, str):
303
            f = builtins.open(f, 'wb')
304
            self._i_opened_the_file = f
305 306 307 308 309 310
        try:
            self.initfp(f)
        except:
            if self._i_opened_the_file:
                f.close()
            raise
311 312 313 314 315 316 317 318 319 320 321

    def initfp(self, file):
        self._file = file
        self._convert = None
        self._nchannels = 0
        self._sampwidth = 0
        self._framerate = 0
        self._nframes = 0
        self._nframeswritten = 0
        self._datawritten = 0
        self._datalength = 0
322
        self._headerwritten = False
323 324

    def __del__(self):
325
        self.close()
326 327 328 329 330 331

    #
    # User visible methods.
    #
    def setnchannels(self, nchannels):
        if self._datawritten:
332
            raise Error('cannot change parameters after starting to write')
333
        if nchannels < 1:
334
            raise Error('bad # of channels')
335 336 337 338
        self._nchannels = nchannels

    def getnchannels(self):
        if not self._nchannels:
339
            raise Error('number of channels not set')
340 341 342 343
        return self._nchannels

    def setsampwidth(self, sampwidth):
        if self._datawritten:
344
            raise Error('cannot change parameters after starting to write')
345
        if sampwidth < 1 or sampwidth > 4:
346
            raise Error('bad sample width')
347 348 349 350
        self._sampwidth = sampwidth

    def getsampwidth(self):
        if not self._sampwidth:
351
            raise Error('sample width not set')
352 353 354 355
        return self._sampwidth

    def setframerate(self, framerate):
        if self._datawritten:
356
            raise Error('cannot change parameters after starting to write')
357
        if framerate <= 0:
358
            raise Error('bad frame rate')
359
        self._framerate = int(round(framerate))
360 361 362

    def getframerate(self):
        if not self._framerate:
363
            raise Error('frame rate not set')
364 365 366 367
        return self._framerate

    def setnframes(self, nframes):
        if self._datawritten:
368
            raise Error('cannot change parameters after starting to write')
369 370 371 372 373 374 375
        self._nframes = nframes

    def getnframes(self):
        return self._nframeswritten

    def setcomptype(self, comptype, compname):
        if self._datawritten:
376
            raise Error('cannot change parameters after starting to write')
377
        if comptype not in ('NONE',):
378
            raise Error('unsupported compression type')
379 380 381 382 383 384 385 386 387
        self._comptype = comptype
        self._compname = compname

    def getcomptype(self):
        return self._comptype

    def getcompname(self):
        return self._compname

388 389
    def setparams(self, params):
        nchannels, sampwidth, framerate, nframes, comptype, compname = params
390
        if self._datawritten:
391
            raise Error('cannot change parameters after starting to write')
392 393 394 395 396 397 398 399
        self.setnchannels(nchannels)
        self.setsampwidth(sampwidth)
        self.setframerate(framerate)
        self.setnframes(nframes)
        self.setcomptype(comptype, compname)

    def getparams(self):
        if not self._nchannels or not self._sampwidth or not self._framerate:
400
            raise Error('not all parameters set')
401 402 403 404
        return self._nchannels, self._sampwidth, self._framerate, \
              self._nframes, self._comptype, self._compname

    def setmark(self, id, pos, name):
405
        raise Error('setmark() not supported')
406 407

    def getmark(self, id):
408
        raise Error('no marks')
409 410 411

    def getmarkers(self):
        return None
412

413 414 415 416 417
    def tell(self):
        return self._nframeswritten

    def writeframesraw(self, data):
        self._ensure_header_written(len(data))
418
        nframes = len(data) // (self._sampwidth * self._nchannels)
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
        if self._convert:
            data = self._convert(data)
        if self._sampwidth > 1 and big_endian:
            import array
            data = array.array(_array_fmts[self._sampwidth], data)
            data.byteswap()
            data.tofile(self._file)
            self._datawritten = self._datawritten + len(data) * self._sampwidth
        else:
            self._file.write(data)
            self._datawritten = self._datawritten + len(data)
        self._nframeswritten = self._nframeswritten + nframes

    def writeframes(self, data):
        self.writeframesraw(data)
        if self._datalength != self._datawritten:
            self._patchheader()

    def close(self):
438 439 440 441 442 443 444 445 446
        if self._file:
            self._ensure_header_written(0)
            if self._datalength != self._datawritten:
                self._patchheader()
            self._file.flush()
            self._file = None
        if self._i_opened_the_file:
            self._i_opened_the_file.close()
            self._i_opened_the_file = None
447 448 449 450 451 452

    #
    # Internal methods.
    #

    def _ensure_header_written(self, datasize):
453
        if not self._headerwritten:
454
            if not self._nchannels:
455
                raise Error('# channels not specified')
456
            if not self._sampwidth:
457
                raise Error('sample width not specified')
458
            if not self._framerate:
459
                raise Error('sampling rate not specified')
460 461 462
            self._write_header(datasize)

    def _write_header(self, initlength):
463
        assert not self._headerwritten
464
        self._file.write(b'RIFF')
465
        if not self._nframes:
466
            self._nframes = initlength // (self._nchannels * self._sampwidth)
467 468
        self._datalength = self._nframes * self._nchannels * self._sampwidth
        self._form_length_pos = self._file.tell()
469
        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
470
            36 + self._datalength, b'WAVE', b'fmt ', 16,
471 472 473
            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
            self._nchannels * self._framerate * self._sampwidth,
            self._nchannels * self._sampwidth,
474
            self._sampwidth * 8, b'data'))
475
        self._data_length_pos = self._file.tell()
476
        self._file.write(struct.pack('<L', self._datalength))
477
        self._headerwritten = True
478 479

    def _patchheader(self):
480
        assert self._headerwritten
481 482 483 484
        if self._datawritten == self._datalength:
            return
        curpos = self._file.tell()
        self._file.seek(self._form_length_pos, 0)
485
        self._file.write(struct.pack('<L', 36 + self._datawritten))
486
        self._file.seek(self._data_length_pos, 0)
487
        self._file.write(struct.pack('<L', self._datawritten))
488 489
        self._file.seek(curpos, 0)
        self._datalength = self._datawritten
490

491
def open(f, mode=None):
492 493 494 495 496 497 498 499 500 501
    if mode is None:
        if hasattr(f, 'mode'):
            mode = f.mode
        else:
            mode = 'rb'
    if mode in ('r', 'rb'):
        return Wave_read(f)
    elif mode in ('w', 'wb'):
        return Wave_write(f)
    else:
502
        raise Error("mode must be 'r', 'rb', 'w', or 'wb'")
503 504

openfp = open # B/W compatibility