diff --git a/http2irc.py b/http2irc.py index 36022b1..c71e082 100644 --- a/http2irc.py +++ b/http2irc.py @@ -3,13 +3,17 @@ import aiohttp.web import asyncio import base64 import collections -import concurrent.futures +import functools import importlib.util import inspect +import ircstates +import irctokens import itertools +import json import logging import os.path import signal +import socket import ssl import string import sys @@ -53,14 +57,20 @@ async def wait_cancel_pending(aws, paws = None, **kwargs): if paws is None: paws = set() tasks = aws | paws + logger.debug(f'waiting for {tasks!r}') done, pending = await asyncio.wait(tasks, **kwargs) + logger.debug(f'done waiting for {tasks!r}; cancelling pending non-persistent tasks: {pending!r}') for task in pending: if task not in paws: + logger.debug(f'cancelling {task!r}') task.cancel() + logger.debug(f'awaiting cancellation of {task!r}') try: await task except asyncio.CancelledError: pass + logger.debug(f'done cancelling {task!r}') + logger.debug(f'done wait_cancel_pending {tasks!r}') return done, pending @@ -92,7 +102,7 @@ class Config(dict): except (ValueError, AssertionError) as e: raise InvalidConfig('Invalid log format: parsing failed') from e if 'irc' in obj: - if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']): + if any(x not in ('host', 'port', 'ssl', 'family', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']): raise InvalidConfig('Unknown key found in irc section') if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname raise InvalidConfig('Invalid IRC host') @@ -100,6 +110,10 @@ class Config(dict): raise InvalidConfig('Invalid IRC port') if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') + if 'family' in obj['irc']: + if obj['irc']['family'] not in ('inet', 'INET', 'inet6', 'INET6'): + raise InvalidConfig('Invalid IRC family') + obj['irc']['family'] = getattr(socket, f'AF_{obj["irc"]["family"].upper()}') if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname raise InvalidConfig('Invalid IRC nick') if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510: @@ -192,7 +206,12 @@ class Config(dict): raise InvalidConfig(f'Invalid map {key!r} overlongmode: unsupported value') # Default values - finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}} + finalObj = { + 'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, + 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'family': 0, 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, + 'web': {'host': '127.0.0.1', 'port': 8080}, + 'maps': {} + } # Fill in default values for the maps for key, map_ in obj['maps'].items(): @@ -253,7 +272,7 @@ class Config(dict): class MessageQueue: - # An object holding onto the messages received from nodeping + # An object holding onto the messages received over HTTP for sending to IRC # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. # Differences to asyncio.Queue include: @@ -310,12 +329,14 @@ class MessageQueue: class IRCClientProtocol(asyncio.Protocol): logger = logging.getLogger('http2irc.IRCClientProtocol') - def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels): - self.messageQueue = messageQueue + def __init__(self, http2ircMessageQueue, connectionClosedEvent, loop, config, channels): + self.http2ircMessageQueue = http2ircMessageQueue self.connectionClosedEvent = connectionClosedEvent self.loop = loop self.config = config self.lastRecvTime = None + self.lastSentTime = None # float timestamp or None; the latter disables the send rate limit + self.sendQueue = asyncio.Queue() self.buffer = b'' self.connected = False self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) @@ -323,7 +344,14 @@ class IRCClientProtocol(asyncio.Protocol): self.pongReceivedEvent = asyncio.Event() self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile']) self.authenticated = False - self.usermask = None + self.server = ircstates.Server(self.config['irc']['host']) + self.capReqsPending = set() # Capabilities requested from the server but not yet ACKd or NAKd + self.caps = set() # Capabilities acknowledged by the server + self.whoxQueue = collections.deque() # Names of channels that were joined successfully but for which no WHO (WHOX) query was sent yet + self.whoxChannel = None # Name of channel for which a WHO query is currently running + self.whoxReply = [] # List of (nickname, account) tuples from the currently running WHO query + self.whoxStartTime = None + self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...} @staticmethod def nick_command(nick: str): @@ -334,17 +362,16 @@ class IRCClientProtocol(asyncio.Protocol): nickb = nick.encode('utf-8') return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8') - def _maybe_set_usermask(self, usermask): - if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')): - self.usermask = usermask - self.logger.debug(f'Usermask is now {usermask!r}') - def connection_made(self, transport): self.logger.info('IRC connected') self.transport = transport self.connected = True + caps = [b'multi-prefix', b'userhost-in-names', b'away-notify', b'account-notify', b'extended-join'] if self.sasl: - self.send(b'CAP REQ :sasl') + caps.append(b'sasl') + for cap in caps: + self.capReqsPending.add(cap.decode('ascii')) + self.send(b'CAP REQ :' + cap) self.send(self.nick_command(self.config['irc']['nick'])) self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real'])) @@ -393,15 +420,41 @@ class IRCClientProtocol(asyncio.Protocol): self._send_join_part(b'JOIN', channelsToJoin) def send(self, data): - self.logger.debug(f'Send: {data!r}') + self.logger.debug(f'Queueing for send: {data!r}') if len(data) > 510: raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}') + self.sendQueue.put_nowait(data) + + def _direct_send(self, data): + self.logger.debug(f'Send: {data!r}') + time_ = time.time() self.transport.write(data + b'\r\n') + return time_ + + async def send_queue(self): + while True: + self.logger.debug('Trying to get data from send queue') + t = asyncio.create_task(self.sendQueue.get()) + done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) + if self.connectionClosedEvent.is_set(): + break + assert t in done, f'{t!r} is not in {done!r}' + data = t.result() + self.logger.debug(f'Got {data!r} from send queue') + now = time.time() + if self.lastSentTime is not None and now - self.lastSentTime < 1: + self.logger.debug(f'Rate limited') + await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now) + if self.connectionClosedEvent.is_set(): + break + time_ = self._direct_send(data) + if self.lastSentTime is not None: + self.lastSentTime = time_ async def _get_message(self): - self.logger.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') - messageFuture = asyncio.create_task(self.messageQueue.get()) - done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = concurrent.futures.FIRST_COMPLETED) + self.logger.debug(f'Message queue {id(self.http2ircMessageQueue)} length: {self.http2ircMessageQueue.qsize()}') + messageFuture = asyncio.create_task(self.http2ircMessageQueue.get()) + done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = asyncio.FIRST_COMPLETED) if self.connectionClosedEvent.is_set(): if messageFuture in pending: self.logger.debug('Cancelling messageFuture') @@ -413,11 +466,16 @@ class IRCClientProtocol(asyncio.Protocol): pass else: # messageFuture is already done but we're stopping, so put the result back onto the queue - self.messageQueue.putleft_nowait(messageFuture.result()) - return None, None + self.http2ircMessageQueue.putleft_nowait(messageFuture.result()) + return None, None, None assert messageFuture in done, 'Invalid state: messageFuture not in done futures' return messageFuture.result() + def _self_usermask_length(self): + if not self.server.nickname or not self.server.username or not self.server.hostname: + return 100 + return len(self.server.nickname) + len(self.server.username) + len(self.server.hostname) + async def send_messages(self): while self.connected: self.logger.debug(f'Trying to get a message') @@ -427,7 +485,7 @@ class IRCClientProtocol(asyncio.Protocol): break channelB = channel.encode('utf-8') messageB = message.encode('utf-8') - usermaskPrefixLength = 1 + (len(self.usermask) if self.usermask else 100) + 1 + usermaskPrefixLength = 1 + self._self_usermask_length() + 1 if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510: # Message too long, need to split or truncate. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated. self.logger.debug(f'Message too long, overlongmode = {overlongmode}') @@ -466,20 +524,19 @@ class IRCClientProtocol(asyncio.Protocol): messageB = message.encode('utf-8') if overlongmode == 'split': for msg in reversed(messages): - self.messageQueue.putleft_nowait((channel, msg, overlongmode)) + self.http2ircMessageQueue.putleft_nowait((channel, msg, overlongmode)) elif overlongmode == 'truncate': - self.messageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode)) + self.http2ircMessageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode)) else: self.logger.info(f'Sending {message!r} to {channel!r}') self.unconfirmedMessages.append((channel, message, overlongmode)) self.send(b'PRIVMSG ' + channelB + b' :' + messageB) - await asyncio.sleep(1) # Rate limit async def confirm_messages(self): while self.connected: - await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 60) # Confirm once per minute + await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 60) # Confirm once per minute if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly - self.messageQueue.putleft_nowait(*self.unconfirmedMessages) + self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages) self.unconfirmedMessages = [] break if not self.unconfirmedMessages: @@ -488,18 +545,19 @@ class IRCClientProtocol(asyncio.Protocol): self.logger.debug('Trying to confirm message delivery') self.pongReceivedEvent.clear() self.send(b'PING :42') - await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 5) + await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 5) self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}') if not self.pongReceivedEvent.is_set(): # No PONG received in five seconds, assume connection's dead self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue') - self.messageQueue.putleft_nowait(*self.unconfirmedMessages) + self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages) self.transport.close() self.unconfirmedMessages = [] def data_received(self, data): + time_ = time.time() self.logger.debug(f'Data received: {data!r}') - self.lastRecvTime = time.time() + self.lastRecvTime = time_ # If there's any data left in the buffer, prepend it to the data. Split on CRLF. # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. # If data does end with CRLF, all messages will have been processed and the buffer will be empty again. @@ -507,104 +565,146 @@ class IRCClientProtocol(asyncio.Protocol): data = self.buffer + data messages = data.split(b'\r\n') for message in messages[:-1]: - self.message_received(message) + lines = self.server.recv(message + b'\r\n') + assert len(lines) == 1, f'recv did not return exactly one line: {message!r} -> {lines!r}' + self.message_received(time_, message, lines[0]) + self.server.parse_tokens(lines[0]) self.buffer = messages[-1] - def message_received(self, message): - self.logger.debug(f'Message received: {message!r}') - rawMessage = message - if message.startswith(b':') and b' ' in message: - # Prefixed message, extract command + parameters (the prefix cannot contain a space) - message = message.split(b' ', 1)[1] + def message_received(self, time_, message, line): + self.logger.debug(f'Message received at {time_}: {message!r}') + + maybeTriggerWhox = False # PING/PONG - if message.startswith(b'PING '): - self.send(b'PONG ' + message[5:]) - elif message.startswith(b'PONG '): + if line.command == 'PING': + self._direct_send(irctokens.build('PONG', line.params).format().encode('utf-8')) + elif line.command == 'PONG': self.pongReceivedEvent.set() - # SASL - elif message.startswith(b'CAP ') and self.sasl: - if message[message.find(b' ', 4) + 1:] == b'ACK :sasl': - self.send(b'AUTHENTICATE EXTERNAL') - else: - self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection') - self.transport.close() - elif message == b'AUTHENTICATE +': + # IRCv3 and SASL + elif line.command == 'CAP': + if line.params[1] == 'ACK': + for cap in line.params[2].split(' '): + self.logger.debug(f'CAP ACK: {cap}') + self.caps.add(cap) + if cap == 'sasl' and self.sasl: + self.send(b'AUTHENTICATE EXTERNAL') + else: + self.capReqsPending.remove(cap) + elif line.params[1] == 'NAK': + self.logger.warning(f'Failed to activate CAP(s): {line.params[2]}') + for cap in line.params[2].split(' '): + self.capReqsPending.remove(cap) + if len(self.capReqsPending) == 0: + self.send(b'CAP END') + elif line.command == 'AUTHENTICATE' and line.params == ['+']: self.send(b'AUTHENTICATE +') - elif message.startswith(b'900 '): # "You are now logged in", includes the usermask - words = message.split(b' ') - if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]: - if b'!~' not in words[2]: - # At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde. - words[2] = words[2].replace(b'!', b'!~', 1) - self._maybe_set_usermask(words[2]) - elif message.startswith(b'903 '): # SASL auth successful + elif line.command == ircstates.numerics.RPL_SASLSUCCESS: self.authenticated = True - self.send(b'CAP END') - elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')): + self.capReqsPending.remove('sasl') + if len(self.capReqsPending) == 0: + self.send(b'CAP END') + elif line.command in ('902', ircstates.numerics.ERR_SASLFAIL, ircstates.numerics.ERR_SASLTOOLONG, ircstates.numerics.ERR_SASLABORTED, ircstates.numerics.RPL_SASLMECHS): self.logger.error('SASL error, terminating connection') self.transport.close() # NICK errors - elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ')): + elif line.command in ('431', ircstates.numerics.ERR_ERRONEUSNICKNAME, ircstates.numerics.ERR_NICKNAMEINUSE, '436'): self.logger.error(f'Failed to set nickname: {message!r}, terminating connection') self.transport.close() # USER errors - elif any(message.startswith(x) for x in (b'461 ', b'462 ')): + elif line.command in ('461', '462'): self.logger.error(f'Failed to register: {message!r}, terminating connection') self.transport.close() # JOIN errors - elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')): + elif line.command in ( + ircstates.numerics.ERR_TOOMANYCHANNELS, + ircstates.numerics.ERR_CHANNELISFULL, + ircstates.numerics.ERR_INVITEONLYCHAN, + ircstates.numerics.ERR_BANNEDFROMCHAN, + ircstates.numerics.ERR_BADCHANNELKEY, + ): self.logger.error(f'Failed to join channel: {message!r}, terminating connection') self.transport.close() # PART errors - elif message.startswith(b'442 '): + elif line.command == '442': self.logger.error(f'Failed to part channel: {message!r}') # JOIN/PART errors - elif message.startswith(b'403 '): + elif line.command == ircstates.numerics.ERR_NOSUCHCHANNEL: self.logger.error(f'Failed to join or part channel: {message!r}') # PRIVMSG errors - elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ')): + elif line.command in (ircstates.numerics.ERR_NOSUCHNICK, '404', '407', '411', '412', '413', '414'): self.logger.error(f'Failed to send message: {message!r}') # Connection registration reply - elif message.startswith(b'001 '): + elif line.command == ircstates.numerics.RPL_WELCOME: self.logger.info('IRC connection registered') if self.sasl and not self.authenticated: self.logger.error('IRC connection registered but not authenticated, terminating connection') self.transport.close() return + self.lastSentTime = time.time() self._send_join_part(b'JOIN', self.channels) asyncio.create_task(self.send_messages()) asyncio.create_task(self.confirm_messages()) - # JOIN success - elif message.startswith(b'JOIN ') and not self.usermask: - # If this is my own join message, it should contain the usermask in the prefix - if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage: - usermask = rawMessage.split(b' ', 1)[0][1:] - self._maybe_set_usermask(usermask) - - # Services host change - elif message.startswith(b'396 '): - words = message.split(b' ') - if len(words) >= 3: - # Sanity check inspired by irssi src/irc/core/irc-servers.c - if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-': - if b'@' in words[2]: # user@host - self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2]) - else: # host (get user from previous mask or settings) - if self.usermask: - user = self.usermask.split(b'@')[0].split(b'!')[1] - else: - user = b'~' + self.config['irc']['nick'].encode('utf-8') - self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2]) + # Bot getting KICKed + elif line.command == 'KICK' and line.source and self.server.casefold(line.params[1]) == self.server.casefold(self.server.nickname): + self.logger.warning(f'Got kicked from {line.params[0]}') + kickedChannel = self.server.casefold(line.params[0]) + for channel in self.channels: + if self.server.casefold(channel) == kickedChannel: + self.channels.remove(channel) + break + + # WHOX on successful JOIN if supported to fetch account information + elif line.command == 'JOIN' and self.server.isupport.whox and line.source and self.server.casefold(line.hostmask.nickname) == self.server.casefold(self.server.nickname): + self.whoxQueue.extend(line.params[0].split(',')) + maybeTriggerWhox = True + + # WHOX response + elif line.command == ircstates.numerics.RPL_WHOSPCRPL and line.params[1] == '042': + self.whoxReply.append({'nick': line.params[4], 'hostmask': f'{line.params[4]}!{line.params[2]}@{line.params[3]}', 'account': line.params[5] if line.params[5] != '0' else None}) + + # End of WHOX response + elif line.command == ircstates.numerics.RPL_ENDOFWHO: + # Patch ircstates account info; ircstates does not parse the WHOX reply itself. + for entry in self.whoxReply: + if entry['account']: + self.server.users[self.server.casefold(entry['nick'])].account = entry['account'] + self.whoxChannel = None + self.whoxReply = [] + self.whoxStartTime = None + maybeTriggerWhox = True + + # General fatal ERROR + elif line.command == 'ERROR': + self.logger.error(f'Server sent ERROR: {message!r}') + self.transport.close() + + # Send next WHOX if appropriate + if maybeTriggerWhox and self.whoxChannel is None and self.whoxQueue: + self.whoxChannel = self.whoxQueue.popleft() + self.whoxReply = [] + self.whoxStartTime = time.time() # Note, may not be the actual start time due to rate limiting + self.send(b'WHO ' + self.whoxChannel.encode('utf-8') + b' c%tuhna,042') + + async def quit(self): + # The server acknowledges a QUIT by sending an ERROR and closing the connection. The latter triggers connection_lost, so just wait for the closure event. + self.logger.info('Quitting') + self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue + self._direct_send(b'QUIT :Bye') + await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10) + if not self.connectionClosedEvent.is_set(): + self.logger.error('Quitting cleanly did not work, closing connection forcefully') + # Event will be set implicitly in connection_lost. + self.transport.close() def connection_lost(self, exc): self.logger.info('IRC connection lost') @@ -615,8 +715,8 @@ class IRCClientProtocol(asyncio.Protocol): class IRCClient: logger = logging.getLogger('http2irc.IRCClient') - def __init__(self, messageQueue, config): - self.messageQueue = messageQueue + def __init__(self, http2ircMessageQueue, config): + self.http2ircMessageQueue = http2ircMessageQueue self.config = config self.channels = {map_['ircchannel'] for map_ in config['maps'].values()} @@ -647,17 +747,43 @@ class IRCClient: while True: connectionClosedEvent.clear() try: - self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) + self.logger.debug('Creating IRC connection') + t = asyncio.create_task(loop.create_connection( + protocol_factory = lambda: IRCClientProtocol(self.http2ircMessageQueue, connectionClosedEvent, loop, self.config, self.channels), + host = self.config['irc']['host'], + port = self.config['irc']['port'], + ssl = self._get_ssl_context(), + family = self.config['irc']['family'], + )) + # No automatic cancellation of t because it's handled manually below. + done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30) + if t not in done: + t.cancel() + await t # Raises the CancelledError + self._transport, self._protocol = t.result() + self.logger.debug('Starting send queue processing') + sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent + self.logger.debug('Waiting for connection closure or SIGINT') try: - await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) finally: - self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? + self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}') + if not connectionClosedEvent.is_set(): + self.logger.debug('Quitting connection') + await self._protocol.quit() + if not sendTask.done(): + sendTask.cancel() + try: + await sendTask + except asyncio.CancelledError: + pass self._transport = None self._protocol = None - except (ConnectionRefusedError, asyncio.TimeoutError) as e: - self.logger.error(str(e)) + except (ConnectionError, ssl.SSLError, asyncio.TimeoutError, asyncio.CancelledError) as e: + self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}') await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5) if sigintEvent.is_set(): + self.logger.debug('Got SIGINT, breaking IRC loop') break @property @@ -668,8 +794,8 @@ class IRCClient: class WebServer: logger = logging.getLogger('http2irc.WebServer') - def __init__(self, messageQueue, ircClient, config): - self.messageQueue = messageQueue + def __init__(self, http2ircMessageQueue, ircClient, config): + self.http2ircMessageQueue = http2ircMessageQueue self.ircClient = ircClient self.config = config @@ -697,7 +823,7 @@ class WebServer: await runner.setup() site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) await site.start() - await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = concurrent.futures.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED) await runner.cleanup() if stopEvent.is_set(): break @@ -735,7 +861,7 @@ class WebServer: self.logger.debug(f'Processing request {id(request)} using default processor') message = await self._default_process(request) self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue') - self.messageQueue.put_nowait((channel, message, overlongmode)) + self.http2ircMessageQueue.put_nowait((channel, message, overlongmode)) raise aiohttp.web.HTTPOk() async def _default_process(self, request): @@ -777,10 +903,10 @@ async def main(): loop = asyncio.get_running_loop() - messageQueue = MessageQueue() + http2ircMessageQueue = MessageQueue() - irc = IRCClient(messageQueue, config) - webserver = WebServer(messageQueue, irc, config) + irc = IRCClient(http2ircMessageQueue, config) + webserver = WebServer(http2ircMessageQueue, irc, config) sigintEvent = asyncio.Event() def sigint_callback():