Преглед изворни кода

Merge changes from irclog

Port to ircstates/irctokens, more capabilities, IRC family config, fix various small bugs
master
JustAnotherArchivist пре 2 година
родитељ
комит
bdb396caff
1 измењених фајлова са 224 додато и 98 уклоњено
  1. +224
    -98
      http2irc.py

+ 224
- 98
http2irc.py Прегледај датотеку

@@ -3,13 +3,17 @@ import aiohttp.web
import asyncio import asyncio
import base64 import base64
import collections import collections
import concurrent.futures
import functools
import importlib.util import importlib.util
import inspect import inspect
import ircstates
import irctokens
import itertools import itertools
import json
import logging import logging
import os.path import os.path
import signal import signal
import socket
import ssl import ssl
import string import string
import sys import sys
@@ -53,14 +57,20 @@ async def wait_cancel_pending(aws, paws = None, **kwargs):
if paws is None: if paws is None:
paws = set() paws = set()
tasks = aws | paws tasks = aws | paws
logger.debug(f'waiting for {tasks!r}')
done, pending = await asyncio.wait(tasks, **kwargs) done, pending = await asyncio.wait(tasks, **kwargs)
logger.debug(f'done waiting for {tasks!r}; cancelling pending non-persistent tasks: {pending!r}')
for task in pending: for task in pending:
if task not in paws: if task not in paws:
logger.debug(f'cancelling {task!r}')
task.cancel() task.cancel()
logger.debug(f'awaiting cancellation of {task!r}')
try: try:
await task await task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
logger.debug(f'done cancelling {task!r}')
logger.debug(f'done wait_cancel_pending {tasks!r}')
return done, pending return done, pending




@@ -92,7 +102,7 @@ class Config(dict):
except (ValueError, AssertionError) as e: except (ValueError, AssertionError) as e:
raise InvalidConfig('Invalid log format: parsing failed') from e raise InvalidConfig('Invalid log format: parsing failed') from e
if 'irc' in obj: if 'irc' in obj:
if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
if any(x not in ('host', 'port', 'ssl', 'family', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
raise InvalidConfig('Unknown key found in irc section') raise InvalidConfig('Unknown key found in irc section')
if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
raise InvalidConfig('Invalid IRC host') raise InvalidConfig('Invalid IRC host')
@@ -100,6 +110,10 @@ class Config(dict):
raise InvalidConfig('Invalid IRC port') raise InvalidConfig('Invalid IRC port')
if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
if 'family' in obj['irc']:
if obj['irc']['family'] not in ('inet', 'INET', 'inet6', 'INET6'):
raise InvalidConfig('Invalid IRC family')
obj['irc']['family'] = getattr(socket, f'AF_{obj["irc"]["family"].upper()}')
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
raise InvalidConfig('Invalid IRC nick') raise InvalidConfig('Invalid IRC nick')
if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510: if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
@@ -192,7 +206,12 @@ class Config(dict):
raise InvalidConfig(f'Invalid map {key!r} overlongmode: unsupported value') raise InvalidConfig(f'Invalid map {key!r} overlongmode: unsupported value')


# Default values # Default values
finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
finalObj = {
'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'},
'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'family': 0, 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None},
'web': {'host': '127.0.0.1', 'port': 8080},
'maps': {}
}


# Fill in default values for the maps # Fill in default values for the maps
for key, map_ in obj['maps'].items(): for key, map_ in obj['maps'].items():
@@ -253,7 +272,7 @@ class Config(dict):




class MessageQueue: class MessageQueue:
# An object holding onto the messages received from nodeping
# An object holding onto the messages received over HTTP for sending to IRC
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
# Differences to asyncio.Queue include: # Differences to asyncio.Queue include:
@@ -310,12 +329,14 @@ class MessageQueue:
class IRCClientProtocol(asyncio.Protocol): class IRCClientProtocol(asyncio.Protocol):
logger = logging.getLogger('http2irc.IRCClientProtocol') logger = logging.getLogger('http2irc.IRCClientProtocol')


def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
self.messageQueue = messageQueue
def __init__(self, http2ircMessageQueue, connectionClosedEvent, loop, config, channels):
self.http2ircMessageQueue = http2ircMessageQueue
self.connectionClosedEvent = connectionClosedEvent self.connectionClosedEvent = connectionClosedEvent
self.loop = loop self.loop = loop
self.config = config self.config = config
self.lastRecvTime = None self.lastRecvTime = None
self.lastSentTime = None # float timestamp or None; the latter disables the send rate limit
self.sendQueue = asyncio.Queue()
self.buffer = b'' self.buffer = b''
self.connected = False self.connected = False
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
@@ -323,7 +344,14 @@ class IRCClientProtocol(asyncio.Protocol):
self.pongReceivedEvent = asyncio.Event() self.pongReceivedEvent = asyncio.Event()
self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile']) self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
self.authenticated = False self.authenticated = False
self.usermask = None
self.server = ircstates.Server(self.config['irc']['host'])
self.capReqsPending = set() # Capabilities requested from the server but not yet ACKd or NAKd
self.caps = set() # Capabilities acknowledged by the server
self.whoxQueue = collections.deque() # Names of channels that were joined successfully but for which no WHO (WHOX) query was sent yet
self.whoxChannel = None # Name of channel for which a WHO query is currently running
self.whoxReply = [] # List of (nickname, account) tuples from the currently running WHO query
self.whoxStartTime = None
self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...}


