# Copyright (C) 2009 Barry Pederson # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 from __future__ import absolute_import import errno import re import socket import ssl # Jython does not have this attribute try: from socket import SOL_TCP except ImportError: # pragma: no cover from socket import IPPROTO_TCP as SOL_TCP # noqa try: from ssl import SSLError except ImportError: class SSLError(Exception): # noqa pass from struct import pack, unpack from .exceptions import UnexpectedFrame from .utils import get_errno, set_cloexec _UNAVAIL = errno.EAGAIN, errno.EINTR, errno.ENOENT AMQP_PORT = 5672 EMPTY_BUFFER = bytes() # Yes, Advanced Message Queuing Protocol Protocol is redundant AMQP_PROTOCOL_HEADER = 'AMQP\x01\x01\x00\x09'.encode('latin_1') # Match things like: [fe80::1]:5432, from RFC 2732 IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?') class _AbstractTransport(object): """Common superclass for TCP and SSL transports""" connected = False def __init__(self, host, connect_timeout): self.connected = True msg = None port = AMQP_PORT m = IPV6_LITERAL.match(host) if m: host = m.group(1) if m.group(2): port = int(m.group(2)) else: if ':' in host: host, port = host.rsplit(':', 1) port = int(port) self.sock = None last_err = None for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, SOL_TCP): af, socktype, proto, canonname, sa = res try: self.sock = socket.socket(af, socktype, proto) try: set_cloexec(self.sock, True) except NotImplementedError: pass self.sock.settimeout(connect_timeout) self.sock.connect(sa) except socket.error as exc: msg = exc self.sock.close() self.sock = None last_err = msg continue break if not self.sock: # Didn't connect, return the most recent error message raise socket.error(last_err) try: self.sock.settimeout(None) self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self._setup_transport() self._write(AMQP_PROTOCOL_HEADER) except (OSError, IOError, socket.error) as exc: if get_errno(exc) not in _UNAVAIL: self.connected = False raise def __del__(self): try: # socket module may have been collected by gc # if this is called by a thread at shutdown. if socket is not None: try: self.close() except socket.error: pass finally: self.sock = None def _read(self, n, initial=False): """Read exactly n bytes from the peer""" raise NotImplementedError('Must be overriden in subclass') def _setup_transport(self): """Do any additional initialization of the class (used by the subclasses).""" pass def _shutdown_transport(self): """Do any preliminary work in shutting down the connection.""" pass def _write(self, s): """Completely write a string to the peer.""" raise NotImplementedError('Must be overriden in subclass') def close(self): if self.sock is not None: self._shutdown_transport() # Call shutdown first to make sure that pending messages # reach the AMQP broker if the program exits after # calling this method. self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() self.sock = None self.connected = False def read_frame(self, unpack=unpack): read = self._read try: frame_type, channel, size = unpack('>BHI', read(7, True)) payload = read(size) ch = ord(read(1)) except socket.timeout: raise except (OSError, IOError, socket.error) as exc: # Don't disconnect for ssl read time outs # http://bugs.python.org/issue10272 if isinstance(exc, SSLError) and 'timed out' in str(exc): raise socket.timeout() if get_errno(exc) not in _UNAVAIL: self.connected = False raise if ch == 206: # '\xce' return frame_type, channel, payload else: raise UnexpectedFrame( 'Received 0x{0:02x} while expecting 0xce'.format(ch)) def write_frame(self, frame_type, channel, payload): size = len(payload) try: self._write(pack( '>BHI%dsB' % size, frame_type, channel, size, payload, 0xce, )) except socket.timeout: raise except (OSError, IOError, socket.error) as exc: if get_errno(exc) not in _UNAVAIL: self.connected = False raise class SSLTransport(_AbstractTransport): """Transport that works over SSL""" def __init__(self, host, connect_timeout, ssl): if isinstance(ssl, dict): self.sslopts = ssl self._read_buffer = EMPTY_BUFFER super(SSLTransport, self).__init__(host, connect_timeout) def _setup_transport(self): """Wrap the socket in an SSL object.""" if hasattr(self, 'sslopts'): self.sock = ssl.wrap_socket(self.sock, **self.sslopts) else: self.sock = ssl.wrap_socket(self.sock) self.sock.do_handshake() self._quick_recv = self.sock.read def _shutdown_transport(self): """Unwrap a Python 2.6 SSL socket, so we can call shutdown()""" if self.sock is not None: try: unwrap = self.sock.unwrap except AttributeError: return self.sock = unwrap() def _read(self, n, initial=False, _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): # According to SSL_read(3), it can at most return 16kb of data. # Thus, we use an internal read buffer like TCPTransport._read # to get the exact number of bytes wanted. recv = self._quick_recv rbuf = self._read_buffer try: while len(rbuf) < n: try: s = recv(n - len(rbuf)) # see note above except socket.error as exc: # ssl.sock.read may cause ENOENT if the # operation couldn't be performed (Issue celery#1414). if not initial and exc.errno in _errnos: continue raise if not s: raise IOError('Socket closed') rbuf += s except: self._read_buffer = rbuf raise result, self._read_buffer = rbuf[:n], rbuf[n:] return result def _write(self, s): """Write a string out to the SSL socket fully.""" try: write = self.sock.write except AttributeError: # Works around a bug in python socket library raise IOError('Socket closed') else: while s: n = write(s) if not n: raise IOError('Socket closed') s = s[n:] class TCPTransport(_AbstractTransport): """Transport that deals directly with TCP socket.""" def _setup_transport(self): """Setup to _write() directly to the socket, and do our own buffered reads.""" self._write = self.sock.sendall self._read_buffer = EMPTY_BUFFER self._quick_recv = self.sock.recv def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): """Read exactly n bytes from the socket""" recv = self._quick_recv rbuf = self._read_buffer try: while len(rbuf) < n: try: s = recv(n - len(rbuf)) except socket.error as exc: if not initial and exc.errno in _errnos: continue raise if not s: raise IOError('Socket closed') rbuf += s except: self._read_buffer = rbuf raise result, self._read_buffer = rbuf[:n], rbuf[n:] return result def create_transport(host, connect_timeout, ssl=False): """Given a few parameters from the Connection constructor, select and create a subclass of _AbstractTransport.""" if ssl: return SSLTransport(host, connect_timeout, ssl) else: return TCPTransport(host, connect_timeout)