ssl.py 17.4 KB
Newer Older
1 2 3 4 5 6 7 8
# Wrapper module for _ssl, providing some additional facilities
# implemented in Python.  Written by Bill Janssen.

"""\
This module provides some more Pythonic support for SSL.

Object types:

Bill Janssen's avatar
Bill Janssen committed
9
  SSLSocket -- subtype of socket.socket which does SSL over the socket
10 11 12

Exceptions:

Bill Janssen's avatar
Bill Janssen committed
13
  SSLError -- exception raised for I/O errors
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

Functions:

  cert_time_to_seconds -- convert time string used for certificate
                          notBefore and notAfter functions to integer
                          seconds past the Epoch (the time values
                          returned from time.time())

  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
                          by the server running on HOST at port PORT.  No
                          validation of the certificate is performed.

Integer constants:

SSL_ERROR_ZERO_RETURN
SSL_ERROR_WANT_READ
SSL_ERROR_WANT_WRITE
SSL_ERROR_WANT_X509_LOOKUP
SSL_ERROR_SYSCALL
SSL_ERROR_SSL
SSL_ERROR_WANT_CONNECT

SSL_ERROR_EOF
SSL_ERROR_INVALID_ERROR_CODE

The following group define certificate requirements that one side is
allowing/requiring from the other side:

CERT_NONE - no certificates from the other side are required (or will
            be looked at if provided)
CERT_OPTIONAL - certificates are not required, but if provided will be
                validated, and if validation fails, the connection will
                also fail
CERT_REQUIRED - certificates are required, and will be validated, and
                if validation fails, the connection will also fail

The following constants identify various SSL protocol variants:

PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""

58
import os, sys, textwrap
59 60

import _ssl             # if we can't import it, let the error propagate
Bill Janssen's avatar
Bill Janssen committed
61 62

from _ssl import SSLError
63 64
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
Bill Janssen's avatar
Bill Janssen committed
65
from _ssl import RAND_status, RAND_egd, RAND_add
66 67 68 69 70 71 72 73 74 75 76 77 78
from _ssl import \
     SSL_ERROR_ZERO_RETURN, \
     SSL_ERROR_WANT_READ, \
     SSL_ERROR_WANT_WRITE, \
     SSL_ERROR_WANT_X509_LOOKUP, \
     SSL_ERROR_SYSCALL, \
     SSL_ERROR_SSL, \
     SSL_ERROR_WANT_CONNECT, \
     SSL_ERROR_EOF, \
     SSL_ERROR_INVALID_ERROR_CODE

from socket import socket
from socket import getnameinfo as _getnameinfo
79
import base64        # for DER-to-PEM translation
80

Bill Janssen's avatar
Bill Janssen committed
81
class SSLSocket (socket):
82

83 84 85 86
    """This class implements a subtype of socket.socket that wraps
    the underlying OS socket in an SSL context when necessary, and
    provides read and write methods over that channel."""

87 88 89 90 91 92
    def __init__(self, sock, keyfile=None, certfile=None,
                 server_side=False, cert_reqs=CERT_NONE,
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None):
        socket.__init__(self, _sock=sock._sock)
        if certfile and not keyfile:
            keyfile = certfile
93 94 95 96 97 98
        # see if it's connected
        try:
            socket.getpeername(self)
        except:
            # no, no connection yet
            self._sslobj = None
99
        else:
100 101 102 103
            # yes, create the SSL object
            self._sslobj = _ssl.sslwrap(self._sock, server_side,
                                        keyfile, certfile,
                                        cert_reqs, ssl_version, ca_certs)
104 105 106 107 108 109 110
        self.keyfile = keyfile
        self.certfile = certfile
        self.cert_reqs = cert_reqs
        self.ssl_version = ssl_version
        self.ca_certs = ca_certs

    def read(self, len=1024):
111 112 113 114

        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""

115 116 117
        return self._sslobj.read(len)

    def write(self, data):
118 119 120 121

        """Write DATA to the underlying SSL channel.  Returns
        number of bytes of DATA actually transmitted."""

122 123
        return self._sslobj.write(data)

Bill Janssen's avatar
Bill Janssen committed
124
    def getpeercert(self, binary_form=False):
125 126 127 128 129 130

        """Returns a formatted version of the data in the
        certificate provided by the other end of the SSL channel.
        Return None if no certificate was provided, {} if a
        certificate was provided, but not validated."""

Bill Janssen's avatar
Bill Janssen committed
131 132 133 134 135 136 137 138
        return self._sslobj.peer_certificate(binary_form)

    def cipher (self):

        if not self._sslobj:
            return None
        else:
            return self._sslobj.cipher()
139 140

    def send (self, data, flags=0):
141 142 143 144 145 146 147 148
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to send() on %s" %
                    self.__class__)
            return self._sslobj.write(data)
        else:
            return socket.send(self, data, flags)
149 150

    def send_to (self, data, addr, flags=0):
151 152 153 154 155
        if self._sslobj:
            raise ValueError("send_to not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.send_to(self, data, addr, flags)
156 157

    def sendall (self, data, flags=0):
158 159 160 161 162 163 164 165
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to sendall() on %s" %
                    self.__class__)
            return self._sslobj.write(data)
        else:
            return socket.sendall(self, data, flags)
166 167

    def recv (self, buflen=1024, flags=0):
168 169 170 171 172 173 174 175
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to sendall() on %s" %
                    self.__class__)
            return self._sslobj.read(data, buflen)
        else:
            return socket.recv(self, buflen, flags)
176 177

    def recv_from (self, addr, buflen=1024, flags=0):
178 179 180 181 182
        if self._sslobj:
            raise ValueError("recv_from not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.recv_from(self, addr, buflen, flags)
183

184
    def shutdown(self, how):
185
        self._sslobj = None
186
        socket.shutdown(self, how)
187 188

    def close(self):
189
        self._sslobj = None
190
        socket.close(self)
191 192

    def connect(self, addr):
193 194 195 196

        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""

197 198
        # Here we assume that the socket is client-side, and not
        # connected at the time of the call.  We connect it, then wrap it.
199
        if self._sslobj:
Bill Janssen's avatar
Bill Janssen committed
200
            raise ValueError("attempt to connect already-connected SSLSocket!")
201
        socket.connect(self, addr)
202
        self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
203 204 205 206
                                    self.cert_reqs, self.ssl_version,
                                    self.ca_certs)

    def accept(self):
207 208 209 210 211

        """Accepts a new connection from a remote client, and returns
        a tuple containing that new connection wrapped with a server-side
        SSL channel, and the address of the remote client."""

212
        newsock, addr = socket.accept(self)
Bill Janssen's avatar
Bill Janssen committed
213 214 215 216 217
        return (SSLSocket(newsock, True, self.keyfile, self.certfile,
                          self.cert_reqs, self.ssl_version,
                          self.ca_certs), addr)


218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
    def makefile(self, mode='r', bufsize=-1):

        """Ouch.  Need to make and return a file-like object that
        works with the SSL connection."""

        if self._sslobj:
            return SSLFileStream(self._sslobj, mode, bufsize)
        else:
            return socket.makefile(self, mode, bufsize)


class SSLFileStream:

    """A class to simulate a file stream on top of a socket.
    Most of this is just lifted from the socket module, and
    adjusted to work with an SSL stream instead of a socket."""


    default_bufsize = 8192
    name = "<SSL stream>"

    __slots__ = ["mode", "bufsize", "softspace",
                 # "closed" is a property, see below
                 "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
                 "_close", "_fileno"]

    def __init__(self, sslobj, mode='rb', bufsize=-1, close=False):
        self._sslobj = sslobj
        self.mode = mode # Not actually used in this version
        if bufsize < 0:
            bufsize = self.default_bufsize
        self.bufsize = bufsize
        self.softspace = False
        if bufsize == 0:
            self._rbufsize = 1
        elif bufsize == 1:
            self._rbufsize = self.default_bufsize
        else:
            self._rbufsize = bufsize
        self._wbufsize = bufsize
        self._rbuf = "" # A string
        self._wbuf = [] # A list of strings
        self._close = close
        self._fileno = -1

    def _getclosed(self):
        return self._sslobj is None
    closed = property(_getclosed, doc="True if the file is closed")

    def fileno(self):
        return self._fileno

    def close(self):
        try:
            if self._sslobj:
                self.flush()
        finally:
            if self._close and self._sslobj:
                self._sslobj.close()
            self._sslobj = None

    def __del__(self):
        try:
            self.close()
        except:
            # close() may fail if __init__ didn't complete
            pass

    def flush(self):
        if self._wbuf:
            buffer = "".join(self._wbuf)
            self._wbuf = []
            count = 0
            while (count < len(buffer)):
                written = self._sslobj.write(buffer)
                count += written
                buffer = buffer[written:]

    def write(self, data):
        data = str(data) # XXX Should really reject non-string non-buffers
        if not data:
            return
        self._wbuf.append(data)
        if (self._wbufsize == 0 or
            self._wbufsize == 1 and '\n' in data or
            self._get_wbuf_len() >= self._wbufsize):
            self.flush()

    def writelines(self, list):
        # XXX We could do better here for very long lists
        # XXX Should really reject non-string non-buffers
        self._wbuf.extend(filter(None, map(str, list)))
        if (self._wbufsize <= 1 or
            self._get_wbuf_len() >= self._wbufsize):
            self.flush()

    def _get_wbuf_len(self):
        buf_len = 0
        for x in self._wbuf:
            buf_len += len(x)
        return buf_len

    def read(self, size=-1):
        data = self._rbuf
        if size < 0:
            # Read until EOF
            buffers = []
            if data:
                buffers.append(data)
            self._rbuf = ""
            if self._rbufsize <= 1:
                recv_size = self.default_bufsize
            else:
                recv_size = self._rbufsize
            while True:
                data = self._sslobj.read(recv_size)
                if not data:
                    break
                buffers.append(data)
            return "".join(buffers)
        else:
            # Read until size bytes or EOF seen, whichever comes first
            buf_len = len(data)
            if buf_len >= size:
                self._rbuf = data[size:]
                return data[:size]
            buffers = []
            if data:
                buffers.append(data)
            self._rbuf = ""
            while True:
                left = size - buf_len
                recv_size = max(self._rbufsize, left)
                data = self._sslobj.read(recv_size)
                if not data:
                    break
                buffers.append(data)
                n = len(data)
                if n >= left:
                    self._rbuf = data[left:]
                    buffers[-1] = data[:left]
                    break
                buf_len += n
            return "".join(buffers)

    def readline(self, size=-1):
        data = self._rbuf
        if size < 0:
            # Read until \n or EOF, whichever comes first
            if self._rbufsize <= 1:
                # Speed up unbuffered case
                assert data == ""
                buffers = []
                while data != "\n":
                    data = self._sslobj.read(1)
                    if not data:
                        break
                    buffers.append(data)
                return "".join(buffers)
            nl = data.find('\n')
            if nl >= 0:
                nl += 1
                self._rbuf = data[nl:]
                return data[:nl]
            buffers = []
            if data:
                buffers.append(data)
            self._rbuf = ""
            while True:
                data = self._sslobj.read(self._rbufsize)
                if not data:
                    break
                buffers.append(data)
                nl = data.find('\n')
                if nl >= 0:
                    nl += 1
                    self._rbuf = data[nl:]
                    buffers[-1] = data[:nl]
                    break
            return "".join(buffers)
        else:
            # Read until size bytes or \n or EOF seen, whichever comes first
            nl = data.find('\n', 0, size)
            if nl >= 0:
                nl += 1
                self._rbuf = data[nl:]
                return data[:nl]
            buf_len = len(data)
            if buf_len >= size:
                self._rbuf = data[size:]
                return data[:size]
            buffers = []
            if data:
                buffers.append(data)
            self._rbuf = ""
            while True:
                data = self._sslobj.read(self._rbufsize)
                if not data:
                    break
                buffers.append(data)
                left = size - buf_len
                nl = data.find('\n', 0, left)
                if nl >= 0:
                    nl += 1
                    self._rbuf = data[nl:]
                    buffers[-1] = data[:nl]
                    break
                n = len(data)
                if n >= left:
                    self._rbuf = data[left:]
                    buffers[-1] = data[:left]
                    break
                buf_len += n
            return "".join(buffers)

    def readlines(self, sizehint=0):
        total = 0
        list = []
        while True:
            line = self.readline()
            if not line:
                break
            list.append(line)
            total += len(line)
            if sizehint and total >= sizehint:
                break
        return list

    # Iterator protocols

    def __iter__(self):
        return self

    def next(self):
        line = self.readline()
        if not line:
            raise StopIteration
        return line




Bill Janssen's avatar
Bill Janssen committed
460 461 462
def wrap_socket(sock, keyfile=None, certfile=None,
                server_side=False, cert_reqs=CERT_NONE,
                ssl_version=PROTOCOL_SSLv23, ca_certs=None):
463

Bill Janssen's avatar
Bill Janssen committed
464 465 466
    return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
                     server_side=server_side, cert_reqs=cert_reqs,
                     ssl_version=ssl_version, ca_certs=ca_certs)
467 468 469 470

# some utility functions

def cert_time_to_seconds(cert_time):
471 472 473 474 475

    """Takes a date-time string in standard ASN1_print form
    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    a Python time value in seconds past the epoch."""

476 477 478
    import time
    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))

479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"

def DER_cert_to_PEM_cert(der_cert_bytes):

    """Takes a certificate in binary DER format and returns the
    PEM version of it as a string."""

    if hasattr(base64, 'standard_b64encode'):
        # preferred because older API gets line-length wrong
        f = base64.standard_b64encode(der_cert_bytes)
        return (PEM_HEADER + '\n' +
                textwrap.fill(f, 64) +
                PEM_FOOTER + '\n')
    else:
        return (PEM_HEADER + '\n' +
                base64.encodestring(der_cert_bytes) +
                PEM_FOOTER + '\n')

def PEM_cert_to_DER_cert(pem_cert_string):

    """Takes a certificate in ASCII PEM format and returns the
    DER-encoded version of it as a byte sequence"""

    if not pem_cert_string.startswith(PEM_HEADER):
        raise ValueError("Invalid PEM encoding; must start with %s"
                         % PEM_HEADER)
    if not pem_cert_string.strip().endswith(PEM_FOOTER):
        raise ValueError("Invalid PEM encoding; must end with %s"
                         % PEM_FOOTER)
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
    return base64.decodestring(d)

def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):

    """Retrieve the certificate from the server at the specified address,
    and return it as a PEM-encoded string.
    If 'ca_certs' is specified, validate the server cert against it.
    If 'ssl_version' is specified, use it in the connection attempt."""

    host, port = addr
    if (ca_certs is not None):
        cert_reqs = CERT_REQUIRED
    else:
        cert_reqs = CERT_NONE
    s = wrap_socket(socket(), ssl_version=ssl_version,
                    cert_reqs=cert_reqs, ca_certs=ca_certs)
    s.connect(addr)
    dercert = s.getpeercert(True)
    s.close()
    return DER_cert_to_PEM_cert(dercert)

def get_protocol_name (protocol_code):
    if protocol_code == PROTOCOL_TLSv1:
        return "TLSv1"
    elif protocol_code == PROTOCOL_SSLv23:
        return "SSLv23"
    elif protocol_code == PROTOCOL_SSLv2:
        return "SSLv2"
    elif protocol_code == PROTOCOL_SSLv3:
        return "SSLv3"
    else:
        return "<unknown>"


544 545 546 547
# a replacement for the old socket.ssl function

def sslwrap_simple (sock, keyfile=None, certfile=None):

548 549 550 551
    """A replacement for the old socket.ssl function.  Designed
    for compability with Python 2.5 and earlier.  Will disappear in
    Python 3.0."""

552 553
    return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
                        PROTOCOL_SSLv23, None)