Kaydet (Commit) 39eb8fa0 authored tarafından Guido van Rossum's avatar Guido van Rossum

This is roughly socket2.diff from issue 1378, with a few changes applied

to ssl.py (no need to test whether we can dup any more).
Regular sockets no longer have a _base, but we still have explicit
reference counting of socket objects for the benefit of makefile();
using duplicate sockets won't work for SSLSocket.
üst dd9e3b87
...@@ -26,6 +26,15 @@ PyAPI_FUNC(size_t) PyLong_AsSize_t(PyObject *); ...@@ -26,6 +26,15 @@ PyAPI_FUNC(size_t) PyLong_AsSize_t(PyObject *);
PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLong(PyObject *); PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLong(PyObject *);
PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLongMask(PyObject *); PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLongMask(PyObject *);
/* Used by socketmodule.c */
#if SIZEOF_SOCKET_T <= SIZEOF_LONG
#define PyLong_FromSocket_t(fd) PyLong_FromLong((SOCKET_T)(fd))
#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLong(fd)
#else
#define PyLong_FromSocket_t(fd) PyLong_FromLongLong(((SOCKET_T)(fd));
#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLongLong(fd)
#endif
/* For use by intobject.c only */ /* For use by intobject.c only */
PyAPI_DATA(int) _PyLong_DigitValue[256]; PyAPI_DATA(int) _PyLong_DigitValue[256];
......
...@@ -79,28 +79,14 @@ if sys.platform.lower().startswith("win"): ...@@ -79,28 +79,14 @@ if sys.platform.lower().startswith("win"):
__all__.append("errorTab") __all__.append("errorTab")
# True if os.dup() can duplicate socket descriptors.
# (On Windows at least, os.dup only works on files)
_can_dup_socket = hasattr(_socket.socket, "dup")
if _can_dup_socket:
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
nfd = os.dup(fd)
return socket(family, type, proto, fileno=nfd)
class socket(_socket.socket): class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method.""" """A subclass of _socket.socket adding the makefile() method."""
__slots__ = ["__weakref__", "_io_refs", "_closed"] __slots__ = ["__weakref__", "_io_refs", "_closed"]
if not _can_dup_socket:
__slots__.append("_base")
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
if fileno is None: _socket.socket.__init__(self, family, type, proto, fileno)
_socket.socket.__init__(self, family, type, proto)
else:
_socket.socket.__init__(self, family, type, proto, fileno)
self._io_refs = 0 self._io_refs = 0
self._closed = False self._closed = False
...@@ -114,23 +100,29 @@ class socket(_socket.socket): ...@@ -114,23 +100,29 @@ class socket(_socket.socket):
s[7:]) s[7:])
return s return s
def dup(self):
"""dup() -> socket object
Return a new socket object connected to the same system resource.
"""
fd = dup(self.fileno())
sock = self.__class__(self.family, self.type, self.proto, fileno=fd)
sock.settimeout(self.gettimeout())
return sock
def accept(self): def accept(self):
"""Wrap accept() to give the connection the right type.""" """accept() -> (socket object, address info)
conn, addr = _socket.socket.accept(self)
fd = conn.fileno() Wait for an incoming connection. Return a new socket
nfd = fd representing the connection, and the address of the client.
if _can_dup_socket: For IP sockets, the address info is a pair (hostaddr, port).
nfd = os.dup(fd) """
wrapper = socket(self.family, self.type, self.proto, fileno=nfd) fd, addr = self._accept()
if fd == nfd: return socket(self.family, self.type, self.proto, fileno=fd), addr
wrapper._base = conn # Keep the base alive
else:
conn.close()
return wrapper, addr
def makefile(self, mode="r", buffering=None, *, def makefile(self, mode="r", buffering=None, *,
encoding=None, newline=None): encoding=None, newline=None):
"""Return an I/O stream connected to the socket. """makefile(...) -> an I/O stream connected to the socket
The arguments are as for io.open() after the filename, The arguments are as for io.open() after the filename,
except the only mode characters supported are 'r', 'w' and 'b'. except the only mode characters supported are 'r', 'w' and 'b'.
...@@ -184,21 +176,18 @@ class socket(_socket.socket): ...@@ -184,21 +176,18 @@ class socket(_socket.socket):
def close(self): def close(self):
self._closed = True self._closed = True
if self._io_refs < 1: if self._io_refs <= 0:
self._real_close() _socket.socket.close(self)
# _real_close calls close on the _socket.socket base class.
if not _can_dup_socket: def fromfd(fd, family, type, proto=0):
def _real_close(self): """ fromfd(fd, family, type[, proto]) -> socket object
_socket.socket.close(self)
base = getattr(self, "_base", None) Create a socket object from a duplicate of the given file
if base is not None: descriptor. The remaining arguments are the same as for socket().
self._base = None """
base.close() nfd = dup(fd)
else: return socket(family, type, proto, nfd)
def _real_close(self):
_socket.socket.close(self)
class SocketIO(io.RawIOBase): class SocketIO(io.RawIOBase):
......
...@@ -78,8 +78,8 @@ from _ssl import ( ...@@ -78,8 +78,8 @@ from _ssl import (
from socket import socket, AF_INET, SOCK_STREAM, error from socket import socket, AF_INET, SOCK_STREAM, error
from socket import getnameinfo as _getnameinfo from socket import getnameinfo as _getnameinfo
from socket import error as socket_error from socket import error as socket_error
from socket import dup as _dup
import base64 # for DER-to-PEM translation import base64 # for DER-to-PEM translation
_can_dup_socket = hasattr(socket, "dup")
class SSLSocket(socket): class SSLSocket(socket):
...@@ -99,20 +99,11 @@ class SSLSocket(socket): ...@@ -99,20 +99,11 @@ class SSLSocket(socket):
if sock is not None: if sock is not None:
# copied this code from socket.accept() # copied this code from socket.accept()
fd = sock.fileno() fd = sock.fileno()
nfd = fd nfd = _dup(fd)
if _can_dup_socket: socket.__init__(self, family=sock.family, type=sock.type,
nfd = os.dup(fd) proto=sock.proto, fileno=nfd)
try: sock.close()
socket.__init__(self, family=sock.family, type=sock.type, sock = None
proto=sock.proto, fileno=nfd)
except:
if nfd != fd:
os.close(nfd)
else:
if fd != nfd:
sock.close()
sock = None
elif fileno is not None: elif fileno is not None:
socket.__init__(self, fileno=fileno) socket.__init__(self, fileno=fileno)
else: else:
......
...@@ -575,6 +575,15 @@ class BasicTCPTest(SocketConnectedTest): ...@@ -575,6 +575,15 @@ class BasicTCPTest(SocketConnectedTest):
def _testFromFd(self): def _testFromFd(self):
self.serv_conn.send(MSG) self.serv_conn.send(MSG)
def testDup(self):
# Testing dup()
sock = self.cli_conn.dup()
msg = sock.recv(1024)
self.assertEqual(msg, MSG)
def _testDup(self):
self.serv_conn.send(MSG)
def testShutdown(self): def testShutdown(self):
# Testing shutdown() # Testing shutdown()
msg = self.cli_conn.recv(1024) msg = self.cli_conn.recv(1024)
......
...@@ -89,12 +89,12 @@ A socket object represents one endpoint of a network connection.\n\ ...@@ -89,12 +89,12 @@ A socket object represents one endpoint of a network connection.\n\
\n\ \n\
Methods of socket objects (keyword arguments not allowed):\n\ Methods of socket objects (keyword arguments not allowed):\n\
\n\ \n\
accept() -- accept a connection, returning new socket and client address\n\ _accept() -- accept connection, returning new socket fd and client address\n\
bind(addr) -- bind the socket to a local address\n\ bind(addr) -- bind the socket to a local address\n\
close() -- close the socket\n\ close() -- close the socket\n\
connect(addr) -- connect the socket to a remote address\n\ connect(addr) -- connect the socket to a remote address\n\
connect_ex(addr) -- connect, return an error code instead of an exception\n\ connect_ex(addr) -- connect, return an error code instead of an exception\n\
dup() -- return a new socket object identical to the current one [*]\n\ _dup() -- return a new socket fd duplicated from fileno()\n\
fileno() -- return underlying file descriptor\n\ fileno() -- return underlying file descriptor\n\
getpeername() -- return remote address [*]\n\ getpeername() -- return remote address [*]\n\
getsockname() -- return local address\n\ getsockname() -- return local address\n\
...@@ -327,10 +327,26 @@ const char *inet_ntop(int af, const void *src, char *dst, socklen_t size); ...@@ -327,10 +327,26 @@ const char *inet_ntop(int af, const void *src, char *dst, socklen_t size);
#include "getnameinfo.c" #include "getnameinfo.c"
#endif #endif
#if defined(MS_WINDOWS) #ifdef MS_WINDOWS
/* seem to be a few differences in the API */ /* On Windows a socket is really a handle not an fd */
static SOCKET
dup_socket(SOCKET handle)
{
HANDLE newhandle;
if (!DuplicateHandle(GetCurrentProcess(), (HANDLE)handle,
GetCurrentProcess(), &newhandle,
0, FALSE, DUPLICATE_SAME_ACCESS))
{
WSASetLastError(GetLastError());
return INVALID_SOCKET;
}
return (SOCKET)newhandle;
}
#define SOCKETCLOSE closesocket #define SOCKETCLOSE closesocket
#define NO_DUP /* Actually it exists on NT 3.5, but what the heck... */ #else
/* On Unix we can use dup to duplicate the file descriptor of a socket*/
#define dup_socket(fd) dup(fd)
#endif #endif
#ifdef MS_WIN32 #ifdef MS_WIN32
...@@ -628,7 +644,7 @@ internal_select(PySocketSockObject *s, int writing) ...@@ -628,7 +644,7 @@ internal_select(PySocketSockObject *s, int writing)
pollfd.events = writing ? POLLOUT : POLLIN; pollfd.events = writing ? POLLOUT : POLLIN;
/* s->sock_timeout is in seconds, timeout in ms */ /* s->sock_timeout is in seconds, timeout in ms */
timeout = (int)(s->sock_timeout * 1000 + 0.5); timeout = (int)(s->sock_timeout * 1000 + 0.5);
n = poll(&pollfd, 1, timeout); n = poll(&pollfd, 1, timeout);
} }
#else #else
...@@ -648,7 +664,7 @@ internal_select(PySocketSockObject *s, int writing) ...@@ -648,7 +664,7 @@ internal_select(PySocketSockObject *s, int writing)
n = select(s->sock_fd+1, &fds, NULL, NULL, &tv); n = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
} }
#endif #endif
if (n < 0) if (n < 0)
return -1; return -1;
if (n == 0) if (n == 0)
...@@ -1423,7 +1439,7 @@ getsockaddrlen(PySocketSockObject *s, socklen_t *len_ret) ...@@ -1423,7 +1439,7 @@ getsockaddrlen(PySocketSockObject *s, socklen_t *len_ret)
} }
/* s.accept() method */ /* s._accept() -> (fd, address) */
static PyObject * static PyObject *
sock_accept(PySocketSockObject *s) sock_accept(PySocketSockObject *s)
...@@ -1457,17 +1473,12 @@ sock_accept(PySocketSockObject *s) ...@@ -1457,17 +1473,12 @@ sock_accept(PySocketSockObject *s)
if (newfd == INVALID_SOCKET) if (newfd == INVALID_SOCKET)
return s->errorhandler(); return s->errorhandler();
/* Create the new object with unspecified family, sock = PyLong_FromSocket_t(newfd);
to avoid calls to bind() etc. on it. */
sock = (PyObject *) new_sockobject(newfd,
s->sock_family,
s->sock_type,
s->sock_proto);
if (sock == NULL) { if (sock == NULL) {
SOCKETCLOSE(newfd); SOCKETCLOSE(newfd);
goto finally; goto finally;
} }
addr = makesockaddr(s->sock_fd, SAS2SA(&addrbuf), addr = makesockaddr(s->sock_fd, SAS2SA(&addrbuf),
addrlen, s->sock_proto); addrlen, s->sock_proto);
if (addr == NULL) if (addr == NULL)
...@@ -1482,11 +1493,11 @@ finally: ...@@ -1482,11 +1493,11 @@ finally:
} }
PyDoc_STRVAR(accept_doc, PyDoc_STRVAR(accept_doc,
"accept() -> (socket object, address info)\n\ "_accept() -> (integer, address info)\n\
\n\ \n\
Wait for an incoming connection. Return a new socket representing the\n\ Wait for an incoming connection. Return a new socket file descriptor\n\
connection, and the address of the client. For IP sockets, the address\n\ representing the connection, and the address of the client.\n\
info is a pair (hostaddr, port)."); For IP sockets, the address info is a pair (hostaddr, port).");
/* s.setblocking(flag) method. Argument: /* s.setblocking(flag) method. Argument:
False -- non-blocking mode; same as settimeout(0) False -- non-blocking mode; same as settimeout(0)
...@@ -1882,11 +1893,7 @@ instead of raising an exception when an error occurs."); ...@@ -1882,11 +1893,7 @@ instead of raising an exception when an error occurs.");
static PyObject * static PyObject *
sock_fileno(PySocketSockObject *s) sock_fileno(PySocketSockObject *s)
{ {
#if SIZEOF_SOCKET_T <= SIZEOF_LONG return PyLong_FromSocket_t(s->sock_fd);
return PyInt_FromLong((long) s->sock_fd);
#else
return PyLong_FromLongLong((PY_LONG_LONG)s->sock_fd);
#endif
} }
PyDoc_STRVAR(fileno_doc, PyDoc_STRVAR(fileno_doc,
...@@ -1895,35 +1902,6 @@ PyDoc_STRVAR(fileno_doc, ...@@ -1895,35 +1902,6 @@ PyDoc_STRVAR(fileno_doc,
Return the integer file descriptor of the socket."); Return the integer file descriptor of the socket.");
#ifndef NO_DUP
/* s.dup() method */
static PyObject *
sock_dup(PySocketSockObject *s)
{
SOCKET_T newfd;
PyObject *sock;
newfd = dup(s->sock_fd);
if (newfd < 0)
return s->errorhandler();
sock = (PyObject *) new_sockobject(newfd,
s->sock_family,
s->sock_type,
s->sock_proto);
if (sock == NULL)
SOCKETCLOSE(newfd);
return sock;
}
PyDoc_STRVAR(dup_doc,
"dup() -> socket object\n\
\n\
Return a new socket object connected to the same system resource.");
#endif
/* s.getsockname() method */ /* s.getsockname() method */
static PyObject * static PyObject *
...@@ -2542,7 +2520,7 @@ of the socket (flag == SHUT_WR), or both ends (flag == SHUT_RDWR)."); ...@@ -2542,7 +2520,7 @@ of the socket (flag == SHUT_WR), or both ends (flag == SHUT_RDWR).");
/* List of methods for socket objects */ /* List of methods for socket objects */
static PyMethodDef sock_methods[] = { static PyMethodDef sock_methods[] = {
{"accept", (PyCFunction)sock_accept, METH_NOARGS, {"_accept", (PyCFunction)sock_accept, METH_NOARGS,
accept_doc}, accept_doc},
{"bind", (PyCFunction)sock_bind, METH_O, {"bind", (PyCFunction)sock_bind, METH_O,
bind_doc}, bind_doc},
...@@ -2552,10 +2530,6 @@ static PyMethodDef sock_methods[] = { ...@@ -2552,10 +2530,6 @@ static PyMethodDef sock_methods[] = {
connect_doc}, connect_doc},
{"connect_ex", (PyCFunction)sock_connect_ex, METH_O, {"connect_ex", (PyCFunction)sock_connect_ex, METH_O,
connect_ex_doc}, connect_ex_doc},
#ifndef NO_DUP
{"dup", (PyCFunction)sock_dup, METH_NOARGS,
dup_doc},
#endif
{"fileno", (PyCFunction)sock_fileno, METH_NOARGS, {"fileno", (PyCFunction)sock_fileno, METH_NOARGS,
fileno_doc}, fileno_doc},
#ifdef HAVE_GETPEERNAME #ifdef HAVE_GETPEERNAME
...@@ -2672,8 +2646,8 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds) ...@@ -2672,8 +2646,8 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds)
&family, &type, &proto, &fdobj)) &family, &type, &proto, &fdobj))
return -1; return -1;
if (fdobj != NULL) { if (fdobj != NULL && fdobj != Py_None) {
fd = PyLong_AsLongLong(fdobj); fd = PyLong_AsSocket_t(fdobj);
if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) if (fd == (SOCKET_T)(-1) && PyErr_Occurred())
return -1; return -1;
if (fd == INVALID_SOCKET) { if (fd == INVALID_SOCKET) {
...@@ -3172,6 +3146,38 @@ PyDoc_STRVAR(getprotobyname_doc, ...@@ -3172,6 +3146,38 @@ PyDoc_STRVAR(getprotobyname_doc,
Return the protocol number for the named protocol. (Rarely used.)"); Return the protocol number for the named protocol. (Rarely used.)");
#ifndef NO_DUP
/* dup() function for socket fds */
static PyObject *
socket_dup(PyObject *self, PyObject *fdobj)
{
SOCKET_T fd, newfd;
PyObject *newfdobj;
fd = PyLong_AsSocket_t(fdobj);
if (fd == (SOCKET_T)(-1) && PyErr_Occurred())
return NULL;
newfd = dup_socket(fd);
if (newfd == INVALID_SOCKET)
return set_error();
newfdobj = PyLong_FromSocket_t(newfd);
if (newfdobj == NULL)
SOCKETCLOSE(newfd);
return newfdobj;
}
PyDoc_STRVAR(dup_doc,
"dup(integer) -> integer\n\
\n\
Duplicate an integer socket file descriptor. This is like os.dup(), but for\n\
sockets; on some platforms os.dup() won't work for socket file descriptors.");
#endif
#ifdef HAVE_SOCKETPAIR #ifdef HAVE_SOCKETPAIR
/* Create a pair of sockets using the socketpair() function. /* Create a pair of sockets using the socketpair() function.
Arguments as for socket() except the default family is AF_UNIX if Arguments as for socket() except the default family is AF_UNIX if
...@@ -3811,6 +3817,10 @@ static PyMethodDef socket_methods[] = { ...@@ -3811,6 +3817,10 @@ static PyMethodDef socket_methods[] = {
METH_VARARGS, getservbyport_doc}, METH_VARARGS, getservbyport_doc},
{"getprotobyname", socket_getprotobyname, {"getprotobyname", socket_getprotobyname,
METH_VARARGS, getprotobyname_doc}, METH_VARARGS, getprotobyname_doc},
#ifndef NO_DUP
{"dup", socket_dup,
METH_O, dup_doc},
#endif
#ifdef HAVE_SOCKETPAIR #ifdef HAVE_SOCKETPAIR
{"socketpair", socket_socketpair, {"socketpair", socket_socketpair,
METH_VARARGS, socketpair_doc}, METH_VARARGS, socketpair_doc},
...@@ -4105,7 +4115,7 @@ init_socket(void) ...@@ -4105,7 +4115,7 @@ init_socket(void)
PyModule_AddIntConstant(m, "NETLINK_IP6_FW", NETLINK_IP6_FW); PyModule_AddIntConstant(m, "NETLINK_IP6_FW", NETLINK_IP6_FW);
#ifdef NETLINK_DNRTMSG #ifdef NETLINK_DNRTMSG
PyModule_AddIntConstant(m, "NETLINK_DNRTMSG", NETLINK_DNRTMSG); PyModule_AddIntConstant(m, "NETLINK_DNRTMSG", NETLINK_DNRTMSG);
#endif #endif
#ifdef NETLINK_TAPBASE #ifdef NETLINK_TAPBASE
PyModule_AddIntConstant(m, "NETLINK_TAPBASE", NETLINK_TAPBASE); PyModule_AddIntConstant(m, "NETLINK_TAPBASE", NETLINK_TAPBASE);
#endif #endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment