forked from mirror/joycontrol
disconnect error handling
This commit is contained in:
+101
-68
@@ -1,12 +1,15 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import BaseTransport, BaseProtocol
|
||||
from contextlib import suppress
|
||||
from typing import Optional, Union, Tuple, Text
|
||||
|
||||
from joycontrol import utils
|
||||
from joycontrol.controller import Controller
|
||||
from joycontrol.controller_state import ControllerState
|
||||
from joycontrol.memory import FlashMemory
|
||||
from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID
|
||||
from joycontrol.transport import NotConnectedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,6 +20,7 @@ def controller_protocol_factory(controller: Controller, spi_flash=None):
|
||||
|
||||
def create_controller_protocol():
|
||||
return ControllerProtocol(controller, spi_flash=spi_flash)
|
||||
|
||||
return create_controller_protocol
|
||||
|
||||
|
||||
@@ -27,23 +31,48 @@ class ControllerProtocol(BaseProtocol):
|
||||
|
||||
self.transport = None
|
||||
|
||||
# Increases for each input report send, overflows at 0x100
|
||||
# Increases for each input report send, should overflow at 0x100
|
||||
self._input_report_timer = 0x00
|
||||
|
||||
self._data_received = asyncio.Event()
|
||||
|
||||
self._controller_state = ControllerState(self, controller, spi_flash=spi_flash)
|
||||
self._controller_state_sender = None
|
||||
|
||||
self._0x30_input_report_sender = None
|
||||
# None = Just answer to sub commands
|
||||
self._input_report_mode = None
|
||||
|
||||
# This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs
|
||||
self.sig_set_player_lights = asyncio.Event()
|
||||
|
||||
async def send_controller_state(self):
|
||||
"""
|
||||
Waits for the controller state to be send.
|
||||
|
||||
Raises NotConnected exception if the transport is not connected or the connection was lost.
|
||||
"""
|
||||
# TODO: Call write directly if not in 0x30 input report mode
|
||||
|
||||
if self.transport is None:
|
||||
raise NotConnectedError('Transport not registered.')
|
||||
|
||||
self._controller_state.sig_is_send.clear()
|
||||
|
||||
# wrap into a future to be able to set an exception in case of a disconnect
|
||||
self._controller_state_sender = asyncio.ensure_future(self._controller_state.sig_is_send.wait())
|
||||
await self._controller_state_sender
|
||||
self._controller_state_sender = None
|
||||
|
||||
async def write(self, input_report: InputReport):
|
||||
"""
|
||||
Sets timer byte and current button state in the input report and sends it.
|
||||
Fires sig_is_send event afterwards.
|
||||
Fires sig_is_send event in the controller state afterwards.
|
||||
|
||||
Raises NotConnected exception if the transport is not connected or the connection was lost.
|
||||
"""
|
||||
if self.transport is None:
|
||||
raise NotConnectedError('Transport not registered.')
|
||||
|
||||
# set button and stick data of input report
|
||||
input_report.set_button_status(self._controller_state.button_state)
|
||||
if self._controller_state.l_stick_state is None:
|
||||
@@ -61,6 +90,7 @@ class ControllerProtocol(BaseProtocol):
|
||||
self._input_report_timer = (self._input_report_timer + 1) % 0x100
|
||||
|
||||
await self.transport.write(input_report)
|
||||
|
||||
self._controller_state.sig_is_send.set()
|
||||
|
||||
def get_controller_state(self) -> ControllerState:
|
||||
@@ -68,7 +98,7 @@ class ControllerProtocol(BaseProtocol):
|
||||
|
||||
async def wait_for_output_report(self):
|
||||
"""
|
||||
Blocks until an output report from the Switch is received.
|
||||
Waits until an output report from the Switch is received.
|
||||
"""
|
||||
self._data_received.clear()
|
||||
await self._data_received.wait()
|
||||
@@ -77,12 +107,17 @@ class ControllerProtocol(BaseProtocol):
|
||||
logger.debug('Connection established.')
|
||||
self.transport = transport
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
def connection_lost(self, exc: Optional[Exception] = None) -> None:
|
||||
if self.transport is not None:
|
||||
logger.error('Connection lost.')
|
||||
asyncio.ensure_future(self.transport.close())
|
||||
self.transport = None
|
||||
|
||||
if self._controller_state_sender is not None:
|
||||
self._controller_state_sender.set_exception(NotConnectedError)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
# TODO
|
||||
# TODO?
|
||||
raise NotImplementedError()
|
||||
|
||||
async def input_report_mode_0x30(self):
|
||||
@@ -99,52 +134,58 @@ class ControllerProtocol(BaseProtocol):
|
||||
|
||||
reader = asyncio.ensure_future(self.transport.read())
|
||||
|
||||
while True:
|
||||
if self.controller == Controller.PRO_CONTROLLER:
|
||||
# send state at 120Hz
|
||||
await asyncio.sleep(1 / 120)
|
||||
else:
|
||||
# send state at 60Hz
|
||||
await asyncio.sleep(1 / 60)
|
||||
try:
|
||||
while True:
|
||||
# TODO: improve timing
|
||||
if self.controller == Controller.PRO_CONTROLLER:
|
||||
# send state at 120Hz
|
||||
await asyncio.sleep(1 / 120)
|
||||
else:
|
||||
# send state at 60Hz
|
||||
await asyncio.sleep(1 / 60)
|
||||
|
||||
reply_send = False
|
||||
if reader.done():
|
||||
data = await reader
|
||||
if not data:
|
||||
# disconnect happened
|
||||
logger.error('No data received (most likely due to a disconnect).')
|
||||
break
|
||||
reply_send = False
|
||||
if reader.done():
|
||||
data = await reader
|
||||
|
||||
reader = asyncio.ensure_future(self.transport.read())
|
||||
reader = asyncio.ensure_future(self.transport.read())
|
||||
|
||||
try:
|
||||
report = OutputReport(list(data))
|
||||
output_report_id = report.get_output_report_id()
|
||||
try:
|
||||
report = OutputReport(list(data))
|
||||
output_report_id = report.get_output_report_id()
|
||||
|
||||
if output_report_id == OutputReportID.RUMBLE_ONLY:
|
||||
# TODO
|
||||
pass
|
||||
elif output_report_id == OutputReportID.SUB_COMMAND:
|
||||
reply_send = await self._reply_to_sub_command(report)
|
||||
except ValueError as v_err:
|
||||
logger.warning(f'Report parsing error "{v_err}" - IGNORE')
|
||||
except NotImplementedError as err:
|
||||
logger.warning(err)
|
||||
if output_report_id == OutputReportID.RUMBLE_ONLY:
|
||||
# TODO
|
||||
pass
|
||||
elif output_report_id == OutputReportID.SUB_COMMAND:
|
||||
reply_send = await self._reply_to_sub_command(report)
|
||||
except ValueError as v_err:
|
||||
logger.warning(f'Report parsing error "{v_err}" - IGNORE')
|
||||
except NotImplementedError as err:
|
||||
logger.warning(err)
|
||||
|
||||
if reply_send:
|
||||
# Hack: Adding a delay here to avoid flooding during pairing
|
||||
await asyncio.sleep(0.3)
|
||||
else:
|
||||
# write 0x30 input report. TODO: set some sensor data
|
||||
input_report.set_6axis_data()
|
||||
await self.write(input_report)
|
||||
if reply_send:
|
||||
# Hack: Adding a delay here to avoid flooding during pairing
|
||||
await asyncio.sleep(0.3)
|
||||
else:
|
||||
# write 0x30 input report.
|
||||
# TODO: set some sensor data
|
||||
input_report.set_6axis_data()
|
||||
|
||||
await self.write(input_report)
|
||||
|
||||
except NotConnectedError as err:
|
||||
# Stop 0x30 input report mode if disconnected.
|
||||
logger.error(err)
|
||||
finally:
|
||||
# cleanup
|
||||
self._input_report_mode = None
|
||||
# cancel the reader
|
||||
with suppress(asyncio.CancelledError, NotConnectedError):
|
||||
if reader.cancel():
|
||||
await reader
|
||||
|
||||
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None:
|
||||
if not data:
|
||||
# disconnect happened
|
||||
logger.error('No data received (most likely due to a disconnect).')
|
||||
return
|
||||
|
||||
self._data_received.set()
|
||||
|
||||
try:
|
||||
@@ -161,7 +202,7 @@ class ControllerProtocol(BaseProtocol):
|
||||
|
||||
if output_report_id == OutputReportID.SUB_COMMAND:
|
||||
await self._reply_to_sub_command(report)
|
||||
#elif output_report_id == OutputReportID.RUMBLE_ONLY:
|
||||
# elif output_report_id == OutputReportID.RUMBLE_ONLY:
|
||||
# pass
|
||||
else:
|
||||
logger.warning(f'Output report {output_report_id} not implemented - ignoring')
|
||||
@@ -216,7 +257,7 @@ class ControllerProtocol(BaseProtocol):
|
||||
else:
|
||||
logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring')
|
||||
return False
|
||||
except Exception as err:
|
||||
except NotImplementedError as err:
|
||||
logger.error(f'Failed to answer {sub_command} - {err}')
|
||||
return False
|
||||
return True
|
||||
@@ -266,7 +307,7 @@ class ControllerProtocol(BaseProtocol):
|
||||
size = sub_command_data[4]
|
||||
|
||||
if self.spi_flash is not None:
|
||||
spi_flash_data = self.spi_flash[offset: offset+size]
|
||||
spi_flash_data = self.spi_flash[offset: offset + size]
|
||||
input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data)
|
||||
else:
|
||||
spi_flash_data = size * [0x00]
|
||||
@@ -288,28 +329,20 @@ class ControllerProtocol(BaseProtocol):
|
||||
await self.write(input_report)
|
||||
|
||||
# start sending 0x30 input reports
|
||||
if self._0x30_input_report_sender is None:
|
||||
if self._input_report_mode != 0x30:
|
||||
self._input_report_mode = 0x30
|
||||
|
||||
self.transport.pause_reading()
|
||||
new_reader = asyncio.ensure_future(self.input_report_mode_0x30())
|
||||
|
||||
# create callback to check for exceptions
|
||||
def callback(future):
|
||||
try:
|
||||
future.result()
|
||||
except asyncio.CancelledError:
|
||||
# Future may be cancelled at anytime
|
||||
pass
|
||||
except Exception as err:
|
||||
logger.exception(err)
|
||||
|
||||
self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30())
|
||||
self._0x30_input_report_sender.add_done_callback(callback)
|
||||
|
||||
# We have to swap the reader in the future because this function was probably called by it
|
||||
# We need to swap the reader in the future because this function was probably called by it
|
||||
async def set_reader():
|
||||
await self.transport.set_reader(self._0x30_input_report_sender)
|
||||
await self.transport.set_reader(new_reader)
|
||||
self.transport.resume_reading()
|
||||
|
||||
asyncio.ensure_future(set_reader()).add_done_callback(callback)
|
||||
asyncio.ensure_future(set_reader()).add_done_callback(
|
||||
utils.create_error_check_callback()
|
||||
)
|
||||
else:
|
||||
logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request')
|
||||
|
||||
@@ -362,7 +395,7 @@ class ControllerProtocol(BaseProtocol):
|
||||
# TODO
|
||||
data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200]
|
||||
for i in range(len(data)):
|
||||
input_report.data[16+i] = data[i]
|
||||
input_report.data[16 + i] = data[i]
|
||||
|
||||
await self.write(input_report)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user