forked from mirror/joycontrol
disconnect error handling
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
|||||||
from aioconsole import ainput
|
from aioconsole import ainput
|
||||||
|
|
||||||
from joycontrol.controller_state import button_push, ControllerState
|
from joycontrol.controller_state import button_push, ControllerState
|
||||||
|
from joycontrol.transport import NotConnectedError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -116,4 +117,8 @@ class ControllerCLI:
|
|||||||
if buttons_to_push:
|
if buttons_to_push:
|
||||||
await button_push(self.controller_state, *buttons_to_push)
|
await button_push(self.controller_state, *buttons_to_push)
|
||||||
else:
|
else:
|
||||||
await self.controller_state.send()
|
try:
|
||||||
|
await self.controller_state.send()
|
||||||
|
except NotConnectedError:
|
||||||
|
logger.info('Connection was lost.')
|
||||||
|
return
|
||||||
|
|||||||
@@ -45,8 +45,11 @@ class ControllerState:
|
|||||||
return self._spi_flash
|
return self._spi_flash
|
||||||
|
|
||||||
async def send(self):
|
async def send(self):
|
||||||
self.sig_is_send.clear()
|
"""
|
||||||
await self.sig_is_send.wait()
|
Invokes protocol.send_controller_state(). Returns after the controller state was send.
|
||||||
|
Raises NotConnected exception if the connection was lost.
|
||||||
|
"""
|
||||||
|
await self._protocol.send_controller_state()
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
+101
-68
@@ -1,12 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from asyncio import BaseTransport, BaseProtocol
|
from asyncio import BaseTransport, BaseProtocol
|
||||||
|
from contextlib import suppress
|
||||||
from typing import Optional, Union, Tuple, Text
|
from typing import Optional, Union, Tuple, Text
|
||||||
|
|
||||||
|
from joycontrol import utils
|
||||||
from joycontrol.controller import Controller
|
from joycontrol.controller import Controller
|
||||||
from joycontrol.controller_state import ControllerState
|
from joycontrol.controller_state import ControllerState
|
||||||
from joycontrol.memory import FlashMemory
|
from joycontrol.memory import FlashMemory
|
||||||
from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID
|
from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID
|
||||||
|
from joycontrol.transport import NotConnectedError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,6 +20,7 @@ def controller_protocol_factory(controller: Controller, spi_flash=None):
|
|||||||
|
|
||||||
def create_controller_protocol():
|
def create_controller_protocol():
|
||||||
return ControllerProtocol(controller, spi_flash=spi_flash)
|
return ControllerProtocol(controller, spi_flash=spi_flash)
|
||||||
|
|
||||||
return create_controller_protocol
|
return create_controller_protocol
|
||||||
|
|
||||||
|
|
||||||
@@ -27,23 +31,48 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
|
|
||||||
self.transport = None
|
self.transport = None
|
||||||
|
|
||||||
# Increases for each input report send, overflows at 0x100
|
# Increases for each input report send, should overflow at 0x100
|
||||||
self._input_report_timer = 0x00
|
self._input_report_timer = 0x00
|
||||||
|
|
||||||
self._data_received = asyncio.Event()
|
self._data_received = asyncio.Event()
|
||||||
|
|
||||||
self._controller_state = ControllerState(self, controller, spi_flash=spi_flash)
|
self._controller_state = ControllerState(self, controller, spi_flash=spi_flash)
|
||||||
|
self._controller_state_sender = None
|
||||||
|
|
||||||
self._0x30_input_report_sender = None
|
# None = Just answer to sub commands
|
||||||
|
self._input_report_mode = None
|
||||||
|
|
||||||
# This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs
|
# This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs
|
||||||
self.sig_set_player_lights = asyncio.Event()
|
self.sig_set_player_lights = asyncio.Event()
|
||||||
|
|
||||||
|
async def send_controller_state(self):
|
||||||
|
"""
|
||||||
|
Waits for the controller state to be send.
|
||||||
|
|
||||||
|
Raises NotConnected exception if the transport is not connected or the connection was lost.
|
||||||
|
"""
|
||||||
|
# TODO: Call write directly if not in 0x30 input report mode
|
||||||
|
|
||||||
|
if self.transport is None:
|
||||||
|
raise NotConnectedError('Transport not registered.')
|
||||||
|
|
||||||
|
self._controller_state.sig_is_send.clear()
|
||||||
|
|
||||||
|
# wrap into a future to be able to set an exception in case of a disconnect
|
||||||
|
self._controller_state_sender = asyncio.ensure_future(self._controller_state.sig_is_send.wait())
|
||||||
|
await self._controller_state_sender
|
||||||
|
self._controller_state_sender = None
|
||||||
|
|
||||||
async def write(self, input_report: InputReport):
|
async def write(self, input_report: InputReport):
|
||||||
"""
|
"""
|
||||||
Sets timer byte and current button state in the input report and sends it.
|
Sets timer byte and current button state in the input report and sends it.
|
||||||
Fires sig_is_send event afterwards.
|
Fires sig_is_send event in the controller state afterwards.
|
||||||
|
|
||||||
|
Raises NotConnected exception if the transport is not connected or the connection was lost.
|
||||||
"""
|
"""
|
||||||
|
if self.transport is None:
|
||||||
|
raise NotConnectedError('Transport not registered.')
|
||||||
|
|
||||||
# set button and stick data of input report
|
# set button and stick data of input report
|
||||||
input_report.set_button_status(self._controller_state.button_state)
|
input_report.set_button_status(self._controller_state.button_state)
|
||||||
if self._controller_state.l_stick_state is None:
|
if self._controller_state.l_stick_state is None:
|
||||||
@@ -61,6 +90,7 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
self._input_report_timer = (self._input_report_timer + 1) % 0x100
|
self._input_report_timer = (self._input_report_timer + 1) % 0x100
|
||||||
|
|
||||||
await self.transport.write(input_report)
|
await self.transport.write(input_report)
|
||||||
|
|
||||||
self._controller_state.sig_is_send.set()
|
self._controller_state.sig_is_send.set()
|
||||||
|
|
||||||
def get_controller_state(self) -> ControllerState:
|
def get_controller_state(self) -> ControllerState:
|
||||||
@@ -68,7 +98,7 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
|
|
||||||
async def wait_for_output_report(self):
|
async def wait_for_output_report(self):
|
||||||
"""
|
"""
|
||||||
Blocks until an output report from the Switch is received.
|
Waits until an output report from the Switch is received.
|
||||||
"""
|
"""
|
||||||
self._data_received.clear()
|
self._data_received.clear()
|
||||||
await self._data_received.wait()
|
await self._data_received.wait()
|
||||||
@@ -77,12 +107,17 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
logger.debug('Connection established.')
|
logger.debug('Connection established.')
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
def connection_lost(self, exc: Optional[Exception] = None) -> None:
|
||||||
# TODO
|
if self.transport is not None:
|
||||||
raise NotImplementedError()
|
logger.error('Connection lost.')
|
||||||
|
asyncio.ensure_future(self.transport.close())
|
||||||
|
self.transport = None
|
||||||
|
|
||||||
|
if self._controller_state_sender is not None:
|
||||||
|
self._controller_state_sender.set_exception(NotConnectedError)
|
||||||
|
|
||||||
def error_received(self, exc: Exception) -> None:
|
def error_received(self, exc: Exception) -> None:
|
||||||
# TODO
|
# TODO?
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def input_report_mode_0x30(self):
|
async def input_report_mode_0x30(self):
|
||||||
@@ -99,52 +134,58 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
|
|
||||||
reader = asyncio.ensure_future(self.transport.read())
|
reader = asyncio.ensure_future(self.transport.read())
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
if self.controller == Controller.PRO_CONTROLLER:
|
while True:
|
||||||
# send state at 120Hz
|
# TODO: improve timing
|
||||||
await asyncio.sleep(1 / 120)
|
if self.controller == Controller.PRO_CONTROLLER:
|
||||||
else:
|
# send state at 120Hz
|
||||||
# send state at 60Hz
|
await asyncio.sleep(1 / 120)
|
||||||
await asyncio.sleep(1 / 60)
|
else:
|
||||||
|
# send state at 60Hz
|
||||||
|
await asyncio.sleep(1 / 60)
|
||||||
|
|
||||||
reply_send = False
|
reply_send = False
|
||||||
if reader.done():
|
if reader.done():
|
||||||
data = await reader
|
data = await reader
|
||||||
if not data:
|
|
||||||
# disconnect happened
|
|
||||||
logger.error('No data received (most likely due to a disconnect).')
|
|
||||||
break
|
|
||||||
|
|
||||||
reader = asyncio.ensure_future(self.transport.read())
|
reader = asyncio.ensure_future(self.transport.read())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
report = OutputReport(list(data))
|
report = OutputReport(list(data))
|
||||||
output_report_id = report.get_output_report_id()
|
output_report_id = report.get_output_report_id()
|
||||||
|
|
||||||
if output_report_id == OutputReportID.RUMBLE_ONLY:
|
if output_report_id == OutputReportID.RUMBLE_ONLY:
|
||||||
# TODO
|
# TODO
|
||||||
pass
|
pass
|
||||||
elif output_report_id == OutputReportID.SUB_COMMAND:
|
elif output_report_id == OutputReportID.SUB_COMMAND:
|
||||||
reply_send = await self._reply_to_sub_command(report)
|
reply_send = await self._reply_to_sub_command(report)
|
||||||
except ValueError as v_err:
|
except ValueError as v_err:
|
||||||
logger.warning(f'Report parsing error "{v_err}" - IGNORE')
|
logger.warning(f'Report parsing error "{v_err}" - IGNORE')
|
||||||
except NotImplementedError as err:
|
except NotImplementedError as err:
|
||||||
logger.warning(err)
|
logger.warning(err)
|
||||||
|
|
||||||
if reply_send:
|
if reply_send:
|
||||||
# Hack: Adding a delay here to avoid flooding during pairing
|
# Hack: Adding a delay here to avoid flooding during pairing
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
else:
|
else:
|
||||||
# write 0x30 input report. TODO: set some sensor data
|
# write 0x30 input report.
|
||||||
input_report.set_6axis_data()
|
# TODO: set some sensor data
|
||||||
await self.write(input_report)
|
input_report.set_6axis_data()
|
||||||
|
|
||||||
|
await self.write(input_report)
|
||||||
|
|
||||||
|
except NotConnectedError as err:
|
||||||
|
# Stop 0x30 input report mode if disconnected.
|
||||||
|
logger.error(err)
|
||||||
|
finally:
|
||||||
|
# cleanup
|
||||||
|
self._input_report_mode = None
|
||||||
|
# cancel the reader
|
||||||
|
with suppress(asyncio.CancelledError, NotConnectedError):
|
||||||
|
if reader.cancel():
|
||||||
|
await reader
|
||||||
|
|
||||||
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None:
|
async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None:
|
||||||
if not data:
|
|
||||||
# disconnect happened
|
|
||||||
logger.error('No data received (most likely due to a disconnect).')
|
|
||||||
return
|
|
||||||
|
|
||||||
self._data_received.set()
|
self._data_received.set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -161,7 +202,7 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
|
|
||||||
if output_report_id == OutputReportID.SUB_COMMAND:
|
if output_report_id == OutputReportID.SUB_COMMAND:
|
||||||
await self._reply_to_sub_command(report)
|
await self._reply_to_sub_command(report)
|
||||||
#elif output_report_id == OutputReportID.RUMBLE_ONLY:
|
# elif output_report_id == OutputReportID.RUMBLE_ONLY:
|
||||||
# pass
|
# pass
|
||||||
else:
|
else:
|
||||||
logger.warning(f'Output report {output_report_id} not implemented - ignoring')
|
logger.warning(f'Output report {output_report_id} not implemented - ignoring')
|
||||||
@@ -216,7 +257,7 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring')
|
logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring')
|
||||||
return False
|
return False
|
||||||
except Exception as err:
|
except NotImplementedError as err:
|
||||||
logger.error(f'Failed to answer {sub_command} - {err}')
|
logger.error(f'Failed to answer {sub_command} - {err}')
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -266,7 +307,7 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
size = sub_command_data[4]
|
size = sub_command_data[4]
|
||||||
|
|
||||||
if self.spi_flash is not None:
|
if self.spi_flash is not None:
|
||||||
spi_flash_data = self.spi_flash[offset: offset+size]
|
spi_flash_data = self.spi_flash[offset: offset + size]
|
||||||
input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data)
|
input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data)
|
||||||
else:
|
else:
|
||||||
spi_flash_data = size * [0x00]
|
spi_flash_data = size * [0x00]
|
||||||
@@ -288,28 +329,20 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
await self.write(input_report)
|
await self.write(input_report)
|
||||||
|
|
||||||
# start sending 0x30 input reports
|
# start sending 0x30 input reports
|
||||||
if self._0x30_input_report_sender is None:
|
if self._input_report_mode != 0x30:
|
||||||
|
self._input_report_mode = 0x30
|
||||||
|
|
||||||
self.transport.pause_reading()
|
self.transport.pause_reading()
|
||||||
|
new_reader = asyncio.ensure_future(self.input_report_mode_0x30())
|
||||||
|
|
||||||
# create callback to check for exceptions
|
# We need to swap the reader in the future because this function was probably called by it
|
||||||
def callback(future):
|
|
||||||
try:
|
|
||||||
future.result()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Future may be cancelled at anytime
|
|
||||||
pass
|
|
||||||
except Exception as err:
|
|
||||||
logger.exception(err)
|
|
||||||
|
|
||||||
self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30())
|
|
||||||
self._0x30_input_report_sender.add_done_callback(callback)
|
|
||||||
|
|
||||||
# We have to swap the reader in the future because this function was probably called by it
|
|
||||||
async def set_reader():
|
async def set_reader():
|
||||||
await self.transport.set_reader(self._0x30_input_report_sender)
|
await self.transport.set_reader(new_reader)
|
||||||
self.transport.resume_reading()
|
self.transport.resume_reading()
|
||||||
|
|
||||||
asyncio.ensure_future(set_reader()).add_done_callback(callback)
|
asyncio.ensure_future(set_reader()).add_done_callback(
|
||||||
|
utils.create_error_check_callback()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request')
|
logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request')
|
||||||
|
|
||||||
@@ -362,7 +395,7 @@ class ControllerProtocol(BaseProtocol):
|
|||||||
# TODO
|
# TODO
|
||||||
data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200]
|
data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200]
|
||||||
for i in range(len(data)):
|
for i in range(len(data)):
|
||||||
input_report.data[16+i] = data[i]
|
input_report.data[16 + i] = data[i]
|
||||||
|
|
||||||
await self.write(input_report)
|
await self.write(input_report)
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ async def create_hid_server(protocol_factory, ctl_psm=17, itr_psm=19, device_id=
|
|||||||
client_itr.setblocking(False)
|
client_itr.setblocking(False)
|
||||||
|
|
||||||
# create transport for the established connection and activate the HID protocol
|
# create transport for the established connection and activate the HID protocol
|
||||||
transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, 50, capture_file=capture_file)
|
transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, client_ctl, 50, capture_file=capture_file)
|
||||||
protocol.connection_made(transport)
|
protocol.connection_made(transport)
|
||||||
|
|
||||||
# send some empty input reports until the Switch decides to reply
|
# send some empty input reports until the Switch decides to reply
|
||||||
|
|||||||
+66
-36
@@ -4,21 +4,31 @@ import struct
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from joycontrol import utils
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotConnectedError(ConnectionResetError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class L2CAP_Transport(asyncio.Transport):
|
class L2CAP_Transport(asyncio.Transport):
|
||||||
def __init__(self, loop, protocol, l2cap_socket, read_buffer_size, capture_file=None) -> None:
|
def __init__(self, loop, protocol, itr_sock, ctr_sock, read_buffer_size, capture_file=None) -> None:
|
||||||
|
super(L2CAP_Transport, self).__init__()
|
||||||
|
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
|
|
||||||
self._sock = l2cap_socket
|
self._itr_sock = itr_sock
|
||||||
|
self._ctr_sock = ctr_sock
|
||||||
|
|
||||||
self._read_buffer_size = read_buffer_size
|
self._read_buffer_size = read_buffer_size
|
||||||
|
|
||||||
self._extra_info = {
|
self._extra_info = {
|
||||||
'peername': self._sock.getpeername(),
|
'peername': self._itr_sock.getpeername(),
|
||||||
'sockname': self._sock.getsockname(),
|
'sockname': self._itr_sock.getsockname(),
|
||||||
'socket': self._sock
|
'socket': self._itr_sock
|
||||||
}
|
}
|
||||||
|
|
||||||
self._is_closing = False
|
self._is_closing = False
|
||||||
@@ -33,10 +43,14 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
|
|
||||||
async def _reader(self):
|
async def _reader(self):
|
||||||
while True:
|
while True:
|
||||||
data = await self.read()
|
try:
|
||||||
|
data = await self.read()
|
||||||
|
except NotConnectedError:
|
||||||
|
self._read_thread = None
|
||||||
|
break
|
||||||
|
|
||||||
#logger.debug(f'received "{list(data)}"')
|
#logger.debug(f'received "{list(data)}"')
|
||||||
await self._protocol.report_received(data, self._sock.getpeername())
|
await self._protocol.report_received(data, self._itr_sock.getpeername())
|
||||||
|
|
||||||
def start_reader(self):
|
def start_reader(self):
|
||||||
"""
|
"""
|
||||||
@@ -47,16 +61,8 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
|
|
||||||
self._read_thread = asyncio.ensure_future(self._reader())
|
self._read_thread = asyncio.ensure_future(self._reader())
|
||||||
|
|
||||||
# create callback to check for exceptions
|
# Create callback in case the reader is failing
|
||||||
def callback(future):
|
callback = utils.create_error_check_callback(ignore=asyncio.CancelledError)
|
||||||
try:
|
|
||||||
future.result()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Future may be cancelled at anytime
|
|
||||||
pass
|
|
||||||
except Exception as err:
|
|
||||||
logger.exception(err)
|
|
||||||
|
|
||||||
self._read_thread.add_done_callback(callback)
|
self._read_thread.add_done_callback(callback)
|
||||||
|
|
||||||
async def set_reader(self, reader: asyncio.Future):
|
async def set_reader(self, reader: asyncio.Future):
|
||||||
@@ -68,11 +74,15 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
"""
|
"""
|
||||||
if self._read_thread is not None:
|
if self._read_thread is not None:
|
||||||
# cancel currently running reader
|
# cancel currently running reader
|
||||||
self._read_thread.cancel()
|
if self._read_thread.cancel():
|
||||||
try:
|
try:
|
||||||
await self._read_thread
|
await self._read_thread
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Create callback for debugging in case the reader is failing
|
||||||
|
err_callback = utils.create_error_check_callback(ignore=asyncio.CancelledError)
|
||||||
|
reader.add_done_callback(err_callback)
|
||||||
|
|
||||||
self._read_thread = reader
|
self._read_thread = reader
|
||||||
|
|
||||||
@@ -81,13 +91,19 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
|
|
||||||
async def read(self):
|
async def read(self):
|
||||||
"""
|
"""
|
||||||
Read data from the unterlying socket. This function "blocks",
|
Read data from the underlying socket. This function waits,
|
||||||
if reading is paused using the pause_reading function.
|
if reading is paused using the pause_reading function.
|
||||||
|
|
||||||
:returns bytes
|
:returns bytes
|
||||||
"""
|
"""
|
||||||
await self._is_reading.wait()
|
await self._is_reading.wait()
|
||||||
data = await self._loop.sock_recv(self._sock, self._read_buffer_size)
|
data = await self._loop.sock_recv(self._itr_sock, self._read_buffer_size)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
# disconnect happened
|
||||||
|
logger.error('No data received.')
|
||||||
|
self._protocol.connection_lost()
|
||||||
|
raise NotConnectedError('No data received.')
|
||||||
|
|
||||||
if self._capture_file is not None:
|
if self._capture_file is not None:
|
||||||
# write data to log file
|
# write data to log file
|
||||||
@@ -105,13 +121,13 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
|
|
||||||
def pause_reading(self) -> None:
|
def pause_reading(self) -> None:
|
||||||
"""
|
"""
|
||||||
Pauses the reader
|
Pauses any 'read' function calls.
|
||||||
"""
|
"""
|
||||||
self._is_reading.clear()
|
self._is_reading.clear()
|
||||||
|
|
||||||
def resume_reading(self) -> None:
|
def resume_reading(self) -> None:
|
||||||
"""
|
"""
|
||||||
Resumes the reader
|
Resumes all 'read' function calls.
|
||||||
"""
|
"""
|
||||||
self._is_reading.set()
|
self._is_reading.set()
|
||||||
|
|
||||||
@@ -131,10 +147,19 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
self._capture_file.write(_time + size + _bytes)
|
self._capture_file.write(_time + size + _bytes)
|
||||||
|
|
||||||
#logger.debug(f'sending "{_bytes}"')
|
#logger.debug(f'sending "{_bytes}"')
|
||||||
await self._loop.sock_sendall(self._sock, _bytes)
|
try:
|
||||||
|
await self._loop.sock_sendall(self._itr_sock, _bytes)
|
||||||
|
except OSError as err:
|
||||||
|
logger.error(err)
|
||||||
|
self._protocol.connection_lost()
|
||||||
|
raise NotConnectedError(err)
|
||||||
|
except ConnectionResetError as err:
|
||||||
|
logger.error(err)
|
||||||
|
self._protocol.connection_lost()
|
||||||
|
raise err
|
||||||
|
|
||||||
def abort(self) -> None:
|
def abort(self) -> None:
|
||||||
super().abort()
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_extra_info(self, name: Any, default=None) -> Any:
|
def get_extra_info(self, name: Any, default=None) -> Any:
|
||||||
return self._extra_info.get(name, default)
|
return self._extra_info.get(name, default)
|
||||||
@@ -146,14 +171,19 @@ class L2CAP_Transport(asyncio.Transport):
|
|||||||
"""
|
"""
|
||||||
Stops reader and closes underlying socket
|
Stops reader and closes underlying socket
|
||||||
"""
|
"""
|
||||||
self._is_closing = True
|
if not self._is_closing:
|
||||||
self._read_thread.cancel()
|
# was not already closed
|
||||||
# wait for reader to cancel
|
self._is_closing = True
|
||||||
try:
|
if self._read_thread.cancel():
|
||||||
await self._read_thread
|
# wait for reader to cancel
|
||||||
except asyncio.CancelledError:
|
try:
|
||||||
pass
|
await self._read_thread
|
||||||
self._sock.close()
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# interrupt connection should be closed first
|
||||||
|
self._itr_sock.close()
|
||||||
|
self._ctr_sock.close()
|
||||||
|
|
||||||
def set_protocol(self, protocol: asyncio.BaseProtocol) -> None:
|
def set_protocol(self, protocol: asyncio.BaseProtocol) -> None:
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
|
|||||||
@@ -12,6 +12,25 @@ def flip_bit(value, n):
|
|||||||
return value ^ (1 << n)
|
return value ^ (1 << n)
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_check_callback(ignore=None):
|
||||||
|
"""
|
||||||
|
Creates callback causing errors of a finished future to be raised.
|
||||||
|
Useful for debugging futures that are never awaited.
|
||||||
|
:param ignore: Any number of errors to ignore.
|
||||||
|
:returns callback which can be added to a future with future.add_done_callback(...)
|
||||||
|
"""
|
||||||
|
def callback(future):
|
||||||
|
if ignore:
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
except ignore:
|
||||||
|
# ignore suppressed errors
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
future.result()
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
async def run_system_command(cmd):
|
async def run_system_command(cmd):
|
||||||
proc = await asyncio.create_subprocess_shell(
|
proc = await asyncio.create_subprocess_shell(
|
||||||
cmd,
|
cmd,
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, suppress
|
||||||
|
|
||||||
from joycontrol import logging_default as log
|
from joycontrol import logging_default as log
|
||||||
from joycontrol.controller_state import ControllerState, button_push
|
from joycontrol.controller_state import ControllerState, button_push
|
||||||
from joycontrol.protocol import controller_protocol_factory, Controller
|
from joycontrol.protocol import controller_protocol_factory, Controller
|
||||||
from joycontrol.server import create_hid_server
|
from joycontrol.server import create_hid_server
|
||||||
|
from joycontrol.transport import NotConnectedError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -65,24 +66,26 @@ async def test_controller_buttons(controller_state: ControllerState):
|
|||||||
if 'home' in button_list:
|
if 'home' in button_list:
|
||||||
button_list.remove('home')
|
button_list.remove('home')
|
||||||
|
|
||||||
# push all buttons consecutively until KeyboardInterrupt
|
# push all buttons consecutively
|
||||||
try:
|
while True:
|
||||||
while True:
|
for button in button_list:
|
||||||
for button in button_list:
|
await button_push(controller_state, button)
|
||||||
await button_push(controller_state, button)
|
await asyncio.sleep(0.1)
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _main(controller, capture_file=None, spi_flash=None, device_id=None):
|
async def _main(controller, capture_file=None, spi_flash=None, device_id=None):
|
||||||
factory = controller_protocol_factory(controller, spi_flash=spi_flash)
|
factory = controller_protocol_factory(controller, spi_flash=spi_flash)
|
||||||
transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id)
|
transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file, device_id=device_id)
|
||||||
|
|
||||||
await test_controller_buttons(protocol.get_controller_state())
|
try:
|
||||||
|
await test_controller_buttons(protocol.get_controller_state())
|
||||||
logger.info('Stopping communication...')
|
except KeyboardInterrupt:
|
||||||
await transport.close()
|
pass
|
||||||
|
except NotConnectedError:
|
||||||
|
logger.error('Connection was lost.')
|
||||||
|
finally:
|
||||||
|
logger.info('Stopping communication...')
|
||||||
|
await transport.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -131,6 +134,20 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
with get_output(args.log) as capture_file:
|
with get_output(args.log) as capture_file:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.run_until_complete(
|
|
||||||
|
main_function = asyncio.ensure_future(
|
||||||
_main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id)
|
_main(controller, capture_file=capture_file, spi_flash=spi_flash, device_id=args.device_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# run main function until keyboard interrupt
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main_function)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
# make sure main function has a chance to clean up
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
main_function.cancel()
|
||||||
|
loop.run_until_complete(
|
||||||
|
main_function
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user