disconnect error handling

This commit is contained in:
Robert Martin
2020-04-05 00:02:14 +09:00
parent a6f70588f4
commit af6f9152dd
7 changed files with 229 additions and 122 deletions
+5
View File
@@ -4,6 +4,7 @@ import logging
from aioconsole import ainput from aioconsole import ainput
from joycontrol.controller_state import button_push, ControllerState from joycontrol.controller_state import button_push, ControllerState
from joycontrol.transport import NotConnectedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -116,4 +117,8 @@ class ControllerCLI:
if buttons_to_push: if buttons_to_push:
await button_push(self.controller_state, *buttons_to_push) await button_push(self.controller_state, *buttons_to_push)
else: else:
try:
await self.controller_state.send() await self.controller_state.send()
except NotConnectedError:
logger.info('Connection was lost.')
return
+5 -2
View File
@@ -45,8 +45,11 @@ class ControllerState:
return self._spi_flash return self._spi_flash
async def send(self): async def send(self):
self.sig_is_send.clear() """
await self.sig_is_send.wait() Invokes protocol.send_controller_state(). Returns after the controller state was send.
Raises NotConnected exception if the connection was lost.
"""
await self._protocol.send_controller_state()
async def connect(self): async def connect(self):
""" """
+69 -36
View File
@@ -1,12 +1,15 @@
import asyncio import asyncio
import logging import logging
from asyncio import BaseTransport, BaseProtocol from asyncio import BaseTransport, BaseProtocol
from contextlib import suppress
from typing import Optional, Union, Tuple, Text from typing import Optional, Union, Tuple, Text
from joycontrol import utils
from joycontrol.controller import Controller from joycontrol.controller import Controller
from joycontrol.controller_state import ControllerState from joycontrol.controller_state import ControllerState
from joycontrol.memory import FlashMemory from joycontrol.memory import FlashMemory
from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID
from joycontrol.transport import NotConnectedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,6 +20,7 @@ def controller_protocol_factory(controller: Controller, spi_flash=None):
def create_controller_protocol(): def create_controller_protocol():
return ControllerProtocol(controller, spi_flash=spi_flash) return ControllerProtocol(controller, spi_flash=spi_flash)
return create_controller_protocol return create_controller_protocol
@@ -27,23 +31,48 @@ class ControllerProtocol(BaseProtocol):
self.transport = None self.transport = None
# Increases for each input report send, overflows at 0x100 # Increases for each input report send, should overflow at 0x100
self._input_report_timer = 0x00 self._input_report_timer = 0x00
self._data_received = asyncio.Event() self._data_received = asyncio.Event()
self._controller_state = ControllerState(self, controller, spi_flash=spi_flash) self._controller_state = ControllerState(self, controller, spi_flash=spi_flash)
self._controller_state_sender = None
self._0x30_input_report_sender = None # None = Just answer to sub commands
self._input_report_mode = None
# This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs # This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs
self.sig_set_player_lights = asyncio.Event() self.sig_set_player_lights = asyncio.Event()
async def send_controller_state(self):
"""
Waits for the controller state to be send.
Raises NotConnected exception if the transport is not connected or the connection was lost.
"""
# TODO: Call write directly if not in 0x30 input report mode
if self.transport is None:
raise NotConnectedError('Transport not registered.')
self._controller_state.sig_is_send.clear()
# wrap into a future to be able to set an exception in case of a disconnect
self._controller_state_sender = asyncio.ensure_future(self._controller_state.sig_is_send.wait())
await self._controller_state_sender
self._controller_state_sender = None
async def write(self, input_report: InputReport): async def write(self, input_report: InputReport):
""" """
Sets timer byte and current button state in the input report and sends it. Sets timer byte and current button state in the input report and sends it.
Fires sig_is_send event afterwards. Fires sig_is_send event in the controller state afterwards.
Raises NotConnected exception if the transport is not connected or the connection was lost.
""" """
if self.transport is None:
raise NotConnectedError('Transport not registered.')
# set button and stick data of input report # set button and stick data of input report
input_report.set_button_status(self._controller_state.button_state) input_report.set_button_status(self._controller_state.button_state)
if self._controller_state.l_stick_state is None: if self._controller_state.l_stick_state is None:
@@ -61,6 +90,7 @@ class ControllerProtocol(BaseProtocol):
self._input_report_timer = (self._input_report_timer + 1) % 0x100 self._input_report_timer = (self._input_report_timer + 1) % 0x100
await self.transport.write(input_report) await self.transport.write(input_report)
self._controller_state.sig_is_send.set() self._controller_state.sig_is_send.set()
def get_controller_state(self) -> ControllerState: def get_controller_state(self) -> ControllerState:
@@ -68,7 +98,7 @@ class ControllerProtocol(BaseProtocol):
async def wait_for_output_report(self): async def wait_for_output_report(self):
""" """
Blocks until an output report from the Switch is received. Waits until an output report from the Switch is received.
""" """
self._data_received.clear() self._data_received.clear()
await self._data_received.wait() await self._data_received.wait()
@@ -77,12 +107,17 @@ class ControllerProtocol(BaseProtocol):
logger.debug('Connection established.') logger.debug('Connection established.')
self.transport = transport self.transport = transport
def connection_lost(self, exc: Optional[Exception]) -> None: def connection_lost(self, exc: Optional[Exception] = None) -> None:
# TODO if self.transport is not None:
raise NotImplementedError() logger.error('Connection lost.')
asyncio.ensure_future(self.transport.close())
self.transport = None
if self._controller_state_sender is not None:
self._controller_state_sender.set_exception(NotConnectedError)
def error_received(self, exc: Exception) -> None: def error_received(self, exc: Exception) -> None:
# TODO # TODO?
raise NotImplementedError() raise NotImplementedError()
async def input_report_mode_0x30(self): async def input_report_mode_0x30(self):
@@ -99,7 +134,9 @@ class ControllerProtocol(BaseProtocol):
reader = asyncio.ensure_future(self.transport.read()) reader = asyncio.ensure_future(self.transport.read())
try:
while True: while True:
# TODO: improve timing
if self.controller == Controller.PRO_CONTROLLER: if self.controller == Controller.PRO_CONTROLLER:
# send state at 120Hz # send state at 120Hz
await asyncio.sleep(1 / 120) await asyncio.sleep(1 / 120)
@@ -110,10 +147,6 @@ class ControllerProtocol(BaseProtocol):
reply_send = False reply_send = False
if reader.done(): if reader.done():
data = await reader data = await reader
if not data:
# disconnect happened
logger.error('No data received (most likely due to a disconnect).')
break
reader = asyncio.ensure_future(self.transport.read()) reader = asyncio.ensure_future(self.transport.read())
@@ -135,16 +168,24 @@ class ControllerProtocol(BaseProtocol):
# Hack: Adding a delay here to avoid flooding during pairing # Hack: Adding a delay here to avoid flooding during pairing
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
else: else:
# write 0x30 input report. TODO: set some sensor data # write 0x30 input report.
# TODO: set some sensor data
input_report.set_6axis_data() input_report.set_6axis_data()
await self.write(input_report) await self.write(input_report)
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: except NotConnectedError as err:
if not data: # Stop 0x30 input report mode if disconnected.
# disconnect happened logger.error(err)
logger.error('No data received (most likely due to a disconnect).') finally:
return # cleanup
self._input_report_mode = None
# cancel the reader
with suppress(asyncio.CancelledError, NotConnectedError):
if reader.cancel():
await reader
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None:
self._data_received.set() self._data_received.set()
try: try:
@@ -216,7 +257,7 @@ class ControllerProtocol(BaseProtocol):
else: else:
logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring') logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring')
return False return False
except Exception as err: except NotImplementedError as err:
logger.error(f'Failed to answer {sub_command} - {err}') logger.error(f'Failed to answer {sub_command} - {err}')
return False return False
return True return True
@@ -288,28 +329,20 @@ class ControllerProtocol(BaseProtocol):
await self.write(input_report) await self.write(input_report)
# start sending 0x30 input reports # start sending 0x30 input reports
if self._0x30_input_report_sender is None: if self._input_report_mode != 0x30:
self._input_report_mode = 0x30
self.transport.pause_reading() self.transport.pause_reading()
new_reader = asyncio.ensure_future(self.input_report_mode_0x30())
# create callback to check for exceptions # We need to swap the reader in the future because this function was probably called by it
def callback(future):
try:
future.result()
except asyncio.CancelledError:
# Future may be cancelled at anytime
pass
except Exception as err:
logger.exception(err)
self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30())
self._0x30_input_report_sender.add_done_callback(callback)
# We have to swap the reader in the future because this function was probably called by it
async def set_reader(): async def set_reader():
await self.transport.set_reader(self._0x30_input_report_sender) await self.transport.set_reader(new_reader)
self.transport.resume_reading() self.transport.resume_reading()
asyncio.ensure_future(set_reader()).add_done_callback(callback) asyncio.ensure_future(set_reader()).add_done_callback(
utils.create_error_check_callback()
)
else: else:
logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request') logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request')
+1 -1
View File
@@ -111,7 +111,7 @@ async def create_hid_server(protocol_factory, ctl_psm=17, itr_psm=19, device_id=
client_itr.setblocking(False) client_itr.setblocking(False)
# create transport for the established connection and activate the HID protocol # create transport for the established connection and activate the HID protocol
transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, 50, capture_file=capture_file) transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, client_ctl, 50, capture_file=capture_file)
protocol.connection_made(transport) protocol.connection_made(transport)
# send some empty input reports until the Switch decides to reply # send some empty input reports until the Switch decides to reply
+55 -25
View File
@@ -4,21 +4,31 @@ import struct
import time import time
from typing import Any from typing import Any
from joycontrol import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NotConnectedError(ConnectionResetError):
pass
class L2CAP_Transport(asyncio.Transport): class L2CAP_Transport(asyncio.Transport):
def __init__(self, loop, protocol, l2cap_socket, read_buffer_size, capture_file=None) -> None: def __init__(self, loop, protocol, itr_sock, ctr_sock, read_buffer_size, capture_file=None) -> None:
super(L2CAP_Transport, self).__init__()
self._loop = loop self._loop = loop
self._protocol = protocol self._protocol = protocol
self._sock = l2cap_socket self._itr_sock = itr_sock
self._ctr_sock = ctr_sock
self._read_buffer_size = read_buffer_size self._read_buffer_size = read_buffer_size
self._extra_info = { self._extra_info = {
'peername': self._sock.getpeername(), 'peername': self._itr_sock.getpeername(),
'sockname': self._sock.getsockname(), 'sockname': self._itr_sock.getsockname(),
'socket': self._sock 'socket': self._itr_sock
} }
self._is_closing = False self._is_closing = False
@@ -33,10 +43,14 @@ class L2CAP_Transport(asyncio.Transport):
async def _reader(self): async def _reader(self):
while True: while True:
try:
data = await self.read() data = await self.read()
except NotConnectedError:
self._read_thread = None
break
#logger.debug(f'received "{list(data)}"') #logger.debug(f'received "{list(data)}"')
await self._protocol.report_received(data, self._sock.getpeername()) await self._protocol.report_received(data, self._itr_sock.getpeername())
def start_reader(self): def start_reader(self):
""" """
@@ -47,16 +61,8 @@ class L2CAP_Transport(asyncio.Transport):
self._read_thread = asyncio.ensure_future(self._reader()) self._read_thread = asyncio.ensure_future(self._reader())
# create callback to check for exceptions # Create callback in case the reader is failing
def callback(future): callback = utils.create_error_check_callback(ignore=asyncio.CancelledError)
try:
future.result()
except asyncio.CancelledError:
# Future may be cancelled at anytime
pass
except Exception as err:
logger.exception(err)
self._read_thread.add_done_callback(callback) self._read_thread.add_done_callback(callback)
async def set_reader(self, reader: asyncio.Future): async def set_reader(self, reader: asyncio.Future):
@@ -68,12 +74,16 @@ class L2CAP_Transport(asyncio.Transport):
""" """
if self._read_thread is not None: if self._read_thread is not None:
# cancel currently running reader # cancel currently running reader
self._read_thread.cancel() if self._read_thread.cancel():
try: try:
await self._read_thread await self._read_thread
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Create callback for debugging in case the reader is failing
err_callback = utils.create_error_check_callback(ignore=asyncio.CancelledError)
reader.add_done_callback(err_callback)
self._read_thread = reader self._read_thread = reader
def get_reader(self): def get_reader(self):
@@ -81,13 +91,19 @@ class L2CAP_Transport(asyncio.Transport):
async def read(self): async def read(self):
""" """
Read data from the unterlying socket. This function "blocks", Read data from the underlying socket. This function waits,
if reading is paused using the pause_reading function. if reading is paused using the pause_reading function.
:returns bytes :returns bytes
""" """
await self._is_reading.wait() await self._is_reading.wait()
data = await self._loop.sock_recv(self._sock, self._read_buffer_size) data = await self._loop.sock_recv(self._itr_sock, self._read_buffer_size)
if not data:
# disconnect happened
logger.error('No data received.')
self._protocol.connection_lost()
raise NotConnectedError('No data received.')
if self._capture_file is not None: if self._capture_file is not None:
# write data to log file # write data to log file
@@ -105,13 +121,13 @@ class L2CAP_Transport(asyncio.Transport):
def pause_reading(self) -> None: def pause_reading(self) -> None:
""" """
Pauses the reader Pauses any 'read' function calls.
""" """
self._is_reading.clear() self._is_reading.clear()
def resume_reading(self) -> None: def resume_reading(self) -> None:
""" """
Resumes the reader Resumes all 'read' function calls.
""" """
self._is_reading.set() self._is_reading.set()
@@ -131,10 +147,19 @@ class L2CAP_Transport(asyncio.Transport):
self._capture_file.write(_time + size + _bytes) self._capture_file.write(_time + size + _bytes)
#logger.debug(f'sending "{_bytes}"') #logger.debug(f'sending "{_bytes}"')
await self._loop.sock_sendall(self._sock, _bytes) try:
await self._loop.sock_sendall(self._itr_sock, _bytes)
except OSError as err:
logger.error(err)
self._protocol.connection_lost()
raise NotConnectedError(err)
except ConnectionResetError as err:
logger.error(err)
self._protocol.connection_lost()
raise err
def abort(self) -> None: def abort(self) -> None:
super().abort() raise NotImplementedError
def get_extra_info(self, name: Any, default=None) -> Any: def get_extra_info(self, name: Any, default=None) -> Any:
return self._extra_info.get(name, default) return self._extra_info.get(name, default)
@@ -146,14 +171,19 @@ class L2CAP_Transport(asyncio.Transport):
""" """
Stops reader and closes underlying socket Stops reader and closes underlying socket
""" """
if not self._is_closing:
# was not already closed
self._is_closing = True self._is_closing = True
self._read_thread.cancel() if self._read_thread.cancel():
# wait for reader to cancel # wait for reader to cancel
try: try:
await self._read_thread await self._read_thread
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
self._sock.close()
# interrupt connection should be closed first
self._itr_sock.close()
self._ctr_sock.close()
def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: def set_protocol(self, protocol: asyncio.BaseProtocol) -> None:
self._protocol = protocol self._protocol = protocol
+19
View File
@@ -12,6 +12,25 @@ def flip_bit(value, n):
return value ^ (1 << n) return value ^ (1 << n)
def create_error_check_callback(ignore=None):
"""
Creates callback causing errors of a finished future to be raised.
Useful for debugging futures that are never awaited.
:param ignore: Any number of errors to ignore.
:returns callback which can be added to a future with future.add_done_callback(...)
"""
def callback(future):
if ignore:
try:
future.result()
except ignore:
# ignore suppressed errors
pass
else:
future.result()
return callback
async def run_system_command(cmd): async def run_system_command(cmd):
proc = await asyncio.create_subprocess_shell( proc = await asyncio.create_subprocess_shell(
cmd, cmd,
+24 -7
View File
@@ -2,12 +2,13 @@ import argparse
import asyncio import asyncio
import logging import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager, suppress
from joycontrol import logging_default as log from joycontrol import logging_default as log
from joycontrol.controller_state import ControllerState, button_push from joycontrol.controller_state import ControllerState, button_push
from joycontrol.protocol import controller_protocol_factory, Controller from joycontrol.protocol import controller_protocol_factory, Controller
from joycontrol.server import create_hid_server from joycontrol.server import create_hid_server
from joycontrol.transport import NotConnectedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -65,22 +66,24 @@ async def test_controller_buttons(controller_state: ControllerState):
if 'home' in button_list: if 'home' in button_list:
button_list.remove('home') button_list.remove('home')
# push all buttons consecutively until KeyboardInterrupt # push all buttons consecutively
try:
while True: while True:
for button in button_list: for button in button_list:
await button_push(controller_state, button) await button_push(controller_state, button)
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
async def _main(controller, capture_file=None, spi_flash=None, device_id=None): async def _main(controller, capture_file=None, spi_flash=None, device_id=None):
factory = controller_protocol_factory(controller, spi_flash=spi_flash) factory = controller_protocol_factory(controller, spi_flash=spi_flash)
transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id) transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id)
try:
await test_controller_buttons(protocol.get_controller_state()) await test_controller_buttons(protocol.get_controller_state())
except KeyboardInterrupt:
pass
except NotConnectedError:
logger.error('Connection was lost.')
finally:
logger.info('Stopping communication...') logger.info('Stopping communication...')
await transport.close() await transport.close()
@@ -131,6 +134,20 @@ if __name__ == '__main__':
with get_output(args.log) as capture_file: with get_output(args.log) as capture_file:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(
main_function = asyncio.ensure_future(
_main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id) _main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id)
) )
# run main function until keyboard interrupt
try:
loop.run_until_complete(main_function)
except KeyboardInterrupt:
pass
finally:
# make sure main function has a chance to clean up
with suppress(asyncio.CancelledError):
main_function.cancel()
loop.run_until_complete(
main_function
)