@staticmethod @staticmethod
def nick_command(nick: str): def nick_command(nick: str):
@@ -334,17 +362,16 @@ class IRCClientProtocol(asyncio.Protocol):
nickb = nick.encode('utf-8') nickb = nick.encode('utf-8')
return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8') return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')


def _maybe_set_usermask(self, usermask):
if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')):
self.usermask = usermask
self.logger.debug(f'Usermask is now {usermask!r}')

def connection_made(self, transport): def connection_made(self, transport):
self.logger.info('IRC connected') self.logger.info('IRC connected')
self.transport = transport self.transport = transport
self.connected = True self.connected = True
caps = [b'multi-prefix', b'userhost-in-names', b'away-notify', b'account-notify', b'extended-join']
if self.sasl: if self.sasl:
self.send(b'CAP REQ :sasl')
caps.append(b'sasl')
for cap in caps:
self.capReqsPending.add(cap.decode('ascii'))
self.send(b'CAP REQ :' + cap)
self.send(self.nick_command(self.config['irc']['nick'])) self.send(self.nick_command(self.config['irc']['nick']))
self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real'])) self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))


@@ -393,15 +420,41 @@ class IRCClientProtocol(asyncio.Protocol):
self._send_join_part(b'JOIN', channelsToJoin) self._send_join_part(b'JOIN', channelsToJoin)


def send(self, data): def send(self, data):
self.logger.debug(f'Send: {data!r}')
self.logger.debug(f'Queueing for send: {data!r}')
if len(data) > 510: if len(data) > 510:
raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}') raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
self.sendQueue.put_nowait(data)

def _direct_send(self, data):
self.logger.debug(f'Send: {data!r}')
time_ = time.time()
self.transport.write(data + b'\r\n') self.transport.write(data + b'\r\n')
return time_

async def send_queue(self):
while True:
self.logger.debug('Trying to get data from send queue')
t = asyncio.create_task(self.sendQueue.get())
done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
if self.connectionClosedEvent.is_set():
break
assert t in done, f'{t!r} is not in {done!r}'
data = t.result()
self.logger.debug(f'Got {data!r} from send queue')
now = time.time()
if self.lastSentTime is not None and now - self.lastSentTime < 1:
self.logger.debug(f'Rate limited')
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now)
if self.connectionClosedEvent.is_set():
break
time_ = self._direct_send(data)
if self.lastSentTime is not None:
self.lastSentTime = time_


async def _get_message(self): async def _get_message(self):
self.logger.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
messageFuture = asyncio.create_task(self.messageQueue.get())
done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = concurrent.futures.FIRST_COMPLETED)
self.logger.debug(f'Message queue {id(self.http2ircMessageQueue)} length: {self.http2ircMessageQueue.qsize()}')
messageFuture = asyncio.create_task(self.http2ircMessageQueue.get())
done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = asyncio.FIRST_COMPLETED)
if self.connectionClosedEvent.is_set(): if self.connectionClosedEvent.is_set():
if messageFuture in pending: if messageFuture in pending:
self.logger.debug('Cancelling messageFuture') self.logger.debug('Cancelling messageFuture')
@@ -413,11 +466,16 @@ class IRCClientProtocol(asyncio.Protocol):
pass pass
else: else:
# messageFuture is already done but we're stopping, so put the result back onto the queue # messageFuture is already done but we're stopping, so put the result back onto the queue
self.messageQueue.putleft_nowait(messageFuture.result())
return None, None
self.http2ircMessageQueue.putleft_nowait(messageFuture.result())
return None, None, None
assert messageFuture in done, 'Invalid state: messageFuture not in done futures' assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
return messageFuture.result() return messageFuture.result()


