This document is for py-amqp's development version, which can be significantly different from previous releases. Get the stable docs here: 2.4.
Source code for amqp.transport
"""Transport implementation."""
# Copyright (C) 2009 Barry Pederson <bp@barryp.org>
import errno
import re
import socket
import ssl
from contextlib import contextmanager
from ssl import SSLError
from struct import pack, unpack
from .exceptions import UnexpectedFrame
from .platform import KNOWN_TCP_OPTS, SOL_TCP
from .utils import get_errno, set_cloexec
_UNAVAIL = {errno.EAGAIN, errno.EINTR, errno.ENOENT, errno.EWOULDBLOCK}
AMQP_PORT = 5672
EMPTY_BUFFER = bytes()
SIGNED_INT_MAX = 0x7FFFFFFF
# Yes, Advanced Message Queuing Protocol Protocol is redundant
AMQP_PROTOCOL_HEADER = 'AMQP\x00\x00\x09\x01'.encode('latin_1')
# Match things like: [fe80::1]:5432, from RFC 2732
IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?')
DEFAULT_SOCKET_SETTINGS = {
'TCP_NODELAY': 1,
'TCP_USER_TIMEOUT': 1000,
'TCP_KEEPIDLE': 60,
'TCP_KEEPINTVL': 10,
'TCP_KEEPCNT': 9,
}
[docs]def to_host_port(host, default=AMQP_PORT):
"""Convert hostname:port string to host, port tuple."""
port = default
m = IPV6_LITERAL.match(host)
if m:
host = m.group(1)
if m.group(2):
port = int(m.group(2))
else:
if ':' in host:
host, port = host.rsplit(':', 1)
port = int(port)
return host, port
class _AbstractTransport(object):
"""Common superclass for TCP and SSL transports."""
def __init__(self, host, connect_timeout=None,
read_timeout=None, write_timeout=None,
socket_settings=None, raise_on_initial_eintr=True, **kwargs):
self.connected = False
self.sock = None
self.raise_on_initial_eintr = raise_on_initial_eintr
self._read_buffer = EMPTY_BUFFER
self.host, self.port = to_host_port(host)
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.write_timeout = write_timeout
self.socket_settings = socket_settings
def connect(self):
try:
# are we already connected?
if self.connected:
return
self._connect(self.host, self.port, self.connect_timeout)
self._init_socket(
self.socket_settings, self.read_timeout, self.write_timeout,
)
# we've sent the banner; signal connect
# EINTR, EAGAIN, EWOULDBLOCK would signal that the banner
# has _not_ been sent
self.connected = True
except (OSError, IOError, SSLError):
# if not fully connected, close socket, and reraise error
if self.sock and not self.connected:
self.sock.close()
self.sock = None
raise
@contextmanager
def having_timeout(self, timeout):
if timeout is None:
yield self.sock
else:
sock = self.sock
prev = sock.gettimeout()
if prev != timeout:
sock.settimeout(timeout)
try:
yield self.sock
except SSLError as exc:
if 'timed out' in str(exc):
# http://bugs.python.org/issue10272
raise socket.timeout()
elif 'The operation did not complete' in str(exc):
# Non-blocking SSL sockets can throw SSLError
raise socket.timeout()
raise
except socket.error as exc:
if get_errno(exc) == errno.EWOULDBLOCK:
raise socket.timeout()
raise
finally:
if timeout != prev:
sock.settimeout(prev)
def _connect(self, host, port, timeout):
e = None
# Below we are trying to avoid additional DNS requests for AAAA if A
# succeeds. This helps a lot in case when a hostname has an IPv4 entry
# in /etc/hosts but not IPv6. Without the (arguably somewhat twisted)
# logic below, getaddrinfo would attempt to resolve the hostname for
# both IP versions, which would make the resolver talk to configured
# DNS servers. If those servers are for some reason not available
# during resolution attempt (either because of system misconfiguration,
# or network connectivity problem), resolution process locks the
# _connect call for extended time.
addr_types = (socket.AF_INET, socket.AF_INET6)
addr_types_num = len(addr_types)
for n, family in enumerate(addr_types):
# first, resolve the address for a single address family
try:
entries = socket.getaddrinfo(
host, port, family, socket.SOCK_STREAM, SOL_TCP)
entries_num = len(entries)
except socket.gaierror:
# we may have depleted all our options
if n + 1 >= addr_types_num:
# if getaddrinfo succeeded before for another address
# family, reraise the previous socket.error since it's more
# relevant to users
raise (e
if e is not None
else socket.error(
"failed to resolve broker hostname"))
continue # pragma: no cover
# now that we have address(es) for the hostname, connect to broker
for i, res in enumerate(entries):
af, socktype, proto, _, sa = res
try:
self.sock = socket.socket(af, socktype, proto)
try:
set_cloexec(self.sock, True)
except NotImplementedError:
pass
self.sock.settimeout(timeout)
self.sock.connect(sa)
except socket.error as ex:
e = ex
if self.sock is not None:
self.sock.close()
self.sock = None
# we may have depleted all our options
if i + 1 >= entries_num and n + 1 >= addr_types_num:
raise
else:
# hurray, we established connection
return
def _init_socket(self, socket_settings, read_timeout, write_timeout):
self.sock.settimeout(None) # set socket back to blocking mode
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._set_socket_options(socket_settings)
# set socket timeouts
for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout),
(socket.SO_RCVTIMEO, read_timeout)):
if interval is not None:
sec = int(interval)
usec = int((interval - sec) * 1000000)
self.sock.setsockopt(
socket.SOL_SOCKET, timeout,
pack('ll', sec, usec),
)
self._setup_transport()
self._write(AMQP_PROTOCOL_HEADER)
def _get_tcp_socket_defaults(self, sock):
tcp_opts = {}
for opt in KNOWN_TCP_OPTS:
enum = None
if opt == 'TCP_USER_TIMEOUT':
try:
from socket import TCP_USER_TIMEOUT as enum
except ImportError:
# should be in Python 3.6+ on Linux.
enum = 18
elif hasattr(socket, opt):
enum = getattr(socket, opt)
if enum:
if opt in DEFAULT_SOCKET_SETTINGS:
tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt]
elif hasattr(socket, opt):
tcp_opts[enum] = sock.getsockopt(
SOL_TCP, getattr(socket, opt))
return tcp_opts
def _set_socket_options(self, socket_settings):
tcp_opts = self._get_tcp_socket_defaults(self.sock)
if socket_settings:
tcp_opts.update(socket_settings)
for opt, val in tcp_opts.items():
self.sock.setsockopt(SOL_TCP, opt, val)
def _read(self, n, initial=False):
"""Read exactly n bytes from the peer."""
raise NotImplementedError('Must be overriden in subclass')
def _setup_transport(self):
"""Do any additional initialization of the class."""
pass
def _shutdown_transport(self):
"""Do any preliminary work in shutting down the connection."""
pass
def _write(self, s):
"""Completely write a string to the peer."""
raise NotImplementedError('Must be overriden in subclass')
def close(self):
if self.sock is not None:
self._shutdown_transport()
# Call shutdown first to make sure that pending messages
# reach the AMQP broker if the program exits after
# calling this method.
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
self.sock = None
self.connected = False
def read_frame(self, unpack=unpack):
read = self._read
read_frame_buffer = EMPTY_BUFFER
try:
frame_header = read(7, True)
read_frame_buffer += frame_header
frame_type, channel, size = unpack('>BHI', frame_header)
# >I is an unsigned int, but the argument to sock.recv is signed,
# so we know the size can be at most 2 * SIGNED_INT_MAX
if size > SIGNED_INT_MAX:
part1 = read(SIGNED_INT_MAX)
part2 = read(size - SIGNED_INT_MAX)
payload = b''.join([part1, part2])
else:
payload = read(size)
read_frame_buffer += payload
ch = ord(read(1))
except socket.timeout:
self._read_buffer = read_frame_buffer + self._read_buffer
raise
except (OSError, IOError, SSLError, socket.error) as exc:
# Don't disconnect for ssl read time outs
# http://bugs.python.org/issue10272
if isinstance(exc, SSLError) and 'timed out' in str(exc):
raise socket.timeout()
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
if ch == 206: # '\xce'
return frame_type, channel, payload
else:
raise UnexpectedFrame(
'Received {0:#04x} while expecting 0xce'.format(ch))
def write(self, s):
try:
self._write(s)
except socket.timeout:
raise
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
[docs]class SSLTransport(_AbstractTransport):
"""Transport that works over SSL."""
def __init__(self, host, connect_timeout=None, ssl=None, **kwargs):
self.sslopts = ssl if isinstance(ssl, dict) else {}
self._read_buffer = EMPTY_BUFFER
super(SSLTransport, self).__init__(
host, connect_timeout=connect_timeout, **kwargs)
def _setup_transport(self):
"""Wrap the socket in an SSL object."""
self.sock = self._wrap_socket(self.sock, **self.sslopts)
self.sock.do_handshake()
self._quick_recv = self.sock.read
def _wrap_socket(self, sock, context=None, **sslopts):
if context:
return self._wrap_context(sock, sslopts, **context)
return self._wrap_socket_sni(sock, **sslopts)
def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options):
ctx = ssl.create_default_context(**ctx_options)
ctx.check_hostname = check_hostname
return ctx.wrap_socket(sock, **sslopts)
def _wrap_socket_sni(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=ssl.CERT_NONE,
ca_certs=None, do_handshake_on_connect=True,
suppress_ragged_eofs=True, server_hostname=None,
ciphers=None, ssl_version=None):
"""Socket wrap with SNI headers.
Default `ssl.wrap_socket` method augmented with support for
setting the server_hostname field required for SNI hostname header
"""
# Setup the right SSL version; default to optimal versions across
# ssl implementations
if ssl_version is None:
# older versions of python 2.7 and python 2.6 do not have the
# ssl.PROTOCOL_TLS defined the equivalent is ssl.PROTOCOL_SSLv23
# we default to PROTOCOL_TLS and fallback to PROTOCOL_SSLv23
# TODO: Drop this once we drop Python 2.7 support
if hasattr(ssl, 'PROTOCOL_TLS'):
ssl_version = ssl.PROTOCOL_TLS
else:
ssl_version = ssl.PROTOCOL_SSLv23
opts = {
'sock': sock,
'keyfile': keyfile,
'certfile': certfile,
'server_side': server_side,
'cert_reqs': cert_reqs,
'ca_certs': ca_certs,
'do_handshake_on_connect': do_handshake_on_connect,
'suppress_ragged_eofs': suppress_ragged_eofs,
'ciphers': ciphers,
'ssl_version': ssl_version
}
sock = ssl.wrap_socket(**opts)
# Set SNI headers if supported
if (server_hostname is not None) and (
hasattr(ssl, 'HAS_SNI') and ssl.HAS_SNI) and (
hasattr(ssl, 'SSLContext')):
context = ssl.SSLContext(opts['ssl_version'])
context.verify_mode = cert_reqs
context.check_hostname = True
context.load_cert_chain(certfile, keyfile)
sock = context.wrap_socket(sock, server_hostname=server_hostname)
return sock
def _shutdown_transport(self):
"""Unwrap a SSL socket, so we can call shutdown()."""
if self.sock is not None:
self.sock = self.sock.unwrap()
def _read(self, n, initial=False,
_errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)):
# According to SSL_read(3), it can at most return 16kb of data.
# Thus, we use an internal read buffer like TCPTransport._read
# to get the exact number of bytes wanted.
recv = self._quick_recv
rbuf = self._read_buffer
try:
while len(rbuf) < n:
try:
s = recv(n - len(rbuf)) # see note above
except socket.error as exc:
# ssl.sock.read may cause a SSLerror without errno
# http://bugs.python.org/issue10272
if isinstance(exc, SSLError) and 'timed out' in str(exc):
raise socket.timeout()
# ssl.sock.read may cause ENOENT if the
# operation couldn't be performed (Issue celery#1414).
if exc.errno in _errnos:
if initial and self.raise_on_initial_eintr:
raise socket.timeout()
continue
raise
if not s:
raise IOError('Server unexpectedly closed connection')
rbuf += s
except: # noqa
self._read_buffer = rbuf
raise
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result
def _write(self, s):
"""Write a string out to the SSL socket fully."""
write = self.sock.write
while s:
try:
n = write(s)
except ValueError:
# AG: sock._sslobj might become null in the meantime if the
# remote connection has hung up.
# In python 3.4, a ValueError is raised is self._sslobj is
# None.
n = 0
if not n:
raise IOError('Socket closed')
s = s[n:]
[docs]class TCPTransport(_AbstractTransport):
"""Transport that deals directly with TCP socket."""
def _setup_transport(self):
# Setup to _write() directly to the socket, and
# do our own buffered reads.
self._write = self.sock.sendall
self._read_buffer = EMPTY_BUFFER
self._quick_recv = self.sock.recv
def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
"""Read exactly n bytes from the socket."""
recv = self._quick_recv
rbuf = self._read_buffer
try:
while len(rbuf) < n:
try:
s = recv(n - len(rbuf))
except socket.error as exc:
if exc.errno in _errnos:
if initial and self.raise_on_initial_eintr:
raise socket.timeout()
continue
raise
if not s:
raise IOError('Server unexpectedly closed connection')
rbuf += s
except: # noqa
self._read_buffer = rbuf
raise
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result
[docs]def Transport(host, connect_timeout=None, ssl=False, **kwargs):
"""Create transport.
Given a few parameters from the Connection constructor,
select and create a subclass of _AbstractTransport.
"""
transport = SSLTransport if ssl else TCPTransport
return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)