diff --git a/joycontrol/command_line_interface.py b/joycontrol/command_line_interface.py index b7aead9..5e13e6c 100644 --- a/joycontrol/command_line_interface.py +++ b/joycontrol/command_line_interface.py @@ -4,6 +4,7 @@ import logging from aioconsole import ainput from joycontrol.controller_state import button_push, ControllerState +from joycontrol.transport import NotConnectedError logger = logging.getLogger(__name__) @@ -116,4 +117,8 @@ class ControllerCLI: if buttons_to_push: await button_push(self.controller_state, *buttons_to_push) else: - await self.controller_state.send() + try: + await self.controller_state.send() + except NotConnectedError: + logger.info('Connection was lost.') + return diff --git a/joycontrol/controller_state.py b/joycontrol/controller_state.py index 6e2adf7..3764f2a 100644 --- a/joycontrol/controller_state.py +++ b/joycontrol/controller_state.py @@ -45,8 +45,11 @@ class ControllerState: return self._spi_flash async def send(self): - self.sig_is_send.clear() - await self.sig_is_send.wait() + """ + Invokes protocol.send_controller_state(). Returns after the controller state was send. + Raises NotConnected exception if the connection was lost. + """ + await self._protocol.send_controller_state() async def connect(self): """ diff --git a/joycontrol/protocol.py b/joycontrol/protocol.py index e41d155..9606211 100644 --- a/joycontrol/protocol.py +++ b/joycontrol/protocol.py @@ -1,12 +1,15 @@ import asyncio import logging from asyncio import BaseTransport, BaseProtocol +from contextlib import suppress from typing import Optional, Union, Tuple, Text +from joycontrol import utils from joycontrol.controller import Controller from joycontrol.controller_state import ControllerState from joycontrol.memory import FlashMemory from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID +from joycontrol.transport import NotConnectedError logger = logging.getLogger(__name__) @@ -17,6 +20,7 @@ def controller_protocol_factory(controller: Controller, spi_flash=None): def create_controller_protocol(): return ControllerProtocol(controller, spi_flash=spi_flash) + return create_controller_protocol @@ -27,23 +31,48 @@ class ControllerProtocol(BaseProtocol): self.transport = None - # Increases for each input report send, overflows at 0x100 + # Increases for each input report send, should overflow at 0x100 self._input_report_timer = 0x00 self._data_received = asyncio.Event() self._controller_state = ControllerState(self, controller, spi_flash=spi_flash) + self._controller_state_sender = None - self._0x30_input_report_sender = None + # None = Just answer to sub commands + self._input_report_mode = None # This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs self.sig_set_player_lights = asyncio.Event() + async def send_controller_state(self): + """ + Waits for the controller state to be send. + + Raises NotConnected exception if the transport is not connected or the connection was lost. + """ + # TODO: Call write directly if not in 0x30 input report mode + + if self.transport is None: + raise NotConnectedError('Transport not registered.') + + self._controller_state.sig_is_send.clear() + + # wrap into a future to be able to set an exception in case of a disconnect + self._controller_state_sender = asyncio.ensure_future(self._controller_state.sig_is_send.wait()) + await self._controller_state_sender + self._controller_state_sender = None + async def write(self, input_report: InputReport): """ Sets timer byte and current button state in the input report and sends it. - Fires sig_is_send event afterwards. + Fires sig_is_send event in the controller state afterwards. + + Raises NotConnected exception if the transport is not connected or the connection was lost. """ + if self.transport is None: + raise NotConnectedError('Transport not registered.') + # set button and stick data of input report input_report.set_button_status(self._controller_state.button_state) if self._controller_state.l_stick_state is None: @@ -61,6 +90,7 @@ class ControllerProtocol(BaseProtocol): self._input_report_timer = (self._input_report_timer + 1) % 0x100 await self.transport.write(input_report) + self._controller_state.sig_is_send.set() def get_controller_state(self) -> ControllerState: @@ -68,7 +98,7 @@ class ControllerProtocol(BaseProtocol): async def wait_for_output_report(self): """ - Blocks until an output report from the Switch is received. + Waits until an output report from the Switch is received. """ self._data_received.clear() await self._data_received.wait() @@ -77,10 +107,17 @@ class ControllerProtocol(BaseProtocol): logger.debug('Connection established.') self.transport = transport - def connection_lost(self, exc: Optional[Exception]) -> None: - raise NotImplementedError() + def connection_lost(self, exc: Optional[Exception] = None) -> None: + if self.transport is not None: + logger.error('Connection lost.') + asyncio.ensure_future(self.transport.close()) + self.transport = None + + if self._controller_state_sender is not None: + self._controller_state_sender.set_exception(NotConnectedError) def error_received(self, exc: Exception) -> None: + # TODO? raise NotImplementedError() async def input_report_mode_0x30(self): @@ -97,52 +134,58 @@ class ControllerProtocol(BaseProtocol): reader = asyncio.ensure_future(self.transport.read()) - while True: - if self.controller == Controller.PRO_CONTROLLER: - # send state at 120Hz - await asyncio.sleep(1 / 120) - else: - # send state at 60Hz - await asyncio.sleep(1 / 60) + try: + while True: + # TODO: improve timing + if self.controller == Controller.PRO_CONTROLLER: + # send state at 120Hz + await asyncio.sleep(1 / 120) + else: + # send state at 60Hz + await asyncio.sleep(1 / 60) - reply_send = False - if reader.done(): - data = await reader - if not data: - # disconnect happened - logger.error('No data received (most likely due to a disconnect).') - break + reply_send = False + if reader.done(): + data = await reader - reader = asyncio.ensure_future(self.transport.read()) + reader = asyncio.ensure_future(self.transport.read()) - try: - report = OutputReport(list(data)) - output_report_id = report.get_output_report_id() + try: + report = OutputReport(list(data)) + output_report_id = report.get_output_report_id() - if output_report_id == OutputReportID.RUMBLE_ONLY: - # TODO - pass - elif output_report_id == OutputReportID.SUB_COMMAND: - reply_send = await self._reply_to_sub_command(report) - except ValueError as v_err: - logger.warning(f'Report parsing error "{v_err}" - IGNORE') - except NotImplementedError as err: - logger.warning(err) + if output_report_id == OutputReportID.RUMBLE_ONLY: + # TODO + pass + elif output_report_id == OutputReportID.SUB_COMMAND: + reply_send = await self._reply_to_sub_command(report) + except ValueError as v_err: + logger.warning(f'Report parsing error "{v_err}" - IGNORE') + except NotImplementedError as err: + logger.warning(err) - if reply_send: - # Hack: Adding a delay here to avoid flooding during pairing - await asyncio.sleep(0.3) - else: - # write 0x30 input report. TODO: set some sensor data - input_report.set_6axis_data() - await self.write(input_report) + if reply_send: + # Hack: Adding a delay here to avoid flooding during pairing + await asyncio.sleep(0.3) + else: + # write 0x30 input report. + # TODO: set some sensor data + input_report.set_6axis_data() + + await self.write(input_report) + + except NotConnectedError as err: + # Stop 0x30 input report mode if disconnected. + logger.error(err) + finally: + # cleanup + self._input_report_mode = None + # cancel the reader + with suppress(asyncio.CancelledError, NotConnectedError): + if reader.cancel(): + await reader async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: - if not data: - # disconnect happened - logger.error('No data received (most likely due to a disconnect).') - return - self._data_received.set() try: @@ -159,7 +202,7 @@ class ControllerProtocol(BaseProtocol): if output_report_id == OutputReportID.SUB_COMMAND: await self._reply_to_sub_command(report) - #elif output_report_id == OutputReportID.RUMBLE_ONLY: + # elif output_report_id == OutputReportID.RUMBLE_ONLY: # pass else: logger.warning(f'Output report {output_report_id} not implemented - ignoring') @@ -214,7 +257,7 @@ class ControllerProtocol(BaseProtocol): else: logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring') return False - except Exception as err: + except NotImplementedError as err: logger.error(f'Failed to answer {sub_command} - {err}') return False return True @@ -264,7 +307,7 @@ class ControllerProtocol(BaseProtocol): size = sub_command_data[4] if self.spi_flash is not None: - spi_flash_data = self.spi_flash[offset: offset+size] + spi_flash_data = self.spi_flash[offset: offset + size] input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data) else: spi_flash_data = size * [0x00] @@ -286,18 +329,20 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) # start sending 0x30 input reports - if self._0x30_input_report_sender is None: + if self._input_report_mode != 0x30: + self._input_report_mode = 0x30 + self.transport.pause_reading() - self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30()) + new_reader = asyncio.ensure_future(self.input_report_mode_0x30()) - # create callback to check for exceptions - def callback(future): - try: - future.result() - except Exception as err: - logger.exception(err) + # We need to swap the reader in the future because this function was probably called by it + async def set_reader(): + await self.transport.set_reader(new_reader) + self.transport.resume_reading() - self._0x30_input_report_sender.add_done_callback(callback) + asyncio.ensure_future(set_reader()).add_done_callback( + utils.create_error_check_callback() + ) else: logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request') @@ -350,7 +395,7 @@ class ControllerProtocol(BaseProtocol): # TODO data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200] for i in range(len(data)): - input_report.data[16+i] = data[i] + input_report.data[16 + i] = data[i] await self.write(input_report) diff --git a/joycontrol/server.py b/joycontrol/server.py index f3e10ae..e4d9fae 100644 --- a/joycontrol/server.py +++ b/joycontrol/server.py @@ -111,7 +111,7 @@ async def create_hid_server(protocol_factory, ctl_psm=17, itr_psm=19, device_id= client_itr.setblocking(False) # create transport for the established connection and activate the HID protocol - transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, 50, capture_file=capture_file) + transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, client_ctl, 50, capture_file=capture_file) protocol.connection_made(transport) # send some empty input reports until the Switch decides to reply diff --git a/joycontrol/transport.py b/joycontrol/transport.py index d2e9c27..e271637 100644 --- a/joycontrol/transport.py +++ b/joycontrol/transport.py @@ -4,55 +4,106 @@ import struct import time from typing import Any -from joycontrol.report import InputReport +from joycontrol import utils logger = logging.getLogger(__name__) +class NotConnectedError(ConnectionResetError): + pass + + class L2CAP_Transport(asyncio.Transport): - def __init__(self, loop, protocol, l2cap_socket, read_buffer_size, capture_file=None) -> None: + def __init__(self, loop, protocol, itr_sock, ctr_sock, read_buffer_size, capture_file=None) -> None: + super(L2CAP_Transport, self).__init__() + self._loop = loop self._protocol = protocol - self._sock = l2cap_socket + self._itr_sock = itr_sock + self._ctr_sock = ctr_sock + self._read_buffer_size = read_buffer_size self._extra_info = { - 'peername': self._sock.getpeername(), - 'sockname': self._sock.getsockname(), - 'socket': self._sock + 'peername': self._itr_sock.getpeername(), + 'sockname': self._itr_sock.getsockname(), + 'socket': self._itr_sock } - self._read_thread = asyncio.ensure_future(self._reader()) - - # create callback to check for exceptions - def callback(future): - try: - future.result() - except Exception as err: - logger.exception(err) - - self._read_thread.add_done_callback(callback) - self._is_closing = False self._is_reading = asyncio.Event() - self._is_reading.set() - - self._input_report_timer = 0x00 self._capture_file = capture_file + # start underlying reader + self._read_thread = None + self._is_reading.set() + self.start_reader() + async def _reader(self): while True: - await self._is_reading.wait() - - data = await self.read() + try: + data = await self.read() + except NotConnectedError: + self._read_thread = None + break #logger.debug(f'received "{list(data)}"') - await self._protocol.report_received(data, self._sock.getpeername()) + await self._protocol.report_received(data, self._itr_sock.getpeername()) + + def start_reader(self): + """ + Starts the transport reader which calls the protocols report_received function for every incoming message + """ + if self._read_thread is not None: + raise ValueError('Reader is already running.') + + self._read_thread = asyncio.ensure_future(self._reader()) + + # Create callback in case the reader is failing + callback = utils.create_error_check_callback(ignore=asyncio.CancelledError) + self._read_thread.add_done_callback(callback) + + async def set_reader(self, reader: asyncio.Future): + """ + Cancel the currently running reader and register the new one. + A reader is a coroutine that calls this transports 'read' function. + The 'read' function calls can be paused by calling pause_reading of this transport. + :param reader: future reader + """ + if self._read_thread is not None: + # cancel currently running reader + if self._read_thread.cancel(): + try: + await self._read_thread + except asyncio.CancelledError: + pass + + # Create callback for debugging in case the reader is failing + err_callback = utils.create_error_check_callback(ignore=asyncio.CancelledError) + reader.add_done_callback(err_callback) + + self._read_thread = reader + + def get_reader(self): + return self._read_thread async def read(self): - data = await self._loop.sock_recv(self._sock, self._read_buffer_size) + """ + Read data from the underlying socket. This function waits, + if reading is paused using the pause_reading function. + + :returns bytes + """ + await self._is_reading.wait() + data = await self._loop.sock_recv(self._itr_sock, self._read_buffer_size) + + if not data: + # disconnect happened + logger.error('No data received.') + self._protocol.connection_lost() + raise NotConnectedError('No data received.') if self._capture_file is not None: # write data to log file @@ -66,17 +117,17 @@ class L2CAP_Transport(asyncio.Transport): """ :returns True if the reader is running """ - return self._is_reading.is_set() + return self._reader is not None and self._is_reading.is_set() def pause_reading(self) -> None: """ - Pauses the reader + Pauses any 'read' function calls. """ self._is_reading.clear() def resume_reading(self) -> None: """ - Resumes the reader + Resumes all 'read' function calls. """ self._is_reading.set() @@ -96,10 +147,19 @@ class L2CAP_Transport(asyncio.Transport): self._capture_file.write(_time + size + _bytes) #logger.debug(f'sending "{_bytes}"') - await self._loop.sock_sendall(self._sock, _bytes) + try: + await self._loop.sock_sendall(self._itr_sock, _bytes) + except OSError as err: + logger.error(err) + self._protocol.connection_lost() + raise NotConnectedError(err) + except ConnectionResetError as err: + logger.error(err) + self._protocol.connection_lost() + raise err def abort(self) -> None: - super().abort() + raise NotImplementedError def get_extra_info(self, name: Any, default=None) -> Any: return self._extra_info.get(name, default) @@ -109,16 +169,21 @@ class L2CAP_Transport(asyncio.Transport): async def close(self): """ - Stops socket reader and closes socket + Stops reader and closes underlying socket """ - self._is_closing = True - self._read_thread.cancel() - # wait for reader to cancel - try: - await self._read_thread - except asyncio.CancelledError: - pass - self._sock.close() + if not self._is_closing: + # was not already closed + self._is_closing = True + if self._read_thread.cancel(): + # wait for reader to cancel + try: + await self._read_thread + except asyncio.CancelledError: + pass + + # interrupt connection should be closed first + self._itr_sock.close() + self._ctr_sock.close() def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: self._protocol = protocol diff --git a/joycontrol/utils.py b/joycontrol/utils.py index 4aa33c8..77c0498 100644 --- a/joycontrol/utils.py +++ b/joycontrol/utils.py @@ -12,6 +12,25 @@ def flip_bit(value, n): return value ^ (1 << n) +def create_error_check_callback(ignore=None): + """ + Creates callback causing errors of a finished future to be raised. + Useful for debugging futures that are never awaited. + :param ignore: Any number of errors to ignore. + :returns callback which can be added to a future with future.add_done_callback(...) + """ + def callback(future): + if ignore: + try: + future.result() + except ignore: + # ignore suppressed errors + pass + else: + future.result() + return callback + + async def run_system_command(cmd): proc = await asyncio.create_subprocess_shell( cmd, diff --git a/run_test_controller_buttons.py b/run_test_controller_buttons.py index d9f0a8b..7bcfad2 100644 --- a/run_test_controller_buttons.py +++ b/run_test_controller_buttons.py @@ -2,12 +2,13 @@ import argparse import asyncio import logging import os -from contextlib import contextmanager +from contextlib import contextmanager, suppress from joycontrol import logging_default as log from joycontrol.controller_state import ControllerState, button_push from joycontrol.protocol import controller_protocol_factory, Controller from joycontrol.server import create_hid_server +from joycontrol.transport import NotConnectedError logger = logging.getLogger(__name__) @@ -65,24 +66,26 @@ async def test_controller_buttons(controller_state: ControllerState): if 'home' in button_list: button_list.remove('home') - # push all buttons consecutively until KeyboardInterrupt - try: - while True: - for button in button_list: - await button_push(controller_state, button) - await asyncio.sleep(0.1) - except KeyboardInterrupt: - pass + # push all buttons consecutively + while True: + for button in button_list: + await button_push(controller_state, button) + await asyncio.sleep(0.1) async def _main(controller, capture_file=None, spi_flash=None, device_id=None): factory = controller_protocol_factory(controller, spi_flash=spi_flash) transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id) - await test_controller_buttons(protocol.get_controller_state()) - - logger.info('Stopping communication...') - await transport.close() + try: + await test_controller_buttons(protocol.get_controller_state()) + except KeyboardInterrupt: + pass + except NotConnectedError: + logger.error('Connection was lost.') + finally: + logger.info('Stopping communication...') + await transport.close() if __name__ == '__main__': @@ -131,6 +134,20 @@ if __name__ == '__main__': with get_output(args.log) as capture_file: loop = asyncio.get_event_loop() - loop.run_until_complete( + + main_function = asyncio.ensure_future( _main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id) ) + + # run main function until keyboard interrupt + try: + loop.run_until_complete(main_function) + except KeyboardInterrupt: + pass + finally: + # make sure main function has a chance to clean up + with suppress(asyncio.CancelledError): + main_function.cancel() + loop.run_until_complete( + main_function + )