ssl.py 14.9 KB
Newer Older
1 2 3
# Wrapper module for _ssl, providing some additional facilities
# implemented in Python.  Written by Bill Janssen.

4
"""This module provides some more Pythonic support for SSL.
5 6 7

Object types:

8
  SSLSocket -- subtype of socket.socket which does SSL over the socket
9 10 11

Exceptions:

12
  SSLError -- exception raised for I/O errors
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

Functions:

  cert_time_to_seconds -- convert time string used for certificate
                          notBefore and notAfter functions to integer
                          seconds past the Epoch (the time values
                          returned from time.time())

  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
                          by the server running on HOST at port PORT.  No
                          validation of the certificate is performed.

Integer constants:

SSL_ERROR_ZERO_RETURN
SSL_ERROR_WANT_READ
SSL_ERROR_WANT_WRITE
SSL_ERROR_WANT_X509_LOOKUP
SSL_ERROR_SYSCALL
SSL_ERROR_SSL
SSL_ERROR_WANT_CONNECT

SSL_ERROR_EOF
SSL_ERROR_INVALID_ERROR_CODE

The following group define certificate requirements that one side is
allowing/requiring from the other side:

CERT_NONE - no certificates from the other side are required (or will
            be looked at if provided)
CERT_OPTIONAL - certificates are not required, but if provided will be
                validated, and if validation fails, the connection will
                also fail
CERT_REQUIRED - certificates are required, and will be validated, and
                if validation fails, the connection will also fail

The following constants identify various SSL protocol variants:

PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""

57
import textwrap
58 59

import _ssl             # if we can't import it, let the error propagate
60 61

from _ssl import SSLError
62
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
63 64
from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
                  PROTOCOL_TLSv1)
65
from _ssl import RAND_status, RAND_egd, RAND_add
66 67 68 69 70 71 72 73 74 75 76
from _ssl import (
    SSL_ERROR_ZERO_RETURN,
    SSL_ERROR_WANT_READ,
    SSL_ERROR_WANT_WRITE,
    SSL_ERROR_WANT_X509_LOOKUP,
    SSL_ERROR_SYSCALL,
    SSL_ERROR_SSL,
    SSL_ERROR_WANT_CONNECT,
    SSL_ERROR_EOF,
    SSL_ERROR_INVALID_ERROR_CODE,
    )
77 78

from socket import getnameinfo as _getnameinfo
79
from socket import error as socket_error
80
from socket import dup as _dup
81
from socket import socket, AF_INET, SOCK_STREAM
82
import base64        # for DER-to-PEM translation
83
import traceback
84

85
class SSLSocket(socket):
86

87 88 89 90
    """This class implements a subtype of socket.socket that wraps
    the underlying OS socket in an SSL context when necessary, and
    provides read and write methods over that channel."""

91
    def __init__(self, sock=None, keyfile=None, certfile=None,
92
                 server_side=False, cert_reqs=CERT_NONE,
93 94 95 96 97 98
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
                 suppress_ragged_eofs=True):

        if sock is not None:
99 100 101 102 103
            socket.__init__(self,
                            family=sock.family,
                            type=sock.type,
                            proto=sock.proto,
                            fileno=_dup(sock.fileno()))
104
            sock.close()
105 106 107 108 109 110 111
        elif fileno is not None:
            socket.__init__(self, fileno=fileno)
        else:
            socket.__init__(self, family=family, type=type, proto=proto)

        self._closed = False

112 113
        if certfile and not keyfile:
            keyfile = certfile
114 115 116
        # see if it's connected
        try:
            socket.getpeername(self)
Benjamin Peterson's avatar
Benjamin Peterson committed
117
        except socket_error:
118 119
            # no, no connection yet
            self._sslobj = None
120
        else:
121
            # yes, create the SSL object
122 123 124 125 126
            try:
                self._sslobj = _ssl.sslwrap(self, server_side,
                                            keyfile, certfile,
                                            cert_reqs, ssl_version, ca_certs)
                if do_handshake_on_connect:
127 128 129 130
                    timeout = self.gettimeout()
                    if timeout == 0.0:
                        # non-blocking
                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
131
                    self.do_handshake()
132

133 134 135 136
            except socket_error as x:
                self.close()
                raise x

137 138 139 140 141
        self.keyfile = keyfile
        self.certfile = certfile
        self.cert_reqs = cert_reqs
        self.ssl_version = ssl_version
        self.ca_certs = ca_certs
142 143 144
        self.do_handshake_on_connect = do_handshake_on_connect
        self.suppress_ragged_eofs = suppress_ragged_eofs

145 146 147 148
    def dup(self):
        raise NotImplemented("Can't dup() %s instances" %
                             self.__class__.__name__)

149 150 151
    def _checkClosed(self, msg=None):
        # raise an exception here if you wish to check for spurious closes
        pass
152

153
    def read(self, len=0, buffer=None):
154 155 156
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""

157 158 159
        self._checkClosed()
        try:
            if buffer:
160
                v = self._sslobj.read(buffer, len)
161
            else:
162 163
                v = self._sslobj.read(len or 1024)
            return v
164 165
        except SSLError as x:
            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
166 167 168 169
                if buffer:
                    return 0
                else:
                    return b''
170 171
            else:
                raise
172 173

    def write(self, data):
174 175 176
        """Write DATA to the underlying SSL channel.  Returns
        number of bytes of DATA actually transmitted."""

177
        self._checkClosed()
178 179
        return self._sslobj.write(data)

180
    def getpeercert(self, binary_form=False):
181 182 183 184 185
        """Returns a formatted version of the data in the
        certificate provided by the other end of the SSL channel.
        Return None if no certificate was provided, {} if a
        certificate was provided, but not validated."""

186
        self._checkClosed()
187 188
        return self._sslobj.peer_certificate(binary_form)

189
    def cipher(self):
190
        self._checkClosed()
191 192 193 194
        if not self._sslobj:
            return None
        else:
            return self._sslobj.cipher()
195

196
    def send(self, data, flags=0):
197
        self._checkClosed()
198 199 200 201 202
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to send() on %s" %
                    self.__class__)
203 204 205 206 207 208 209 210 211 212 213 214
            while True:
                try:
                    v = self._sslobj.write(data)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        return 0
                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
                        return 0
                    else:
                        raise
                else:
                    return v
215 216
        else:
            return socket.send(self, data, flags)
217

218
    def sendto(self, data, addr, flags=0):
219
        self._checkClosed()
220
        if self._sslobj:
221
            raise ValueError("sendto not allowed on instances of %s" %
222 223
                             self.__class__)
        else:
224
            return socket.sendto(self, data, addr, flags)
225

226
    def sendall(self, data, flags=0):
227
        self._checkClosed()
228
        if self._sslobj:
229 230 231 232 233 234
            amount = len(data)
            count = 0
            while (count < amount):
                v = self.send(data[count:])
                count += v
            return amount
235 236
        else:
            return socket.sendall(self, data, flags)
237

238
    def recv(self, buflen=1024, flags=0):
239
        self._checkClosed()
240 241 242
        if self._sslobj:
            if flags != 0:
                raise ValueError(
243 244
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
245 246 247 248 249 250 251 252
            while True:
                try:
                    return self.read(buflen)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        continue
                    else:
                        raise x
253 254
        else:
            return socket.recv(self, buflen, flags)
255

256
    def recv_into(self, buffer, nbytes=None, flags=0):
257 258 259 260 261 262 263 264
        self._checkClosed()
        if buffer and (nbytes is None):
            nbytes = len(buffer)
        elif nbytes is None:
            nbytes = 1024
        if self._sslobj:
            if flags != 0:
                raise ValueError(
265 266
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
267 268 269 270 271 272 273 274 275 276 277 278
            while True:
                try:
                    v = self.read(nbytes, buffer)
                    return v
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        continue
                    else:
                        raise x
        else:
            return socket.recv_into(self, buffer, nbytes, flags)

279
    def recvfrom(self, addr, buflen=1024, flags=0):
280
        self._checkClosed()
281
        if self._sslobj:
282
            raise ValueError("recvfrom not allowed on instances of %s" %
283 284
                             self.__class__)
        else:
285
            return socket.recvfrom(self, addr, buflen, flags)
286

287 288 289 290 291 292 293 294
    def recvfrom_into(self, buffer, nbytes=None, flags=0):
        self._checkClosed()
        if self._sslobj:
            raise ValueError("recvfrom_into not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.recvfrom_into(self, buffer, nbytes, flags)

295
    def pending(self):
296 297 298 299 300 301
        self._checkClosed()
        if self._sslobj:
            return self._sslobj.pending()
        else:
            return 0

302
    def shutdown(self, how):
303
        self._checkClosed()
304
        self._sslobj = None
305
        socket.shutdown(self, how)
306

307 308 309 310 311 312 313 314
    def unwrap (self):
        if self._sslobj:
            s = self._sslobj.shutdown()
            self._sslobj = None
            return s
        else:
            raise ValueError("No SSL wrapper around " + str(self))

315
    def _real_close(self):
316
        self._sslobj = None
317
        # self._closed = True
318
        socket._real_close(self)
319

320
    def do_handshake(self, block=False):
321 322
        """Perform a TLS/SSL handshake."""

323
        timeout = self.gettimeout()
324
        try:
325 326
            if timeout == 0.0 and block:
                self.settimeout(None)
327
            self._sslobj.do_handshake()
328 329
        finally:
            self.settimeout(timeout)
330 331

    def connect(self, addr):
332 333 334
        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""

335 336
        # Here we assume that the socket is client-side, and not
        # connected at the time of the call.  We connect it, then wrap it.
337
        if self._sslobj:
338
            raise ValueError("attempt to connect already-connected SSLSocket!")
339
        socket.connect(self, addr)
340
        self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
341 342
                                    self.cert_reqs, self.ssl_version,
                                    self.ca_certs)
343 344 345 346 347 348
        try:
            if self.do_handshake_on_connect:
                self.do_handshake()
        except:
            self._sslobj = None
            raise
349 350

    def accept(self):
351 352 353 354 355
        """Accepts a new connection from a remote client, and returns
        a tuple containing that new connection wrapped with a server-side
        SSL channel, and the address of the remote client."""

        newsock, addr = socket.accept(self)
356 357 358
        return (SSLSocket(sock=newsock,
                          keyfile=self.keyfile, certfile=self.certfile,
                          server_side=True,
359 360
                          cert_reqs=self.cert_reqs,
                          ssl_version=self.ssl_version,
361
                          ca_certs=self.ca_certs,
362 363
                          do_handshake_on_connect=
                              self.do_handshake_on_connect),
364
                addr)
365

366
    def __del__(self):
367
        # sys.stderr.write("__del__ on %s\n" % repr(self))
368 369
        self._real_close()

370

371 372
def wrap_socket(sock, keyfile=None, certfile=None,
                server_side=False, cert_reqs=CERT_NONE,
373
                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
374 375
                do_handshake_on_connect=True,
                suppress_ragged_eofs=True):
376

377
    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
378
                     server_side=server_side, cert_reqs=cert_reqs,
379
                     ssl_version=ssl_version, ca_certs=ca_certs,
380 381
                     do_handshake_on_connect=do_handshake_on_connect,
                     suppress_ragged_eofs=suppress_ragged_eofs)
382

383 384 385
# some utility functions

def cert_time_to_seconds(cert_time):
386 387 388 389
    """Takes a date-time string in standard ASN1_print form
    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    a Python time value in seconds past the epoch."""

390 391 392
    import time
    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))

393 394 395 396 397 398 399
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"

def DER_cert_to_PEM_cert(der_cert_bytes):
    """Takes a certificate in binary DER format and returns the
    PEM version of it as a string."""

400 401 402 403
    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
    return (PEM_HEADER + '\n' +
            textwrap.fill(f, 64) + '\n' +
            PEM_FOOTER + '\n')
404 405 406 407 408 409 410 411 412 413 414 415

def PEM_cert_to_DER_cert(pem_cert_string):
    """Takes a certificate in ASCII PEM format and returns the
    DER-encoded version of it as a byte sequence"""

    if not pem_cert_string.startswith(PEM_HEADER):
        raise ValueError("Invalid PEM encoding; must start with %s"
                         % PEM_HEADER)
    if not pem_cert_string.strip().endswith(PEM_FOOTER):
        raise ValueError("Invalid PEM encoding; must end with %s"
                         % PEM_FOOTER)
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
416
    return base64.decodebytes(d.encode('ASCII', 'strict'))
417

418
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
    """Retrieve the certificate from the server at the specified address,
    and return it as a PEM-encoded string.
    If 'ca_certs' is specified, validate the server cert against it.
    If 'ssl_version' is specified, use it in the connection attempt."""

    host, port = addr
    if (ca_certs is not None):
        cert_reqs = CERT_REQUIRED
    else:
        cert_reqs = CERT_NONE
    s = wrap_socket(socket(), ssl_version=ssl_version,
                    cert_reqs=cert_reqs, ca_certs=ca_certs)
    s.connect(addr)
    dercert = s.getpeercert(True)
    s.close()
    return DER_cert_to_PEM_cert(dercert)

436
def get_protocol_name(protocol_code):
437 438 439 440 441 442 443 444 445 446
    if protocol_code == PROTOCOL_TLSv1:
        return "TLSv1"
    elif protocol_code == PROTOCOL_SSLv23:
        return "SSLv23"
    elif protocol_code == PROTOCOL_SSLv2:
        return "SSLv2"
    elif protocol_code == PROTOCOL_SSLv3:
        return "SSLv3"
    else:
        return "<unknown>"