diff --git a/irclog.py b/irclog.py index 74b6958..975b061 100644 --- a/irclog.py +++ b/irclog.py @@ -3,7 +3,6 @@ import aiohttp.web import asyncio import base64 import collections -import concurrent.futures import importlib.util import inspect import itertools @@ -20,6 +19,8 @@ import toml logger = logging.getLogger('irclog') SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()} +messageConnectionClosed = object() # Signals that the connection was closed by either the bot or the server +messageEOF = object() # Special object to signal the end of messages to Storage class InvalidConfig(Exception): @@ -82,6 +83,7 @@ class Config(dict): if 'path' in obj['storage']: obj['storage']['path'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['storage']['path'])) try: + #TODO This doesn't seem to work correctly; doesn't fail when the dir is -w f = tempfile.TemporaryFile(dir = obj['storage']['path']) f.close() except (OSError, IOError) as e: @@ -194,6 +196,7 @@ class IRCClientProtocol(asyncio.Protocol): self.buffer = b'' self.connected = False self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) + self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...} self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile']) self.authenticated = False self.usermask = None @@ -207,6 +210,25 @@ class IRCClientProtocol(asyncio.Protocol): nickb = nick.encode('utf-8') return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8') + @staticmethod + def valid_channel(channel: str): + return channel[0] in ('#', '&') and not any(x in channel for x in (' ', '\x00', '\x07', '\r', '\n', ',')) + + @staticmethod + def valid_nick(nick: str): + # According to RFC 1459, a nick must be ' { | | }'. This is obviously not true in practice because doesn't include underscores, for example. + # So instead, just do a sanity check similar to the channel one to disallow obvious bullshit. + return not any(x in nick for x in (' ', '\x00', '\x07', '\r', '\n', ',')) + + @staticmethod + def prefix_to_nick(prefix: str): + nick = prefix[1:] + if '!' in nick: + nick = nick.split('!', 1)[0] + if '@' in nick: # nick@host is also legal + nick = nick.split('@', 1)[0] + return nick + def _maybe_set_usermask(self, usermask): if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')): self.usermask = usermask @@ -271,7 +293,7 @@ class IRCClientProtocol(asyncio.Protocol): raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}') time_ = time.time() self.transport.write(data + b'\r\n') - self.messageQueue.put_nowait((time_, b'> ' + data)) + self.messageQueue.put_nowait((time_, b'> ' + data, None)) def data_received(self, data): self.logger.debug(f'Data received: {data!r}') @@ -290,12 +312,15 @@ class IRCClientProtocol(asyncio.Protocol): def message_received(self, time_, message): self.logger.debug(f'Message received at {time_}: {message!r}') rawMessage = message + hasPrefix = False if message.startswith(b':') and b' ' in message: # Prefixed message, extract command + parameters (the prefix cannot contain a space) - message = message.split(b' ', 1)[1] + prefix, message = message.split(b' ', 1) + hasPrefix = True - # Queue message for storage - self.messageQueue.put_nowait((time_, b'< ' + rawMessage)) + # Queue message for storage, except QUITs and NICKs which are handled below with user tracking + if not message.startswith(b'QUIT ') and message != b'QUIT' and not message.startswith(b'NICK '): + self.messageQueue.put_nowait((time_, b'< ' + rawMessage, None)) # PING/PONG if message.startswith(b'PING '): @@ -382,10 +407,94 @@ class IRCClientProtocol(asyncio.Protocol): user = b'~' + self.config['irc']['nick'].encode('utf-8') self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2]) + # User tracking (for NICK and QUIT) + decoded = False + if any(message.startswith(x) for x in (b'353 ', b'JOIN ', b'PART ', b'KICK ', b'NICK ', b'QUIT ')) or message == b'QUIT': + try: + if hasPrefix: + prefixStr = prefix.decode('utf-8') + messageStr = message.decode('utf-8') + except UnicodeDecodeError as e: + self.logger.warning(f'Could not decode prefix/message {prefix!r}/{message!r} ({e!s}), user tracking may be wrong') + else: + decoded = True + if message.startswith(b'353 ') and decoded: # RPL_NAMREPLY + _, channel, nicksStr = messageStr.split(' ', 2) + if nicksStr.startswith(':'): # It always should, but who knows... + nicksStr = nicksStr[1:] + nicks = nicksStr.split(' ') + for nick in nicks: + if nick[0] in ('@', '+'): + nick = nick[1:] + if self.valid_channel(channel) and self.valid_nick(nick): + self.userChannels[nick].add(channel) + if (message.startswith(b'JOIN ') or message.startswith(b'PART ')) and decoded and hasPrefix: + nick = self.prefix_to_nick(prefixStr) + channels = messageStr[5:] # Could be more than one channel in theory + for channel in channels.split(','): + if self.valid_channel(channel) and self.valid_nick(nick): + if message.startswith(b'JOIN '): + self.userChannels[nick].add(channel) + else: + self.userChannels[nick].discard(channel) + if message.startswith(b'KICK ') and decoded: # Prefix is supposed to indicate who kicked the user, but we don't care about that for the user tracking. + _, channel, nick = messageStr.split(' ', 2) + if ' ' in nick: # There might be a kick reason after the nick + nick = nick.split(' ', 1)[0] + if self.valid_channel(channel) and self.valid_nick(nick): + self.userChannels[nick].discard(channel) + if message.startswith(b'NICK '): + # If something can't be processed, just send it to storage without user tracking. + sendGeneric = True + if decoded and hasPrefix: + oldNick = self.prefix_to_nick(prefixStr) + newNick = message[5:] + if self.valid_nick(oldNick) and self.valid_nick(newNick) and oldNick in self.userChannels: + self.userChannels[newNick] = self.userChannels[oldNick] + del self.userChannels[oldNick] + if self.userChannels[newNick]: + sendGeneric = False + self.messageQueue.put_nowait((time_, rawMessage, self.userChannels[newNick])) + if sendGeneric: + self.logger.warning(f'Could not process nick change {rawMessage!r}, user tracking may be wrong') + self.messageQueue.put_nowait((time_, rawMessage, None)) + if message.startswith(b'QUIT ') or message == b'QUIT': + # Technically a simple 'QUIT' is not legal per RFC 1459. That's because there must always be a space after the command due to how is defined. + # In practice, it is accepted by ircds though, so it can presumably also be received by a client. + sendGeneric = True + if decoded and hasPrefix: + nick = self.prefix_to_nick(prefixStr) + if nick != self.config['irc']['nick'] and nick in self.userChannels: + if self.userChannels[nick]: + sendGeneric = False + self.messageQueue.put_nowait((time_, rawMessage, self.userChannels[nick])) + del self.userChannels[nick] + if not hasPrefix or (decoded and hasPrefix and nick == self.config['irc']['nick']): + # Oh no, *I* am getting disconnected! :-( + # I'm not actually sure whether the prefix version can happen, but better safe than sorry... + # In this case, it should be logged to all channels as well as the general log. The extra 'general' entry triggers Storage's code to write a message to the general log. + # Side effect: if the connection dies before any channels were joined, this causes the quit to be logged everywhere. However, there won't be a JOIN in the log, so it would still be unambiguous. + # Also, the connection loss after the disconnect triggers another message to be written to the logs. ¯\_(ツ)_/¯ + sendGeneric = False + self.messageQueue.put_nowait((time_, rawMessage, list(self.channels) + ['general'])) + if sendGeneric: + self.logger.warning(f'Could not process quit message {rawMessage!r}, user tracking may be wrong') + self.messageQueue.put_nowait((time_, rawMessage, None)) + + async def quit(self): + # It appears to be hard to impossible to send a clean quit, wait for it to be actually sent, and only then close the transport. + # This is because asyncio.sslproto.SSLTransport doesn't support explicit draining nor waiting for an empty write queue nor write_eof. + # So instead, just close the transport and wait until connection_lost is triggered (which also puts a message in the logs). + self.logger.info('Quitting') + self.transport.close() + await self.connectionClosedEvent.wait() + def connection_lost(self, exc): + time_ = time.time() self.logger.info('IRC connection lost') self.connected = False self.connectionClosedEvent.set() + self.messageQueue.put_nowait((time_, messageConnectionClosed, list(self.channels) + ['general'])) class IRCClient: @@ -425,13 +534,16 @@ class IRCClient: try: self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) try: - await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED) finally: - self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? + if not connectionClosedEvent.is_set(): + await self._protocol.quit() except (ConnectionRefusedError, asyncio.TimeoutError) as e: self.logger.error(str(e)) - await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED) if sigintEvent.is_set(): + self.logger.debug('Got SIGINT, putting EOF and breaking') + self.messageQueue.put_nowait(messageEOF) break @@ -456,28 +568,40 @@ class Storage: del self.files[channel] #TODO mkdir as required + #TODO month for channel in self.config['channels'].values(): if channel['ircchannel'] not in self.files and channel['active']: self.files[channel['ircchannel']] = open(os.path.join(self.config['storage']['path'], channel['ircchannel'], '2020-10.log'), 'ab') if None not in self.files: - self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab') #TODO Month + self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab') async def run(self, loop, sigintEvent): self.update_config(self.config) # Ensure that files are open etc. #TODO Task to rotate log files at the beginning of a new month storageTask = asyncio.create_task(self.store_messages(sigintEvent)) - flushTask = asyncio.create_task(self.flush_files(sigintEvent)) + flushTask = asyncio.create_task(self.flush_files()) await sigintEvent.wait() + self.logger.debug('Got SIGINT, waiting for remaining messages to be stored') + await storageTask # Wait until everything's stored self.active = False - #TODO Wait for tasks + self.logger.debug('Waiting for flush task') + await flushTask self.close() async def store_messages(self, sigintEvent): while self.active: - #TODO wait for sigint as well - time_, rawMessage = await self.messageQueue.get() + self.logger.debug('Waiting for message') + res = await self.messageQueue.get() + self.logger.debug(f'Got {res!r} from message queue') + if res is messageEOF: + self.logger.debug('Message EOF, breaking store_messages loop') + break + + time_, rawMessage, channels = res + if rawMessage is messageConnectionClosed: + rawMessage = b'- Connection closed' message = rawMessage[2:] # Remove leading > or < if message.startswith(b':') and b' ' in message: prefix, message = message.split(b' ', 1) @@ -492,9 +616,15 @@ class Storage: for channel in channels: self.store_message(time_, rawMessage, channel) continue - if message.startswith(b'QUIT '): - #TODO Need to keep track of users to figure out in which channels they were... Ugh - pass + if message.startswith(b'QUIT ') or message == b'QUIT' or message.startswith(b'NICK '): + # If channels is not None, IRCClientProtocol managed to track the user and identify the channels this needs to be logged to. + # If it isn't, there might be channels in there (for some odd reason?) that are not being logged. In that case, emit one and only one message to the general log as well. + if channels is not None: + for channel in channels: + self.store_message(time_, rawMessage, channel, redirectToGeneral = False) + if channels is None or any(channel not in self.files for channel in channels): + self.store_message(time_, rawMessage, None) + continue if message.startswith(b'MODE #') or message.startswith(b'MODE &') or message.startswith(b'KICK '): channel = message.split(b' ', 2)[1] channel = self.decode_channel(time_, rawMessage, channel) @@ -502,10 +632,18 @@ class Storage: continue self.store_message(time_, rawMessage, channel) continue - self.store_message(time_, rawMessage, None) + if channels is not None: + for channel in channels: + self.store_message(time_, rawMessage, channel) + else: + self.store_message(time_, rawMessage, None) - def store_message(self, time_, rawMessage, targetChannel): + def store_message(self, time_, rawMessage, targetChannel, redirectToGeneral = True): + self.logger.debug(f'Logging {rawMessage!r} at {time_} for {targetChannel!r}') if targetChannel is not None and targetChannel not in self.files: + self.logger.debug(f'Target channel {targetChannel!r} not opened, redirecting to general log is {redirectToGeneral}') + if not redirectToGeneral: + return targetChannel = None self.files[targetChannel].write(str(time_).encode('ascii') + b' ' + rawMessage + b'\r\n') @@ -519,9 +657,10 @@ class Storage: self.store_message(time_, rawMessage, None) return None - async def flush_files(self, sigintEvent): + async def flush_files(self): while self.active: - await sigintEvent.wait() + await asyncio.sleep(1) + self.logger.debug('Exiting flush_files') def close(self): for f in self.files.values(): @@ -556,7 +695,7 @@ class WebServer: await runner.setup() site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) await site.start() - await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = asyncio.FIRST_COMPLETED) await runner.cleanup() if stopEvent.is_set(): break @@ -637,6 +776,9 @@ async def main(): loop = asyncio.get_running_loop() messageQueue = asyncio.Queue() + # tuple(time: float, message: bytes or None, channels: list[str] or None) + # message = None indicates a connection loss + # channels = None indicates that IRCClientProtocol did not identify which channels are affected; it is only not None for QUIT or NICK messages. irc = IRCClient(messageQueue, config) webserver = WebServer(config) @@ -652,7 +794,7 @@ async def main(): def sigusr1_callback(): global logger - nonlocal config, irc, webserver + nonlocal config, irc, webserver, storage logger.info('Got SIGUSR1, reloading config') try: newConfig = config.reread()