Browse Source

Ensure that everything fits into IRC's line length limit or split up accordingly

master
JustAnotherArchivist 4 years ago
parent
commit
8f8e7cb0ed
1 changed files with 95 additions and 12 deletions
  1. +95
    -12
      http2irc.py

+ 95
- 12
http2irc.py View File

@@ -6,6 +6,7 @@ import collections
import concurrent.futures
import importlib.util
import inspect
import itertools
import logging
import os.path
import signal
@@ -84,8 +85,12 @@ class Config(dict):
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
raise InvalidConfig('Invalid IRC nick')
if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
raise InvalidConfig('Invalid IRC nick: NICK command too long')
if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
raise InvalidConfig('Invalid IRC realname')
if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510:
raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long')
if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']):
raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile')
if 'certfile' in obj['irc']:
@@ -137,6 +142,8 @@ class Config(dict):
raise InvalidConfig(f'Invalid map {key!r} IRC channel: does not start with # or &')
if any(x in map_['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')):
raise InvalidConfig(f'Invalid map {key!r} IRC channel: contains forbidden characters')
if 14 + len(map_['ircchannel']) > 510: # 14 = prefix 'PRIVMSG ' + suffix ' :' + at least one UTF-8 character; implicitly also covers the shorter JOIN/PART messages
raise InvalidConfig(f'Invalid map {key!r} IRC channel: too long')

if 'auth' in map_:
if map_['auth'] is not False and not isinstance(map_['auth'], str):
@@ -282,15 +289,56 @@ class IRCClientProtocol(asyncio.Protocol):
self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
self.authenticated = False

@staticmethod
def nick_command(nick: str):
return b'NICK ' + nick.encode('utf-8')

@staticmethod
def user_command(nick: str, real: str):
nickb = nick.encode('utf-8')
return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')

def connection_made(self, transport):
self.logger.info('IRC connected')
self.transport = transport
self.connected = True
nickb = self.config['irc']['nick'].encode('utf-8')
if self.sasl:
self.send(b'CAP REQ :sasl')
self.send(b'NICK ' + nickb)
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config['irc']['real'].encode('utf-8'))
self.send(self.nick_command(self.config['irc']['nick']))
self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))

def _send_join_part(self, command, channels):
'''Split a JOIN or PART into multiple messages as necessary'''
# command: b'JOIN' or b'PART'; channels: set[str]

channels = [x.encode('utf-8') for x in channels]
if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma)
# Everything fits into one command.
self.send(command + b' ' + b','.join(channels))
return

# List too long, need to split.
limit = 510 - len(command)
lengths = [1 + len(x) for x in channels] # separator + channel name
chanLengthAcceptable = [l <= limit for l in lengths]
if not all(chanLengthAcceptable):
# There are channel names that are too long to even fit into one message on their own; filter them out and warn about them.
# This should never happen since the config reader would already filter it out.
tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a]
channels = [x for x, a in zip(channels, chanLengthAcceptable) if a]
lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a]
for channel in tooLongChannels:
self.logger.warning(f'Cannot {command} {channel}: name too long')
runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators
offset = 0
while channels:
i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1)
if i == -1: # Last batch
i = len(channels)
self.send(command + b' ' + b','.join(channels[:i]))
offset = runningLengths[i-1]
channels = channels[i:]
runningLengths = runningLengths[i:]

def update_channels(self, channels: set):
channelsToPart = self.channels - channels
@@ -299,10 +347,9 @@ class IRCClientProtocol(asyncio.Protocol):

if self.connected:
if channelsToPart:
#TODO: Split if too long
self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8'))
self._send_join_part(b'PART', channelsToPart)
if channelsToJoin:
self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8'))
self._send_join_part(b'JOIN', channelsToJoin)

def send(self, data):
self.logger.debug(f'Send: {data!r}')
@@ -335,11 +382,47 @@ class IRCClientProtocol(asyncio.Protocol):
self.logger.debug(f'Got message: {message!r}')
if message is None:
break
self.logger.info(f'Sending {message!r} to {channel!r}')
#TODO Split if the message is too long.
self.unconfirmedMessages.append((channel, message))
self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8'))
await asyncio.sleep(1) # Rate limit
channelB = channel.encode('utf-8')
messageB = message.encode('utf-8')
if len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510:
self.logger.debug(f'Splitting up into smaller messages')
# Message too long, need to split. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated.
prefix = b'PRIVMSG ' + channelB + b' :'
prefixLength = len(prefix)
maxMessageLength = 510 - prefixLength # maximum length of the message part within each line
messages = []
while message:
if len(messageB) <= maxMessageLength:
messages.append(message)
break

spacePos = messageB.rfind(b' ', 0, maxMessageLength + 1)
if spacePos != -1:
messages.append(messageB[:spacePos].decode('utf-8'))
messageB = messageB[spacePos + 1:]
message = messageB.decode('utf-8')
continue

# No space found, need to search for a suitable codepoint location.
pMessage = message[:maxMessageLength] # at most 510 codepoints which expand to at least 510 bytes
pLengths = [len(x.encode('utf-8')) for x in pMessage] # byte size of each codepoint
pRunningLengths = list(itertools.accumulate(pLengths)) # byte size up to each codepoint
if pRunningLengths[-1] <= maxMessageLength: # Special case: entire pMessage is short enough
messages.append(pMessage)
message = message[maxMessageLength:]
messageB = message.encode('utf-8')
continue
cutoffIndex = next(x[0] for x in enumerate(pRunningLengths) if x[1] > maxMessageLength)
messages.append(message[:cutoffIndex])
message = message[cutoffIndex:]
messageB = message.encode('utf-8')
for msg in reversed(messages):
self.messageQueue.putleft_nowait((channel, msg))
else:
self.logger.info(f'Sending {message!r} to {channel!r}')
self.unconfirmedMessages.append((channel, message))
self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8'))
await asyncio.sleep(1) # Rate limit

async def confirm_messages(self):
while self.connected:
@@ -438,7 +521,7 @@ class IRCClientProtocol(asyncio.Protocol):
self.logger.error('IRC connection registered but not authenticated, terminating connection')
self.transport.close()
return
self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long
self._send_join_part(b'JOIN', self.channels)
asyncio.create_task(self.send_messages())
asyncio.create_task(self.confirm_messages())



Loading…
Cancel
Save