def _self_usermask_length(self):
if not self.server.nickname or not self.server.username or not self.server.hostname:
return 100
return len(self.server.nickname) + len(self.server.username) + len(self.server.hostname)

async def send_messages(self): async def send_messages(self):
while self.connected: while self.connected:
self.logger.debug(f'Trying to get a message') self.logger.debug(f'Trying to get a message')
@@ -427,7 +485,7 @@ class IRCClientProtocol(asyncio.Protocol):
break break
channelB = channel.encode('utf-8') channelB = channel.encode('utf-8')
messageB = message.encode('utf-8') messageB = message.encode('utf-8')
usermaskPrefixLength = 1 + (len(self.usermask) if self.usermask else 100) + 1
usermaskPrefixLength = 1 + self._self_usermask_length() + 1
if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510: if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510:
# Message too long, need to split or truncate. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated. # Message too long, need to split or truncate. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated.
self.logger.debug(f'Message too long, overlongmode = {overlongmode}') self.logger.debug(f'Message too long, overlongmode = {overlongmode}')
@@ -466,20 +524,19 @@ class IRCClientProtocol(asyncio.Protocol):
messageB = message.encode('utf-8') messageB = message.encode('utf-8')
if overlongmode == 'split': if overlongmode == 'split':
for msg in reversed(messages): for msg in reversed(messages):
self.messageQueue.putleft_nowait((channel, msg, overlongmode))
self.http2ircMessageQueue.putleft_nowait((channel, msg, overlongmode))
elif overlongmode == 'truncate': elif overlongmode == 'truncate':
self.messageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
self.http2ircMessageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
else: else:
self.logger.info(f'Sending {message!r} to {channel!r}') self.logger.info(f'Sending {message!r} to {channel!r}')
self.unconfirmedMessages.append((channel, message, overlongmode)) self.unconfirmedMessages.append((channel, message, overlongmode))
self.send(b'PRIVMSG ' + channelB + b' :' + messageB) self.send(b'PRIVMSG ' + channelB + b' :' + messageB)
await asyncio.sleep(1) # Rate limit


async def confirm_messages(self): async def confirm_messages(self):
while self.connected: while self.connected:
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 60) # Confirm once per minute
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 60) # Confirm once per minute
if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly
self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages)
self.unconfirmedMessages = [] self.unconfirmedMessages = []
break break
if not self.unconfirmedMessages: if not self.unconfirmedMessages:
@@ -488,18 +545,19 @@ class IRCClientProtocol(asyncio.Protocol):
self.logger.debug('Trying to confirm message delivery') self.logger.debug('Trying to confirm message delivery')
self.pongReceivedEvent.clear() self.pongReceivedEvent.clear()
self.send(b'PING :42') self.send(b'PING :42')
await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 5)
await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 5)
self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}') self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}')
if not self.pongReceivedEvent.is_set(): if not self.pongReceivedEvent.is_set():
# No PONG received in five seconds, assume connection's dead # No PONG received in five seconds, assume connection's dead
self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue') self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue')
self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages)
self.transport.close() self.transport.close()
self.unconfirmedMessages = [] self.unconfirmedMessages = []


def data_received(self, data): def data_received(self, data):
time_ = time.time()
self.logger.debug(f'Data received: {data!r}') self.logger.debug(f'Data received: {data!r}')
self.lastRecvTime = time.time()
self.lastRecvTime = time_
# If there's any data left in the buffer, prepend it to the data. Split on CRLF. # If there's any data left in the buffer, prepend it to the data. Split on CRLF.
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again. # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
@@ -507,104 +565,146 @@ class IRCClientProtocol(asyncio.Protocol):
data = self.buffer + data data = self.buffer + data
messages = data.split(b'\r\n') messages = data.split(b'\r\n')
for message in messages[:-1]: for message in messages[:-1]:
self.message_received(message)
lines = self.server.recv(message + b'\r\n')
assert len(lines) == 1, f'recv did not return exactly one line: {message!r} -> {lines!r}'
self.message_received(time_, message, lines[0])
self.server.parse_tokens(lines[0])
self.buffer = messages[-1] self.buffer = messages[-1]


