diff --git a/joycontrol/protocol.py b/joycontrol/protocol.py index e41d155..935cc1e 100644 --- a/joycontrol/protocol.py +++ b/joycontrol/protocol.py @@ -78,9 +78,11 @@ class ControllerProtocol(BaseProtocol): self.transport = transport def connection_lost(self, exc: Optional[Exception]) -> None: + # TODO raise NotImplementedError() def error_received(self, exc: Exception) -> None: + # TODO raise NotImplementedError() async def input_report_mode_0x30(self): @@ -288,16 +290,26 @@ class ControllerProtocol(BaseProtocol): # start sending 0x30 input reports if self._0x30_input_report_sender is None: self.transport.pause_reading() - self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30()) # create callback to check for exceptions def callback(future): try: future.result() + except asyncio.CancelledError: + # Future may be cancelled at anytime + pass except Exception as err: logger.exception(err) + self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30()) self._0x30_input_report_sender.add_done_callback(callback) + + # We have to swap the reader in the future because this function was probably called by it + async def set_reader(): + await self.transport.set_reader(self._0x30_input_report_sender) + self.transport.resume_reading() + + asyncio.ensure_future(set_reader()).add_done_callback(callback) else: logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request') diff --git a/joycontrol/transport.py b/joycontrol/transport.py index d2e9c27..e8beddc 100644 --- a/joycontrol/transport.py +++ b/joycontrol/transport.py @@ -4,8 +4,6 @@ import struct import time from typing import Any -from joycontrol.report import InputReport - logger = logging.getLogger(__name__) @@ -23,35 +21,72 @@ class L2CAP_Transport(asyncio.Transport): 'socket': self._sock } + self._is_closing = False + self._is_reading = asyncio.Event() + + self._capture_file = capture_file + + # start underlying reader + self._read_thread = None + self._is_reading.set() + self.start_reader() + + async def _reader(self): + while True: + data = await self.read() + + #logger.debug(f'received "{list(data)}"') + await self._protocol.report_received(data, self._sock.getpeername()) + + def start_reader(self): + """ + Starts the transport reader which calls the protocols report_received function for every incoming message + """ + if self._read_thread is not None: + raise ValueError('Reader is already running.') + self._read_thread = asyncio.ensure_future(self._reader()) # create callback to check for exceptions def callback(future): try: future.result() + except asyncio.CancelledError: + # Future may be cancelled at anytime + pass except Exception as err: logger.exception(err) self._read_thread.add_done_callback(callback) - self._is_closing = False - self._is_reading = asyncio.Event() - self._is_reading.set() + async def set_reader(self, reader: asyncio.Future): + """ + Cancel the currently running reader and register the new one. + A reader is a coroutine that calls this transports 'read' function. + The 'read' function calls can be paused by calling pause_reading of this transport. + :param reader: future reader + """ + if self._read_thread is not None: + # cancel currently running reader + self._read_thread.cancel() + try: + await self._read_thread + except asyncio.CancelledError: + pass - self._input_report_timer = 0x00 + self._read_thread = reader - self._capture_file = capture_file - - async def _reader(self): - while True: - await self._is_reading.wait() - - data = await self.read() - - #logger.debug(f'received "{list(data)}"') - await self._protocol.report_received(data, self._sock.getpeername()) + def get_reader(self): + return self._read_thread async def read(self): + """ + Read data from the unterlying socket. This function "blocks", + if reading is paused using the pause_reading function. + + :returns bytes + """ + await self._is_reading.wait() data = await self._loop.sock_recv(self._sock, self._read_buffer_size) if self._capture_file is not None: @@ -66,7 +101,7 @@ class L2CAP_Transport(asyncio.Transport): """ :returns True if the reader is running """ - return self._is_reading.is_set() + return self._reader is not None and self._is_reading.is_set() def pause_reading(self) -> None: """ @@ -109,7 +144,7 @@ class L2CAP_Transport(asyncio.Transport): async def close(self): """ - Stops socket reader and closes socket + Stops reader and closes underlying socket """ self._is_closing = True self._read_thread.cancel()