diff --git a/http2irc.py b/http2irc.py index a3e7ef0..f36d179 100644 --- a/http2irc.py +++ b/http2irc.py @@ -6,6 +6,7 @@ import collections import concurrent.futures import importlib.util import inspect +import itertools import logging import os.path import signal @@ -84,8 +85,12 @@ class Config(dict): raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname raise InvalidConfig('Invalid IRC nick') + if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510: + raise InvalidConfig('Invalid IRC nick: NICK command too long') if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str): raise InvalidConfig('Invalid IRC realname') + if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510: + raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long') if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']): raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile') if 'certfile' in obj['irc']: @@ -137,6 +142,8 @@ class Config(dict): raise InvalidConfig(f'Invalid map {key!r} IRC channel: does not start with # or &') if any(x in map_['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')): raise InvalidConfig(f'Invalid map {key!r} IRC channel: contains forbidden characters') + if 14 + len(map_['ircchannel']) > 510: # 14 = prefix 'PRIVMSG ' + suffix ' :' + at least one UTF-8 character; implicitly also covers the shorter JOIN/PART messages + raise InvalidConfig(f'Invalid map {key!r} IRC channel: too long') if 'auth' in map_: if map_['auth'] is not False and not isinstance(map_['auth'], str): @@ -282,15 +289,56 @@ class IRCClientProtocol(asyncio.Protocol): self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile']) self.authenticated = False + @staticmethod + def nick_command(nick: str): + return b'NICK ' + nick.encode('utf-8') + + @staticmethod + def user_command(nick: str, real: str): + nickb = nick.encode('utf-8') + return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8') + def connection_made(self, transport): self.logger.info('IRC connected') self.transport = transport self.connected = True - nickb = self.config['irc']['nick'].encode('utf-8') if self.sasl: self.send(b'CAP REQ :sasl') - self.send(b'NICK ' + nickb) - self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config['irc']['real'].encode('utf-8')) + self.send(self.nick_command(self.config['irc']['nick'])) + self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real'])) + + def _send_join_part(self, command, channels): + '''Split a JOIN or PART into multiple messages as necessary''' + # command: b'JOIN' or b'PART'; channels: set[str] + + channels = [x.encode('utf-8') for x in channels] + if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma) + # Everything fits into one command. + self.send(command + b' ' + b','.join(channels)) + return + + # List too long, need to split. + limit = 510 - len(command) + lengths = [1 + len(x) for x in channels] # separator + channel name + chanLengthAcceptable = [l <= limit for l in lengths] + if not all(chanLengthAcceptable): + # There are channel names that are too long to even fit into one message on their own; filter them out and warn about them. + # This should never happen since the config reader would already filter it out. + tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a] + channels = [x for x, a in zip(channels, chanLengthAcceptable) if a] + lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a] + for channel in tooLongChannels: + self.logger.warning(f'Cannot {command} {channel}: name too long') + runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators + offset = 0 + while channels: + i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1) + if i == -1: # Last batch + i = len(channels) + self.send(command + b' ' + b','.join(channels[:i])) + offset = runningLengths[i-1] + channels = channels[i:] + runningLengths = runningLengths[i:] def update_channels(self, channels: set): channelsToPart = self.channels - channels @@ -299,10 +347,9 @@ class IRCClientProtocol(asyncio.Protocol): if self.connected: if channelsToPart: - #TODO: Split if too long - self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8')) + self._send_join_part(b'PART', channelsToPart) if channelsToJoin: - self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8')) + self._send_join_part(b'JOIN', channelsToJoin) def send(self, data): self.logger.debug(f'Send: {data!r}') @@ -335,11 +382,47 @@ class IRCClientProtocol(asyncio.Protocol): self.logger.debug(f'Got message: {message!r}') if message is None: break - self.logger.info(f'Sending {message!r} to {channel!r}') - #TODO Split if the message is too long. - self.unconfirmedMessages.append((channel, message)) - self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8')) - await asyncio.sleep(1) # Rate limit + channelB = channel.encode('utf-8') + messageB = message.encode('utf-8') + if len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510: + self.logger.debug(f'Splitting up into smaller messages') + # Message too long, need to split. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated. + prefix = b'PRIVMSG ' + channelB + b' :' + prefixLength = len(prefix) + maxMessageLength = 510 - prefixLength # maximum length of the message part within each line + messages = [] + while message: + if len(messageB) <= maxMessageLength: + messages.append(message) + break + + spacePos = messageB.rfind(b' ', 0, maxMessageLength + 1) + if spacePos != -1: + messages.append(messageB[:spacePos].decode('utf-8')) + messageB = messageB[spacePos + 1:] + message = messageB.decode('utf-8') + continue + + # No space found, need to search for a suitable codepoint location. + pMessage = message[:maxMessageLength] # at most 510 codepoints which expand to at least 510 bytes + pLengths = [len(x.encode('utf-8')) for x in pMessage] # byte size of each codepoint + pRunningLengths = list(itertools.accumulate(pLengths)) # byte size up to each codepoint + if pRunningLengths[-1] <= maxMessageLength: # Special case: entire pMessage is short enough + messages.append(pMessage) + message = message[maxMessageLength:] + messageB = message.encode('utf-8') + continue + cutoffIndex = next(x[0] for x in enumerate(pRunningLengths) if x[1] > maxMessageLength) + messages.append(message[:cutoffIndex]) + message = message[cutoffIndex:] + messageB = message.encode('utf-8') + for msg in reversed(messages): + self.messageQueue.putleft_nowait((channel, msg)) + else: + self.logger.info(f'Sending {message!r} to {channel!r}') + self.unconfirmedMessages.append((channel, message)) + self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8')) + await asyncio.sleep(1) # Rate limit async def confirm_messages(self): while self.connected: @@ -438,7 +521,7 @@ class IRCClientProtocol(asyncio.Protocol): self.logger.error('IRC connection registered but not authenticated, terminating connection') self.transport.close() return - self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long + self._send_join_part(b'JOIN', self.channels) asyncio.create_task(self.send_messages()) asyncio.create_task(self.confirm_messages())