def message_received(self, message):
self.logger.debug(f'Message received: {message!r}')
rawMessage = message
if message.startswith(b':') and b' ' in message:
# Prefixed message, extract command + parameters (the prefix cannot contain a space)
message = message.split(b' ', 1)[1]
def message_received(self, time_, message, line):
self.logger.debug(f'Message received at {time_}: {message!r}')

maybeTriggerWhox = False


# PING/PONG # PING/PONG
if message.startswith(b'PING '):
self.send(b'PONG ' + message[5:])
elif message.startswith(b'PONG '):
if line.command == 'PING':
self._direct_send(irctokens.build('PONG', line.params).format().encode('utf-8'))
elif line.command == 'PONG':
self.pongReceivedEvent.set() self.pongReceivedEvent.set()


# SASL
elif message.startswith(b'CAP ') and self.sasl:
if message[message.find(b' ', 4) + 1:] == b'ACK :sasl':
self.send(b'AUTHENTICATE EXTERNAL')
else:
self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection')
self.transport.close()
elif message == b'AUTHENTICATE +':
# IRCv3 and SASL
elif line.command == 'CAP':
if line.params[1] == 'ACK':
for cap in line.params[2].split(' '):
self.logger.debug(f'CAP ACK: {cap}')
self.caps.add(cap)
if cap == 'sasl' and self.sasl:
self.send(b'AUTHENTICATE EXTERNAL')
else:
self.capReqsPending.remove(cap)
elif line.params[1] == 'NAK':
self.logger.warning(f'Failed to activate CAP(s): {line.params[2]}')
for cap in line.params[2].split(' '):
self.capReqsPending.remove(cap)
if len(self.capReqsPending) == 0:
self.send(b'CAP END')
elif line.command == 'AUTHENTICATE' and line.params == ['+']:
self.send(b'AUTHENTICATE +') self.send(b'AUTHENTICATE +')
elif message.startswith(b'900 '): # "You are now logged in", includes the usermask
words = message.split(b' ')
if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]:
if b'!~' not in words[2]:
# At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde.
words[2] = words[2].replace(b'!', b'!~', 1)
self._maybe_set_usermask(words[2])
elif message.startswith(b'903 '): # SASL auth successful
elif line.command == ircstates.numerics.RPL_SASLSUCCESS:
self.authenticated = True self.authenticated = True
self.send(b'CAP END')
elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')):
self.capReqsPending.remove('sasl')
if len(self.capReqsPending) == 0:
self.send(b'CAP END')
elif line.command in ('902', ircstates.numerics.ERR_SASLFAIL, ircstates.numerics.ERR_SASLTOOLONG, ircstates.numerics.ERR_SASLABORTED, ircstates.numerics.RPL_SASLMECHS):
self.logger.error('SASL error, terminating connection') self.logger.error('SASL error, terminating connection')
self.transport.close() self.transport.close()


# NICK errors # NICK errors
elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ')):
elif line.command in ('431', ircstates.numerics.ERR_ERRONEUSNICKNAME, ircstates.numerics.ERR_NICKNAMEINUSE, '436'):
self.logger.error(f'Failed to set nickname: {message!r}, terminating connection') self.logger.error(f'Failed to set nickname: {message!r}, terminating connection')
self.transport.close() self.transport.close()


# USER errors # USER errors
elif any(message.startswith(x) for x in (b'461 ', b'462 ')):
elif line.command in ('461', '462'):
self.logger.error(f'Failed to register: {message!r}, terminating connection') self.logger.error(f'Failed to register: {message!r}, terminating connection')
self.transport.close() self.transport.close()


# JOIN errors # JOIN errors
elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')):
elif line.command in (
ircstates.numerics.ERR_TOOMANYCHANNELS,
ircstates.numerics.ERR_CHANNELISFULL,
ircstates.numerics.ERR_INVITEONLYCHAN,
ircstates.numerics.ERR_BANNEDFROMCHAN,
ircstates.numerics.ERR_BADCHANNELKEY,
):
self.logger.error(f'Failed to join channel: {message!r}, terminating connection') self.logger.error(f'Failed to join channel: {message!r}, terminating connection')
self.transport.close() self.transport.close()


