Index: Lib/socket.py =================================================================== --- Lib/socket.py (revision 3591) +++ Lib/socket.py (working copy) @@ -75,7 +75,6 @@ import java.nio.channels.UnsupportedAddressTypeException import javax.net.ssl.SSLSocketFactory -import org.python.core.PyFile class error(Exception): pass class herror(error): pass @@ -360,7 +359,7 @@ # Same situation as above raise NotImplementedError("getprotobyname not yet supported on jython.") -def socket(family = AF_INET, type = SOCK_STREAM, flags=0): +def _realsocket(family = AF_INET, type = SOCK_STREAM, flags=0): assert family == AF_INET assert type in (SOCK_DGRAM, SOCK_STREAM) assert flags == 0 @@ -409,6 +408,8 @@ timeout = _defaulttimeout mode = MODE_BLOCKING + reference_count = 0 + close_lock = threading.Lock() def gettimeout(self): return self.timeout @@ -464,7 +465,6 @@ ostream = None local_addr = None server = 0 - file_count = 0 reuse_addr = 0 def bind(self, addr): @@ -610,66 +610,6 @@ if optname == SO_REUSEADDR: return self.reuse_addr - def makefile(self, mode="r", bufsize=-1): - file = None - if self.istream: - if self.ostream: - file = org.python.core.PyFile(self.istream, self.ostream, - "", mode) - else: - file = org.python.core.PyFile(self.istream, "", mode) - elif self.ostream: - file = org.python.core.PyFile(self.ostream, "", mode) - else: - raise IOError, "both istream and ostream have been shut down" - if file: - return _tcpsocket.FileWrapper(self, file) - - class FileWrapper: - def __init__(self, socket, file): - self.socket = socket - self.istream = socket.istream - self.ostream = socket.ostream - - self.file = file - self.read = file.read - self.readline = file.readline - self.readlines = file.readlines - self.write = file.write - self.writelines = file.writelines - self.flush = file.flush - self.seek = file.seek - self.tell = file.tell - self.closed = file.closed - - self.socket.file_count += 1 - - def close(self): - if self.closed: - # Already closed - return - - self.socket.file_count -= 1 - # AMAK: 20070715: Cannot close the PyFile, because closing - # it causes the InputStream and OutputStream to be closed. - # This in turn causes the underlying socket to be closed. - # This was always true for java.net sockets - # And continues to be true for java.nio sockets - # http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=4717638 -# self.file.close() - istream = self.istream - ostream = self.ostream - self.istream = None - self.ostream = None -# self.closed = self.file.closed - self.closed = 1 - - if self.socket.file_count == 0 and self.socket.sock_impl is None: - # This is the last file Only close the socket and streams - # if there are no outstanding files left. - istream.close() - ostream.close() - def shutdown(self, how): assert how in (SHUT_RD, SHUT_WR, SHUT_RDWR) assert self.sock_impl @@ -680,23 +620,12 @@ def close(self): try: - if not self.sock_impl: - return - sock_impl = self.sock_impl - istream = self.istream - ostream = self.ostream - self.sock_impl = None - self.istream = None - self.ostream = None - # Only close the socket and streams if there are no - # outstanding files left. - if self.file_count == 0: - if istream: - istream.close() - if ostream: - ostream.close() - if sock_impl: - sock_impl.close() + if self.istream: + self.istream.close() + if self.ostream: + self.ostream.close() + if self.sock_impl: + self.sock_impl.close() except java.lang.Exception, jlx: raise _map_exception(jlx) @@ -814,11 +743,8 @@ def close(self): try: - if not self.sock_impl: - return - sock = self.sock_impl - self.sock_impl = None - sock.close() + if self.sock_impl: + self.sock_impl.close() except java.lang.Exception, jlx: raise _map_exception(jlx) @@ -839,9 +765,323 @@ except java.lang.Exception, jlx: raise _map_exception(jlx) -SocketType = _tcpsocket -SocketTypes = [_tcpsocket, _udpsocket] +_socketmethods = ( + 'bind', 'connect', 'connect_ex', 'fileno', 'listen', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'sendall', 'setblocking', + 'settimeout', 'gettimeout', 'shutdown') +class _closedsocket(object): + __slots__ = [] + def _dummy(*args): + raise error(errno.EBADF, 'Bad file descriptor') + send = recv = sendto = recvfrom = __getattr__ = _dummy + +class _socketobject(object): + + __doc__ = _realsocket.__doc__ + + __slots__ = ["_sock", "send", "recv", "sendto", "recvfrom", + "__weakref__"] + + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): + if _sock is None: + _sock = _realsocket(family, type, proto) + _sock.reference_count += 1 + elif not isinstance(_sock, _closedsocket): + _sock.reference_count += 1 + self._sock = _sock + self.send = self._sock.send + self.recv = self._sock.recv + if hasattr(self._sock, 'sendto'): + self.sendto = self._sock.sendto + self.recvfrom = self._sock.recvfrom + + def close(self): + _sock = self._sock + if not isinstance(_sock, _closedsocket): + _sock.close_lock.acquire() + try: + _sock.reference_count -=1 + if not _sock.reference_count: + _sock.close() + self._sock = _closedsocket() + self.send = self.recv = self.sendto = self.recvfrom = \ + self._sock._dummy + finally: + _sock.close_lock.release() + #close.__doc__ = _realsocket.close.__doc__ + + def accept(self): + sock, addr = self._sock.accept() + return _socketobject(_sock=sock), addr + #accept.__doc__ = _realsocket.accept.__doc__ + + def dup(self): + """dup() -> socket object + + Return a new socket object connected to the same system resource.""" + _sock = self._sock + if isinstance(_sock, _closedsocket): + return _socketobject(_sock=_sock) + + _sock.close_lock.acquire() + try: + duped = _socketobject(_sock=_sock) + finally: + _sock.close_lock.release() + return duped + + def makefile(self, mode='r', bufsize=-1): + """makefile([mode[, bufsize]]) -> file object + + Return a regular file object corresponding to the socket. The mode + and bufsize arguments are as for the built-in open() function.""" + _sock = self._sock + if isinstance(_sock, _closedsocket): + return _fileobject(_sock, mode, bufsize) + + _sock.close_lock.acquire() + try: + fileobject = _fileobject(_sock, mode, bufsize) + finally: + _sock.close_lock.release() + return fileobject + + _s = ("def %s(self, *args): return self._sock.%s(*args)\n\n" + #"%s.__doc__ = _realsocket.%s.__doc__\n") + ) + for _m in _socketmethods: + #exec _s % (_m, _m, _m, _m) + exec _s % (_m, _m) + del _m, _s + +socket = SocketType = _socketobject + +class _fileobject(object): + """Faux file object attached to a socket object.""" + + default_bufsize = 8192 + name = "" + + __slots__ = ["mode", "bufsize", "softspace", + # "closed" is a property, see below + "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", + "_close"] + + def __init__(self, sock, mode='rb', bufsize=-1): + self._sock = sock + if not isinstance(sock, _closedsocket): + sock.reference_count += 1 + self.mode = mode # Not actually used in this version + if bufsize < 0: + bufsize = self.default_bufsize + self.bufsize = bufsize + self.softspace = False + if bufsize == 0: + self._rbufsize = 1 + elif bufsize == 1: + self._rbufsize = self.default_bufsize + else: + self._rbufsize = bufsize + self._wbufsize = bufsize + self._rbuf = "" # A string + self._wbuf = [] # A list of strings + + def _getclosed(self): + return self._sock is None + closed = property(_getclosed, doc="True if the file is closed") + + def close(self): + try: + if self._sock: + self.flush() + finally: + if self._sock and not isinstance(self._sock, _closedsocket): + self._sock.reference_count -= 1 + if not self._sock.reference_count: + self._sock.close() + self._sock = None + + def __del__(self): + try: + self.close() + except: + # close() may fail if __init__ didn't complete + pass + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self._sock.sendall(buffer) + + def fileno(self): + return self._sock.fileno() + + def write(self, data): + data = str(data) # XXX Should really reject non-string non-buffers + if not data: + return + self._wbuf.append(data) + if (self._wbufsize == 0 or + self._wbufsize == 1 and '\n' in data or + self._get_wbuf_len() >= self._wbufsize): + self.flush() + + def writelines(self, list): + # XXX We could do better here for very long lists + # XXX Should really reject non-string non-buffers + self._wbuf.extend(filter(None, map(str, list))) + if (self._wbufsize <= 1 or + self._get_wbuf_len() >= self._wbufsize): + self.flush() + + def _get_wbuf_len(self): + buf_len = 0 + for x in self._wbuf: + buf_len += len(x) + return buf_len + + def read(self, size=-1): + data = self._rbuf + if size < 0: + # Read until EOF + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + while True: + data = self._sock.recv(recv_size) + if not data: + break + buffers.append(data) + return "".join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self._sock.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + recv = self._sock.recv + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self._sock.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self._sock.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readlines(self, sizehint=0): + total = 0 + list = [] + while True: + line = self.readline() + if not line: + break + list.append(line) + total += len(line) + if sizehint and total >= sizehint: + break + return list + + # Iterator protocols + + def __iter__(self): + return self + + def next(self): + line = self.readline() + if not line: + raise StopIteration + return line + + # Define the SSL support class ssl: Index: Lib/test/test_socket.py =================================================================== --- Lib/test/test_socket.py (revision 3592) +++ Lib/test/test_socket.py (working copy) @@ -583,6 +583,18 @@ time.sleep(0.5) self.fail("Sending on remotely closed socket should have raised exception") + def testDup(self): + msg = self.cli_conn.recv(len(MSG)) + self.assertEqual(msg, MSG) + + dup_conn = self.cli_conn.dup() + msg = dup_conn.recv(len('and ' + MSG)) + self.assertEqual(msg, 'and ' + MSG) + + def _testDup(self): + self.serv_conn.send(MSG) + self.serv_conn.send('and ' + MSG) + class BasicUDPTest(ThreadedUDPSocketTest): def __init__(self, methodName='runTest'): @@ -762,7 +774,7 @@ # TODO: Write some non-blocking UDP tests # -class FileObjectClassOpenCloseTests(SocketConnectedTest): +class TCPFileObjectClassOpenCloseTests(SocketConnectedTest): def testCloseFileDoesNotCloseSocket(self): # This test is necessary on java/jython @@ -790,6 +802,71 @@ except Exception, x: self.fail("Closing socket appears to have closed file wrapper: %s" % str(x)) +class UDPFileObjectClassOpenCloseTests(ThreadedUDPSocketTest): + + def testCloseFileDoesNotCloseSocket(self): + # This test is necessary on java/jython + msg = self.serv.recv(1024) + self.assertEqual(msg, MSG) + + def _testCloseFileDoesNotCloseSocket(self): + self.cli_file = self.cli.makefile('wb') + self.cli_file.close() + try: + self.cli.sendto(MSG, 0, (HOST, PORT)) + except Exception, x: + self.fail("Closing file wrapper appears to have closed underlying socket: %s" % str(x)) + + def testCloseSocketDoesNotCloseFile(self): + self.serv_file = self.serv.makefile('rb') + self.serv.close() + msg = self.serv_file.readline() + self.assertEqual(msg, MSG) + + def _testCloseSocketDoesNotCloseFile(self): + try: + self.cli.sendto(MSG, 0, (HOST, PORT)) + except Exception, x: + self.fail("Closing file wrapper appears to have closed underlying socket: %s" % str(x)) + +class FileAndDupOpenCloseTests(SocketConnectedTest): + + def testCloseDoesNotCloseOthers(self): + msg = self.cli_conn.recv(len(MSG)) + self.assertEqual(msg, MSG) + + msg = self.cli_conn.recv(len('and ' + MSG)) + self.assertEqual(msg, 'and ' + MSG) + + def _testCloseDoesNotCloseOthers(self): + self.dup_conn1 = self.serv_conn.dup() + self.dup_conn2 = self.serv_conn.dup() + self.cli_file = self.serv_conn.makefile('wb') + self.serv_conn.close() + self.dup_conn1.close() + + try: + self.serv_conn.send(MSG) + except socket.error, se: + self.failUnlessEqual(se[0], errno.EBADF) + else: + self.fail("Original socket did not close") + try: + self.dup_conn1.send(MSG) + except socket.error, se: + self.failUnlessEqual(se[0], errno.EBADF) + else: + self.fail("Duplicate socket 1 did not close") + + self.dup_conn2.send(MSG) + self.dup_conn2.close() + + try: + self.cli_file.write('and ' + MSG) + except Exception, x: + self.fail("Closing others appears to have closed the socket file: %s" % str(x)) + self.cli_file.close() + class FileObjectClassTestCase(SocketConnectedTest): bufsize = -1 # Use default buffer size @@ -1021,7 +1098,6 @@ else: self.fail("Binding to already bound host/port should have raised exception") - def testUnresolvedAddress(self): try: self.s.connect( ('non.existent.server', PORT) ) @@ -1050,6 +1126,26 @@ else: self.fail("Receive on unconnected socket raised exception") + def testClosedSocket(self): + self.s.close() + try: + self.s.send(MSG) + except socket.error, se: + self.failUnlessEqual(se[0], errno.EBADF) + + dup = self.s.dup() + try: + dup.send(MSG) + except socket.error, se: + self.failUnlessEqual(se[0], errno.EBADF) + + fp = self.s.makefile() + try: + fp.write(MSG) + fp.flush() + except socket.error, se: + self.failUnlessEqual(se[0], errno.EBADF) + class TestAddressParameters: def testBindNonTupleEndpointRaisesTypeError(self): @@ -1099,7 +1195,9 @@ UDPTimeoutTest, NonBlockingTCPTests, NonBlockingUDPTests, - FileObjectClassOpenCloseTests, + TCPFileObjectClassOpenCloseTests, + UDPFileObjectClassOpenCloseTests, + FileAndDupOpenCloseTests, FileObjectClassTestCase, UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase,