From 1548246b08011230a42fed2479a93f9b80148927 Mon Sep 17 00:00:00 2001 From: JustAnotherArchivist Date: Wed, 28 Apr 2021 04:55:26 +0000 Subject: [PATCH] Fix memory leak due to asyncio tasks not being cancelled (cf. irclog commit 50a8b798) --- http2irc.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/http2irc.py b/http2irc.py index d14d8cb..4a125a0 100644 --- a/http2irc.py +++ b/http2irc.py @@ -47,6 +47,22 @@ def is_valid_pem(path, withCert): return False +async def wait_cancel_pending(aws, paws = None, **kwargs): + '''asyncio.wait but with automatic cancellation of non-completed tasks. Tasks in paws (persistent awaitables) are not automatically cancelled.''' + if paws is None: + paws = set() + tasks = aws | paws + done, pending = await asyncio.wait(tasks, **kwargs) + for task in pending: + if task not in paws: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + return done, pending + + class Config(dict): def __init__(self, filename): super().__init__() @@ -381,7 +397,7 @@ class IRCClientProtocol(asyncio.Protocol): async def _get_message(self): self.logger.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') messageFuture = asyncio.create_task(self.messageQueue.get()) - done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = concurrent.futures.FIRST_COMPLETED) if self.connectionClosedEvent.is_set(): if messageFuture in pending: self.logger.debug('Cancelling messageFuture') @@ -457,7 +473,7 @@ class IRCClientProtocol(asyncio.Protocol): async def confirm_messages(self): while self.connected: - await asyncio.wait((asyncio.sleep(60), self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) # Confirm once per minute + await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 60) # Confirm once per minute if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly self.messageQueue.putleft_nowait(*self.unconfirmedMessages) self.unconfirmedMessages = [] @@ -468,7 +484,7 @@ class IRCClientProtocol(asyncio.Protocol): self.logger.debug('Trying to confirm message delivery') self.pongReceivedEvent.clear() self.send(b'PING :42') - await asyncio.wait((asyncio.sleep(5), self.pongReceivedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 5) self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}') if not self.pongReceivedEvent.is_set(): # No PONG received in five seconds, assume connection's dead @@ -629,12 +645,12 @@ class IRCClient: try: self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) try: - await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED) finally: self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? except (ConnectionRefusedError, asyncio.TimeoutError) as e: self.logger.error(str(e)) - await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5) if sigintEvent.is_set(): break @@ -667,7 +683,7 @@ class WebServer: await runner.setup() site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) await site.start() - await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = concurrent.futures.FIRST_COMPLETED) await runner.cleanup() if stopEvent.is_set(): break