# PART errors # PART errors
elif message.startswith(b'442 '):
elif line.command == '442':
self.logger.error(f'Failed to part channel: {message!r}') self.logger.error(f'Failed to part channel: {message!r}')


# JOIN/PART errors # JOIN/PART errors
elif message.startswith(b'403 '):
elif line.command == ircstates.numerics.ERR_NOSUCHCHANNEL:
self.logger.error(f'Failed to join or part channel: {message!r}') self.logger.error(f'Failed to join or part channel: {message!r}')


# PRIVMSG errors # PRIVMSG errors
elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ')):
elif line.command in (ircstates.numerics.ERR_NOSUCHNICK, '404', '407', '411', '412', '413', '414'):
self.logger.error(f'Failed to send message: {message!r}') self.logger.error(f'Failed to send message: {message!r}')


# Connection registration reply # Connection registration reply
elif message.startswith(b'001 '):
elif line.command == ircstates.numerics.RPL_WELCOME:
self.logger.info('IRC connection registered') self.logger.info('IRC connection registered')
if self.sasl and not self.authenticated: if self.sasl and not self.authenticated:
self.logger.error('IRC connection registered but not authenticated, terminating connection') self.logger.error('IRC connection registered but not authenticated, terminating connection')
self.transport.close() self.transport.close()
return return
self.lastSentTime = time.time()
self._send_join_part(b'JOIN', self.channels) self._send_join_part(b'JOIN', self.channels)
asyncio.create_task(self.send_messages()) asyncio.create_task(self.send_messages())
asyncio.create_task(self.confirm_messages()) asyncio.create_task(self.confirm_messages())


# JOIN success
elif message.startswith(b'JOIN ') and not self.usermask:
# If this is my own join message, it should contain the usermask in the prefix
if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage:
usermask = rawMessage.split(b' ', 1)[0][1:]
self._maybe_set_usermask(usermask)

# Services host change
elif message.startswith(b'396 '):
words = message.split(b' ')
if len(words) >= 3:
# Sanity check inspired by irssi src/irc/core/irc-servers.c
if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-':
if b'@' in words[2]: # user@host
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2])
else: # host (get user from previous mask or settings)
if self.usermask:
user = self.usermask.split(b'@')[0].split(b'!')[1]
else:
user = b'~' + self.config['irc']['nick'].encode('utf-8')
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2])
# Bot getting KICKed
elif line.command == 'KICK' and line.source and self.server.casefold(line.params[1]) == self.server.casefold(self.server.nickname):
self.logger.warning(f'Got kicked from {line.params[0]}')
kickedChannel = self.server.casefold(line.params[0])
for channel in self.channels:
if self.server.casefold(channel) == kickedChannel:
self.channels.remove(channel)
break

# WHOX on successful JOIN if supported to fetch account information
elif line.command == 'JOIN' and self.server.isupport.whox and line.source and self.server.casefold(line.hostmask.nickname) == self.server.casefold(self.server.nickname):
self.whoxQueue.extend(line.params[0].split(','))
maybeTriggerWhox = True

# WHOX response
elif line.command == ircstates.numerics.RPL_WHOSPCRPL and line.params[1] == '042':
self.whoxReply.append({'nick': line.params[4], 'hostmask': f'{line.params[4]}!{line.params[2]}@{line.params[3]}', 'account': line.params[5] if line.params[5] != '0' else None})

# End of WHOX response
elif line.command == ircstates.numerics.RPL_ENDOFWHO:
# Patch ircstates account info; ircstates does not parse the WHOX reply itself.
for entry in self.whoxReply:
if entry['account']:
self.server.users[self.server.casefold(entry['nick'])].account = entry['account']
self.whoxChannel = None
self.whoxReply = []
self.whoxStartTime = None
maybeTriggerWhox = True

# General fatal ERROR
elif line.command == 'ERROR':
self.logger.error(f'Server sent ERROR: {message!r}')
self.transport.close()

# Send next WHOX if appropriate
if maybeTriggerWhox and self.whoxChannel is None and self.whoxQueue:
self.whoxChannel = self.whoxQueue.popleft()
self.whoxReply = []
self.whoxStartTime = time.time() # Note, may not be the actual start time due to rate limiting
self.send(b'WHO ' + self.whoxChannel.encode('utf-8') + b' c%tuhna,042')

