Merge pull request #16 from mart1nro/pause_reading_fix

pause reading in transport now also affects input report modes, added connection lost error handling
This commit is contained in:
Robert Martin
2020-04-05 00:04:02 +09:00
committed by GitHub
7 changed files with 269 additions and 115 deletions
+5
View File
@@ -4,6 +4,7 @@ import logging
from aioconsole import ainput from aioconsole import ainput
from joycontrol.controller_state import button_push, ControllerState from joycontrol.controller_state import button_push, ControllerState
from joycontrol.transport import NotConnectedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -116,4 +117,8 @@ class ControllerCLI:
if buttons_to_push: if buttons_to_push:
await button_push(self.controller_state, *buttons_to_push) await button_push(self.controller_state, *buttons_to_push)
else: else:
try:
await self.controller_state.send() await self.controller_state.send()
except NotConnectedError:
logger.info('Connection was lost.')
return
+5 -2
View File
@@ -45,8 +45,11 @@ class ControllerState:
return self._spi_flash return self._spi_flash
async def send(self): async def send(self):
self.sig_is_send.clear() """
await self.sig_is_send.wait() Invokes protocol.send_controller_state(). Returns after the controller state was send.
Raises NotConnected exception if the connection was lost.
"""
await self._protocol.send_controller_state()
async def connect(self): async def connect(self):
""" """
+74 -29
View File
@@ -1,12 +1,15 @@
import asyncio import asyncio
import logging import logging
from asyncio import BaseTransport, BaseProtocol from asyncio import BaseTransport, BaseProtocol
from contextlib import suppress
from typing import Optional, Union, Tuple, Text from typing import Optional, Union, Tuple, Text
from joycontrol import utils
from joycontrol.controller import Controller from joycontrol.controller import Controller
from joycontrol.controller_state import ControllerState from joycontrol.controller_state import ControllerState
from joycontrol.memory import FlashMemory from joycontrol.memory import FlashMemory
from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID
from joycontrol.transport import NotConnectedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,6 +20,7 @@ def controller_protocol_factory(controller: Controller, spi_flash=None):
def create_controller_protocol(): def create_controller_protocol():
return ControllerProtocol(controller, spi_flash=spi_flash) return ControllerProtocol(controller, spi_flash=spi_flash)
return create_controller_protocol return create_controller_protocol
@@ -27,23 +31,48 @@ class ControllerProtocol(BaseProtocol):
self.transport = None self.transport = None
# Increases for each input report send, overflows at 0x100 # Increases for each input report send, should overflow at 0x100
self._input_report_timer = 0x00 self._input_report_timer = 0x00
self._data_received = asyncio.Event() self._data_received = asyncio.Event()
self._controller_state = ControllerState(self, controller, spi_flash=spi_flash) self._controller_state = ControllerState(self, controller, spi_flash=spi_flash)
self._controller_state_sender = None
self._0x30_input_report_sender = None # None = Just answer to sub commands
self._input_report_mode = None
# This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs # This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs
self.sig_set_player_lights = asyncio.Event() self.sig_set_player_lights = asyncio.Event()
async def send_controller_state(self):
"""
Waits for the controller state to be send.
Raises NotConnected exception if the transport is not connected or the connection was lost.
"""
# TODO: Call write directly if not in 0x30 input report mode
if self.transport is None:
raise NotConnectedError('Transport not registered.')
self._controller_state.sig_is_send.clear()
# wrap into a future to be able to set an exception in case of a disconnect
self._controller_state_sender = asyncio.ensure_future(self._controller_state.sig_is_send.wait())
await self._controller_state_sender
self._controller_state_sender = None
async def write(self, input_report: InputReport): async def write(self, input_report: InputReport):
""" """
Sets timer byte and current button state in the input report and sends it. Sets timer byte and current button state in the input report and sends it.
Fires sig_is_send event afterwards. Fires sig_is_send event in the controller state afterwards.
Raises NotConnected exception if the transport is not connected or the connection was lost.
""" """
if self.transport is None:
raise NotConnectedError('Transport not registered.')
# set button and stick data of input report # set button and stick data of input report
input_report.set_button_status(self._controller_state.button_state) input_report.set_button_status(self._controller_state.button_state)
if self._controller_state.l_stick_state is None: if self._controller_state.l_stick_state is None:
@@ -61,6 +90,7 @@ class ControllerProtocol(BaseProtocol):
self._input_report_timer = (self._input_report_timer + 1) % 0x100 self._input_report_timer = (self._input_report_timer + 1) % 0x100
await self.transport.write(input_report) await self.transport.write(input_report)
self._controller_state.sig_is_send.set() self._controller_state.sig_is_send.set()
def get_controller_state(self) -> ControllerState: def get_controller_state(self) -> ControllerState:
@@ -68,7 +98,7 @@ class ControllerProtocol(BaseProtocol):
async def wait_for_output_report(self): async def wait_for_output_report(self):
""" """
Blocks until an output report from the Switch is received. Waits until an output report from the Switch is received.
""" """
self._data_received.clear() self._data_received.clear()
await self._data_received.wait() await self._data_received.wait()
@@ -77,10 +107,17 @@ class ControllerProtocol(BaseProtocol):
logger.debug('Connection established.') logger.debug('Connection established.')
self.transport = transport self.transport = transport
def connection_lost(self, exc: Optional[Exception]) -> None: def connection_lost(self, exc: Optional[Exception] = None) -> None:
raise NotImplementedError() if self.transport is not None:
logger.error('Connection lost.')
asyncio.ensure_future(self.transport.close())
self.transport = None
if self._controller_state_sender is not None:
self._controller_state_sender.set_exception(NotConnectedError)
def error_received(self, exc: Exception) -> None: def error_received(self, exc: Exception) -> None:
# TODO?
raise NotImplementedError() raise NotImplementedError()
async def input_report_mode_0x30(self): async def input_report_mode_0x30(self):
@@ -97,7 +134,9 @@ class ControllerProtocol(BaseProtocol):
reader = asyncio.ensure_future(self.transport.read()) reader = asyncio.ensure_future(self.transport.read())
try:
while True: while True:
# TODO: improve timing
if self.controller == Controller.PRO_CONTROLLER: if self.controller == Controller.PRO_CONTROLLER:
# send state at 120Hz # send state at 120Hz
await asyncio.sleep(1 / 120) await asyncio.sleep(1 / 120)
@@ -108,10 +147,6 @@ class ControllerProtocol(BaseProtocol):
reply_send = False reply_send = False
if reader.done(): if reader.done():
data = await reader data = await reader
if not data:
# disconnect happened
logger.error('No data received (most likely due to a disconnect).')
break
reader = asyncio.ensure_future(self.transport.read()) reader = asyncio.ensure_future(self.transport.read())
@@ -133,16 +168,24 @@ class ControllerProtocol(BaseProtocol):
# Hack: Adding a delay here to avoid flooding during pairing # Hack: Adding a delay here to avoid flooding during pairing
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
else: else:
# write 0x30 input report. TODO: set some sensor data # write 0x30 input report.
# TODO: set some sensor data
input_report.set_6axis_data() input_report.set_6axis_data()
await self.write(input_report) await self.write(input_report)
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: except NotConnectedError as err:
if not data: # Stop 0x30 input report mode if disconnected.
# disconnect happened logger.error(err)
logger.error('No data received (most likely due to a disconnect).') finally:
return # cleanup
self._input_report_mode = None
# cancel the reader
with suppress(asyncio.CancelledError, NotConnectedError):
if reader.cancel():
await reader
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None:
self._data_received.set() self._data_received.set()
try: try:
@@ -159,7 +202,7 @@ class ControllerProtocol(BaseProtocol):
if output_report_id == OutputReportID.SUB_COMMAND: if output_report_id == OutputReportID.SUB_COMMAND:
await self._reply_to_sub_command(report) await self._reply_to_sub_command(report)
#elif output_report_id == OutputReportID.RUMBLE_ONLY: # elif output_report_id == OutputReportID.RUMBLE_ONLY:
# pass # pass
else: else:
logger.warning(f'Output report {output_report_id} not implemented - ignoring') logger.warning(f'Output report {output_report_id} not implemented - ignoring')
@@ -214,7 +257,7 @@ class ControllerProtocol(BaseProtocol):
else: else:
logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring') logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring')
return False return False
except Exception as err: except NotImplementedError as err:
logger.error(f'Failed to answer {sub_command} - {err}') logger.error(f'Failed to answer {sub_command} - {err}')
return False return False
return True return True
@@ -264,7 +307,7 @@ class ControllerProtocol(BaseProtocol):
size = sub_command_data[4] size = sub_command_data[4]
if self.spi_flash is not None: if self.spi_flash is not None:
spi_flash_data = self.spi_flash[offset: offset+size] spi_flash_data = self.spi_flash[offset: offset + size]
input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data) input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data)
else: else:
spi_flash_data = size * [0x00] spi_flash_data = size * [0x00]
@@ -286,18 +329,20 @@ class ControllerProtocol(BaseProtocol):
await self.write(input_report) await self.write(input_report)
# start sending 0x30 input reports # start sending 0x30 input reports
if self._0x30_input_report_sender is None: if self._input_report_mode != 0x30:
self._input_report_mode = 0x30
self.transport.pause_reading() self.transport.pause_reading()
self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30()) new_reader = asyncio.ensure_future(self.input_report_mode_0x30())
# create callback to check for exceptions # We need to swap the reader in the future because this function was probably called by it
def callback(future): async def set_reader():
try: await self.transport.set_reader(new_reader)
future.result() self.transport.resume_reading()
except Exception as err:
logger.exception(err)
self._0x30_input_report_sender.add_done_callback(callback) asyncio.ensure_future(set_reader()).add_done_callback(
utils.create_error_check_callback()
)
else: else:
logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request') logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request')
@@ -350,7 +395,7 @@ class ControllerProtocol(BaseProtocol):
# TODO # TODO
data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200] data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200]
for i in range(len(data)): for i in range(len(data)):
input_report.data[16+i] = data[i] input_report.data[16 + i] = data[i]
await self.write(input_report) await self.write(input_report)
+1 -1
View File
@@ -111,7 +111,7 @@ async def create_hid_server(protocol_factory, ctl_psm=17, itr_psm=19, device_id=
client_itr.setblocking(False) client_itr.setblocking(False)
# create transport for the established connection and activate the HID protocol # create transport for the established connection and activate the HID protocol
transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, 50, capture_file=capture_file) transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, client_ctl, 50, capture_file=capture_file)
protocol.connection_made(transport) protocol.connection_made(transport)
# send some empty input reports until the Switch decides to reply # send some empty input reports until the Switch decides to reply
+97 -32
View File
@@ -4,55 +4,106 @@ import struct
import time import time
from typing import Any from typing import Any
from joycontrol.report import InputReport from joycontrol import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NotConnectedError(ConnectionResetError):
pass
class L2CAP_Transport(asyncio.Transport): class L2CAP_Transport(asyncio.Transport):
def __init__(self, loop, protocol, l2cap_socket, read_buffer_size, capture_file=None) -> None: def __init__(self, loop, protocol, itr_sock, ctr_sock, read_buffer_size, capture_file=None) -> None:
super(L2CAP_Transport, self).__init__()
self._loop = loop self._loop = loop
self._protocol = protocol self._protocol = protocol
self._sock = l2cap_socket self._itr_sock = itr_sock
self._ctr_sock = ctr_sock
self._read_buffer_size = read_buffer_size self._read_buffer_size = read_buffer_size
self._extra_info = { self._extra_info = {
'peername': self._sock.getpeername(), 'peername': self._itr_sock.getpeername(),
'sockname': self._sock.getsockname(), 'sockname': self._itr_sock.getsockname(),
'socket': self._sock 'socket': self._itr_sock
} }
self._read_thread = asyncio.ensure_future(self._reader())
# create callback to check for exceptions
def callback(future):
try:
future.result()
except Exception as err:
logger.exception(err)
self._read_thread.add_done_callback(callback)
self._is_closing = False self._is_closing = False
self._is_reading = asyncio.Event() self._is_reading = asyncio.Event()
self._is_reading.set()
self._input_report_timer = 0x00
self._capture_file = capture_file self._capture_file = capture_file
# start underlying reader
self._read_thread = None
self._is_reading.set()
self.start_reader()
async def _reader(self): async def _reader(self):
while True: while True:
await self._is_reading.wait() try:
data = await self.read() data = await self.read()
except NotConnectedError:
self._read_thread = None
break
#logger.debug(f'received "{list(data)}"') #logger.debug(f'received "{list(data)}"')
await self._protocol.report_received(data, self._sock.getpeername()) await self._protocol.report_received(data, self._itr_sock.getpeername())
def start_reader(self):
"""
Starts the transport reader which calls the protocols report_received function for every incoming message
"""
if self._read_thread is not None:
raise ValueError('Reader is already running.')
self._read_thread = asyncio.ensure_future(self._reader())
# Create callback in case the reader is failing
callback = utils.create_error_check_callback(ignore=asyncio.CancelledError)
self._read_thread.add_done_callback(callback)
async def set_reader(self, reader: asyncio.Future):
"""
Cancel the currently running reader and register the new one.
A reader is a coroutine that calls this transports 'read' function.
The 'read' function calls can be paused by calling pause_reading of this transport.
:param reader: future reader
"""
if self._read_thread is not None:
# cancel currently running reader
if self._read_thread.cancel():
try:
await self._read_thread
except asyncio.CancelledError:
pass
# Create callback for debugging in case the reader is failing
err_callback = utils.create_error_check_callback(ignore=asyncio.CancelledError)
reader.add_done_callback(err_callback)
self._read_thread = reader
def get_reader(self):
return self._read_thread
async def read(self): async def read(self):
data = await self._loop.sock_recv(self._sock, self._read_buffer_size) """
Read data from the underlying socket. This function waits,
if reading is paused using the pause_reading function.
:returns bytes
"""
await self._is_reading.wait()
data = await self._loop.sock_recv(self._itr_sock, self._read_buffer_size)
if not data:
# disconnect happened
logger.error('No data received.')
self._protocol.connection_lost()
raise NotConnectedError('No data received.')
if self._capture_file is not None: if self._capture_file is not None:
# write data to log file # write data to log file
@@ -66,17 +117,17 @@ class L2CAP_Transport(asyncio.Transport):
""" """
:returns True if the reader is running :returns True if the reader is running
""" """
return self._is_reading.is_set() return self._reader is not None and self._is_reading.is_set()
def pause_reading(self) -> None: def pause_reading(self) -> None:
""" """
Pauses the reader Pauses any 'read' function calls.
""" """
self._is_reading.clear() self._is_reading.clear()
def resume_reading(self) -> None: def resume_reading(self) -> None:
""" """
Resumes the reader Resumes all 'read' function calls.
""" """
self._is_reading.set() self._is_reading.set()
@@ -96,10 +147,19 @@ class L2CAP_Transport(asyncio.Transport):
self._capture_file.write(_time + size + _bytes) self._capture_file.write(_time + size + _bytes)
#logger.debug(f'sending "{_bytes}"') #logger.debug(f'sending "{_bytes}"')
await self._loop.sock_sendall(self._sock, _bytes) try:
await self._loop.sock_sendall(self._itr_sock, _bytes)
except OSError as err:
logger.error(err)
self._protocol.connection_lost()
raise NotConnectedError(err)
except ConnectionResetError as err:
logger.error(err)
self._protocol.connection_lost()
raise err
def abort(self) -> None: def abort(self) -> None:
super().abort() raise NotImplementedError
def get_extra_info(self, name: Any, default=None) -> Any: def get_extra_info(self, name: Any, default=None) -> Any:
return self._extra_info.get(name, default) return self._extra_info.get(name, default)
@@ -109,16 +169,21 @@ class L2CAP_Transport(asyncio.Transport):
async def close(self): async def close(self):
""" """
Stops socket reader and closes socket Stops reader and closes underlying socket
""" """
if not self._is_closing:
# was not already closed
self._is_closing = True self._is_closing = True
self._read_thread.cancel() if self._read_thread.cancel():
# wait for reader to cancel # wait for reader to cancel
try: try:
await self._read_thread await self._read_thread
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
self._sock.close()
# interrupt connection should be closed first
self._itr_sock.close()
self._ctr_sock.close()
def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: def set_protocol(self, protocol: asyncio.BaseProtocol) -> None:
self._protocol = protocol self._protocol = protocol
+19
View File
@@ -12,6 +12,25 @@ def flip_bit(value, n):
return value ^ (1 << n) return value ^ (1 << n)
def create_error_check_callback(ignore=None):
"""
Creates callback causing errors of a finished future to be raised.
Useful for debugging futures that are never awaited.
:param ignore: Any number of errors to ignore.
:returns callback which can be added to a future with future.add_done_callback(...)
"""
def callback(future):
if ignore:
try:
future.result()
except ignore:
# ignore suppressed errors
pass
else:
future.result()
return callback
async def run_system_command(cmd): async def run_system_command(cmd):
proc = await asyncio.create_subprocess_shell( proc = await asyncio.create_subprocess_shell(
cmd, cmd,
+24 -7
View File
@@ -2,12 +2,13 @@ import argparse
import asyncio import asyncio
import logging import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager, suppress
from joycontrol import logging_default as log from joycontrol import logging_default as log
from joycontrol.controller_state import ControllerState, button_push from joycontrol.controller_state import ControllerState, button_push
from joycontrol.protocol import controller_protocol_factory, Controller from joycontrol.protocol import controller_protocol_factory, Controller
from joycontrol.server import create_hid_server from joycontrol.server import create_hid_server
from joycontrol.transport import NotConnectedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -65,22 +66,24 @@ async def test_controller_buttons(controller_state: ControllerState):
if 'home' in button_list: if 'home' in button_list:
button_list.remove('home') button_list.remove('home')
# push all buttons consecutively until KeyboardInterrupt # push all buttons consecutively
try:
while True: while True:
for button in button_list: for button in button_list:
await button_push(controller_state, button) await button_push(controller_state, button)
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
async def _main(controller, capture_file=None, spi_flash=None, device_id=None): async def _main(controller, capture_file=None, spi_flash=None, device_id=None):
factory = controller_protocol_factory(controller, spi_flash=spi_flash) factory = controller_protocol_factory(controller, spi_flash=spi_flash)
transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id) transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id)
try:
await test_controller_buttons(protocol.get_controller_state()) await test_controller_buttons(protocol.get_controller_state())
except KeyboardInterrupt:
pass
except NotConnectedError:
logger.error('Connection was lost.')
finally:
logger.info('Stopping communication...') logger.info('Stopping communication...')
await transport.close() await transport.close()
@@ -131,6 +134,20 @@ if __name__ == '__main__':
with get_output(args.log) as capture_file: with get_output(args.log) as capture_file:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(
main_function = asyncio.ensure_future(
_main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id) _main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id)
) )
# run main function until keyboard interrupt
try:
loop.run_until_complete(main_function)
except KeyboardInterrupt:
pass
finally:
# make sure main function has a chance to clean up
with suppress(asyncio.CancelledError):
main_function.cancel()
loop.run_until_complete(
main_function
)