%PDF- %PDF-
| Direktori : /lib/python3/dist-packages/mitmproxy/net/ |
| Current File : //lib/python3/dist-packages/mitmproxy/net/tcp.py |
import os
import errno
import select
import socket
import sys
import threading
import time
import traceback
from typing import Optional # noqa
from mitmproxy.net import tls
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy import exceptions
from mitmproxy.coretypes import basethread
socket_fileobject = socket.SocketIO
# workaround for https://bugs.python.org/issue29515
# Python 3.8 for Windows is missing a constant, fixed in 3.9
IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41)
class _FileLike:
BLOCKSIZE = 1024 * 32
def __init__(self, o):
self.o = o
self._log = None
self.first_byte_timestamp = None
def set_descriptor(self, o):
self.o = o
def __getattr__(self, attr):
return getattr(self.o, attr)
def start_log(self):
"""
Starts or resets the log.
This will store all bytes read or written.
"""
self._log = []
def stop_log(self):
"""
Stops the log.
"""
self._log = None
def is_logging(self):
return self._log is not None
def get_log(self):
"""
Returns the log as a string.
"""
if not self.is_logging():
raise ValueError("Not logging!")
return b"".join(self._log)
def add_log(self, v):
if self.is_logging():
self._log.append(v)
def reset_timestamps(self):
self.first_byte_timestamp = None
class Writer(_FileLike):
def flush(self):
"""
May raise exceptions.TcpDisconnect
"""
if hasattr(self.o, "flush"):
try:
self.o.flush()
except OSError as v:
raise exceptions.TcpDisconnect(str(v))
def write(self, v):
"""
May raise exceptions.TcpDisconnect
"""
if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
try:
if hasattr(self.o, "sendall"):
self.add_log(v)
return self.o.sendall(v)
else:
r = self.o.write(v)
self.add_log(v[:r])
return r
except (SSL.Error, OSError) as e:
raise exceptions.TcpDisconnect(str(e))
class Reader(_FileLike):
def read(self, length):
"""
If length is -1, we read until connection closes.
"""
result = b''
start = time.time()
while length == -1 or length > 0:
if length == -1 or length > self.BLOCKSIZE:
rlen = self.BLOCKSIZE
else:
rlen = length
try:
data = self.o.read(rlen)
except SSL.ZeroReturnError:
# TLS connection was shut down cleanly
break
except (SSL.WantWriteError, SSL.WantReadError):
# From the OpenSSL docs:
# If the underlying BIO is non-blocking, SSL_read() will also return when the
# underlying BIO could not satisfy the needs of SSL_read() to continue the
# operation. In this case a call to SSL_get_error with the return value of
# SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
# 300 is OpenSSL default timeout
timeout = self.o.gettimeout() or 300
if (time.time() - start) < timeout:
time.sleep(0.1)
continue
else:
raise exceptions.TcpTimeout()
except socket.timeout:
raise exceptions.TcpTimeout()
except OSError as e:
raise exceptions.TcpDisconnect(str(e))
except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'):
break
raise exceptions.TlsException(str(e))
except SSL.Error as e:
raise exceptions.TlsException(str(e))
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data:
break
result += data
if length != -1:
length -= len(data)
self.add_log(result)
return result
def readline(self, size=None):
result = b''
bytes_read = 0
while True:
if size is not None and bytes_read >= size:
break
ch = self.read(1)
bytes_read += 1
if not ch:
break
else:
result += ch
if ch == b'\n':
break
return result
def safe_read(self, length):
"""
Like .read, but is guaranteed to either return length bytes, or
raise an exception.
"""
result = self.read(length)
if length != -1 and len(result) != length:
if not result:
raise exceptions.TcpDisconnect()
else:
raise exceptions.TcpReadIncomplete(
"Expected {} bytes, got {}".format(length, len(result))
)
return result
def peek(self, length):
"""
Tries to peek into the underlying file object.
Returns:
Up to the next N bytes if peeking is successful.
Raises:
exceptions.TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
"""
if isinstance(self.o, socket_fileobject):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except OSError as e:
raise exceptions.TcpException(repr(e))
elif isinstance(self.o, SSL.Connection):
try:
return self.o.recv(length, socket.MSG_PEEK)
except SSL.Error as e:
raise exceptions.TlsException(str(e))
else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
def ssl_read_select(rlist, timeout):
"""
This is a wrapper around select.select() which also works for SSL.Connections
by taking ssl_connection.pending() into account.
Caveats:
If .pending() > 0 for any of the connections in rlist, we avoid the select syscall
and **will not include any other connections which may or may not be ready**.
Args:
rlist: wait until ready for reading
Returns:
subset of rlist which is ready for reading.
"""
return [
conn for conn in rlist
if isinstance(conn, SSL.Connection) and conn.pending() > 0
] or select.select(rlist, (), (), timeout)[0]
def close_socket(sock):
"""
Does a hard close of a socket, without emitting a RST.
"""
try:
# We already indicate that we close our end.
# may raise "Transport endpoint is not connected" on Linux
sock.shutdown(socket.SHUT_WR)
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
# readable data could lead to an immediate RST being sent (which is the
# case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
# This in turn results in the following issue: If we send an error page
# to the client and then close the socket, the RST may be received by
# the client before the error page and the users sees a connection
# error rather than the error page. Thus, we try to empty the read
# buffer on Windows first. (see
# https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
#
if os.name == "nt": # pragma: no cover
# We cannot rely on the shutdown()-followed-by-read()-eof technique
# proposed by the page above: Some remote machines just don't send
# a TCP FIN, which would leave us in the unfortunate situation that
# recv() would block infinitely. As a workaround, we set a timeout
# here even if we are in blocking mode.
sock.settimeout(sock.gettimeout() or 20)
# limit at a megabyte so that we don't read infinitely
for _ in range(1024 ** 3 // 4096):
# may raise a timeout/disconnect exception.
if not sock.recv(4096):
break
# Now we can close the other half as well.
sock.shutdown(socket.SHUT_RD)
except OSError:
pass
sock.close()
class _Connection:
rbufsize = -1
wbufsize = -1
def _makefile(self):
"""
Set up .rfile and .wfile attributes from .connection
"""
# Ideally, we would use the Buffered IO in Python 3 by default.
# Unfortunately, the implementation of .peek() is broken for n>1 bytes,
# as it may just return what's left in the buffer and not all the bytes we want.
# As a workaround, we just use unbuffered sockets directly.
# https://mail.python.org/pipermail/python-dev/2009-June/089986.html
self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
def __init__(self, connection):
if connection:
self.connection = connection
self.ip_address = connection.getpeername()
self._makefile()
else:
self.connection = None
self.ip_address = None
self.rfile = None
self.wfile = None
self.tls_established = False
self.finished = False
def get_current_cipher(self):
if not self.tls_established:
return None
name = self.connection.get_cipher_name()
bits = self.connection.get_cipher_bits()
version = self.connection.get_cipher_version()
return name, bits, version
def finish(self):
self.finished = True
# If we have an SSL connection, wfile.close == connection.close
# (We call _FileLike.set_descriptor(conn))
# Closing the socket is not our task, therefore we don't call close
# then.
if not isinstance(self.connection, SSL.Connection):
if not getattr(self.wfile, "closed", False):
try:
self.wfile.flush()
self.wfile.close()
except exceptions.TcpDisconnect:
pass
self.rfile.close()
else:
try:
self.connection.shutdown()
except SSL.Error:
pass
class ConnectionCloser:
def __init__(self, conn):
self.conn = conn
self._canceled = False
def pop(self):
"""
Cancel the current closer, and return a fresh one.
"""
self._canceled = True
return ConnectionCloser(self.conn)
def __enter__(self):
return self
def __exit__(self, *args):
if not self._canceled:
self.conn.close()
class TCPClient(_Connection):
def __init__(self, address, source_address=None, spoof_source_address=None):
super().__init__(None)
self.address = address
self.source_address = source_address
self.cert = None
self.server_certs = []
self.sni = None
self.spoof_source_address = spoof_source_address
@property
def ssl_verification_error(self) -> Optional[exceptions.InvalidCertificateException]:
return getattr(self.connection, "cert_error", None)
def close(self):
# Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate...
if self.connection:
if isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket)
else:
close_socket(self.connection)
def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs):
context = tls.create_client_context(
alpn_protos=alpn_protos,
sni=sni,
**sslctx_kwargs
)
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
self.connection.set_tlsext_host_name(sni.encode("idna"))
self.connection.set_connect_state()
try:
self.connection.do_handshake()
except SSL.Error as v:
if self.ssl_verification_error:
raise self.ssl_verification_error
else:
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.cert = certs.Cert(self.connection.get_peer_certificate())
# Keep all server certificates in a list
for i in self.connection.get_peer_cert_chain():
self.server_certs.append(certs.Cert(i))
self.tls_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
def makesocket(self, family, type, proto):
# some parties (cuckoo sandbox) need to hook this
return socket.socket(family, type, proto)
def create_connection(self, timeout=None):
# Based on the official socket.create_connection implementation of Python 3.6.
# https://github.com/python/cpython/blob/3cc5817cfaf5663645f4ee447eaed603d2ad290a/Lib/socket.py
err = None
for res in socket.getaddrinfo(self.address[0], self.address[1], 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = self.makesocket(af, socktype, proto)
if timeout:
sock.settimeout(timeout)
if self.source_address:
sock.bind(self.source_address)
if self.spoof_source_address:
try:
if not sock.getsockopt(socket.SOL_IP, socket.IP_TRANSPARENT):
sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover
except Exception as e:
# socket.IP_TRANSPARENT might not be available on every OS and Python version
if sock is not None:
sock.close()
raise exceptions.TcpException(
"Failed to spoof the source address: " + str(e)
)
sock.connect(sa)
return sock
except OSError as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise OSError("getaddrinfo returns an empty list") # pragma: no cover
def connect(self):
try:
connection = self.create_connection()
except OSError as err:
raise exceptions.TcpException(
'Error connecting to "%s": %s' %
(self.address[0], err)
)
self.connection = connection
self.source_address = connection.getsockname()
self.ip_address = connection.getpeername()
self._makefile()
return ConnectionCloser(self)
def settimeout(self, n):
self.connection.settimeout(n)
def gettimeout(self):
return self.connection.gettimeout()
def get_alpn_proto_negotiated(self):
if self.tls_established:
return self.connection.get_alpn_proto_negotiated()
else:
return b""
class BaseHandler(_Connection):
"""
The instantiator is expected to call the handle() and finish() methods.
"""
def __init__(self, connection, address, server):
super().__init__(connection)
self.address = address
self.server = server
self.clientcert = None
def convert_to_tls(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see tls.create_server_context(...)
"""
context = tls.create_server_context(
cert=cert,
key=key,
**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
try:
self.connection.do_handshake()
except SSL.Error as v:
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.tls_established = True
cert = self.connection.get_peer_certificate()
if cert:
self.clientcert = certs.Cert(cert)
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
def handle(self): # pragma: no cover
raise NotImplementedError
def settimeout(self, n):
self.connection.settimeout(n)
def get_alpn_proto_negotiated(self):
if self.tls_established:
return self.connection.get_alpn_proto_negotiated()
else:
return b""
class Counter:
def __init__(self):
self._count = 0
self._lock = threading.Lock()
@property
def count(self):
with self._lock:
return self._count
def __enter__(self):
with self._lock:
self._count += 1
def __exit__(self, *args):
with self._lock:
self._count -= 1
class TCPServer:
def __init__(self, address):
self.address = address
self.__is_shut_down = threading.Event()
self.__is_shut_down.set()
self.__shutdown_request = False
if self.address[0] == 'localhost':
raise OSError("Binding to 'localhost' is prohibited. Please use '::1' or '127.0.0.1' directly.")
self.socket = None
try:
# First try to bind an IPv6 socket, attempting to enable IPv4 support if the OS supports it.
# This allows us to accept connections for ::1 and 127.0.0.1 on the same socket.
# Only works if self.address == ""
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
self.socket.bind(self.address)
except OSError:
if self.socket:
self.socket.close()
self.socket = None
if not self.socket:
try:
# Binding to an IPv6 + IPv4 socket failed, lets fall back to IPv4 only.
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.bind(self.address)
except OSError:
if self.socket:
self.socket.close()
self.socket = None
if not self.socket:
# Binding to an IPv4 only socket failed, lets fall back to IPv6 only.
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.bind(self.address)
self.address = self.socket.getsockname()
self.socket.listen()
self.handler_counter = Counter()
def connection_thread(self, connection, client_address):
with self.handler_counter:
try:
self.handle_client_connection(connection, client_address)
except OSError as e: # pragma: no cover
# This catches situations where the underlying connection is
# closed beneath us. Syscalls on the connection object at this
# point returns EINVAL. If this happens, we close the socket and
# move on.
if not e.errno == errno.EINVAL:
raise
except:
self.handle_error(connection, client_address)
finally:
close_socket(connection)
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
try:
while not self.__shutdown_request:
r, w_, e_ = select.select([self.socket], [], [], poll_interval)
if self.socket in r:
connection, client_address = self.socket.accept()
t = basethread.BaseThread(
"TCPConnectionHandler ({}: {}:{} -> {}:{})".format(
self.__class__.__name__,
client_address[0],
client_address[1],
self.address[0],
self.address[1],
),
target=self.connection_thread,
args=(connection, client_address),
)
t.setDaemon(1)
try:
t.start()
except threading.ThreadError:
self.handle_error(connection, client_address)
connection.close()
finally:
self.__shutdown_request = False
self.__is_shut_down.set()
def shutdown(self):
self.__shutdown_request = True
self.__is_shut_down.wait()
self.socket.close()
self.handle_shutdown()
def handle_error(self, connection_, client_address, fp=sys.stderr):
"""
Called when handle_client_connection raises an exception.
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
if traceback:
exc = str(traceback.format_exc())
print('-' * 40, file=fp)
print(
"Error in processing of request from %s" % repr(client_address), file=fp)
print(exc, file=fp)
print('-' * 40, file=fp)
def handle_client_connection(self, conn, client_address): # pragma: no cover
"""
Called after client connection.
"""
raise NotImplementedError
def handle_shutdown(self):
"""
Called after server shutdown.
"""
def wait_for_silence(self, timeout=5):
start = time.time()
while 1:
if time.time() - start >= timeout:
raise exceptions.Timeout(
"%s service threads still alive" %
self.handler_counter.count
)
if self.handler_counter.count == 0:
return