This document describes the current stable version of py-amqp (5.0). For development docs, go here.

Source code for amqp.transport

"""Transport implementation."""
# Copyright (C) 2009 Barry Pederson <bp@barryp.org>

import errno
import os
import re
import socket
import ssl
from contextlib import contextmanager
from ssl import SSLError
from struct import pack, unpack

from .exceptions import UnexpectedFrame
from .platform import KNOWN_TCP_OPTS, SOL_TCP
from .utils import set_cloexec

_UNAVAIL = {errno.EAGAIN, errno.EINTR, errno.ENOENT, errno.EWOULDBLOCK}

AMQP_PORT = 5672

EMPTY_BUFFER = bytes()

SIGNED_INT_MAX = 0x7FFFFFFF

# Yes, Advanced Message Queuing Protocol Protocol is redundant
AMQP_PROTOCOL_HEADER = b'AMQP\x00\x00\x09\x01'

# Match things like: [fe80::1]:5432, from RFC 2732
IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?')

DEFAULT_SOCKET_SETTINGS = {
    'TCP_NODELAY': 1,
    'TCP_USER_TIMEOUT': 1000,
    'TCP_KEEPIDLE': 60,
    'TCP_KEEPINTVL': 10,
    'TCP_KEEPCNT': 9,
}


[docs]def to_host_port(host, default=AMQP_PORT): """Convert hostname:port string to host, port tuple.""" port = default m = IPV6_LITERAL.match(host) if m: host = m.group(1) if m.group(2): port = int(m.group(2)) else: if ':' in host: host, port = host.rsplit(':', 1) port = int(port) return host, port
class _AbstractTransport: """Common superclass for TCP and SSL transports.""" def __init__(self, host, connect_timeout=None, read_timeout=None, write_timeout=None, socket_settings=None, raise_on_initial_eintr=True, **kwargs): self.connected = False self.sock = None self.raise_on_initial_eintr = raise_on_initial_eintr self._read_buffer = EMPTY_BUFFER self.host, self.port = to_host_port(host) self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.write_timeout = write_timeout self.socket_settings = socket_settings def connect(self): try: # are we already connected? if self.connected: return self._connect(self.host, self.port, self.connect_timeout) self._init_socket( self.socket_settings, self.read_timeout, self.write_timeout, ) # we've sent the banner; signal connect # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner # has _not_ been sent self.connected = True except (OSError, SSLError): # if not fully connected, close socket, and reraise error if self.sock and not self.connected: self.sock.close() self.sock = None raise @contextmanager def having_timeout(self, timeout): if timeout is None: yield self.sock else: sock = self.sock prev = sock.gettimeout() if prev != timeout: sock.settimeout(timeout) try: yield self.sock except SSLError as exc: if 'timed out' in str(exc): # http://bugs.python.org/issue10272 raise socket.timeout() elif 'The operation did not complete' in str(exc): # Non-blocking SSL sockets can throw SSLError raise socket.timeout() raise except OSError as exc: if exc.errno == errno.EWOULDBLOCK: raise socket.timeout() raise finally: if timeout != prev: sock.settimeout(prev) def _connect(self, host, port, timeout): e = None # Below we are trying to avoid additional DNS requests for AAAA if A # succeeds. This helps a lot in case when a hostname has an IPv4 entry # in /etc/hosts but not IPv6. Without the (arguably somewhat twisted) # logic below, getaddrinfo would attempt to resolve the hostname for # both IP versions, which would make the resolver talk to configured # DNS servers. If those servers are for some reason not available # during resolution attempt (either because of system misconfiguration, # or network connectivity problem), resolution process locks the # _connect call for extended time. addr_types = (socket.AF_INET, socket.AF_INET6) addr_types_num = len(addr_types) for n, family in enumerate(addr_types): # first, resolve the address for a single address family try: entries = socket.getaddrinfo( host, port, family, socket.SOCK_STREAM, SOL_TCP) entries_num = len(entries) except socket.gaierror: # we may have depleted all our options if n + 1 >= addr_types_num: # if getaddrinfo succeeded before for another address # family, reraise the previous socket.error since it's more # relevant to users raise (e if e is not None else socket.error( "failed to resolve broker hostname")) continue # pragma: no cover # now that we have address(es) for the hostname, connect to broker for i, res in enumerate(entries): af, socktype, proto, _, sa = res try: self.sock = socket.socket(af, socktype, proto) try: set_cloexec(self.sock, True) except NotImplementedError: pass self.sock.settimeout(timeout) self.sock.connect(sa) except OSError as ex: e = ex if self.sock is not None: self.sock.close() self.sock = None # we may have depleted all our options if i + 1 >= entries_num and n + 1 >= addr_types_num: raise else: # hurray, we established connection return def _init_socket(self, socket_settings, read_timeout, write_timeout): self.sock.settimeout(None) # set socket back to blocking mode self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self._set_socket_options(socket_settings) # set socket timeouts for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout), (socket.SO_RCVTIMEO, read_timeout)): if interval is not None: sec = int(interval) usec = int((interval - sec) * 1000000) self.sock.setsockopt( socket.SOL_SOCKET, timeout, pack('ll', sec, usec), ) self._setup_transport() self._write(AMQP_PROTOCOL_HEADER) def _get_tcp_socket_defaults(self, sock): tcp_opts = {} for opt in KNOWN_TCP_OPTS: enum = None if opt == 'TCP_USER_TIMEOUT': try: from socket import TCP_USER_TIMEOUT as enum except ImportError: # should be in Python 3.6+ on Linux. enum = 18 elif hasattr(socket, opt): enum = getattr(socket, opt) if enum: if opt in DEFAULT_SOCKET_SETTINGS: tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] elif hasattr(socket, opt): tcp_opts[enum] = sock.getsockopt( SOL_TCP, getattr(socket, opt)) return tcp_opts def _set_socket_options(self, socket_settings): tcp_opts = self._get_tcp_socket_defaults(self.sock) if socket_settings: tcp_opts.update(socket_settings) for opt, val in tcp_opts.items(): self.sock.setsockopt(SOL_TCP, opt, val) def _read(self, n, initial=False): """Read exactly n bytes from the peer.""" raise NotImplementedError('Must be overriden in subclass') def _setup_transport(self): """Do any additional initialization of the class.""" pass def _shutdown_transport(self): """Do any preliminary work in shutting down the connection.""" pass def _write(self, s): """Completely write a string to the peer.""" raise NotImplementedError('Must be overriden in subclass') def close(self): if self.sock is not None: self._shutdown_transport() # Call shutdown first to make sure that pending messages # reach the AMQP broker if the program exits after # calling this method. self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() self.sock = None self.connected = False def read_frame(self, unpack=unpack): read = self._read read_frame_buffer = EMPTY_BUFFER try: frame_header = read(7, True) read_frame_buffer += frame_header frame_type, channel, size = unpack('>BHI', frame_header) # >I is an unsigned int, but the argument to sock.recv is signed, # so we know the size can be at most 2 * SIGNED_INT_MAX if size > SIGNED_INT_MAX: part1 = read(SIGNED_INT_MAX) try: part2 = read(size - SIGNED_INT_MAX) except (socket.timeout, OSError, SSLError): # In case this read times out, we need to make sure to not # lose part1 when we retry the read read_frame_buffer += part1 raise payload = b''.join([part1, part2]) else: payload = read(size) read_frame_buffer += payload ch = ord(read(1)) except socket.timeout: self._read_buffer = read_frame_buffer + self._read_buffer raise except (OSError, SSLError) as exc: if ( isinstance(exc, socket.error) and os.name == 'nt' and exc.errno == errno.EWOULDBLOCK # noqa ): # On windows we can get a read timeout with a winsock error # code instead of a proper socket.timeout() error, see # https://github.com/celery/py-amqp/issues/320 self._read_buffer = read_frame_buffer + self._read_buffer raise socket.timeout() if isinstance(exc, SSLError) and 'timed out' in str(exc): # Don't disconnect for ssl read time outs # http://bugs.python.org/issue10272 self._read_buffer = read_frame_buffer + self._read_buffer raise socket.timeout() if exc.errno not in _UNAVAIL: self.connected = False raise if ch == 206: # '\xce' return frame_type, channel, payload else: raise UnexpectedFrame( f'Received {ch:#04x} while expecting 0xce') def write(self, s): try: self._write(s) except socket.timeout: raise except OSError as exc: if exc.errno not in _UNAVAIL: self.connected = False raise
[docs]class SSLTransport(_AbstractTransport): """Transport that works over SSL.""" def __init__(self, host, connect_timeout=None, ssl=None, **kwargs): self.sslopts = ssl if isinstance(ssl, dict) else {} self._read_buffer = EMPTY_BUFFER super().__init__( host, connect_timeout=connect_timeout, **kwargs) def _setup_transport(self): """Wrap the socket in an SSL object.""" self.sock = self._wrap_socket(self.sock, **self.sslopts) self.sock.do_handshake() self._quick_recv = self.sock.read def _wrap_socket(self, sock, context=None, **sslopts): if context: return self._wrap_context(sock, sslopts, **context) return self._wrap_socket_sni(sock, **sslopts) def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): ctx = ssl.create_default_context(**ctx_options) ctx.check_hostname = check_hostname return ctx.wrap_socket(sock, **sslopts) def _wrap_socket_sni(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=ssl.CERT_NONE, do_handshake_on_connect=False, suppress_ragged_eofs=True, server_hostname=None, ssl_version=ssl.PROTOCOL_TLS): """Socket wrap with SNI headers. stdlib `ssl.SSLContext.wrap_socket` method augmented with support for setting the server_hostname field required for SNI hostname header """ opts = { 'sock': sock, 'server_side': server_side, 'do_handshake_on_connect': do_handshake_on_connect, 'suppress_ragged_eofs': suppress_ragged_eofs, 'server_hostname': server_hostname, } context = ssl.SSLContext(ssl_version) if certfile is not None: context.load_cert_chain(certfile, keyfile) if cert_reqs != ssl.CERT_NONE: context.check_hostname = True # Set SNI headers if supported if (server_hostname is not None) and ( hasattr(ssl, 'HAS_SNI') and ssl.HAS_SNI) and ( hasattr(ssl, 'SSLContext')): context.verify_mode = cert_reqs sock = context.wrap_socket(**opts) return sock def _shutdown_transport(self): """Unwrap a SSL socket, so we can call shutdown().""" if self.sock is not None: self.sock = self.sock.unwrap() def _read(self, n, initial=False, _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): # According to SSL_read(3), it can at most return 16kb of data. # Thus, we use an internal read buffer like TCPTransport._read # to get the exact number of bytes wanted. recv = self._quick_recv rbuf = self._read_buffer try: while len(rbuf) < n: try: s = recv(n - len(rbuf)) # see note above except OSError as exc: # ssl.sock.read may cause ENOENT if the # operation couldn't be performed (Issue celery#1414). if exc.errno in _errnos: if initial and self.raise_on_initial_eintr: raise socket.timeout() continue raise if not s: raise OSError('Server unexpectedly closed connection') rbuf += s except: # noqa self._read_buffer = rbuf raise result, self._read_buffer = rbuf[:n], rbuf[n:] return result def _write(self, s): """Write a string out to the SSL socket fully.""" write = self.sock.write while s: try: n = write(s) except ValueError: # AG: sock._sslobj might become null in the meantime if the # remote connection has hung up. # In python 3.4, a ValueError is raised is self._sslobj is # None. n = 0 if not n: raise OSError('Socket closed') s = s[n:]
[docs]class TCPTransport(_AbstractTransport): """Transport that deals directly with TCP socket.""" def _setup_transport(self): # Setup to _write() directly to the socket, and # do our own buffered reads. self._write = self.sock.sendall self._read_buffer = EMPTY_BUFFER self._quick_recv = self.sock.recv def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): """Read exactly n bytes from the socket.""" recv = self._quick_recv rbuf = self._read_buffer try: while len(rbuf) < n: try: s = recv(n - len(rbuf)) except OSError as exc: if exc.errno in _errnos: if initial and self.raise_on_initial_eintr: raise socket.timeout() continue raise if not s: raise OSError('Server unexpectedly closed connection') rbuf += s except: # noqa self._read_buffer = rbuf raise result, self._read_buffer = rbuf[:n], rbuf[n:] return result
[docs]def Transport(host, connect_timeout=None, ssl=False, **kwargs): """Create transport. Given a few parameters from the Connection constructor, select and create a subclass of _AbstractTransport. """ transport = SSLTransport if ssl else TCPTransport return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)