%PDF- %PDF-
Direktori : /lib/calibre/calibre/srv/ |
Current File : //lib/calibre/calibre/srv/loop.py |
#!/usr/bin/env python3 __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' import ipaddress import os import select import socket import ssl import traceback from contextlib import suppress from functools import partial, lru_cache from io import BytesIO from calibre import as_unicode from calibre.constants import iswindows from calibre.ptempfile import TemporaryDirectory from calibre.srv.errors import JobQueueFull from calibre.srv.jobs import JobsManager from calibre.srv.opts import Options from calibre.srv.pool import PluginPool, ThreadPool from calibre.srv.utils import ( DESIRED_SEND_BUFFER_SIZE, HandleInterrupt, create_sock_pair, socket_errors_eintr, socket_errors_nonblocking, socket_errors_socket_closed, start_cork, stop_cork ) from calibre.utils.logging import ThreadSafeLog from calibre.utils.mdns import get_external_ip from calibre.utils.monotonic import monotonic from calibre.utils.socket_inheritance import set_socket_inherit from polyglot.builtins import iteritems from polyglot.queue import Empty, Full READ, WRITE, RDWR, WAIT = 'READ', 'WRITE', 'RDWR', 'WAIT' WAKEUP, JOB_DONE = b'\0', b'\x01' IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) class ReadBuffer: # {{{ ' A ring buffer used to speed up the readline() implementation by minimizing recv() calls ' __slots__ = ('ba', 'buf', 'read_pos', 'write_pos', 'full_state') def __init__(self, size=4096): self.ba = bytearray(size) self.buf = memoryview(self.ba) self.read_pos = 0 self.write_pos = 0 self.full_state = WRITE @property def has_data(self): return self.read_pos != self.write_pos or self.full_state is READ @property def has_space(self): return self.read_pos != self.write_pos or self.full_state is WRITE def read(self, size): # Read from this buffer, retuning the read bytes as a bytestring if self.read_pos == self.write_pos and self.full_state is WRITE: return b'' if self.read_pos < self.write_pos: sz = min(self.write_pos - self.read_pos, size) npos = self.read_pos + sz ans = self.buf[self.read_pos:npos].tobytes() self.read_pos = npos if self.read_pos == self.write_pos: self.full_state = WRITE else: sz = min(size, len(self.buf) - self.read_pos) ans = self.buf[self.read_pos:self.read_pos + sz].tobytes() self.read_pos = (self.read_pos + sz) % len(self.buf) if self.read_pos == self.write_pos: self.full_state = WRITE if size > sz and self.read_pos < self.write_pos: ans += self.read(size - len(ans)) return ans def recv_from(self, socket): # Write into this buffer from socket, return number of bytes written if self.read_pos == self.write_pos and self.full_state is READ: return 0 if self.write_pos < self.read_pos: num = socket.recv_into(self.buf[self.write_pos:self.read_pos]) self.write_pos += num else: num = socket.recv_into(self.buf[self.write_pos:]) self.write_pos = (self.write_pos + num) % len(self.buf) if self.write_pos == self.read_pos: self.full_state = READ return num def readline(self): # Return whatever is in the buffer up to (and including) the first \n # If no \n is present, returns everything if self.read_pos == self.write_pos and self.full_state is WRITE: return b'' if self.read_pos < self.write_pos: pos = self.ba.find(b'\n', self.read_pos, self.write_pos) if pos < 0: pos = self.write_pos - 1 ans = self.buf[self.read_pos:pos + 1].tobytes() self.read_pos = (pos + 1) % len(self.buf) if self.read_pos == self.write_pos: self.full_state = WRITE else: pos = self.ba.find(b'\n', self.read_pos) if pos < 0: pos = self.ba.find(b'\n', 0, self.write_pos) if pos < 0: pos = self.write_pos - 1 ans = self.buf[self.read_pos:].tobytes() + self.buf[:pos+1].tobytes() self.read_pos = (pos + 1) % len(self.buf) if self.read_pos == self.write_pos: self.full_state = WRITE else: ans = self.buf[self.read_pos:pos + 1].tobytes() self.read_pos = (pos + 1) % len(self.buf) if self.read_pos == self.write_pos: self.full_state = WRITE return ans # }}} class BadIPSpec(ValueError): pass def parse_trusted_ips(spec): for part in as_unicode(spec).split(','): part = part.strip() try: if '/' in part: yield ipaddress.ip_network(part) else: yield ipaddress.ip_address(part) except Exception as e: raise BadIPSpec(_('{0} is not a valid IP address/network, with error: {1}').format(part, e)) def is_ip_trusted(remote_addr, trusted_ips): for tip in trusted_ips: if hasattr(tip, 'hosts'): if remote_addr in tip: return True else: if tip == remote_addr: return True return False class Connection: # {{{ def __init__(self, socket, opts, ssl_context, tdir, addr, pool, log, access_log, wakeup): self.opts, self.pool, self.log, self.wakeup, self.access_log = opts, pool, log, wakeup, access_log try: self.remote_addr = addr[0] self.remote_port = addr[1] self.parsed_remote_addr = ipaddress.ip_address(as_unicode(self.remote_addr)) except Exception: # In case addr is None, which can occasionally happen self.remote_addr = self.remote_port = self.parsed_remote_addr = None self.is_trusted_ip = bool(self.opts.local_write and getattr(self.parsed_remote_addr, 'is_loopback', False)) if not self.is_trusted_ip and self.opts.trusted_ips and self.parsed_remote_addr is not None: self.is_trusted_ip = is_ip_trusted(self.parsed_remote_addr, parsed_trusted_ips(self.opts.trusted_ips)) self.orig_send_bufsize = self.send_bufsize = 4096 self.tdir = tdir self.wait_for = READ self.response_started = False self.read_buffer = ReadBuffer() self.handle_event = None self.ssl_context = ssl_context self.ssl_handshake_done = False self.ssl_terminated = False if self.ssl_context is not None: self.ready = False self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False) self.set_state(RDWR, self.do_ssl_handshake) else: self.socket = socket self.connection_ready() self.last_activity = monotonic() self.ready = True def optimize_for_sending_packet(self): start_cork(self.socket) self.orig_send_bufsize = self.send_bufsize = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) if self.send_bufsize < DESIRED_SEND_BUFFER_SIZE: try: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, DESIRED_SEND_BUFFER_SIZE) except OSError: pass else: self.send_bufsize = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) def end_send_optimization(self): stop_cork(self.socket) if self.send_bufsize != self.orig_send_bufsize: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.orig_send_bufsize) def set_state(self, wait_for, func, *args, **kwargs): self.wait_for = wait_for if args or kwargs: pfunc = partial(func, *args, **kwargs) pfunc.__name__ = func.__name__ func = pfunc self.handle_event = func def do_ssl_handshake(self, event): try: self.socket._sslobj.do_handshake() except ssl.SSLWantReadError: self.set_state(READ, self.do_ssl_handshake) except ssl.SSLWantWriteError: self.set_state(WRITE, self.do_ssl_handshake) else: self.ssl_handshake_done = True self.connection_ready() def send(self, data): try: ret = self.socket.send(data) if self.ssl_context is None else self.socket.write(data) self.last_activity = monotonic() return ret except ssl.SSLWantWriteError: return 0 except OSError as e: if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: return 0 elif e.errno in socket_errors_socket_closed: self.log.error('Failed to send all data in state:', self.state_description, 'with error:', e) self.ready = False return 0 raise def recv(self, amt): # If there is data in the read buffer we have to return only that, # since we dont know if the socket has signalled it is ready for # reading if self.read_buffer.has_data: return self.read_buffer.read(amt) # read buffer is empty, so read directly from socket try: data = self.socket.recv(amt) self.last_activity = monotonic() if not data: # a closed connection is indicated by signaling # a read condition, and having recv() return 0. self.ready = False return b'' return data except ssl.SSLWantReadError: return b'' except OSError as e: if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: return b'' if e.errno in socket_errors_socket_closed: self.ready = False return b'' raise def recv_into(self, buf, amt=0): amt = amt or len(buf) if self.read_buffer.has_data: data = self.read_buffer.read(amt) buf[0:len(data)] = data return len(data) try: bytes_read = self.socket.recv_into(buf, amt) self.last_activity = monotonic() if bytes_read == 0: # a closed connection is indicated by signaling # a read condition, and having recv() return 0. self.ready = False return 0 return bytes_read except ssl.SSLWantReadError: return 0 except OSError as e: if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: return 0 if e.errno in socket_errors_socket_closed: self.ready = False return 0 raise def fill_read_buffer(self): try: num = self.read_buffer.recv_from(self.socket) self.last_activity = monotonic() if not num: # a closed connection is indicated by signaling # a read condition, and having recv() return 0. self.ready = False except ssl.SSLWantReadError: return except OSError as e: if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: return if e.errno in socket_errors_socket_closed: self.ready = False return raise def drain_ssl_buffer(self): try: self.read_buffer.recv_from(self.socket) except ssl.SSLWantReadError: return except ssl.SSLError as e: self.log.error('Error while reading SSL data from client: %s' % as_unicode(e)) self.ready = False return except OSError as e: if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: return if e.errno in socket_errors_socket_closed: self.ready = False return raise def close(self): self.ready = False self.handle_event = None # prevent reference cycles try: self.socket.shutdown(socket.SHUT_WR) except OSError: pass try: self.socket.close() except OSError: pass def queue_job(self, func, *args): if args: func = partial(func, *args) try: self.pool.put_nowait(self.socket.fileno(), func) except Full: raise JobQueueFull() self.set_state(WAIT, self._job_done) def _job_done(self, event): self.job_done(*event) def job_done(self, ok, result): raise NotImplementedError() @property def state_description(self): return '' def report_unhandled_exception(self, e, formatted_traceback): pass def report_busy(self): pass def connection_ready(self): raise NotImplementedError() def handle_timeout(self): return False # }}} @lru_cache(maxsize=2) def parsed_trusted_ips(raw): return tuple(parse_trusted_ips(raw)) if raw else () class ServerLoop: LISTENING_MSG = 'calibre server listening on' def __init__( self, handler, opts=None, plugins=(), # A calibre logging object. If None, a default log that logs to # stdout is used log=None, # A calibre logging object for access logging, by default no access # logging is performed access_log=None ): self.ready = False self.handler = handler self.opts = opts or Options() self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG) self.jobs_manager = JobsManager(self.opts, self.log) self.access_log = access_log ba = (self.opts.listen_on, int(self.opts.port)) if not ba[0]: # AI_PASSIVE does not work with host of '' or None ba = ('0.0.0.0', ba[1]) self.bind_address = ba self.bound_address = None self.connection_map = {} self.ssl_context = None if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None: self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_context.load_cert_chain(certfile=self.opts.ssl_certfile, keyfile=self.opts.ssl_keyfile) self.ssl_context.set_servername_callback(self.on_ssl_servername) self.pre_activated_socket = None if self.opts.allow_socket_preallocation: from calibre.srv.pre_activated import pre_activated_socket self.pre_activated_socket = pre_activated_socket() if self.pre_activated_socket is not None: set_socket_inherit(self.pre_activated_socket, False) self.bind_address = self.pre_activated_socket.getsockname() self.create_control_connection() self.pool = ThreadPool(self.log, self.job_completed, count=self.opts.worker_count) self.plugin_pool = PluginPool(self, plugins) def on_ssl_servername(self, socket, server_name, ssl_context): c = self.connection_map.get(socket.fileno()) if getattr(c, 'ssl_handshake_done', False): c.ready = False c.ssl_terminated = True # We do not allow client initiated SSL renegotiation return ssl.ALERT_DESCRIPTION_NO_RENEGOTIATION def create_control_connection(self): if iswindows: self.control_in, self.control_out = create_sock_pair() else: r, w = os.pipe() os.set_blocking(r, False) os.set_blocking(w, True) self.control_in = open(w, 'wb') self.control_out = open(r, 'rb') def close_control_connection(self): with suppress(Exception): self.control_in.close() with suppress(Exception): self.control_out.close() def __str__(self): return f"{self.__class__.__name__}({self.bind_address!r})" __repr__ = __str__ @property def num_active_connections(self): return len(self.connection_map) def do_bind(self): # Get the correct address family for our host (allows IPv6 addresses) host, port = self.bind_address try: info = socket.getaddrinfo( host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) except socket.gaierror: if ':' in host: info = [(socket.AF_INET6, socket.SOCK_STREAM, 0, "", self.bind_address + (0, 0))] else: info = [(socket.AF_INET, socket.SOCK_STREAM, 0, "", self.bind_address)] self.socket = None msg = "No socket could be created" for res in info: af, socktype, proto, canonname, sa = res try: self.bind(af, socktype, proto) except OSError as serr: msg = f"{msg} -- ({sa}: {as_unicode(serr)})" if self.socket: self.socket.close() self.socket = None continue break if not self.socket: raise OSError(msg) def initialize_socket(self): if self.pre_activated_socket is None: try: self.do_bind() except OSError as err: if not self.opts.fallback_to_detected_interface: raise ip = get_external_ip() if ip == self.bind_address[0]: raise self.log.warn('Failed to bind to {} with error: {}. Trying to bind to the default interface: {} instead'.format( self.bind_address[0], as_unicode(err), ip)) self.bind_address = (ip, self.bind_address[1]) self.do_bind() else: self.socket = self.pre_activated_socket self.pre_activated_socket = None self.setup_socket() def serve(self): self.connection_map = {} self.socket.listen(min(socket.SOMAXCONN, 128)) self.bound_address = ba = self.socket.getsockname() if isinstance(ba, tuple): ba = ':'.join(map(str, ba)) self.pool.start() with TemporaryDirectory(prefix='srv-') as tdir: self.tdir = tdir if self.LISTENING_MSG: self.log(self.LISTENING_MSG, ba) self.plugin_pool.start() self.ready = True while self.ready: try: self.tick() except SystemExit: self.shutdown() raise except KeyboardInterrupt: break except: self.log.exception('Error in ServerLoop.tick') self.shutdown() def serve_forever(self): """ Listen for incoming connections. """ self.initialize_socket() self.serve() def setup_socket(self): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), # activate dual-stack. if (hasattr(socket, 'AF_INET6') and self.socket.family == socket.AF_INET6 and self.bind_address[0] in ('::', '::0', '::0.0.0.0')): try: self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) except (AttributeError, OSError): # Apparently, the socket option is not available in # this machine's TCP stack pass self.socket.setblocking(0) def bind(self, family, atype, proto=0): '''Create (or recreate) the actual socket object.''' self.socket = socket.socket(family, atype, proto) set_socket_inherit(self.socket, False) self.setup_socket() self.socket.bind(self.bind_address) def tick(self): now = monotonic() read_needed, write_needed, readable, remove, close_needed = [], [], [], [], [] has_ssl = self.ssl_context is not None for s, conn in iteritems(self.connection_map): if now - conn.last_activity > self.opts.timeout: if conn.handle_timeout(): conn.last_activity = now else: remove.append((s, conn)) continue wf = conn.wait_for if wf is READ or wf is RDWR: if wf is RDWR: write_needed.append(s) if conn.read_buffer.has_data: readable.append(s) else: if has_ssl: conn.drain_ssl_buffer() if conn.ready: (readable if conn.read_buffer.has_data else read_needed).append(s) else: close_needed.append((s, conn)) else: read_needed.append(s) elif wf is WRITE: write_needed.append(s) for s, conn in remove: self.log('Closing connection because of extended inactivity: %s' % conn.state_description) self.close(s, conn) for x, conn in close_needed: self.close(s, conn) if readable: writable = [] else: try: readable, writable, _ = select.select([self.socket.fileno(), self.control_out.fileno()] + read_needed, write_needed, [], self.opts.timeout) except ValueError: # self.socket.fileno() == -1 self.ready = False self.log.error('Listening socket was unexpectedly terminated') return except OSError as e: # select.error has no errno attribute. errno is instead # e.args[0] if getattr(e, 'errno', e.args[0]) in socket_errors_eintr: return for s, conn in tuple(iteritems(self.connection_map)): try: select.select([s], [], [], 0) except OSError as e: if getattr(e, 'errno', e.args[0]) not in socket_errors_eintr: self.close(s, conn) # Bad socket, discard return if not self.ready: return ignore = set() for s, conn, event in self.get_actions(readable, writable): if s in ignore: continue try: conn.handle_event(event) if not conn.ready: self.close(s, conn) except JobQueueFull: self.log.exception('Server busy handling request: %s' % conn.state_description) if conn.ready: if conn.response_started: self.close(s, conn) else: try: conn.report_busy() except Exception: self.close(s, conn) except Exception as e: ignore.add(s) ssl_terminated = getattr(conn, 'ssl_terminated', False) if ssl_terminated: self.log.warn('Client tried to initiate SSL renegotiation, closing connection') self.close(s, conn) else: self.log.exception('Unhandled exception in state: %s' % conn.state_description) if conn.ready: if conn.response_started: self.close(s, conn) else: try: conn.report_unhandled_exception(e, traceback.format_exc()) except Exception: self.close(s, conn) else: self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e)) self.close(s, conn) def write_to_control(self, what): if iswindows: self.control_in.sendall(what) else: self.control_in.write(what) self.control_in.flush() def wakeup(self): self.write_to_control(WAKEUP) def job_completed(self): self.write_to_control(JOB_DONE) def dispatch_job_results(self): while True: try: s, ok, result = self.pool.get_nowait() except Empty: break conn = self.connection_map.get(s) if conn is not None: yield s, conn, (ok, result) def close(self, s, conn): self.connection_map.pop(s, None) conn.close() def get_actions(self, readable, writable): listener = self.socket.fileno() control = self.control_out.fileno() for s in readable: if s == listener: sock, addr = self.accept() if sock is not None: s = sock.fileno() if s > -1: self.connection_map[s] = conn = self.handler( sock, self.opts, self.ssl_context, self.tdir, addr, self.pool, self.log, self.access_log, self.wakeup) if self.ssl_context is not None: yield s, conn, RDWR elif s == control: f = self.control_out.recv if iswindows else self.control_out.read try: c = f(1) except OSError as e: if not self.ready: return self.log.error('Control connection raised an error:', e) raise if c == JOB_DONE: for s, conn, event in self.dispatch_job_results(): yield s, conn, event elif c == WAKEUP: pass elif not c: if not self.ready: return self.log.error('Control connection failed to read after signalling ready') raise Exception('Control connection failed to read, something bad happened') else: yield s, self.connection_map[s], READ for s in writable: try: conn = self.connection_map[s] except KeyError: continue # Happens if connection was closed during read phase yield s, conn, WRITE def accept(self): try: sock, addr = self.socket.accept() set_socket_inherit(sock, False), sock.setblocking(False) return sock, addr except OSError: return None, None def stop(self): self.ready = False self.wakeup() def shutdown(self): self.jobs_manager.shutdown() with suppress(socket.error): if getattr(self, 'socket', None): self.socket.close() self.socket = None for s, conn in tuple(iteritems(self.connection_map)): self.close(s, conn) wait_till = monotonic() + self.opts.shutdown_timeout for pool in (self.plugin_pool, self.pool): pool.stop(wait_till) if pool.workers: self.log.warn('Failed to shutdown %d workers in %s cleanly' % (len(pool.workers), pool.__class__.__name__)) self.jobs_manager.wait_for_shutdown(wait_till) class EchoLine(Connection): # {{{ bye_after_echo = False def connection_ready(self): self.rbuf = BytesIO() self.set_state(READ, self.read_line) def read_line(self, event): data = self.recv(1) if data: self.rbuf.write(data) if b'\n' == data: if self.rbuf.tell() < 3: # Empty line self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue()) self.bye_after_echo = True self.set_state(WRITE, self.echo) self.rbuf.seek(0) def echo(self, event): pos = self.rbuf.tell() self.rbuf.seek(0, os.SEEK_END) left = self.rbuf.tell() - pos self.rbuf.seek(pos) sent = self.send(self.rbuf.read(512)) if sent == left: self.rbuf = BytesIO() self.set_state(READ, self.read_line) if self.bye_after_echo: self.ready = False else: self.rbuf.seek(pos + sent) # }}} def main(): print('Starting Echo server') s = ServerLoop(EchoLine) with HandleInterrupt(s.stop): s.serve_forever() if __name__ == '__main__': main()