Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

247 Zeilen
9.1 KiB

  1. import aiohttp
  2. import aiohttp.web
  3. import argparse
  4. import asyncio
  5. import collections
  6. import concurrent.futures
  7. import json
  8. import logging
  9. import signal
  10. logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')
  11. class MessageQueue:
  12. # An object holding onto the messages received from nodeping
  13. # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
  14. # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
  15. # Differences to asyncio.Queue include:
  16. # - No maxsize
  17. # - No put coroutine (not necessary since the queue can never be full)
  18. # - Only one concurrent getter
  19. # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)
  20. def __init__(self):
  21. self._getter = None # None | asyncio.Future
  22. self._queue = collections.deque()
  23. async def get(self):
  24. if self._getter is not None:
  25. raise RuntimeError('Cannot get concurrently')
  26. if len(self._queue) == 0:
  27. self._getter = asyncio.get_running_loop().create_future()
  28. logging.debug('Awaiting getter')
  29. try:
  30. await self._getter
  31. except asyncio.CancelledError:
  32. logging.debug('Cancelled getter')
  33. self._getter = None
  34. raise
  35. logging.debug('Awaited getter')
  36. self._getter = None
  37. # For testing the cancellation/putting back onto the queue
  38. #logging.debug('Delaying message queue get')
  39. #await asyncio.sleep(3)
  40. #logging.debug('Done delaying')
  41. return self.get_nowait()
  42. def get_nowait(self):
  43. if len(self._queue) == 0:
  44. raise asyncio.QueueEmpty
  45. return self._queue.popleft()
  46. def put_nowait(self, item):
  47. self._queue.append(item)
  48. if self._getter is not None:
  49. self._getter.set_result(None)
  50. def putleft_nowait(self, item):
  51. self._queue.appendleft(item)
  52. if self._getter is not None:
  53. self._getter.set_result(None)
  54. def qsize(self):
  55. return len(self._queue)
  56. class IRCClientProtocol(asyncio.Protocol):
  57. def __init__(self, messageQueue, stopEvent, loop, nick, real, channel):
  58. logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {stopEvent}, {loop}')
  59. self.messageQueue = messageQueue
  60. self.stopEvent = stopEvent
  61. self.loop = loop
  62. self.nick = nick
  63. self.real = real
  64. self.channel = channel
  65. self.channelb = channel.encode('utf-8')
  66. self.buffer = b''
  67. self.connected = False
  68. def send(self, data):
  69. logging.info(f'Send: {data!r}')
  70. self.transport.write(data + b'\r\n')
  71. def connection_made(self, transport):
  72. logging.info('Connected')
  73. self.transport = transport
  74. self.connected = True
  75. nickb = self.nick.encode('utf-8')
  76. self.send(b'NICK ' + nickb)
  77. self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.real.encode('utf-8'))
  78. self.send(b'JOIN ' + self.channelb)
  79. asyncio.create_task(self.send_messages())
  80. async def _get_message(self):
  81. logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
  82. messageFuture = asyncio.create_task(self.messageQueue.get())
  83. done, pending = await asyncio.wait((messageFuture, self.stopEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
  84. if self.stopEvent.is_set():
  85. if messageFuture in pending:
  86. logging.debug('Cancelling messageFuture')
  87. messageFuture.cancel()
  88. try:
  89. await messageFuture
  90. except asyncio.CancelledError:
  91. logging.debug('Cancelled messageFuture')
  92. pass
  93. else:
  94. # messageFuture is already done but we're stopping, so put the result back onto the queue
  95. self.messageQueue.putleft_nowait(messageFuture.result())
  96. return None
  97. assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
  98. return messageFuture.result()
  99. async def send_messages(self):
  100. while self.connected:
  101. logging.debug(f'{id(self)}: trying to get a message')
  102. message = await self._get_message()
  103. logging.debug(f'{id(self)}: got message: {message!r}')
  104. if message is None:
  105. break
  106. self.send(b'PRIVMSG ' + self.channelb + b' :' + message.encode('utf-8'))
  107. #TODO self.messageQueue.putleft_nowait if delivery fails
  108. await asyncio.sleep(1) # Rate limit
  109. def data_received(self, data):
  110. logging.debug(f'Data received: {data!r}')
  111. # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
  112. # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
  113. # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
  114. messages = data.split(b'\r\n')
  115. if self.buffer:
  116. self.message_received(self.buffer + messages[0])
  117. messages = messages[1:]
  118. for message in messages[:-1]:
  119. self.message_received(message)
  120. self.buffer = messages[-1]
  121. def message_received(self, message):
  122. logging.info(f'Message received: {message!r}')
  123. if message.startswith(b'PING '):
  124. self.send(b'PONG ' + message[5:])
  125. def connection_lost(self, exc):
  126. logging.info('The server closed the connection')
  127. self.connected = False
  128. self.stopEvent.set()
  129. class WebServer:
  130. def __init__(self, messageQueue, host, port, auth):
  131. self.messageQueue = messageQueue
  132. self.host = host
  133. self.port = port
  134. self.auth = auth
  135. if auth:
  136. self.authHeader = f'Basic {base64.b64encode(auth.encode("utf-8")).decode("utf-8")}'
  137. self._app = aiohttp.web.Application()
  138. self._app.add_routes([aiohttp.web.post('/nodeping', self.nodeping_post)])
  139. async def run(self, stopEvent):
  140. runner = aiohttp.web.AppRunner(self._app)
  141. await runner.setup()
  142. site = aiohttp.web.TCPSite(runner, self.host, self.port)
  143. await site.start()
  144. await stopEvent.wait()
  145. await runner.cleanup()
  146. async def nodeping_post(self, request):
  147. logging.info(f'Received request with data: {await request.read()!r}')
  148. authHeader = request.headers.get('Authorization')
  149. if self.auth and (not authHeader or authHeader != self.authHeader):
  150. return aiohttp.web.HTTPForbidden()
  151. try:
  152. data = await request.json()
  153. except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
  154. logging.error(f'Received invalid data: {await request.read()!r}')
  155. return aiohttp.web.HTTPBadRequest()
  156. if 'message' not in data:
  157. logging.error(f'Received invalid data: {await request.read()!r}')
  158. return aiohttp.web.HTTPBadRequest()
  159. if '\r' in data['message'] or '\n' in data['message']:
  160. logging.error(f'Received invalid data: {await request.read()!r}')
  161. return aiohttp.web.HTTPBadRequest()
  162. logging.debug(f'Putting to message queue {id(self.messageQueue)}')
  163. self.messageQueue.put_nowait(data['message'])
  164. return aiohttp.web.HTTPOk()
  165. async def run_irc(loop, messageQueue, sigintEvent, host, port, ssl, nick, real, channel):
  166. stopEvent = asyncio.Event()
  167. while True:
  168. stopEvent.clear()
  169. try:
  170. transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop, nick = nick, real = real, channel = channel), host, port, ssl = ssl)
  171. try:
  172. await asyncio.wait((stopEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
  173. finally:
  174. transport.close()
  175. except (ConnectionRefusedError, asyncio.TimeoutError) as e:
  176. logging.error(str(e))
  177. await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
  178. if sigintEvent.is_set():
  179. break
  180. async def run_webserver(loop, messageQueue, sigintEvent, host, port, auth):
  181. server = WebServer(messageQueue, host, port, auth)
  182. await server.run(sigintEvent)
  183. def parse_args():
  184. parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter)
  185. parser.add_argument('--irchost', type = str, help = 'IRC server hostname', default = 'irc.hackint.org')
  186. parser.add_argument('--ircport', type = int, help = 'IRC server port', default = 6697)
  187. parser.add_argument('--ircssl', choices = ['yes', 'no', 'insecure'], help = 'enable, disable, or use insecure SSL/TLS', default = 'yes')
  188. parser.add_argument('--ircnick', help = 'IRC nickname', default = 'npbot')
  189. parser.add_argument('--ircreal', help = 'IRC realname', default = 'I am a bot.')
  190. parser.add_argument('--ircchannel', help = 'IRC channel to join and post messages', default = '#nodeping')
  191. parser.add_argument('--webhost', type = str, help = 'web server host to bind to', default = '127.0.0.1')
  192. parser.add_argument('--webport', type = int, help = 'web server port to bind to', default = 8080)
  193. parser.add_argument('--webauth', type = str, help = 'basic auth data (user:pass, or None to disable the check)', default = None)
  194. return parser.parse_args()
  195. async def main():
  196. args = parse_args()
  197. ssl = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}[args.ircssl]
  198. loop = asyncio.get_running_loop()
  199. messageQueue = MessageQueue()
  200. sigintEvent = asyncio.Event()
  201. def sigint_callback():
  202. logging.info('Got SIGINT')
  203. nonlocal sigintEvent
  204. sigintEvent.set()
  205. loop.add_signal_handler(signal.SIGINT, sigint_callback)
  206. irc = run_irc(loop, messageQueue, sigintEvent, host = args.irchost, port = args.ircport, ssl = ssl, nick = args.ircnick, real = args.ircreal, channel = args.ircchannel)
  207. webserver = run_webserver(loop, messageQueue, sigintEvent, host = args.webhost, port = args.webport, auth = args.webauth)
  208. await asyncio.gather(irc, webserver)
  209. asyncio.run(main())