commit 112285d3342f915b5f14f02500ebaa0b28ae7563 Author: JustAnotherArchivist Date: Fri Oct 2 00:58:39 2020 +0000 Initial commit diff --git a/config.example.toml b/config.example.toml new file mode 100644 index 0000000..d97ea97 --- /dev/null +++ b/config.example.toml @@ -0,0 +1,34 @@ +[logging] + # level must be one of logging's defined levels (NOTSET, DEBUG, INFO, WARNING, ERROR, CRITICAL) + #level = 'INFO' + # format must use the '{' style of logging + #format = '{asctime} {levelname} {name} {message}' + +[storage] + # Must point to an existing directory that is writable by the user running irclog. Relative paths are evaluated relative to the config file. + #path = '.' + +[irc] + #host = 'irc.hackint.org' + #port = 6697 + # Possible values: 'yes' (connect over SSL/TLS and verify certificates), 'no' (connect without SSL/TLS), and 'insecure' (connect over SSL/TLS but disable all certificate checks) + #ssl = 'yes' + #nick = 'irclogbot' + #real = 'I am an irclog bot.' + # Certificate and key for SASL EXTERNAL authentication with NickServ; certfile is a string containing the path to a .pem file which has the certificate and the key, certkeyfile similarly for one containing only the key; default values are empty (None in Python) to disable authentication; the connection is terminated if authentication fails; relative paths are evaluated relative to the config file. + #certfile = + #certkeyfile = + +[web] + #host = '127.0.0.1' + #port = 8080 + +[channels] + # No channels are logged by default. + #[channels.spam] + # If ircchannel isn't specified, it corresponds to '#' followed by the map key. + #ircchannel = '#spam' + # auth can be either 'user:pass' for basic authentication or false to disable auth + #auth = false + # Whether this channel should still be actively logged. Set this to false to stop logging the channel but keep serving the previous logs. + #active = true diff --git a/irclog.py b/irclog.py new file mode 100644 index 0000000..74b6958 --- /dev/null +++ b/irclog.py @@ -0,0 +1,673 @@ +import aiohttp +import aiohttp.web +import asyncio +import base64 +import collections +import concurrent.futures +import importlib.util +import inspect +import itertools +import logging +import os.path +import signal +import ssl +import string +import sys +import tempfile +import time +import toml + + +logger = logging.getLogger('irclog') +SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()} + + +class InvalidConfig(Exception): + '''Error in configuration file''' + + +def is_valid_pem(path, withCert): + '''Very basic check whether something looks like a valid PEM certificate''' + try: + with open(path, 'rb') as fp: + contents = fp.read() + + # All of these raise exceptions if something's wrong... + if withCert: + assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n') + endCertPos = contents.index(b'-----END CERTIFICATE-----\n') + base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True) + assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n') + else: + assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n') + endCertPos = -26 # Please shoot me. + endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n') + base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True) + assert contents[endKeyPos + 26:] == b'' + return True + except: # Yes, really + return False + + +class Config(dict): + def __init__(self, filename): + super().__init__() + self._filename = filename + + with open(self._filename, 'r') as fp: + obj = toml.load(fp) + + # Sanity checks + if any(x not in ('logging', 'storage', 'irc', 'web', 'channels') for x in obj.keys()): + raise InvalidConfig('Unknown sections found in base object') + if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()): + raise InvalidConfig('Invalid section type(s), expected objects/dicts') + if 'logging' in obj: + if any(x not in ('level', 'format') for x in obj['logging']): + raise InvalidConfig('Unknown key found in log section') + if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'): + raise InvalidConfig('Invalid log level') + if 'format' in obj['logging']: + if not isinstance(obj['logging']['format'], str): + raise InvalidConfig('Invalid log format') + try: + #TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion) + # This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text). + assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0 + except (ValueError, AssertionError) as e: + raise InvalidConfig('Invalid log format: parsing failed') from e + if 'storage' in obj: + if any(x != 'path' for x in obj['storage']): + raise InvalidConfig('Unknown key found in storage section') + if 'path' in obj['storage']: + obj['storage']['path'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['storage']['path'])) + try: + f = tempfile.TemporaryFile(dir = obj['storage']['path']) + f.close() + except (OSError, IOError) as e: + raise InvalidConfig('Invalid storage path: not writable') from e + if 'irc' in obj: + if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']): + raise InvalidConfig('Unknown key found in irc section') + if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname + raise InvalidConfig('Invalid IRC host') + if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535): + raise InvalidConfig('Invalid IRC port') + if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): + raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') + if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname, username, etc. + raise InvalidConfig('Invalid IRC nick') + if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510: + raise InvalidConfig('Invalid IRC nick: NICK command too long') + if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str): + raise InvalidConfig('Invalid IRC realname') + if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510: + raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long') + if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']): + raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile') + if 'certfile' in obj['irc']: + if not isinstance(obj['irc']['certfile'], str): + raise InvalidConfig('Invalid certificate file: not a string') + obj['irc']['certfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certfile'])) + if not os.path.isfile(obj['irc']['certfile']): + raise InvalidConfig('Invalid certificate file: not a regular file') + if not is_valid_pem(obj['irc']['certfile'], True): + raise InvalidConfig('Invalid certificate file: not a valid PEM cert') + if 'certkeyfile' in obj['irc']: + if not isinstance(obj['irc']['certkeyfile'], str): + raise InvalidConfig('Invalid certificate key file: not a string') + obj['irc']['certkeyfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certkeyfile'])) + if not os.path.isfile(obj['irc']['certkeyfile']): + raise InvalidConfig('Invalid certificate key file: not a regular file') + if not is_valid_pem(obj['irc']['certkeyfile'], False): + raise InvalidConfig('Invalid certificate key file: not a valid PEM key') + if 'web' in obj: + if any(x not in ('host', 'port') for x in obj['web']): + raise InvalidConfig('Unknown key found in web section') + if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?) + raise InvalidConfig('Invalid web hostname') + if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535): + raise InvalidConfig('Invalid web port') + if 'channels' in obj: + seenChannels = {} + for key, channel in obj['channels'].items(): + if not isinstance(key, str) or not key: + raise InvalidConfig(f'Invalid channel key {key!r}') + if not isinstance(channel, collections.abc.Mapping): + raise InvalidConfig(f'Invalid channel for {key!r}') + if any(x not in ('ircchannel', 'auth', 'active') for x in channel): + raise InvalidConfig(f'Unknown key(s) found in channel {key!r}') + + if 'ircchannel' not in channel: + channel['ircchannel'] = f'#{key}' + if not isinstance(channel['ircchannel'], str): + raise InvalidConfig(f'Invalid channel {key!r} IRC channel: not a string') + if not channel['ircchannel'].startswith('#') and not channel['ircchannel'].startswith('&'): + raise InvalidConfig(f'Invalid channel {key!r} IRC channel: does not start with # or &') + if any(x in channel['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')): + raise InvalidConfig(f'Invalid channel {key!r} IRC channel: contains forbidden characters') + if len(channel['ircchannel']) > 200: + raise InvalidConfig(f'Invalid channel {key!r} IRC channel: too long') + if channel['ircchannel'] in seenChannels: + raise InvalidConfig(f'Invalid channel {key!r} IRC channel: collides with channel {seenWebPaths[channel["ircchannel"]]!r}') + seenChannels[channel['ircchannel']] = key + + if 'auth' in channel: + if channel['auth'] is not False and not isinstance(channel['auth'], str): + raise InvalidConfig(f'Invalid channel {key!r} auth: must be false or a string') + if isinstance(channel['auth'], str) and ':' not in channel['auth']: + raise InvalidConfig(f'Invalid channel {key!r} auth: must contain a colon') + else: + channel['auth'] = False + + if 'active' in channel: + if channel['active'] is not True and channel['active'] is not False: + raise InvalidConfig(f'Invalid channel {key!r} active: must be true or false') + else: + channel['active'] = True + + # Default values + finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'storage': {'path': os.path.abspath(os.path.dirname(self._filename))}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'irclogbot', 'real': 'I am an irclog bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'channels': {}} + # Default values for channels are already set above. + + # Merge in what was read from the config file and set keys on self + for key in ('logging', 'storage', 'irc', 'web', 'channels'): + if key in obj: + finalObj[key].update(obj[key]) + self[key] = finalObj[key] + + def __repr__(self): + return f'' + + def reread(self): + return Config(self._filename) + + +class IRCClientProtocol(asyncio.Protocol): + logger = logging.getLogger('irclog.IRCClientProtocol') + + def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels): + self.messageQueue = messageQueue + self.connectionClosedEvent = connectionClosedEvent + self.loop = loop + self.config = config + self.buffer = b'' + self.connected = False + self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) + self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile']) + self.authenticated = False + self.usermask = None + + @staticmethod + def nick_command(nick: str): + return b'NICK ' + nick.encode('utf-8') + + @staticmethod + def user_command(nick: str, real: str): + nickb = nick.encode('utf-8') + return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8') + + def _maybe_set_usermask(self, usermask): + if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')): + self.usermask = usermask + self.logger.debug(f'Usermask is now {usermask!r}') + + def connection_made(self, transport): + self.logger.info('IRC connected') + self.transport = transport + self.connected = True + if self.sasl: + self.send(b'CAP REQ :sasl') + self.send(self.nick_command(self.config['irc']['nick'])) + self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real'])) + + def _send_join_part(self, command, channels): + '''Split a JOIN or PART into multiple messages as necessary''' + # command: b'JOIN' or b'PART'; channels: set[str] + + channels = [x.encode('utf-8') for x in channels] + if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma) + # Everything fits into one command. + self.send(command + b' ' + b','.join(channels)) + return + + # List too long, need to split. + limit = 510 - len(command) + lengths = [1 + len(x) for x in channels] # separator + channel name + chanLengthAcceptable = [l <= limit for l in lengths] + if not all(chanLengthAcceptable): + # There are channel names that are too long to even fit into one message on their own; filter them out and warn about them. + # This should never happen since the config reader would already filter it out. + tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a] + channels = [x for x, a in zip(channels, chanLengthAcceptable) if a] + lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a] + for channel in tooLongChannels: + self.logger.warning(f'Cannot {command} {channel}: name too long') + runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators + offset = 0 + while channels: + i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1) + if i == -1: # Last batch + i = len(channels) + self.send(command + b' ' + b','.join(channels[:i])) + offset = runningLengths[i-1] + channels = channels[i:] + runningLengths = runningLengths[i:] + + def update_channels(self, channels: set): + channelsToPart = self.channels - channels + channelsToJoin = channels - self.channels + self.channels = channels + + if self.connected: + if channelsToPart: + self._send_join_part(b'PART', channelsToPart) + if channelsToJoin: + self._send_join_part(b'JOIN', channelsToJoin) + + def send(self, data): + self.logger.debug(f'Send: {data!r}') + if len(data) > 510: + raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}') + time_ = time.time() + self.transport.write(data + b'\r\n') + self.messageQueue.put_nowait((time_, b'> ' + data)) + + def data_received(self, data): + self.logger.debug(f'Data received: {data!r}') + time_ = time.time() + # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. + # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. + # If data does end with CRLF, all messages will have been processed and the buffer will be empty again. + messages = data.split(b'\r\n') + if self.buffer: + self.message_received(time_, self.buffer + messages[0]) + messages = messages[1:] + for message in messages[:-1]: + self.message_received(time_, message) + self.buffer = messages[-1] + + def message_received(self, time_, message): + self.logger.debug(f'Message received at {time_}: {message!r}') + rawMessage = message + if message.startswith(b':') and b' ' in message: + # Prefixed message, extract command + parameters (the prefix cannot contain a space) + message = message.split(b' ', 1)[1] + + # Queue message for storage + self.messageQueue.put_nowait((time_, b'< ' + rawMessage)) + + # PING/PONG + if message.startswith(b'PING '): + self.send(b'PONG ' + message[5:]) + + # SASL + elif message.startswith(b'CAP ') and self.sasl: + if message[message.find(b' ', 4) + 1:] == b'ACK :sasl': + self.send(b'AUTHENTICATE EXTERNAL') + else: + self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection') + self.transport.close() + elif message == b'AUTHENTICATE +': + self.send(b'AUTHENTICATE +') + elif message.startswith(b'900 '): # "You are now logged in", includes the usermask + words = message.split(b' ') + if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]: + if b'!~' not in words[2]: + # At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde. + words[2] = words[2].replace(b'!', b'!~', 1) + self._maybe_set_usermask(words[2]) + elif message.startswith(b'903 '): # SASL auth successful + self.authenticated = True + self.send(b'CAP END') + elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')): + self.logger.error('SASL error, terminating connection') + self.transport.close() + + # NICK errors + elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ')): + self.logger.error(f'Failed to set nickname: {message!r}, terminating connection') + self.transport.close() + + # USER errors + elif any(message.startswith(x) for x in (b'461 ', b'462 ')): + self.logger.error(f'Failed to register: {message!r}, terminating connection') + self.transport.close() + + # JOIN errors + elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')): + self.logger.error(f'Failed to join channel: {message!r}, terminating connection') + self.transport.close() + + # PART errors + elif message.startswith(b'442 '): + self.logger.error(f'Failed to part channel: {message!r}') + + # JOIN/PART errors + elif message.startswith(b'403 '): + self.logger.error(f'Failed to join or part channel: {message!r}') + + # PRIVMSG errors + elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ')): + self.logger.error(f'Failed to send message: {message!r}') + + # Connection registration reply + elif message.startswith(b'001 '): + self.logger.info('IRC connection registered') + if self.sasl and not self.authenticated: + self.logger.error('IRC connection registered but not authenticated, terminating connection') + self.transport.close() + return + self._send_join_part(b'JOIN', self.channels) + + # JOIN success + elif message.startswith(b'JOIN ') and not self.usermask: + # If this is my own join message, it should contain the usermask in the prefix + if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage: + usermask = rawMessage.split(b' ', 1)[0][1:] + self._maybe_set_usermask(usermask) + + # Services host change + elif message.startswith(b'396 '): + words = message.split(b' ') + if len(words) >= 3: + # Sanity check inspired by irssi src/irc/core/irc-servers.c + if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-': + if b'@' in words[2]: # user@host + self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2]) + else: # host (get user from previous mask or settings) + if self.usermask: + user = self.usermask.split(b'@')[0].split(b'!')[1] + else: + user = b'~' + self.config['irc']['nick'].encode('utf-8') + self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2]) + + def connection_lost(self, exc): + self.logger.info('IRC connection lost') + self.connected = False + self.connectionClosedEvent.set() + + +class IRCClient: + logger = logging.getLogger('irclog.IRCClient') + + def __init__(self, messageQueue, config): + self.messageQueue = messageQueue + self.config = config + self.channels = {channel['ircchannel'] for channel in config['channels'].values()} + + self._transport = None + self._protocol = None + + def update_config(self, config): + needReconnect = self.config['irc'] != config['irc'] + self.config = config + if self._transport: # if currently connected: + if needReconnect: + self._transport.close() + else: + self.channels = {channel['ircchannel'] for channel in config['channels'].values()} + self._protocol.update_channels(self.channels) + + def _get_ssl_context(self): + ctx = SSL_CONTEXTS[self.config['irc']['ssl']] + if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']: + if ctx is True: + ctx = ssl.create_default_context() + if isinstance(ctx, ssl.SSLContext): + ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile']) + return ctx + + async def run(self, loop, sigintEvent): + connectionClosedEvent = asyncio.Event() + while True: + connectionClosedEvent.clear() + try: + self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) + try: + await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + finally: + self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? + except (ConnectionRefusedError, asyncio.TimeoutError) as e: + self.logger.error(str(e)) + await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + if sigintEvent.is_set(): + break + + +class Storage: + logger = logging.getLogger('irclog.Storage') + + def __init__(self, messageQueue, config): + self.messageQueue = messageQueue + self.config = config + self.files = {} # channel|None -> fileobj; None = general log for anything that wasn't recognised as a message for the channel log + self.active = True + + def update_config(self, config): + channelsOld = {channel['ircchannel'] for channel in self.config['channels'].values()} + channelsNew = {channel['ircchannel'] for channel in config['channels'].values()} + channelsRemoved = channelsOld - channelsNew + self.config = config + + for channel in channelsRemoved: + if channel in self.files: + self.files[channel].close() + del self.files[channel] + + #TODO mkdir as required + + for channel in self.config['channels'].values(): + if channel['ircchannel'] not in self.files and channel['active']: + self.files[channel['ircchannel']] = open(os.path.join(self.config['storage']['path'], channel['ircchannel'], '2020-10.log'), 'ab') + + if None not in self.files: + self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab') #TODO Month + + async def run(self, loop, sigintEvent): + self.update_config(self.config) # Ensure that files are open etc. + #TODO Task to rotate log files at the beginning of a new month + storageTask = asyncio.create_task(self.store_messages(sigintEvent)) + flushTask = asyncio.create_task(self.flush_files(sigintEvent)) + await sigintEvent.wait() + self.active = False + #TODO Wait for tasks + self.close() + + async def store_messages(self, sigintEvent): + while self.active: + #TODO wait for sigint as well + time_, rawMessage = await self.messageQueue.get() + message = rawMessage[2:] # Remove leading > or < + if message.startswith(b':') and b' ' in message: + prefix, message = message.split(b' ', 1) + + # Identify channel-bound messages: JOIN, PART, QUIT, MODE, KICK, PRIVMSG, NOTICE (see https://tools.ietf.org/html/rfc1459#section-4.2.1) + if message.startswith(b'JOIN ') or message.startswith(b'PART ') or message.startswith(b'PRIVMSG ') or message.startswith(b'NOTICE '): + # I *think* that the first parameter of JOIN/PART can only ever be a single channel for messages announcing other people joining, but who knows with how awful RFC 1459 is... + channelsRaw = message.split(b' ', 2)[1] + channels = self.decode_channel(time_, rawMessage, channelsRaw.split(b',')) + if channels is None: + continue + for channel in channels: + self.store_message(time_, rawMessage, channel) + continue + if message.startswith(b'QUIT '): + #TODO Need to keep track of users to figure out in which channels they were... Ugh + pass + if message.startswith(b'MODE #') or message.startswith(b'MODE &') or message.startswith(b'KICK '): + channel = message.split(b' ', 2)[1] + channel = self.decode_channel(time_, rawMessage, channel) + if channel is None: + continue + self.store_message(time_, rawMessage, channel) + continue + self.store_message(time_, rawMessage, None) + + def store_message(self, time_, rawMessage, targetChannel): + if targetChannel is not None and targetChannel not in self.files: + targetChannel = None + self.files[targetChannel].write(str(time_).encode('ascii') + b' ' + rawMessage + b'\r\n') + + def decode_channel(self, time_, rawMessage, channel): + try: + if isinstance(channel, list): + return [c.decode('utf-8') for c in channel] + return channel.decode('utf-8') + except UnicodeDecodeError as e: + self.logger.warning(f'Failed to decode channel name {channel!r} from {rawMessage!r} at {time_}: {e!s}') + self.store_message(time_, rawMessage, None) + return None + + async def flush_files(self, sigintEvent): + while self.active: + await sigintEvent.wait() + + def close(self): + for f in self.files.values(): + f.close() + self.files = {} + + +class WebServer: + logger = logging.getLogger('irclog.WebServer') + + def __init__(self, config): + self.config = config + + self._paths = {} # '/path' => ('#channel', auth, module, moduleargs) where auth is either False (no authentication) or the HTTP header value for basic auth + + self._app = aiohttp.web.Application() + self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)]) + + self.update_config(config) + self._configChanged = asyncio.Event() + + def update_config(self, config): +# self._paths = {channel['webpath']: (channel['ircchannel'], f'Basic {base64.b64encode(channel["auth"].encode("utf-8")).decode("utf-8")}' if channel['auth'] else False) for channel in config['channels'].values()} + needRebind = self.config['web'] != config['web'] #TODO only if there are changes to web.host or web.port; everything else can be updated without rebinding + self.config = config + if needRebind: + self._configChanged.set() + + async def run(self, stopEvent): + while True: + runner = aiohttp.web.AppRunner(self._app) + await runner.setup() + site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) + await site.start() + await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await runner.cleanup() + if stopEvent.is_set(): + break + self._configChanged.clear() + +# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process +# https://stackoverflow.com/questions/1180606/using-subprocess-popen-for-process-with-large-output +# -> https://stackoverflow.com/questions/57730010/python-asyncio-subprocess-write-stdin-and-read-stdout-stderr-continuously + + async def post(self, request): + self.logger.info(f'Received request {id(request)} from {request.remote!r} for {request.path!r} with body {(await request.read())!r}') + try: + channel, auth, module, moduleargs, overlongmode = self._paths[request.path] + except KeyError: + self.logger.info(f'Bad request {id(request)}: no path {request.path!r}') + raise aiohttp.web.HTTPNotFound() + if auth: + authHeader = request.headers.get('Authorization') + if not authHeader or authHeader != auth: + self.logger.info(f'Bad request {id(request)}: authentication failed: {authHeader!r} != {auth}') + raise aiohttp.web.HTTPForbidden() + if module is not None: + self.logger.debug(f'Processing request {id(request)} using {module!r}') + try: + message = await module.process(request, *moduleargs) + except aiohttp.web.HTTPException as e: + raise e + except Exception as e: + self.logger.error(f'Bad request {id(request)}: exception in module process function: {type(e).__module__}.{type(e).__name__}: {e!s}') + raise aiohttp.web.HTTPBadRequest() + if '\r' in message or '\n' in message: + self.logger.error(f'Bad request {id(request)}: module process function returned message with linebreaks: {message!r}') + raise aiohttp.web.HTTPBadRequest() + else: + self.logger.debug(f'Processing request {id(request)} using default processor') + message = await self._default_process(request) + self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue') + self.messageQueue.put_nowait((channel, message, overlongmode)) + raise aiohttp.web.HTTPOk() + + async def _default_process(self, request): + try: + message = await request.text() + except Exception as e: + self.logger.info(f'Bad request {id(request)}: exception while reading request data: {e!s}') + raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-) + self.logger.debug(f'Request {id(request)} payload: {message!r}') + # Strip optional [CR] LF at the end of the payload + if message.endswith('\r\n'): + message = message[:-2] + elif message.endswith('\n'): + message = message[:-1] + if '\r' in message or '\n' in message: + self.logger.info(f'Bad request {id(request)}: linebreaks in message') + raise aiohttp.web.HTTPBadRequest() + return message + + +def configure_logging(config): + #TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+) + root = logging.getLogger() + root.setLevel(getattr(logging, config['logging']['level'])) + root.handlers = [] #FIXME: Undocumented attribute of logging.Logger + formatter = logging.Formatter(config['logging']['format'], style = '{') + stderrHandler = logging.StreamHandler() + stderrHandler.setFormatter(formatter) + root.addHandler(stderrHandler) + + +async def main(): + if len(sys.argv) != 2: + print('Usage: irclog.py CONFIGFILE', file = sys.stderr) + sys.exit(1) + configFile = sys.argv[1] + config = Config(configFile) + configure_logging(config) + + loop = asyncio.get_running_loop() + + messageQueue = asyncio.Queue() + + irc = IRCClient(messageQueue, config) + webserver = WebServer(config) + storage = Storage(messageQueue, config) + + sigintEvent = asyncio.Event() + def sigint_callback(): + global logger + nonlocal sigintEvent + logger.info('Got SIGINT, stopping') + sigintEvent.set() + loop.add_signal_handler(signal.SIGINT, sigint_callback) + + def sigusr1_callback(): + global logger + nonlocal config, irc, webserver + logger.info('Got SIGUSR1, reloading config') + try: + newConfig = config.reread() + except InvalidConfig as e: + logger.error(f'Config reload failed: {e!s} (old config remains active)') + return + config = newConfig + configure_logging(config) + irc.update_config(config) + webserver.update_config(config) + storage.update_config(config) + loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback) + + await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent), storage.run(loop, sigintEvent)) + + +if __name__ == '__main__': + asyncio.run(main())