async def quit(self):
# The server acknowledges a QUIT by sending an ERROR and closing the connection. The latter triggers connection_lost, so just wait for the closure event.
self.logger.info('Quitting')
self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue
self._direct_send(b'QUIT :Bye')
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10)
if not self.connectionClosedEvent.is_set():
self.logger.error('Quitting cleanly did not work, closing connection forcefully')
# Event will be set implicitly in connection_lost.
self.transport.close()


def connection_lost(self, exc): def connection_lost(self, exc):
self.logger.info('IRC connection lost') self.logger.info('IRC connection lost')
@@ -615,8 +715,8 @@ class IRCClientProtocol(asyncio.Protocol):
class IRCClient: class IRCClient:
logger = logging.getLogger('http2irc.IRCClient') logger = logging.getLogger('http2irc.IRCClient')


def __init__(self, messageQueue, config):
self.messageQueue = messageQueue
def __init__(self, http2ircMessageQueue, config):
self.http2ircMessageQueue = http2ircMessageQueue
self.config = config self.config = config
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()} self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}


@@ -647,17 +747,43 @@ class IRCClient:
while True: while True:
connectionClosedEvent.clear() connectionClosedEvent.clear()
try: try:
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
self.logger.debug('Creating IRC connection')
t = asyncio.create_task(loop.create_connection(
protocol_factory = lambda: IRCClientProtocol(self.http2ircMessageQueue, connectionClosedEvent, loop, self.config, self.channels),
host = self.config['irc']['host'],
port = self.config['irc']['port'],
ssl = self._get_ssl_context(),
family = self.config['irc']['family'],
))
# No automatic cancellation of t because it's handled manually below.
done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30)
if t not in done:
t.cancel()
await t # Raises the CancelledError
self._transport, self._protocol = t.result()
self.logger.debug('Starting send queue processing')
sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent
self.logger.debug('Waiting for connection closure or SIGINT')
try: try:
await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED)
await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
finally: finally:
self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}')
if not connectionClosedEvent.is_set():
self.logger.debug('Quitting connection')
await self._protocol.quit()
if not sendTask.done():
sendTask.cancel()
try:
await sendTask
except asyncio.CancelledError:
pass
self._transport = None self._transport = None
self._protocol = None self._protocol = None
except (ConnectionRefusedError, asyncio.TimeoutError) as e:
self.logger.error(str(e))
except (ConnectionError, ssl.SSLError, asyncio.TimeoutError, asyncio.CancelledError) as e:
self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}')
await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5) await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5)
if sigintEvent.is_set(): if sigintEvent.is_set():
self.logger.debug('Got SIGINT, breaking IRC loop')
break break


@property @property
@@ -668,8 +794,8 @@ class IRCClient:
class WebServer: class WebServer:
logger = logging.getLogger('http2irc.WebServer') logger = logging.getLogger('http2irc.WebServer')


def __init__(self, messageQueue, ircClient, config):
self.messageQueue = messageQueue
def __init__(self, http2ircMessageQueue, ircClient, config):
self.http2ircMessageQueue = http2ircMessageQueue
self.ircClient = ircClient self.ircClient = ircClient
self.config = config self.config = config


@@ -697,7 +823,7 @@ class WebServer:
await runner.setup() await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
await site.start() await site.start()
await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = concurrent.futures.FIRST_COMPLETED)
await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED)
await runner.cleanup() await runner.cleanup()
if stopEvent.is_set(): if stopEvent.is_set():
break break
@@ -735,7 +861,7 @@ class WebServer:
self.logger.debug(f'Processing request {id(request)} using default processor') self.logger.debug(f'Processing request {id(request)} using default processor')
message = await self._default_process(request) message = await self._default_process(request)
self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue') self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue')
self.messageQueue.put_nowait((channel, message, overlongmode))
self.http2ircMessageQueue.put_nowait((channel, message, overlongmode))
raise aiohttp.web.HTTPOk() raise aiohttp.web.HTTPOk()


async def _default_process(self, request): async def _default_process(self, request):
@@ -777,10 +903,10 @@ async def main():


loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()


messageQueue = MessageQueue()
http2ircMessageQueue = MessageQueue()


irc = IRCClient(messageQueue, config)
webserver = WebServer(messageQueue, irc, config)
irc = IRCClient(http2ircMessageQueue, config)
webserver = WebServer(http2ircMessageQueue, irc, config)


sigintEvent = asyncio.Event() sigintEvent = asyncio.Event()
def sigint_callback(): def sigint_callback():


Loading…
Откажи
Сачувај