|
|
@@ -68,6 +68,22 @@ def is_valid_pem(path, withCert): |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
async def wait_cancel_pending(aws, paws = None, **kwargs): |
|
|
|
'''asyncio.wait but with automatic cancellation of non-completed tasks. Tasks in paws (persistent awaitables) are not automatically cancelled.''' |
|
|
|
if paws is None: |
|
|
|
paws = set() |
|
|
|
tasks = aws | paws |
|
|
|
done, pending = await asyncio.wait(tasks, **kwargs) |
|
|
|
for task in pending: |
|
|
|
if task not in paws: |
|
|
|
task.cancel() |
|
|
|
try: |
|
|
|
await task |
|
|
|
except asyncio.CancelledError: |
|
|
|
pass |
|
|
|
return done, pending |
|
|
|
|
|
|
|
|
|
|
|
class Config(dict): |
|
|
|
def __init__(self, filename): |
|
|
|
super().__init__() |
|
|
@@ -398,7 +414,7 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
while True: |
|
|
|
self.logger.debug(f'Trying to get data from send queue') |
|
|
|
t = asyncio.create_task(self.sendQueue.get()) |
|
|
|
done, pending = await asyncio.wait({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) |
|
|
|
done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) |
|
|
|
if self.connectionClosedEvent.is_set(): |
|
|
|
break |
|
|
|
assert t in done, f'{t!r} is not in {done!r}' |
|
|
@@ -407,7 +423,7 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
now = time.time() |
|
|
|
if self.lastSentTime is not None and now - self.lastSentTime < 1: |
|
|
|
self.logger.debug(f'Rate limited') |
|
|
|
await asyncio.wait({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now) |
|
|
|
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now) |
|
|
|
if self.connectionClosedEvent.is_set(): |
|
|
|
break |
|
|
|
time_ = self._direct_send(data) |
|
|
@@ -635,7 +651,7 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
self.logger.info('Quitting') |
|
|
|
self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue |
|
|
|
self._direct_send(b'QUIT :Bye') |
|
|
|
await asyncio.wait({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10) |
|
|
|
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10) |
|
|
|
if not self.connectionClosedEvent.is_set(): |
|
|
|
self.logger.error('Quitting cleanly did not work, closing connection forcefully') |
|
|
|
# Event will be set implicitly in connection_lost. |
|
|
@@ -693,26 +709,33 @@ class IRCClient: |
|
|
|
port = self.config['irc']['port'], |
|
|
|
ssl = self._get_ssl_context(), |
|
|
|
)) |
|
|
|
done, _ = await asyncio.wait({t, asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 30) |
|
|
|
# No automatic cancellation of t because it's handled manually below. |
|
|
|
done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30) |
|
|
|
if t not in done: |
|
|
|
t.cancel() |
|
|
|
await t # Raises the CancelledError |
|
|
|
self._transport, self._protocol = t.result() |
|
|
|
self.logger.debug('Starting send queue processing') |
|
|
|
asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent |
|
|
|
sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent |
|
|
|
self.logger.debug('Waiting for connection closure or SIGINT') |
|
|
|
try: |
|
|
|
await asyncio.wait({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) |
|
|
|
await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) |
|
|
|
finally: |
|
|
|
self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}') |
|
|
|
if not connectionClosedEvent.is_set(): |
|
|
|
self.logger.debug('Quitting connection') |
|
|
|
await self._protocol.quit() |
|
|
|
if not sendTask.done(): |
|
|
|
sendTask.cancel() |
|
|
|
try: |
|
|
|
await sendTask |
|
|
|
except asyncio.CancelledError: |
|
|
|
pass |
|
|
|
self._transport = None |
|
|
|
self._protocol = None |
|
|
|
except (ConnectionRefusedError, ssl.SSLError, asyncio.TimeoutError, asyncio.CancelledError) as e: |
|
|
|
self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}') |
|
|
|
await asyncio.wait({asyncio.create_task(sigintEvent.wait())}, timeout = 5) |
|
|
|
await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5) |
|
|
|
if sigintEvent.is_set(): |
|
|
|
self.logger.debug('Got SIGINT, putting EOF and breaking') |
|
|
|
self.messageQueue.put_nowait(messageEOF) |
|
|
@@ -807,7 +830,7 @@ class Storage: |
|
|
|
async def flush_files(self, flushExitEvent): |
|
|
|
lastFlushTime = 0 |
|
|
|
while True: |
|
|
|
await asyncio.wait({asyncio.create_task(flushExitEvent.wait())}, timeout = self.config['storage']['flushTime']) |
|
|
|
await wait_cancel_pending({asyncio.create_task(flushExitEvent.wait())}, timeout = self.config['storage']['flushTime']) |
|
|
|
self.logger.debug('Flushing files') |
|
|
|
flushedFiles = [] |
|
|
|
for channel, (fn, f, fLastWriteTime) in self.files.items(): |
|
|
@@ -883,7 +906,7 @@ class WebServer: |
|
|
|
await runner.setup() |
|
|
|
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) |
|
|
|
await site.start() |
|
|
|
await asyncio.wait({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED) |
|
|
|
await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED) |
|
|
|
await runner.cleanup() |
|
|
|
if stopEvent.is_set(): |
|
|
|
break |
|
|
@@ -1140,17 +1163,20 @@ class WebServer: |
|
|
|
stderrTask = asyncio.create_task(process_stderr()) |
|
|
|
await asyncio.wait({stdoutTask, stderrTask}, timeout = self.config['web']['search']['maxTime'] if self.config['web']['search']['maxTime'] != 0 else None) |
|
|
|
# The stream readers may quit before the process is done even on a successful grep. Wait a tiny bit longer for the process to exit. |
|
|
|
await asyncio.wait({asyncio.create_task(proc.wait())}, timeout = 0.1) |
|
|
|
procTask = asyncio.create_task(proc.wait()) |
|
|
|
await asyncio.wait({procTask}, timeout = 0.1) |
|
|
|
if proc.returncode is None: |
|
|
|
# Process hasn't finished yet after maxTime. Murder it and wait for it to die. |
|
|
|
assert not procTask.done(), 'procTask is done but proc.returncode is None' |
|
|
|
self.logger.warning(f'Request {id(request)} grep took more than the time limit') |
|
|
|
proc.kill() |
|
|
|
await asyncio.wait({stdoutTask, stderrTask, asyncio.create_task(proc.wait())}, timeout = 1) # This really shouldn't take longer. |
|
|
|
await asyncio.wait({stdoutTask, stderrTask, procTask}, timeout = 1) # This really shouldn't take longer. |
|
|
|
if proc.returncode is None: |
|
|
|
# Still not done?! Cancel tasks and bail. |
|
|
|
self.logger.error(f'Request {id(request)} grep did not exit after getting killed!') |
|
|
|
stdoutTask.cancel() |
|
|
|
stderrTask.cancel() |
|
|
|
procTask.cancel() |
|
|
|
return aiohttp.web.HTTPInternalServerError() |
|
|
|
stdout, incomplete = stdoutTask.result() |
|
|
|
self.logger.info(f'Request {id(request)} grep exited with {proc.returncode} and produced {len(stdout)} bytes (incomplete: {incomplete})') |
|
|
|