|
- import aiohttp
- import aiohttp.web
- import argparse
- import asyncio
- import collections
- import concurrent.futures
- import json
- import logging
- import signal
-
-
- logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')
-
-
- class MessageQueue:
- # An object holding onto the messages received from nodeping
- # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
- # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
- # Differences to asyncio.Queue include:
- # - No maxsize
- # - No put coroutine (not necessary since the queue can never be full)
- # - Only one concurrent getter
- # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)
-
- def __init__(self):
- self._getter = None # None | asyncio.Future
- self._queue = collections.deque()
-
- async def get(self):
- if self._getter is not None:
- raise RuntimeError('Cannot get concurrently')
- if len(self._queue) == 0:
- self._getter = asyncio.get_running_loop().create_future()
- logging.debug('Awaiting getter')
- try:
- await self._getter
- except asyncio.CancelledError:
- logging.debug('Cancelled getter')
- self._getter = None
- raise
- logging.debug('Awaited getter')
- self._getter = None
- # For testing the cancellation/putting back onto the queue
- #logging.debug('Delaying message queue get')
- #await asyncio.sleep(3)
- #logging.debug('Done delaying')
- return self.get_nowait()
-
- def get_nowait(self):
- if len(self._queue) == 0:
- raise asyncio.QueueEmpty
- return self._queue.popleft()
-
- def put_nowait(self, item):
- self._queue.append(item)
- if self._getter is not None:
- self._getter.set_result(None)
-
- def putleft_nowait(self, item):
- self._queue.appendleft(item)
- if self._getter is not None:
- self._getter.set_result(None)
-
- def qsize(self):
- return len(self._queue)
-
-
- class IRCClientProtocol(asyncio.Protocol):
- def __init__(self, messageQueue, stopEvent, loop, nick, real, channel):
- logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {stopEvent}, {loop}')
- self.messageQueue = messageQueue
- self.stopEvent = stopEvent
- self.loop = loop
- self.nick = nick
- self.real = real
- self.channel = channel
- self.channelb = channel.encode('utf-8')
- self.buffer = b''
- self.connected = False
-
- def send(self, data):
- logging.info(f'Send: {data!r}')
- self.transport.write(data + b'\r\n')
-
- def connection_made(self, transport):
- logging.info('Connected')
- self.transport = transport
- self.connected = True
- nickb = self.nick.encode('utf-8')
- self.send(b'NICK ' + nickb)
- self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.real.encode('utf-8'))
- self.send(b'JOIN ' + self.channelb)
- asyncio.create_task(self.send_messages())
-
- async def _get_message(self):
- logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
- messageFuture = asyncio.create_task(self.messageQueue.get())
- done, pending = await asyncio.wait((messageFuture, self.stopEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- if self.stopEvent.is_set():
- if messageFuture in pending:
- logging.debug('Cancelling messageFuture')
- messageFuture.cancel()
- try:
- await messageFuture
- except asyncio.CancelledError:
- logging.debug('Cancelled messageFuture')
- pass
- else:
- # messageFuture is already done but we're stopping, so put the result back onto the queue
- self.messageQueue.putleft_nowait(messageFuture.result())
- return None
- assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
- return messageFuture.result()
-
- async def send_messages(self):
- while self.connected:
- logging.debug(f'{id(self)}: trying to get a message')
- message = await self._get_message()
- logging.debug(f'{id(self)}: got message: {message!r}')
- if message is None:
- break
- self.send(b'PRIVMSG ' + self.channelb + b' :' + message.encode('utf-8'))
- #TODO self.messageQueue.putleft_nowait if delivery fails
- await asyncio.sleep(1) # Rate limit
-
- def data_received(self, data):
- logging.debug(f'Data received: {data!r}')
- # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
- # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
- # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
- messages = data.split(b'\r\n')
- if self.buffer:
- self.message_received(self.buffer + messages[0])
- messages = messages[1:]
- for message in messages[:-1]:
- self.message_received(message)
- self.buffer = messages[-1]
-
- def message_received(self, message):
- logging.info(f'Message received: {message!r}')
- if message.startswith(b'PING '):
- self.send(b'PONG ' + message[5:])
-
- def connection_lost(self, exc):
- logging.info('The server closed the connection')
- self.connected = False
- self.stopEvent.set()
-
-
- class WebServer:
- def __init__(self, messageQueue, host, port, auth):
- self.messageQueue = messageQueue
- self.host = host
- self.port = port
- self.auth = auth
- if auth:
- self.authHeader = f'Basic {base64.b64encode(auth.encode("utf-8")).decode("utf-8")}'
- self._app = aiohttp.web.Application()
- self._app.add_routes([aiohttp.web.post('/nodeping', self.nodeping_post)])
-
- async def run(self, stopEvent):
- runner = aiohttp.web.AppRunner(self._app)
- await runner.setup()
- site = aiohttp.web.TCPSite(runner, self.host, self.port)
- await site.start()
- await stopEvent.wait()
- await runner.cleanup()
-
- async def nodeping_post(self, request):
- logging.info(f'Received request with data: {await request.read()!r}')
- authHeader = request.headers.get('Authorization')
- if self.auth and (not authHeader or authHeader != self.authHeader):
- return aiohttp.web.HTTPForbidden()
- try:
- data = await request.json()
- except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
- logging.error(f'Received invalid data: {await request.read()!r}')
- return aiohttp.web.HTTPBadRequest()
- if 'message' not in data:
- logging.error(f'Received invalid data: {await request.read()!r}')
- return aiohttp.web.HTTPBadRequest()
- if '\r' in data['message'] or '\n' in data['message']:
- logging.error(f'Received invalid data: {await request.read()!r}')
- return aiohttp.web.HTTPBadRequest()
- logging.debug(f'Putting to message queue {id(self.messageQueue)}')
- self.messageQueue.put_nowait(data['message'])
- return aiohttp.web.HTTPOk()
-
-
- async def run_irc(loop, messageQueue, sigintEvent, host, port, ssl, nick, real, channel):
- stopEvent = asyncio.Event()
- while True:
- stopEvent.clear()
- try:
- transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop, nick = nick, real = real, channel = channel), host, port, ssl = ssl)
- try:
- await asyncio.wait((stopEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- finally:
- transport.close()
- except (ConnectionRefusedError, asyncio.TimeoutError) as e:
- logging.error(str(e))
- await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- if sigintEvent.is_set():
- break
-
-
- async def run_webserver(loop, messageQueue, sigintEvent, host, port, auth):
- server = WebServer(messageQueue, host, port, auth)
- await server.run(sigintEvent)
-
-
- def parse_args():
- parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--irchost', type = str, help = 'IRC server hostname', default = 'irc.hackint.org')
- parser.add_argument('--ircport', type = int, help = 'IRC server port', default = 6697)
- parser.add_argument('--ircssl', choices = ['yes', 'no', 'insecure'], help = 'enable, disable, or use insecure SSL/TLS', default = 'yes')
- parser.add_argument('--ircnick', help = 'IRC nickname', default = 'npbot')
- parser.add_argument('--ircreal', help = 'IRC realname', default = 'I am a bot.')
- parser.add_argument('--ircchannel', help = 'IRC channel to join and post messages', default = '#nodeping')
- parser.add_argument('--webhost', type = str, help = 'web server host to bind to', default = '127.0.0.1')
- parser.add_argument('--webport', type = int, help = 'web server port to bind to', default = 8080)
- parser.add_argument('--webauth', type = str, help = 'basic auth data (user:pass, or None to disable the check)', default = None)
- return parser.parse_args()
-
-
- async def main():
- args = parse_args()
- ssl = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}[args.ircssl]
-
- loop = asyncio.get_running_loop()
-
- messageQueue = MessageQueue()
- sigintEvent = asyncio.Event()
-
- def sigint_callback():
- logging.info('Got SIGINT')
- nonlocal sigintEvent
- sigintEvent.set()
- loop.add_signal_handler(signal.SIGINT, sigint_callback)
-
- irc = run_irc(loop, messageQueue, sigintEvent, host = args.irchost, port = args.ircport, ssl = ssl, nick = args.ircnick, real = args.ircreal, channel = args.ircchannel)
- webserver = run_webserver(loop, messageQueue, sigintEvent, host = args.webhost, port = args.webport, auth = args.webauth)
- await asyncio.gather(irc, webserver)
-
-
- asyncio.run(main())
|