import aiohttp import aiohttp.web import asyncio import base64 import collections import concurrent.futures import logging import os.path import signal import ssl import string import sys import toml SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()} class InvalidConfig(Exception): '''Error in configuration file''' def is_valid_pem(path, withCert): '''Very basic check whether something looks like a valid PEM certificate''' try: with open(path, 'rb') as fp: contents = fp.read() # All of these raise exceptions if something's wrong... if withCert: assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n') endCertPos = contents.index(b'-----END CERTIFICATE-----\n') base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True) assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n') else: assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n') endCertPos = -26 # Please shoot me. endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n') base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True) assert contents[endKeyPos + 26:] == b'' return True except: # Yes, really return False class Config(dict): def __init__(self, filename): super().__init__() self._filename = filename with open(self._filename, 'r') as fp: obj = toml.load(fp) logging.info(repr(obj)) # Sanity checks if any(x not in ('logging', 'irc', 'web', 'maps') for x in obj.keys()): raise InvalidConfig('Unknown sections found in base object') if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()): raise InvalidConfig('Invalid section type(s), expected objects/dicts') if 'logging' in obj: if any(x not in ('level', 'format') for x in obj['logging']): raise InvalidConfig('Unknown key found in log section') if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'): raise InvalidConfig('Invalid log level') if 'format' in obj['logging']: if not isinstance(obj['logging']['format'], str): raise InvalidConfig('Invalid log format') try: #TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion) # This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text). assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0 except (ValueError, AssertionError) as e: raise InvalidConfig('Invalid log format: parsing failed') from e if 'irc' in obj: if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']): raise InvalidConfig('Unknown key found in irc section') if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname raise InvalidConfig('Invalid IRC host') if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535): raise InvalidConfig('Invalid IRC port') if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname raise InvalidConfig('Invalid IRC nick') if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str): raise InvalidConfig('Invalid IRC realname') if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']): raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile') if 'certfile' in obj['irc']: if not isinstance(obj['irc']['certfile'], str): raise InvalidConfig('Invalid certificate file: not a string') if not os.path.isfile(obj['irc']['certfile']): raise InvalidConfig('Invalid certificate file: not a regular file') if not is_valid_pem(obj['irc']['certfile'], True): raise InvalidConfig('Invalid certificate file: not a valid PEM cert') if 'certkeyfile' in obj['irc']: if not isinstance(obj['irc']['certkeyfile'], str): raise InvalidConfig('Invalid certificate key file: not a string') if not os.path.isfile(obj['irc']['certkeyfile']): raise InvalidConfig('Invalid certificate key file: not a regular file') if not is_valid_pem(obj['irc']['certkeyfile'], False): raise InvalidConfig('Invalid certificate key file: not a valid PEM key') if 'web' in obj: if any(x not in ('host', 'port') for x in obj['web']): raise InvalidConfig('Unknown key found in web section') if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?) raise InvalidConfig('Invalid web hostname') if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535): raise InvalidConfig('Invalid web port') if 'maps' in obj: for key, map_ in obj['maps'].items(): if not isinstance(key, str) or not key: raise InvalidConfig(f'Invalid map key {key!r}') if not isinstance(map_, collections.abc.Mapping): raise InvalidConfig(f'Invalid map for {key!r}') if any(x not in ('webpath', 'ircchannel', 'auth') for x in map_): raise InvalidConfig(f'Unknown key(s) found in map {key!r}') #TODO: Check values # Default values finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}} # Fill in default values for the maps for key, map_ in obj['maps'].items(): if 'webpath' not in map_: map_['webpath'] = f'/{key}' if 'ircchannel' not in map_: map_['ircchannel'] = f'#{key}' if 'auth' not in map_: map_['auth'] = False # Merge in what was read from the config file and set keys on self for key in ('logging', 'irc', 'web', 'maps'): if key in obj: finalObj[key].update(obj[key]) self[key] = finalObj[key] def __repr__(self): return f'' def reread(self): return Config(self._filename) class MessageQueue: # An object holding onto the messages received from nodeping # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. # Differences to asyncio.Queue include: # - No maxsize # - No put coroutine (not necessary since the queue can never be full) # - Only one concurrent getter # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails) def __init__(self): self._getter = None # None | asyncio.Future self._queue = collections.deque() async def get(self): if self._getter is not None: raise RuntimeError('Cannot get concurrently') if len(self._queue) == 0: self._getter = asyncio.get_running_loop().create_future() logging.debug('Awaiting getter') try: await self._getter except asyncio.CancelledError: logging.debug('Cancelled getter') self._getter = None raise logging.debug('Awaited getter') self._getter = None # For testing the cancellation/putting back onto the queue #logging.debug('Delaying message queue get') #await asyncio.sleep(3) #logging.debug('Done delaying') return self.get_nowait() def get_nowait(self): if len(self._queue) == 0: raise asyncio.QueueEmpty return self._queue.popleft() def put_nowait(self, item): self._queue.append(item) if self._getter is not None and not self._getter.cancelled(): self._getter.set_result(None) def putleft_nowait(self, *item): self._queue.extendleft(reversed(item)) if self._getter is not None and not self._getter.cancelled(): self._getter.set_result(None) def qsize(self): return len(self._queue) class IRCClientProtocol(asyncio.Protocol): def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels): logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {connectionClosedEvent}, {loop}') self.messageQueue = messageQueue self.connectionClosedEvent = connectionClosedEvent self.loop = loop self.config = config self.buffer = b'' self.connected = False self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) self.unconfirmedMessages = [] self.pongReceivedEvent = asyncio.Event() def connection_made(self, transport): logging.info('Connected') self.transport = transport self.connected = True nickb = self.config['irc']['nick'].encode('utf-8') self.send(b'NICK ' + nickb) self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config['irc']['real'].encode('utf-8')) def update_channels(self, channels: set): channelsToPart = self.channels - channels channelsToJoin = channels - self.channels self.channels = channels if self.connected: if channelsToPart: #TODO: Split if too long self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8')) if channelsToJoin: self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8')) def send(self, data): logging.info(f'Send: {data!r}') self.transport.write(data + b'\r\n') async def _get_message(self): logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') messageFuture = asyncio.create_task(self.messageQueue.get()) done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) if self.connectionClosedEvent.is_set(): if messageFuture in pending: logging.debug('Cancelling messageFuture') messageFuture.cancel() try: await messageFuture except asyncio.CancelledError: logging.debug('Cancelled messageFuture') pass else: # messageFuture is already done but we're stopping, so put the result back onto the queue self.messageQueue.putleft_nowait(messageFuture.result()) return None, None assert messageFuture in done, 'Invalid state: messageFuture not in done futures' return messageFuture.result() async def send_messages(self): while self.connected: logging.debug(f'{id(self)}: trying to get a message') channel, message = await self._get_message() logging.debug(f'{id(self)}: got message: {message!r}') if message is None: break #TODO Split if the message is too long. self.unconfirmedMessages.append((channel, message)) self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8')) await asyncio.sleep(1) # Rate limit async def confirm_messages(self): while self.connected: await asyncio.wait((asyncio.sleep(60), self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) # Confirm once per minute if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly self.messageQueue.putleft_nowait(*self.unconfirmedMessages) self.unconfirmedMessages = [] break if not self.unconfirmedMessages: logging.debug(f'{id(self)}: no messages to confirm') continue logging.debug(f'{id(self)}: trying to confirm message delivery') self.pongReceivedEvent.clear() self.send(b'PING :42') await asyncio.wait((asyncio.sleep(5), self.pongReceivedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) logging.debug(f'{id(self)}: message delivery success: {self.pongReceivedEvent.is_set()}') if not self.pongReceivedEvent.is_set(): # No PONG received in five seconds, assume connection's dead self.messageQueue.putleft_nowait(*self.unconfirmedMessages) self.transport.close() self.unconfirmedMessages = [] def data_received(self, data): logging.debug(f'Data received: {data!r}') # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. # If data does end with CRLF, all messages will have been processed and the buffer will be empty again. messages = data.split(b'\r\n') if self.buffer: self.message_received(self.buffer + messages[0]) messages = messages[1:] for message in messages[:-1]: self.message_received(message) self.buffer = messages[-1] def message_received(self, message): logging.info(f'Message received: {message!r}') if message.startswith(b':'): # Prefixed message, extract command + parameters (the prefix cannot contain a space) message = message.split(b' ', 1)[1] if message.startswith(b'PING '): self.send(b'PONG ' + message[5:]) elif message.startswith(b'PONG '): self.pongReceivedEvent.set() elif message.startswith(b'001 '): # Connection registered self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long asyncio.create_task(self.send_messages()) asyncio.create_task(self.confirm_messages()) def connection_lost(self, exc): logging.info('The server closed the connection') self.connected = False self.connectionClosedEvent.set() class IRCClient: def __init__(self, messageQueue, config): self.messageQueue = messageQueue self.config = config self.channels = {map_['ircchannel'] for map_ in config['maps'].values()} self._transport = None self._protocol = None def update_config(self, config): needReconnect = self.config['irc'] != config['irc'] self.config = config if self._transport: # if currently connected: if needReconnect: self._transport.close() else: self.channels = {map_['ircchannel'] for map_ in config['maps'].values()} self._protocol.update_channels(self.channels) def _get_ssl_context(self): ctx = SSL_CONTEXTS[self.config['irc']['ssl']] if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']: if ctx is True: ctx = ssl.create_default_context() if isinstance(ctx, ssl.SSLContext): ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile']) return ctx async def run(self, loop, sigintEvent): connectionClosedEvent = asyncio.Event() while True: connectionClosedEvent.clear() try: self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) try: await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) finally: self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? except (ConnectionRefusedError, asyncio.TimeoutError) as e: logging.error(str(e)) await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) if sigintEvent.is_set(): break class WebServer: def __init__(self, messageQueue, config): self.messageQueue = messageQueue self.config = config self._paths = {} # '/path' => ('#channel', auth) where auth is either False (no authentication) or the HTTP header value for basic auth self._app = aiohttp.web.Application() self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)]) self.update_config(config) self._configChanged = asyncio.Event() def update_config(self, config): self._paths = {map_['webpath']: (map_['ircchannel'], f'Basic {base64.b64encode(map_["auth"].encode("utf-8")).decode("utf-8")}' if map_['auth'] else False) for map_ in config['maps'].values()} needRebind = self.config['web'] != config['web'] self.config = config if needRebind: self._configChanged.set() async def run(self, stopEvent): while True: runner = aiohttp.web.AppRunner(self._app) await runner.setup() site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) await site.start() await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED) await runner.cleanup() if stopEvent.is_set(): break self._configChanged.clear() async def post(self, request): logging.info(f'Received request for {request.path!r}') try: channel, auth = self._paths[request.path] except KeyError: logging.info(f'Bad request: no path {request.path!r}') raise aiohttp.web.HTTPNotFound() if auth: authHeader = request.headers.get('Authorization') if not authHeader or authHeader != auth: logging.info(f'Bad request: authentication failed: {authHeader!r} != {auth}') raise aiohttp.web.HTTPForbidden() try: message = await request.text() except Exception as e: logging.info(f'Bad request: exception while reading request data: {e!s}') raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-) logging.debug(f'Request payload: {message!r}') # Strip optional [CR] LF at the end of the payload if message.endswith('\r\n'): message = message[:-2] elif message.endswith('\n'): message = message[:-1] if '\r' in message or '\n' in message: logging.info('Bad request: linebreaks in message') raise aiohttp.web.HTTPBadRequest() logging.debug(f'Putting message {message!r} for {channel} into message queue') self.messageQueue.put_nowait((channel, message)) raise aiohttp.web.HTTPOk() def configure_logging(config): #TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+) root = logging.getLogger() root.setLevel(getattr(logging, config['logging']['level'])) root.handlers = [] #FIXME: Undocumented attribute of logging.Logger formatter = logging.Formatter(config['logging']['format'], style = '{') stderrHandler = logging.StreamHandler() stderrHandler.setFormatter(formatter) root.addHandler(stderrHandler) async def main(): if len(sys.argv) != 2: print('Usage: http2irc.py CONFIGFILE', file = sys.stderr) sys.exit(1) configFile = sys.argv[1] config = Config(configFile) configure_logging(config) loop = asyncio.get_running_loop() messageQueue = MessageQueue() irc = IRCClient(messageQueue, config) webserver = WebServer(messageQueue, config) sigintEvent = asyncio.Event() def sigint_callback(): logging.info('Got SIGINT') nonlocal sigintEvent sigintEvent.set() loop.add_signal_handler(signal.SIGINT, sigint_callback) def sigusr1_callback(): logging.info('Got SIGUSR1, reloading config') nonlocal config, irc, webserver try: newConfig = config.reread() except InvalidConfig as e: logging.error(f'Config reload failed: {e!s}') return config = newConfig configure_logging(config) irc.update_config(config) webserver.update_config(config) loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback) await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent)) if __name__ == '__main__': asyncio.run(main())