Kaydet (Commit) 01113fae authored tarafından Christian Heimes's avatar Christian Heimes

Issue #26470: Port ssl and hashlib module to OpenSSL 1.1.0.

This diff is collapsed.
...@@ -51,6 +51,7 @@ The following constants identify various SSL protocol variants: ...@@ -51,6 +51,7 @@ The following constants identify various SSL protocol variants:
PROTOCOL_SSLv2 PROTOCOL_SSLv2
PROTOCOL_SSLv3 PROTOCOL_SSLv3
PROTOCOL_SSLv23 PROTOCOL_SSLv23
PROTOCOL_TLS
PROTOCOL_TLSv1 PROTOCOL_TLSv1
PROTOCOL_TLSv1_1 PROTOCOL_TLSv1_1
PROTOCOL_TLSv1_2 PROTOCOL_TLSv1_2
...@@ -128,9 +129,10 @@ from _ssl import _OPENSSL_API_VERSION ...@@ -128,9 +129,10 @@ from _ssl import _OPENSSL_API_VERSION
_IntEnum._convert( _IntEnum._convert(
'_SSLMethod', __name__, '_SSLMethod', __name__,
lambda name: name.startswith('PROTOCOL_'), lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
source=_ssl) source=_ssl)
PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
try: try:
...@@ -357,13 +359,13 @@ class SSLContext(_SSLContext): ...@@ -357,13 +359,13 @@ class SSLContext(_SSLContext):
__slots__ = ('protocol', '__weakref__') __slots__ = ('protocol', '__weakref__')
_windows_cert_stores = ("CA", "ROOT") _windows_cert_stores = ("CA", "ROOT")
def __new__(cls, protocol, *args, **kwargs): def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs):
self = _SSLContext.__new__(cls, protocol) self = _SSLContext.__new__(cls, protocol)
if protocol != _SSLv2_IF_EXISTS: if protocol != _SSLv2_IF_EXISTS:
self.set_ciphers(_DEFAULT_CIPHERS) self.set_ciphers(_DEFAULT_CIPHERS)
return self return self
def __init__(self, protocol): def __init__(self, protocol=PROTOCOL_TLS):
self.protocol = protocol self.protocol = protocol
def wrap_socket(self, sock, server_side=False, def wrap_socket(self, sock, server_side=False,
...@@ -438,7 +440,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, ...@@ -438,7 +440,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
if not isinstance(purpose, _ASN1Object): if not isinstance(purpose, _ASN1Object):
raise TypeError(purpose) raise TypeError(purpose)
context = SSLContext(PROTOCOL_SSLv23) context = SSLContext(PROTOCOL_TLS)
# SSLv2 considered harmful. # SSLv2 considered harmful.
context.options |= OP_NO_SSLv2 context.options |= OP_NO_SSLv2
...@@ -475,7 +477,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, ...@@ -475,7 +477,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
context.load_default_certs(purpose) context.load_default_certs(purpose)
return context return context
def _create_unverified_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None, def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None,
check_hostname=False, purpose=Purpose.SERVER_AUTH, check_hostname=False, purpose=Purpose.SERVER_AUTH,
certfile=None, keyfile=None, certfile=None, keyfile=None,
cafile=None, capath=None, cadata=None): cafile=None, capath=None, cadata=None):
...@@ -666,7 +668,7 @@ class SSLSocket(socket): ...@@ -666,7 +668,7 @@ class SSLSocket(socket):
def __init__(self, sock=None, keyfile=None, certfile=None, def __init__(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None, ssl_version=PROTOCOL_TLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
...@@ -1055,7 +1057,7 @@ class SSLSocket(socket): ...@@ -1055,7 +1057,7 @@ class SSLSocket(socket):
def wrap_socket(sock, keyfile=None, certfile=None, def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None, ssl_version=PROTOCOL_TLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, suppress_ragged_eofs=True,
ciphers=None): ciphers=None):
...@@ -1124,7 +1126,7 @@ def PEM_cert_to_DER_cert(pem_cert_string): ...@@ -1124,7 +1126,7 @@ def PEM_cert_to_DER_cert(pem_cert_string):
d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
return base64.decodebytes(d.encode('ASCII', 'strict')) return base64.decodebytes(d.encode('ASCII', 'strict'))
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None):
"""Retrieve the certificate from the server at the specified address, """Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string. and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it. If 'ca_certs' is specified, validate the server cert against it.
......
...@@ -30,6 +30,9 @@ else: ...@@ -30,6 +30,9 @@ else:
PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = support.HOST HOST = support.HOST
IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
IS_OPENSSL_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
def data_file(*name): def data_file(*name):
return os.path.join(os.path.dirname(__file__), *name) return os.path.join(os.path.dirname(__file__), *name)
...@@ -150,8 +153,8 @@ class BasicSocketTests(unittest.TestCase): ...@@ -150,8 +153,8 @@ class BasicSocketTests(unittest.TestCase):
def test_str_for_enums(self): def test_str_for_enums(self):
# Make sure that the PROTOCOL_* constants have enum-like string # Make sure that the PROTOCOL_* constants have enum-like string
# reprs. # reprs.
proto = ssl.PROTOCOL_SSLv23 proto = ssl.PROTOCOL_TLS
self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23') self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
ctx = ssl.SSLContext(proto) ctx = ssl.SSLContext(proto)
self.assertIs(ctx.protocol, proto) self.assertIs(ctx.protocol, proto)
...@@ -319,8 +322,8 @@ class BasicSocketTests(unittest.TestCase): ...@@ -319,8 +322,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertGreaterEqual(status, 0) self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15) self.assertLessEqual(status, 15)
# Version string as returned by {Open,Libre}SSL, the format might change # Version string as returned by {Open,Libre}SSL, the format might change
if "LibreSSL" in s: if IS_LIBRESSL:
self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)), self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
(s, t, hex(n))) (s, t, hex(n)))
else: else:
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
...@@ -813,7 +816,8 @@ class ContextTests(unittest.TestCase): ...@@ -813,7 +816,8 @@ class ContextTests(unittest.TestCase):
def test_constructor(self): def test_constructor(self):
for protocol in PROTOCOLS: for protocol in PROTOCOLS:
ssl.SSLContext(protocol) ssl.SSLContext(protocol)
self.assertRaises(TypeError, ssl.SSLContext) ctx = ssl.SSLContext()
self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
self.assertRaises(ValueError, ssl.SSLContext, -1) self.assertRaises(ValueError, ssl.SSLContext, -1)
self.assertRaises(ValueError, ssl.SSLContext, 42) self.assertRaises(ValueError, ssl.SSLContext, 42)
...@@ -834,15 +838,15 @@ class ContextTests(unittest.TestCase): ...@@ -834,15 +838,15 @@ class ContextTests(unittest.TestCase):
def test_options(self): def test_options(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
# OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3, default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
ctx.options) if not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0):
default |= ssl.OP_NO_COMPRESSION
self.assertEqual(default, ctx.options)
ctx.options |= ssl.OP_NO_TLSv1 ctx.options |= ssl.OP_NO_TLSv1
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1, self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
ctx.options)
if can_clear_options(): if can_clear_options():
ctx.options = (ctx.options & ~ssl.OP_NO_SSLv2) | ssl.OP_NO_TLSv1 ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_TLSv1 | ssl.OP_NO_SSLv3, self.assertEqual(default, ctx.options)
ctx.options)
ctx.options = 0 ctx.options = 0
# Ubuntu has OP_NO_SSLv3 forced on by default # Ubuntu has OP_NO_SSLv3 forced on by default
self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3) self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
...@@ -1178,6 +1182,7 @@ class ContextTests(unittest.TestCase): ...@@ -1178,6 +1182,7 @@ class ContextTests(unittest.TestCase):
self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
@unittest.skipIf(sys.platform == "win32", "not-Windows specific") @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
@unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
def test_load_default_certs_env(self): def test_load_default_certs_env(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
with support.EnvironmentVarGuard() as env: with support.EnvironmentVarGuard() as env:
...@@ -1668,13 +1673,13 @@ class SimpleBackgroundTests(unittest.TestCase): ...@@ -1668,13 +1673,13 @@ class SimpleBackgroundTests(unittest.TestCase):
sslobj = ctx.wrap_bio(incoming, outgoing, False, 'localhost') sslobj = ctx.wrap_bio(incoming, outgoing, False, 'localhost')
self.assertIs(sslobj._sslobj.owner, sslobj) self.assertIs(sslobj._sslobj.owner, sslobj)
self.assertIsNone(sslobj.cipher()) self.assertIsNone(sslobj.cipher())
self.assertIsNone(sslobj.shared_ciphers()) self.assertIsNotNone(sslobj.shared_ciphers())
self.assertRaises(ValueError, sslobj.getpeercert) self.assertRaises(ValueError, sslobj.getpeercert)
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertIsNone(sslobj.get_channel_binding('tls-unique')) self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
self.assertTrue(sslobj.cipher()) self.assertTrue(sslobj.cipher())
self.assertIsNone(sslobj.shared_ciphers()) self.assertIsNotNone(sslobj.shared_ciphers())
self.assertTrue(sslobj.getpeercert()) self.assertTrue(sslobj.getpeercert())
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertTrue(sslobj.get_channel_binding('tls-unique')) self.assertTrue(sslobj.get_channel_binding('tls-unique'))
...@@ -2988,7 +2993,7 @@ if _have_threads: ...@@ -2988,7 +2993,7 @@ if _have_threads:
with context.wrap_socket(socket.socket()) as s: with context.wrap_socket(socket.socket()) as s:
self.assertIs(s.version(), None) self.assertIs(s.version(), None)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
self.assertEqual(s.version(), "TLSv1") self.assertEqual(s.version(), 'TLSv1')
self.assertIs(s.version(), None) self.assertIs(s.version(), None)
@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
...@@ -3130,24 +3135,36 @@ if _have_threads: ...@@ -3130,24 +3135,36 @@ if _have_threads:
(['http/3.0', 'http/4.0'], None) (['http/3.0', 'http/4.0'], None)
] ]
for client_protocols, expected in protocol_tests: for client_protocols, expected in protocol_tests:
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
server_context.load_cert_chain(CERTFILE) server_context.load_cert_chain(CERTFILE)
server_context.set_alpn_protocols(server_protocols) server_context.set_alpn_protocols(server_protocols)
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
client_context.load_cert_chain(CERTFILE) client_context.load_cert_chain(CERTFILE)
client_context.set_alpn_protocols(client_protocols) client_context.set_alpn_protocols(client_protocols)
stats = server_params_test(client_context, server_context,
chatty=True, connectionchatty=True)
msg = "failed trying %s (s) and %s (c).\n" \ try:
"was expecting %s, but got %%s from the %%s" \ stats = server_params_test(client_context,
% (str(server_protocols), str(client_protocols), server_context,
str(expected)) chatty=True,
client_result = stats['client_alpn_protocol'] connectionchatty=True)
self.assertEqual(client_result, expected, msg % (client_result, "client")) except ssl.SSLError as e:
server_result = stats['server_alpn_protocols'][-1] \ stats = e
if len(stats['server_alpn_protocols']) else 'nothing'
self.assertEqual(server_result, expected, msg % (server_result, "server")) if expected is None and IS_OPENSSL_1_1:
# OpenSSL 1.1.0 raises handshake error
self.assertIsInstance(stats, ssl.SSLError)
else:
msg = "failed trying %s (s) and %s (c).\n" \
"was expecting %s, but got %%s from the %%s" \
% (str(server_protocols), str(client_protocols),
str(expected))
client_result = stats['client_alpn_protocol']
self.assertEqual(client_result, expected,
msg % (client_result, "client"))
server_result = stats['server_alpn_protocols'][-1] \
if len(stats['server_alpn_protocols']) else 'nothing'
self.assertEqual(server_result, expected,
msg % (server_result, "server"))
def test_selected_npn_protocol(self): def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used # selected_npn_protocol() is None unless NPN is used
...@@ -3295,13 +3312,23 @@ if _have_threads: ...@@ -3295,13 +3312,23 @@ if _have_threads:
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
client_context.verify_mode = ssl.CERT_REQUIRED client_context.verify_mode = ssl.CERT_REQUIRED
client_context.load_verify_locations(SIGNING_CA) client_context.load_verify_locations(SIGNING_CA)
client_context.set_ciphers("RC4") if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
server_context.set_ciphers("AES:RC4") client_context.set_ciphers("AES128:AES256")
server_context.set_ciphers("AES256")
alg1 = "AES256"
alg2 = "AES-256"
else:
client_context.set_ciphers("AES:3DES")
server_context.set_ciphers("3DES")
alg1 = "3DES"
alg2 = "DES-CBC3"
stats = server_params_test(client_context, server_context) stats = server_params_test(client_context, server_context)
ciphers = stats['server_shared_ciphers'][0] ciphers = stats['server_shared_ciphers'][0]
self.assertGreater(len(ciphers), 0) self.assertGreater(len(ciphers), 0)
for name, tls_version, bits in ciphers: for name, tls_version, bits in ciphers:
self.assertIn("RC4", name.split("-")) if not alg1 in name.split("-") and alg2 not in name:
self.fail(name)
def test_read_write_after_close_raises_valuerror(self): def test_read_write_after_close_raises_valuerror(self):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
......
...@@ -75,6 +75,8 @@ Core and Builtins ...@@ -75,6 +75,8 @@ Core and Builtins
Library Library
------- -------
- Issue #26470: Port ssl and hashlib module to OpenSSL 1.1.0.
- Issue #11620: Fix support for SND_MEMORY in winsound.PlaySound. Based on a - Issue #11620: Fix support for SND_MEMORY in winsound.PlaySound. Based on a
patch by Tim Lesher. patch by Tim Lesher.
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
/* EVP is the preferred interface to hashing in OpenSSL */ /* EVP is the preferred interface to hashing in OpenSSL */
#include <openssl/evp.h> #include <openssl/evp.h>
#include <openssl/hmac.h>
/* We use the object interface to discover what hashes OpenSSL supports. */ /* We use the object interface to discover what hashes OpenSSL supports. */
#include <openssl/objects.h> #include <openssl/objects.h>
#include "openssl/err.h" #include "openssl/err.h"
...@@ -32,11 +31,22 @@ ...@@ -32,11 +31,22 @@
#define HASH_OBJ_CONSTRUCTOR 0 #define HASH_OBJ_CONSTRUCTOR 0
#endif #endif
#if (OPENSSL_VERSION_NUMBER < 0x10100000L) || defined(LIBRESSL_VERSION_NUMBER)
/* OpenSSL < 1.1.0 */
#define EVP_MD_CTX_new EVP_MD_CTX_create
#define EVP_MD_CTX_free EVP_MD_CTX_destroy
#define HAS_FAST_PKCS5_PBKDF2_HMAC 0
#include <openssl/hmac.h>
#else
/* OpenSSL >= 1.1.0 */
#define HAS_FAST_PKCS5_PBKDF2_HMAC 1
#endif
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
PyObject *name; /* name of this hash algorithm */ PyObject *name; /* name of this hash algorithm */
EVP_MD_CTX ctx; /* OpenSSL message digest context */ EVP_MD_CTX *ctx; /* OpenSSL message digest context */
#ifdef WITH_THREAD #ifdef WITH_THREAD
PyThread_type_lock lock; /* OpenSSL context lock */ PyThread_type_lock lock; /* OpenSSL context lock */
#endif #endif
...@@ -48,7 +58,6 @@ static PyTypeObject EVPtype; ...@@ -48,7 +58,6 @@ static PyTypeObject EVPtype;
#define DEFINE_CONSTS_FOR_NEW(Name) \ #define DEFINE_CONSTS_FOR_NEW(Name) \
static PyObject *CONST_ ## Name ## _name_obj = NULL; \ static PyObject *CONST_ ## Name ## _name_obj = NULL; \
static EVP_MD_CTX CONST_new_ ## Name ## _ctx; \
static EVP_MD_CTX *CONST_new_ ## Name ## _ctx_p = NULL; static EVP_MD_CTX *CONST_new_ ## Name ## _ctx_p = NULL;
DEFINE_CONSTS_FOR_NEW(md5) DEFINE_CONSTS_FOR_NEW(md5)
...@@ -59,19 +68,57 @@ DEFINE_CONSTS_FOR_NEW(sha384) ...@@ -59,19 +68,57 @@ DEFINE_CONSTS_FOR_NEW(sha384)
DEFINE_CONSTS_FOR_NEW(sha512) DEFINE_CONSTS_FOR_NEW(sha512)
/* LCOV_EXCL_START */
static PyObject *
_setException(PyObject *exc)
{
unsigned long errcode;
const char *lib, *func, *reason;
errcode = ERR_peek_last_error();
if (!errcode) {
PyErr_SetString(exc, "unknown reasons");
return NULL;
}
ERR_clear_error();
lib = ERR_lib_error_string(errcode);
func = ERR_func_error_string(errcode);
reason = ERR_reason_error_string(errcode);
if (lib && func) {
PyErr_Format(exc, "[%s: %s] %s", lib, func, reason);
}
else if (lib) {
PyErr_Format(exc, "[%s] %s", lib, reason);
}
else {
PyErr_SetString(exc, reason);
}
return NULL;
}
/* LCOV_EXCL_STOP */
static EVPobject * static EVPobject *
newEVPobject(PyObject *name) newEVPobject(PyObject *name)
{ {
EVPobject *retval = (EVPobject *)PyObject_New(EVPobject, &EVPtype); EVPobject *retval = (EVPobject *)PyObject_New(EVPobject, &EVPtype);
if (retval == NULL) {
return NULL;
}
retval->ctx = EVP_MD_CTX_new();
if (retval->ctx == NULL) {
PyErr_NoMemory();
return NULL;
}
/* save the name for .name to return */ /* save the name for .name to return */
if (retval != NULL) { Py_INCREF(name);
Py_INCREF(name); retval->name = name;
retval->name = name;
#ifdef WITH_THREAD #ifdef WITH_THREAD
retval->lock = NULL; retval->lock = NULL;
#endif #endif
}
return retval; return retval;
} }
...@@ -86,7 +133,7 @@ EVP_hash(EVPobject *self, const void *vp, Py_ssize_t len) ...@@ -86,7 +133,7 @@ EVP_hash(EVPobject *self, const void *vp, Py_ssize_t len)
process = MUNCH_SIZE; process = MUNCH_SIZE;
else else
process = Py_SAFE_DOWNCAST(len, Py_ssize_t, unsigned int); process = Py_SAFE_DOWNCAST(len, Py_ssize_t, unsigned int);
EVP_DigestUpdate(&self->ctx, (const void*)cp, process); EVP_DigestUpdate(self->ctx, (const void*)cp, process);
len -= process; len -= process;
cp += process; cp += process;
} }
...@@ -101,16 +148,19 @@ EVP_dealloc(EVPobject *self) ...@@ -101,16 +148,19 @@ EVP_dealloc(EVPobject *self)
if (self->lock != NULL) if (self->lock != NULL)
PyThread_free_lock(self->lock); PyThread_free_lock(self->lock);
#endif #endif
EVP_MD_CTX_cleanup(&self->ctx); EVP_MD_CTX_free(self->ctx);
Py_XDECREF(self->name); Py_XDECREF(self->name);
PyObject_Del(self); PyObject_Del(self);
} }
static void locked_EVP_MD_CTX_copy(EVP_MD_CTX *new_ctx_p, EVPobject *self) static int
locked_EVP_MD_CTX_copy(EVP_MD_CTX *new_ctx_p, EVPobject *self)
{ {
int result;
ENTER_HASHLIB(self); ENTER_HASHLIB(self);
EVP_MD_CTX_copy(new_ctx_p, &self->ctx); result = EVP_MD_CTX_copy(new_ctx_p, self->ctx);
LEAVE_HASHLIB(self); LEAVE_HASHLIB(self);
return result;
} }
/* External methods for a hash object */ /* External methods for a hash object */
...@@ -126,7 +176,9 @@ EVP_copy(EVPobject *self, PyObject *unused) ...@@ -126,7 +176,9 @@ EVP_copy(EVPobject *self, PyObject *unused)
if ( (newobj = newEVPobject(self->name))==NULL) if ( (newobj = newEVPobject(self->name))==NULL)
return NULL; return NULL;
locked_EVP_MD_CTX_copy(&newobj->ctx, self); if (!locked_EVP_MD_CTX_copy(newobj->ctx, self)) {
return _setException(PyExc_ValueError);
}
return (PyObject *)newobj; return (PyObject *)newobj;
} }
...@@ -137,16 +189,24 @@ static PyObject * ...@@ -137,16 +189,24 @@ static PyObject *
EVP_digest(EVPobject *self, PyObject *unused) EVP_digest(EVPobject *self, PyObject *unused)
{ {
unsigned char digest[EVP_MAX_MD_SIZE]; unsigned char digest[EVP_MAX_MD_SIZE];
EVP_MD_CTX temp_ctx; EVP_MD_CTX *temp_ctx;
PyObject *retval; PyObject *retval;
unsigned int digest_size; unsigned int digest_size;
locked_EVP_MD_CTX_copy(&temp_ctx, self); temp_ctx = EVP_MD_CTX_new();
digest_size = EVP_MD_CTX_size(&temp_ctx); if (temp_ctx == NULL) {
EVP_DigestFinal(&temp_ctx, digest, NULL); PyErr_NoMemory();
return NULL;
}
if (!locked_EVP_MD_CTX_copy(temp_ctx, self)) {
return _setException(PyExc_ValueError);
}
digest_size = EVP_MD_CTX_size(temp_ctx);
EVP_DigestFinal(temp_ctx, digest, NULL);
retval = PyBytes_FromStringAndSize((const char *)digest, digest_size); retval = PyBytes_FromStringAndSize((const char *)digest, digest_size);
EVP_MD_CTX_cleanup(&temp_ctx); EVP_MD_CTX_free(temp_ctx);
return retval; return retval;
} }
...@@ -157,15 +217,23 @@ static PyObject * ...@@ -157,15 +217,23 @@ static PyObject *
EVP_hexdigest(EVPobject *self, PyObject *unused) EVP_hexdigest(EVPobject *self, PyObject *unused)
{ {
unsigned char digest[EVP_MAX_MD_SIZE]; unsigned char digest[EVP_MAX_MD_SIZE];
EVP_MD_CTX temp_ctx; EVP_MD_CTX *temp_ctx;
unsigned int digest_size; unsigned int digest_size;
temp_ctx = EVP_MD_CTX_new();
if (temp_ctx == NULL) {
PyErr_NoMemory();
return NULL;
}
/* Get the raw (binary) digest value */ /* Get the raw (binary) digest value */
locked_EVP_MD_CTX_copy(&temp_ctx, self); if (!locked_EVP_MD_CTX_copy(temp_ctx, self)) {
digest_size = EVP_MD_CTX_size(&temp_ctx); return _setException(PyExc_ValueError);
EVP_DigestFinal(&temp_ctx, digest, NULL); }
digest_size = EVP_MD_CTX_size(temp_ctx);
EVP_DigestFinal(temp_ctx, digest, NULL);
EVP_MD_CTX_cleanup(&temp_ctx); EVP_MD_CTX_free(temp_ctx);
return _Py_strhex((const char *)digest, digest_size); return _Py_strhex((const char *)digest, digest_size);
} }
...@@ -219,7 +287,7 @@ static PyObject * ...@@ -219,7 +287,7 @@ static PyObject *
EVP_get_block_size(EVPobject *self, void *closure) EVP_get_block_size(EVPobject *self, void *closure)
{ {
long block_size; long block_size;
block_size = EVP_MD_CTX_block_size(&self->ctx); block_size = EVP_MD_CTX_block_size(self->ctx);
return PyLong_FromLong(block_size); return PyLong_FromLong(block_size);
} }
...@@ -227,7 +295,7 @@ static PyObject * ...@@ -227,7 +295,7 @@ static PyObject *
EVP_get_digest_size(EVPobject *self, void *closure) EVP_get_digest_size(EVPobject *self, void *closure)
{ {
long size; long size;
size = EVP_MD_CTX_size(&self->ctx); size = EVP_MD_CTX_size(self->ctx);
return PyLong_FromLong(size); return PyLong_FromLong(size);
} }
...@@ -288,7 +356,7 @@ EVP_tp_init(EVPobject *self, PyObject *args, PyObject *kwds) ...@@ -288,7 +356,7 @@ EVP_tp_init(EVPobject *self, PyObject *args, PyObject *kwds)
PyBuffer_Release(&view); PyBuffer_Release(&view);
return -1; return -1;
} }
EVP_DigestInit(&self->ctx, digest); EVP_DigestInit(self->ctx, digest);
self->name = name_obj; self->name = name_obj;
Py_INCREF(self->name); Py_INCREF(self->name);
...@@ -385,9 +453,9 @@ EVPnew(PyObject *name_obj, ...@@ -385,9 +453,9 @@ EVPnew(PyObject *name_obj,
return NULL; return NULL;
if (initial_ctx) { if (initial_ctx) {
EVP_MD_CTX_copy(&self->ctx, initial_ctx); EVP_MD_CTX_copy(self->ctx, initial_ctx);
} else { } else {
EVP_DigestInit(&self->ctx, digest); EVP_DigestInit(self->ctx, digest);
} }
if (cp && len) { if (cp && len) {
...@@ -453,6 +521,7 @@ EVP_new(PyObject *self, PyObject *args, PyObject *kwdict) ...@@ -453,6 +521,7 @@ EVP_new(PyObject *self, PyObject *args, PyObject *kwdict)
#define PY_PBKDF2_HMAC 1 #define PY_PBKDF2_HMAC 1
#if !HAS_FAST_PKCS5_PBKDF2_HMAC
/* Improved implementation of PKCS5_PBKDF2_HMAC() /* Improved implementation of PKCS5_PBKDF2_HMAC()
* *
* PKCS5_PBKDF2_HMAC_fast() hashes the password exactly one time instead of * PKCS5_PBKDF2_HMAC_fast() hashes the password exactly one time instead of
...@@ -534,37 +603,8 @@ PKCS5_PBKDF2_HMAC_fast(const char *pass, int passlen, ...@@ -534,37 +603,8 @@ PKCS5_PBKDF2_HMAC_fast(const char *pass, int passlen,
HMAC_CTX_cleanup(&hctx_tpl); HMAC_CTX_cleanup(&hctx_tpl);
return 1; return 1;
} }
#endif
/* LCOV_EXCL_START */
static PyObject *
_setException(PyObject *exc)
{
unsigned long errcode;
const char *lib, *func, *reason;
errcode = ERR_peek_last_error();
if (!errcode) {
PyErr_SetString(exc, "unknown reasons");
return NULL;
}
ERR_clear_error();
lib = ERR_lib_error_string(errcode);
func = ERR_func_error_string(errcode);
reason = ERR_reason_error_string(errcode);
if (lib && func) {
PyErr_Format(exc, "[%s: %s] %s", lib, func, reason);
}
else if (lib) {
PyErr_Format(exc, "[%s] %s", lib, reason);
}
else {
PyErr_SetString(exc, reason);
}
return NULL;
}
/* LCOV_EXCL_STOP */
PyDoc_STRVAR(pbkdf2_hmac__doc__, PyDoc_STRVAR(pbkdf2_hmac__doc__,
"pbkdf2_hmac(hash_name, password, salt, iterations, dklen=None) -> key\n\ "pbkdf2_hmac(hash_name, password, salt, iterations, dklen=None) -> key\n\
...@@ -646,10 +686,17 @@ pbkdf2_hmac(PyObject *self, PyObject *args, PyObject *kwdict) ...@@ -646,10 +686,17 @@ pbkdf2_hmac(PyObject *self, PyObject *args, PyObject *kwdict)
key = PyBytes_AS_STRING(key_obj); key = PyBytes_AS_STRING(key_obj);
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
#if HAS_FAST_PKCS5_PBKDF2_HMAC
retval = PKCS5_PBKDF2_HMAC((char*)password.buf, (int)password.len,
(unsigned char *)salt.buf, (int)salt.len,
iterations, digest, dklen,
(unsigned char *)key);
#else
retval = PKCS5_PBKDF2_HMAC_fast((char*)password.buf, (int)password.len, retval = PKCS5_PBKDF2_HMAC_fast((char*)password.buf, (int)password.len,
(unsigned char *)salt.buf, (int)salt.len, (unsigned char *)salt.buf, (int)salt.len,
iterations, digest, dklen, iterations, digest, dklen,
(unsigned char *)key); (unsigned char *)key);
#endif
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
if (!retval) { if (!retval) {
...@@ -768,7 +815,7 @@ generate_hash_name_list(void) ...@@ -768,7 +815,7 @@ generate_hash_name_list(void)
if (CONST_ ## NAME ## _name_obj == NULL) { \ if (CONST_ ## NAME ## _name_obj == NULL) { \
CONST_ ## NAME ## _name_obj = PyUnicode_FromString(#NAME); \ CONST_ ## NAME ## _name_obj = PyUnicode_FromString(#NAME); \
if (EVP_get_digestbyname(#NAME)) { \ if (EVP_get_digestbyname(#NAME)) { \
CONST_new_ ## NAME ## _ctx_p = &CONST_new_ ## NAME ## _ctx; \ CONST_new_ ## NAME ## _ctx_p = EVP_MD_CTX_new(); \
EVP_DigestInit(CONST_new_ ## NAME ## _ctx_p, EVP_get_digestbyname(#NAME)); \ EVP_DigestInit(CONST_new_ ## NAME ## _ctx_p, EVP_get_digestbyname(#NAME)); \
} \ } \
} \ } \
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment