diff --git a/README.md b/README.md index 1b7a7b7..0c88ad6 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,60 @@ # joycontrol Emulate Nintendo Switch Controllers over Bluetooth. -Work in progress. +Tested on Ubuntu 19.10, and with Raspberry Pi 3B+ and 4B Raspbian GNU/Linux 10 (buster) -Pairing works, emulated controller shows up in the "Change Grip/Order" menu of the Switch. - -Tested on Ubuntu 19.10 and with Raspberry Pi 4B Raspbian GNU/Linux 10 (buster) +## Features +Emulation of JOYCON_R, JOYCON_L and PRO_CONTROLLER. Able to send: +- button commands +- stick state +- nfc data ## Installation -- Install dbus-python package +- Install dependencies + +Ubuntu: Install the `dbus-python` and `libhidapi-hidraw0` packages ```bash -sudo apt install python3-dbus +sudo apt install python3-dbus libhidapi-hidraw0 ``` + +Arch Linux Derivatives: Install the `hidapi` and `bluez-utils-compat`(AUR) packages + + - Clone the repository and install the joycontrol package to get missing dependencies (Note: Controller script needs super user rights, so python packages must be installed as root). In the joycontrol folder run: ```bash sudo pip3 install . ``` +- Disable the bluez "input" plugin, see [#8](https://github.com/mart1nro/joycontrol/issues/8) -## "Test Controller Buttons" example +## Command line interface example - Run the script ```bash -sudo python3 run_test_controller_buttons.py +sudo python3 run_controller_cli.py PRO_CONTROLLER ``` +This will create a PRO_CONTROLLER instance waiting for the Switch to connect. + - Open the "Change Grip/Order" menu of the Switch -- The emulated controller should pair with the Switch and automatically navigate to the "Test Controller Buttons" menu + +The Switch only pairs with new controllers in the "Change Grip/Order" menu. + +Note: If you already connected an emulated controller once, you can use the reconnect option of the script (-r "\"). +This does not require the "Change Grip/Order" menu to be opened. You can find out a paired mac address using the "bluetoothctl" system command. + +- After connecting, a command line interface is opened. Note: Press \ if you don't see a prompt. + +Call "help" to see a list of available commands. + +- If you call "test_buttons", the emulated controller automatically navigates to the "Test Controller Buttons" menu. + ## Issues -- When using a Raspberry Pi 4B the connection drops after some time. Might be a hardware issue, since it works fine on my laptop. Using a different bluetooth adapter may help, but haven't tested it yet. +- Some bluetooth adapters seem to cause disconnects for reasons unknown, try to use an usb adapter instead +- Incompatibility with Bluetooth "input" plugin requires a bluetooth restart, see [#8](https://github.com/mart1nro/joycontrol/issues/8) +- It seems like the Switch is slower processing incoming messages while in the "Change Grip/Order" menu. + This causes flooding of packets and makes pairing somewhat inconsistent. + Not sure yet what exactly a real controller does to prevent that. + A workaround is to use the reconnect option after a controller was paired once, so that + opening of the "Change Grip/Order" menu is not required. - ... diff --git a/joycontrol/command_line_interface.py b/joycontrol/command_line_interface.py index b7aead9..d3b5f90 100644 --- a/joycontrol/command_line_interface.py +++ b/joycontrol/command_line_interface.py @@ -1,28 +1,113 @@ import inspect import logging +import shlex from aioconsole import ainput from joycontrol.controller_state import button_push, ControllerState +from joycontrol.transport import NotConnectedError logger = logging.getLogger(__name__) -class ControllerCLI: - def __init__(self, controller_state: ControllerState): - self.controller_state = controller_state +def _print_doc(string): + """ + Attempts to remove common white space at the start of the lines in a doc string + to unify the output of doc strings with different indention levels. + + Keeps whitespace lines intact. + + :param fun: function to print the doc string of + """ + lines = string.split('\n') + if lines: + prefix_i = 0 + for i, line_0 in enumerate(lines): + # find non empty start lines + if line_0.strip(): + # traverse line and stop if character mismatch with other non empty lines + for prefix_i, c in enumerate(line_0): + if not c.isspace(): + break + if any(lines[j].strip() and (prefix_i >= len(lines[j]) or c != lines[j][prefix_i]) + for j in range(i+1, len(lines))): + break + break + + for line in lines: + print(line[prefix_i:] if line.strip() else line) + + +class CLI: + def __init__(self): self.commands = {} - async def cmd_help(self): - print('Buttons can be used as commands: ', ', '.join(self.controller_state.button_state.get_available_buttons())) + def add_command(self, name, command): + if name in self.commands: + raise ValueError(f'Command {name} already registered.') + self.commands[name] = command + async def cmd_help(self): + print('Commands:') for name, fun in inspect.getmembers(self): if name.startswith('cmd_') and fun.__doc__: - print(fun.__doc__) + _print_doc(fun.__doc__) + + for name, fun in self.commands.items(): + if fun.__doc__: + _print_doc(fun.__doc__) print('Commands can be chained using "&&"') print('Type "exit" to close.') + async def run(self): + while True: + user_input = await ainput(prompt='cmd >> ') + if not user_input: + continue + + for command in user_input.split('&&'): + cmd, *args = shlex.split(command) + + if cmd == 'exit': + return + + if hasattr(self, f'cmd_{cmd}'): + try: + result = await getattr(self, f'cmd_{cmd}')(*args) + if result: + print(result) + except Exception as e: + print(e) + elif cmd in self.commands: + try: + result = await self.commands[cmd](*args) + if result: + print(result) + except Exception as e: + print(e) + else: + print('command', cmd, 'not found, call help for help.') + + @staticmethod + def deprecated(message): + async def dep_printer(*args, **kwargs): + print(message) + + return dep_printer + + +class ControllerCLI(CLI): + def __init__(self, controller_state: ControllerState): + super().__init__() + self.controller_state = controller_state + + async def cmd_help(self): + print('Button commands:') + print(', '.join(self.controller_state.button_state.get_available_buttons())) + print() + await super().cmd_help() + @staticmethod def _set_stick(stick, direction, value): if direction == 'center': @@ -61,7 +146,7 @@ class ControllerCLI: stick - Command to set stick positions. :param side: 'l', 'left' for left control stick; 'r', 'right' for right control stick :param direction: 'center', 'up', 'down', 'left', 'right'; - 'h', 'horizontal' or 'v', 'vertical' to set the value directly to the "value" argument + 'h', 'horizontal' or 'v', 'vertical' to set the value directly to the "value" argument :param value: horizontal or vertical value """ if side in ('l', 'left'): @@ -73,11 +158,6 @@ class ControllerCLI: else: raise ValueError('Value of side must be "l", "left" or "r", "right"') - def add_command(self, name, command): - if name in self.commands: - raise ValueError(f'Command {name} already registered.') - self.commands[name] = command - async def run(self): while True: user_input = await ainput(prompt='cmd >> ') @@ -87,7 +167,7 @@ class ControllerCLI: buttons_to_push = [] for command in user_input.split('&&'): - cmd, *args = command.split() + cmd, *args = shlex.split(command) if cmd == 'exit': return @@ -103,7 +183,7 @@ class ControllerCLI: print(e) elif cmd in self.commands: try: - result = await self.commands[cmd](self, *args) + result = await self.commands[cmd](*args) if result: print(result) except Exception as e: @@ -116,4 +196,8 @@ class ControllerCLI: if buttons_to_push: await button_push(self.controller_state, *buttons_to_push) else: - await self.controller_state.send() + try: + await self.controller_state.send() + except NotConnectedError: + logger.info('Connection was lost.') + return diff --git a/joycontrol/controller.py b/joycontrol/controller.py index 33da3ae..3faecb4 100644 --- a/joycontrol/controller.py +++ b/joycontrol/controller.py @@ -18,3 +18,14 @@ class Controller(enum.Enum): return 'Pro Controller' else: raise NotImplementedError() + + @staticmethod + def from_arg(arg): + if arg == 'JOYCON_R': + return Controller.JOYCON_R + elif arg == 'JOYCON_L': + return Controller.JOYCON_L + elif arg == 'PRO_CONTROLLER': + return Controller.PRO_CONTROLLER + else: + raise ValueError(f'Unknown controller "{arg}".') diff --git a/joycontrol/controller_state.py b/joycontrol/controller_state.py index 6e2adf7..8dadd55 100644 --- a/joycontrol/controller_state.py +++ b/joycontrol/controller_state.py @@ -9,6 +9,7 @@ class ControllerState: def __init__(self, protocol, controller: Controller, spi_flash: FlashMemory = None): self._protocol = protocol self._controller = controller + self._nfc_content = None self._spi_flash = spi_flash @@ -26,6 +27,8 @@ class ControllerState: calibration = LeftStickCalibration.from_bytes(calibration_data) self.l_stick_state = StickState(calibration=calibration) + if calibration is not None: + self.l_stick_state.set_center() # create right stick state if controller in (Controller.PRO_CONTROLLER, Controller.JOYCON_R): @@ -38,15 +41,29 @@ class ControllerState: calibration = RightStickCalibration.from_bytes(calibration_data) self.r_stick_state = StickState(calibration=calibration) + if calibration is not None: + self.r_stick_state.set_center() self.sig_is_send = asyncio.Event() + def get_controller(self): + return self._controller + def get_flash_memory(self): return self._spi_flash + def set_nfc(self, nfc_content): + self._nfc_content = nfc_content + + def get_nfc(self): + return self._nfc_content + async def send(self): - self.sig_is_send.clear() - await self.sig_is_send.wait() + """ + Invokes protocol.send_controller_state(). Returns after the controller state was send. + Raises NotConnected exception if the connection was lost. + """ + await self._protocol.send_controller_state() async def connect(self): """ @@ -160,7 +177,7 @@ class ButtonState: def __iter__(self): """ - @returns iterator over the button bytes + :returns iterator over the button bytes """ yield self._byte_1 yield self._byte_2 diff --git a/joycontrol/device.py b/joycontrol/device.py index 1a42281..f92c0d6 100644 --- a/joycontrol/device.py +++ b/joycontrol/device.py @@ -1,6 +1,5 @@ import logging import uuid - import dbus from joycontrol import utils @@ -8,43 +7,82 @@ from joycontrol import utils logger = logging.getLogger(__name__) +HID_UUID = '00001124-0000-1000-8000-00805f9b34fb' +HID_PATH = '/bluez/switch/hid' + + class HidDevice: - _HID_UUID = '00001124-0000-1000-8000-00805f9b34fb' - _HID_PATH = '/bluez/switch/hid' - - def __init__(self): - self._uuid = str(uuid.uuid4()) - - # Setting up dbus to advertise the service record + def __init__(self, device_id=None): bus = dbus.SystemBus() - obj = bus.get_object('org.bluez', '/org/bluez/hci0') - self.adapter = dbus.Interface(obj, 'org.bluez.Adapter1') - self.properties = dbus.Interface(self.adapter, 'org.freedesktop.DBus.Properties') + + # Get Bluetooth adapter from dbus interface + manager = dbus.Interface(bus.get_object('org.bluez', '/'), 'org.freedesktop.DBus.ObjectManager') + for path, ifaces in manager.GetManagedObjects().items(): + adapter_info = ifaces.get('org.bluez.Adapter1') + if adapter_info is None: + continue + elif device_id is None or device_id == adapter_info['Address'] or path.endswith(str(device_id)): + obj = bus.get_object('org.bluez', path) + self.adapter = dbus.Interface(obj, 'org.bluez.Adapter1') + self.address = adapter_info['Address'] + self._adapter_name = path.split('/')[-1] + + self.properties = dbus.Interface(self.adapter, 'org.freedesktop.DBus.Properties') + break + else: + raise ValueError(f'Adapter {device_id} not found.') + + def get_address(self) -> str: + """ + :returns adapter Bluetooth address + """ + return self.address + + def powered(self, boolean=True): + self.properties.Set(self.adapter.dbus_interface, 'Powered', boolean) def discoverable(self, boolean=True): - #self.properties.Set(self.adapter.dbus_interface, 'Powered', True) + """ + Make adapter discoverable, starts advertising. + """ self.properties.Set(self.adapter.dbus_interface, 'Discoverable', boolean) - async def set_class(self, cls=0x002508): + def pairable(self, boolean=True): """ + Make adapter pairable + """ + self.properties.Set(self.adapter.dbus_interface, 'Pairable', boolean) + + async def set_class(self, cls='0x002508'): + """ + Sets Bluetooth device class. Requires hciconfig system command. :param cls: default 0x002508 (Gamepad/joystick device class) """ logger.info(f'setting device class to {cls}...') - await utils.run_system_command(f'hciconfig hci0 class {cls}') + await utils.run_system_command(f'hciconfig {self._adapter_name} class {cls}') async def set_name(self, name: str): + """ + Set Bluetooth device name. + :param name: to set. + """ logger.info(f'setting device name to {name}...') - await utils.run_system_command(f'hciconfig hci0 name "{name}"') + self.properties.Set(self.adapter.dbus_interface, 'Alias', name) + + @staticmethod + def register_sdp_record(record_path): + _uuid = str(uuid.uuid4()) - def register_sdp_record(self, record_path): with open(record_path) as record: opts = { 'ServiceRecord': record.read(), 'Role': 'server', - 'Service': self._HID_UUID, + 'Service': HID_UUID, 'RequireAuthentication': False, 'RequireAuthorization': False } bus = dbus.SystemBus() manager = dbus.Interface(bus.get_object("org.bluez", "/org/bluez"), "org.bluez.ProfileManager1") - manager.RegisterProfile(self._HID_PATH, self._uuid, opts) \ No newline at end of file + manager.RegisterProfile(HID_PATH, _uuid, opts) + + return _uuid diff --git a/joycontrol/ir_nfc_mcu.py b/joycontrol/ir_nfc_mcu.py new file mode 100644 index 0000000..50b387b --- /dev/null +++ b/joycontrol/ir_nfc_mcu.py @@ -0,0 +1,155 @@ +import logging +from enum import Enum +from crc8 import crc8 + +logger = logging.getLogger(__name__) + + +class Action(Enum): + NON = 0 + REQUEST_STATUS = 1 + START_TAG_POLLING = 2 + START_TAG_DISCOVERY = 3 + READ_TAG = 4 + READ_TAG_2 = 5 + READ_FINISHED = 6 + + +class McuState(Enum): + NOT_INITIALIZED = 0 + IRC = 1 + NFC = 2 + STAND_BY = 3 + BUSY = 4 + + +def copyarray(dest, offset, src): + for i in range(len(src)): + dest[offset + i] = src[i] + + +class IrNfcMcu: + """ + TODO: cleanup + """ + + def __init__(self): + self._fw_major = [0, 3] + self._fw_minor = [0, 5] + + self._bytes = [0] * 313 + + self._action = Action.NON + self._state = McuState.NOT_INITIALIZED + + self._nfc_content = None + + def get_fw_major(self): + return self._fw_major + + def get_fw_minor(self): + return self._fw_minor + + def set_action(self, v): + self._action = v + + def get_action(self): + return self._action + + def set_state(self, v): + self._state = v + + def get_state(self): + return self._state + + def _get_state_byte(self): + if self.get_state() == McuState.NFC: + return 4 + elif self.get_state() == McuState.BUSY: + return 6 + elif self.get_state() == McuState.NOT_INITIALIZED: + return 1 + elif self.get_state() == McuState.STAND_BY: + return 1 + else: + return 0 + + def update_status(self): + self._bytes[0] = 1 + self._bytes[1] = 0 + self._bytes[2] = 0 + self._bytes[3] = self._fw_major[0] + self._bytes[4] = self._fw_major[1] + self._bytes[5] = self._fw_minor[0] + self._bytes[6] = self._fw_minor[1] + self._bytes[7] = self._get_state_byte() + + def update_nfc_report(self): + self._bytes = [0] * 313 + if self.get_action() == Action.REQUEST_STATUS: + self.update_status() + elif self.get_action() == Action.NON: + self._bytes[0] = 0xff + elif self.get_action() == Action.START_TAG_DISCOVERY: + self._bytes[0] = 0x2a + self._bytes[1] = 0 + self._bytes[2] = 5 + self._bytes[3] = 0 + self._bytes[4] = 0 + self._bytes[5] = 9 + self._bytes[6] = 0x31 + self._bytes[7] = 0 + elif self.get_action() == Action.START_TAG_POLLING: + self._bytes[0] = 0x2a + self._bytes[1] = 0 + self._bytes[2] = 5 + self._bytes[3] = 0 + self._bytes[4] = 0 + if self._nfc_content is not None: + data = [0x09, 0x31, 0x09, 0x00, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x07] + copyarray(self._bytes, 5, data) + copyarray(self._bytes, 5 + len(data), self._nfc_content[0:3]) + copyarray(self._bytes, 5 + len(data) + 3, self._nfc_content[4:8]) + else: + logger.info('nfc content is none') + self._bytes[5] = 9 + self._bytes[6] = 0x31 + self._bytes[7] = 0 + elif self.get_action() in (Action.READ_TAG, Action.READ_TAG_2): + self._bytes[0] = 0x3a + self._bytes[1] = 0 + self._bytes[2] = 7 + if self.get_action() == Action.READ_TAG: + data1 = bytes.fromhex('010001310200000001020007') + copyarray(self._bytes, 3, data1) + copyarray(self._bytes, 3 + len(data1), self._nfc_content[0:3]) + copyarray(self._bytes, 3 + len(data1) + 3, self._nfc_content[4:8]) + data2 = bytes.fromhex('000000007DFDF0793651ABD7466E39C191BABEB856CEEDF1CE44CC75EAFB27094D087AE803003B3C7778860000') + copyarray(self._bytes, 3 + len(data1) + 3 + 4, data2) + copyarray(self._bytes, 3 + len(data1) + 3 + 4 + len(data2), self._nfc_content[0:245]) + self.set_action(Action.READ_TAG_2) + else: + data = bytes.fromhex('02000927') + copyarray(self._bytes, 3, data) + copyarray(self._bytes, 3 + len(data), self._nfc_content[245:540]) + self.set_action(Action.READ_FINISHED) + elif self.get_action() == Action.READ_FINISHED: + self._bytes[0] = 0x2a + self._bytes[1] = 0 + self._bytes[2] = 5 + self._bytes[3] = 0 + self._bytes[4] = 0 + data = bytes.fromhex('0931040000000101020007') + copyarray(self._bytes, 5, data) + copyarray(self._bytes, 5 + len(data), self._nfc_content[0:3]) + copyarray(self._bytes, 5 + len(data) + 3, self._nfc_content[4:8]) + + crc = crc8() + crc.update(bytes(self._bytes[:-1])) + self._bytes[-1] = ord(crc.digest()) + + def set_nfc(self, nfc_content): + self._nfc_content = nfc_content + + def __bytes__(self): + return bytes(self._bytes) diff --git a/joycontrol/memory.py b/joycontrol/memory.py index 272cb50..50b7788 100644 --- a/joycontrol/memory.py +++ b/joycontrol/memory.py @@ -1,14 +1,28 @@ class FlashMemory: - def __init__(self, spi_flash_memory_data=None, size=0x80000): + def __init__(self, spi_flash_memory_data=None, default_stick_cal=False, size=0x80000): + """ + :param spi_flash_memory_data: data from a memory dump (can be created using dump_spi_flash.py). + :param default_stick_cal: If True, override stick calibration bytes with factory default + :param size of the memory dump, should be constant + """ if spi_flash_memory_data is None: - self.data = size * [0x00] - else: - if len(spi_flash_memory_data) != size: - raise ValueError(f'Given data size {len(spi_flash_memory_data)} does not match size {size}.') - if isinstance(spi_flash_memory_data, bytes): - spi_flash_memory_data = list(spi_flash_memory_data) - self.data = spi_flash_memory_data + spi_flash_memory_data = [0xFF] * size # Blank data is all 0xFF + default_stick_cal = True + + if len(spi_flash_memory_data) != size: + raise ValueError(f'Given data size {len(spi_flash_memory_data)} does not match size {size}.') + if isinstance(spi_flash_memory_data, bytes): + spi_flash_memory_data = list(spi_flash_memory_data) + + # set default controller stick calibration + if default_stick_cal: + # L-stick factory calibration + spi_flash_memory_data[0x603D:0x6046] = [0x00, 0x07, 0x70, 0x00, 0x08, 0x80, 0x00, 0x07, 0x70] + # R-stick factory calibration + spi_flash_memory_data[0x6046:0x604F] = [0x00, 0x08, 0x80, 0x00, 0x07, 0x70, 0x00, 0x07, 0x70] + + self.data = spi_flash_memory_data def __getitem__(self, item): return self.data[item] diff --git a/joycontrol/profile/sdp_record_hid.xml b/joycontrol/profile/sdp_record_hid.xml index 36e5c93..0b6acb6 100644 --- a/joycontrol/profile/sdp_record_hid.xml +++ b/joycontrol/profile/sdp_record_hid.xml @@ -1,50 +1,50 @@ - + - + - + - - + + - + - + - + - + - - + + - + - + - + diff --git a/joycontrol/protocol.py b/joycontrol/protocol.py index e41d155..7134f0d 100644 --- a/joycontrol/protocol.py +++ b/joycontrol/protocol.py @@ -1,12 +1,18 @@ import asyncio import logging +import time from asyncio import BaseTransport, BaseProtocol +from contextlib import suppress from typing import Optional, Union, Tuple, Text +from joycontrol import utils from joycontrol.controller import Controller from joycontrol.controller_state import ControllerState from joycontrol.memory import FlashMemory from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID +from joycontrol.transport import NotConnectedError +from joycontrol.ir_nfc_mcu import IrNfcMcu, McuState, Action +from crc8 import crc8 logger = logging.getLogger(__name__) @@ -17,6 +23,7 @@ def controller_protocol_factory(controller: Controller, spi_flash=None): def create_controller_protocol(): return ControllerProtocol(controller, spi_flash=spi_flash) + return create_controller_protocol @@ -27,23 +34,50 @@ class ControllerProtocol(BaseProtocol): self.transport = None - # Increases for each input report send, overflows at 0x100 + # Increases for each input report send, should overflow at 0x100 self._input_report_timer = 0x00 self._data_received = asyncio.Event() self._controller_state = ControllerState(self, controller, spi_flash=spi_flash) + self._controller_state_sender = None - self._0x30_input_report_sender = None + self._mcu = IrNfcMcu() + + # None = Just answer to sub commands + self._input_report_mode = None # This event gets triggered once the Switch assigns a player number to the controller and accepts user inputs self.sig_set_player_lights = asyncio.Event() + async def send_controller_state(self): + """ + Waits for the controller state to be send. + + Raises NotConnected exception if the transport is not connected or the connection was lost. + """ + # TODO: Call write directly if in continuously sending input report mode + + if self.transport is None: + raise NotConnectedError('Transport not registered.') + + self._controller_state.sig_is_send.clear() + + # wrap into a future to be able to set an exception in case of a disconnect + self._controller_state_sender = asyncio.ensure_future(self._controller_state.sig_is_send.wait()) + await self._controller_state_sender + self._controller_state_sender = None + async def write(self, input_report: InputReport): """ Sets timer byte and current button state in the input report and sends it. - Fires sig_is_send event afterwards. + Fires sig_is_send event in the controller state afterwards. + + Raises NotConnected exception if the transport is not connected or the connection was lost. """ + if self.transport is None: + raise NotConnectedError('Transport not registered.') + # set button and stick data of input report input_report.set_button_status(self._controller_state.button_state) if self._controller_state.l_stick_state is None: @@ -61,6 +95,7 @@ class ControllerProtocol(BaseProtocol): self._input_report_timer = (self._input_report_timer + 1) % 0x100 await self.transport.write(input_report) + self._controller_state.sig_is_send.set() def get_controller_state(self) -> ControllerState: @@ -68,7 +103,7 @@ class ControllerProtocol(BaseProtocol): async def wait_for_output_report(self): """ - Blocks until an output report from the Switch is received. + Waits until an output report from the Switch is received. """ self._data_received.clear() await self._data_received.wait() @@ -77,72 +112,110 @@ class ControllerProtocol(BaseProtocol): logger.debug('Connection established.') self.transport = transport - def connection_lost(self, exc: Optional[Exception]) -> None: - raise NotImplementedError() + def connection_lost(self, exc: Optional[Exception] = None) -> None: + if self.transport is not None: + logger.error('Connection lost.') + asyncio.ensure_future(self.transport.close()) + self.transport = None + + if self._controller_state_sender is not None: + self._controller_state_sender.set_exception(NotConnectedError) def error_received(self, exc: Exception) -> None: + # TODO? raise NotImplementedError() - async def input_report_mode_0x30(self): + async def input_report_mode_full(self): """ - Continuously sends 0x30 input reports containing the controller state. + Continuously sends: + 0x30 input reports containing the controller state OR + 0x31 input reports containing the controller state and nfc data """ if self.transport.is_reading(): - raise ValueError('Transport must be paused in 0x30 input report mode') + raise ValueError('Transport must be paused in full input report mode') + + # send state at 66Hz + send_delay = 0.015 + await asyncio.sleep(send_delay) + last_send_time = time.time() input_report = InputReport() - input_report.set_input_report_id(0x30) input_report.set_vibrator_input() input_report.set_misc() + if self._input_report_mode is None: + raise ValueError('Input report mode is not set.') + input_report.set_input_report_id(self._input_report_mode) reader = asyncio.ensure_future(self.transport.read()) - while True: - if self.controller == Controller.PRO_CONTROLLER: - # send state at 120Hz - await asyncio.sleep(1 / 120) - else: - # send state at 60Hz - await asyncio.sleep(1 / 60) + try: + while True: + reply_send = False + if reader.done(): + data = await reader - reply_send = False - if reader.done(): - data = await reader - if not data: - # disconnect happened - logger.error('No data received (most likely due to a disconnect).') - break + reader = asyncio.ensure_future(self.transport.read()) - reader = asyncio.ensure_future(self.transport.read()) + try: + report = OutputReport(list(data)) + output_report_id = report.get_output_report_id() - try: - report = OutputReport(list(data)) - output_report_id = report.get_output_report_id() + if output_report_id == OutputReportID.RUMBLE_ONLY: + # TODO + pass + elif output_report_id == OutputReportID.SUB_COMMAND: + reply_send = await self._reply_to_sub_command(report) + elif output_report_id == OutputReportID.REQUEST_IR_NFC_MCU: + # TODO: This does not reply anything + # reply_send = await self._reply_to_ir_nfc_mcu(report) + await self._reply_to_ir_nfc_mcu(report) + else: + logger.warning(f'Report unknown output report "{output_report_id}" - IGNORE') + except ValueError as v_err: + logger.warning(f'Report parsing error "{v_err}" - IGNORE') + except NotImplementedError as err: + logger.warning(err) - if output_report_id == OutputReportID.RUMBLE_ONLY: - # TODO - pass - elif output_report_id == OutputReportID.SUB_COMMAND: - reply_send = await self._reply_to_sub_command(report) - except ValueError as v_err: - logger.warning(f'Report parsing error "{v_err}" - IGNORE') - except NotImplementedError as err: - logger.warning(err) + if reply_send: + # Hack: Adding a delay here to avoid flooding during pairing + await asyncio.sleep(0.3) + else: + # write 0x30 input report. + # TODO: set some sensor data + input_report.set_6axis_data() - if reply_send: - # Hack: Adding a delay here to avoid flooding during pairing - await asyncio.sleep(0.3) - else: - # write 0x30 input report. TODO: set some sensor data - input_report.set_6axis_data() - await self.write(input_report) + # set nfc data + if input_report.get_input_report_id() == 0x31: + self._mcu.set_nfc(self._controller_state.get_nfc()) + self._mcu.update_nfc_report() + input_report.set_ir_nfc_data(bytes(self._mcu)) + + await self.write(input_report) + + # calculate delay + current_time = time.time() + time_delta = time.time() - last_send_time + sleep_time = send_delay - time_delta + last_send_time = current_time + + if sleep_time < 0: + # logger.warning(f'Code is running {abs(sleep_time)} s too slow!') + sleep_time = 0 + + await asyncio.sleep(sleep_time) + + except NotConnectedError as err: + # Stop 0x30 input report mode if disconnected. + logger.error(err) + finally: + # cleanup + self._input_report_mode = None + # cancel the reader + with suppress(asyncio.CancelledError, NotConnectedError): + if reader.cancel(): + await reader async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: - if not data: - # disconnect happened - logger.error('No data received (most likely due to a disconnect).') - return - self._data_received.set() try: @@ -159,11 +232,55 @@ class ControllerProtocol(BaseProtocol): if output_report_id == OutputReportID.SUB_COMMAND: await self._reply_to_sub_command(report) - #elif output_report_id == OutputReportID.RUMBLE_ONLY: + # elif output_report_id == OutputReportID.RUMBLE_ONLY: # pass else: logger.warning(f'Output report {output_report_id} not implemented - ignoring') + async def _reply_to_ir_nfc_mcu(self, report): + """ + TODO: Cleanup + We aren't replying to anything here, do we need to? + """ + sub_command = report.data[11] + sub_command_data = report.data[12:] + + # logging.info(f'received output report - Request MCU sub command {sub_command}') + + if self._mcu.get_action() in (Action.READ_TAG, Action.READ_TAG_2, Action.READ_FINISHED): + return + + # Request mcu state + if sub_command == 0x01: + # input_report = InputReport() + # input_report.set_input_report_id(0x21) + # input_report.set_misc() + + # input_report.set_ack(0xA0) + # input_report.reply_to_subcommand_id(0x21) + + self._mcu.set_action(Action.REQUEST_STATUS) + # input_report.set_mcu(self._mcu) + + # await self.write(input_report) + # Send Start tag discovery + elif sub_command == 0x02: + # 0: Cancel all, 4: StartWaitingReceive + if sub_command_data[0] == 0x04: + self._mcu.set_action(Action.START_TAG_DISCOVERY) + # 1: Start polling + elif sub_command_data[0] == 0x01: + self._mcu.set_action(Action.START_TAG_POLLING) + # 2: stop polling + elif sub_command_data[0] == 0x02: + self._mcu.set_action(Action.NON) + elif sub_command_data[0] == 0x06: + self._mcu.set_action(Action.READ_TAG) + else: + logging.info(f'Unknown sub_command_data arg {sub_command_data}') + else: + logging.info(f'Unknown MCU sub command {sub_command}') + async def _reply_to_sub_command(self, report): # classify sub command try: @@ -214,7 +331,7 @@ class ControllerProtocol(BaseProtocol): else: logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring') return False - except Exception as err: + except NotImplementedError as err: logger.error(f'Failed to answer {sub_command} - {err}') return False return True @@ -264,7 +381,7 @@ class ControllerProtocol(BaseProtocol): size = sub_command_data[4] if self.spi_flash is not None: - spi_flash_data = self.spi_flash[offset: offset+size] + spi_flash_data = self.spi_flash[offset: offset + size] input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data) else: spi_flash_data = size * [0x00] @@ -273,33 +390,43 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) async def _command_set_input_report_mode(self, sub_command_data): - if sub_command_data[0] == 0x30: - logger.info('Setting input report mode to 0x30...') + if self._input_report_mode == sub_command_data[0]: + logger.warning(f'Already in input report mode {sub_command_data[0]} - ignoring request') - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - - input_report.set_ack(0x80) - input_report.reply_to_subcommand_id(0x03) - - await self.write(input_report) - - # start sending 0x30 input reports - if self._0x30_input_report_sender is None: - self.transport.pause_reading() - self._0x30_input_report_sender = asyncio.ensure_future(self.input_report_mode_0x30()) - - # create callback to check for exceptions - def callback(future): - try: - future.result() - except Exception as err: - logger.exception(err) - - self._0x30_input_report_sender.add_done_callback(callback) + # Start input report reader + if sub_command_data[0] in (0x30, 0x31): + new_reader = asyncio.ensure_future(self.input_report_mode_full()) else: logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request') + return + + # Replace the currently running reader with the input report mode sender, + # which will also handle incoming requests in the future + + self.transport.pause_reading() + + # We need to replace the reader in the future because this function was probably called by it + async def set_reader(): + await self.transport.set_reader(new_reader) + + logger.info(f'Setting input report mode to {hex(sub_command_data[0])}...') + self._input_report_mode = sub_command_data[0] + + self.transport.resume_reading() + + asyncio.ensure_future(set_reader()).add_done_callback( + utils.create_error_check_callback() + ) + + # Send acknowledgement + input_report = InputReport() + input_report.set_input_report_id(0x21) + input_report.set_misc() + + input_report.set_ack(0x80) + input_report.reply_to_subcommand_id(0x03) + + await self.write(input_report) async def _command_trigger_buttons_elapsed_time(self, sub_command_data): input_report = InputReport() @@ -347,10 +474,26 @@ class ControllerProtocol(BaseProtocol): input_report.set_ack(0xA0) input_report.reply_to_subcommand_id(SubCommand.SET_NFC_IR_MCU_CONFIG.value) - # TODO - data = [1, 0, 255, 0, 8, 0, 27, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200] + self._mcu.update_status() + data = list(bytes(self._mcu)[0:34]) + crc = crc8() + crc.update(bytes(data[:-1])) + checksum = crc.digest() + data[-1] = ord(checksum) + for i in range(len(data)): input_report.data[16+i] = data[i] + + # Set MCU mode cmd + if sub_command_data[1] == 0: + if sub_command_data[2] == 0: + self._mcu.set_state(McuState.STAND_BY) + elif sub_command_data[2] == 4: + self._mcu.set_state(McuState.NFC) + else: + logger.info(f"unknown mcu state {sub_command_data[2]}") + else: + logger.info(f"unknown mcu config command {sub_command_data}") await self.write(input_report) @@ -363,10 +506,13 @@ class ControllerProtocol(BaseProtocol): # 0x01 = Resume input_report.set_ack(0x80) input_report.reply_to_subcommand_id(SubCommand.SET_NFC_IR_MCU_STATE.value) + self._mcu.set_action(Action.NON) + self._mcu.set_state(McuState.STAND_BY) elif sub_command_data[0] == 0x00: # 0x00 = Suspend input_report.set_ack(0x80) input_report.reply_to_subcommand_id(SubCommand.SET_NFC_IR_MCU_STATE.value) + self._mcu.set_state(McuState.STAND_BY) else: raise NotImplementedError(f'Argument {sub_command_data[0]} of {SubCommand.SET_NFC_IR_MCU_STATE} ' f'not implemented.') diff --git a/joycontrol/report.py b/joycontrol/report.py index 0801592..849ac68 100644 --- a/joycontrol/report.py +++ b/joycontrol/report.py @@ -11,8 +11,7 @@ class InputReport: """ def __init__(self, data=None): if not data: - # TODO: not enough space for NFC/IR data input report - self.data = [0x00] * 51 + self.data = [0x00] * 364 # all input reports are prepended with 0xA1 self.data[0] = 0xA1 else: @@ -114,6 +113,14 @@ class InputReport: for i in range(14, 50): self.data[i] = 0x00 + def set_ir_nfc_data(self, data): + if 50 + len(data) > len(self.data): + raise ValueError('Too much data.') + + # write to data + for i in range(len(data)): + self.data[50 + i] = data[i] + def reply_to_subcommand_id(self, _id): if isinstance(_id, SubCommand): self.data[15] = _id.value @@ -196,8 +203,19 @@ class InputReport: return bytes(self.data[:51]) elif _id == 0x30: return bytes(self.data[:14]) + elif _id == 0x31: + return bytes(self.data[:363]) else: - return bytes(self.data) + return bytes(self.data[:51]) + + def __str__(self): + _id = f'Input {self.get_input_report_id():x}' + _info = '' + if self.get_input_report_id() == 0x21: + _info = self.get_reply_to_subcommand_id() + _bytes = ' '.join(f'{byte:x}' for byte in bytes(self)) + + return f'{_id} {_info}\n{_bytes}' class SubCommand(Enum): @@ -216,6 +234,7 @@ class SubCommand(Enum): class OutputReportID(Enum): SUB_COMMAND = 0x01 RUMBLE_ONLY = 0x10 + REQUEST_IR_NFC_MCU = 0x11 class OutputReport: @@ -349,3 +368,12 @@ class OutputReport: def __bytes__(self): return bytes(self.data) + + def __str__(self): + _id = f'Output {self.get_output_report_id()}' + _info = '' + if self.get_output_report_id() == OutputReportID.SUB_COMMAND: + _info = self.get_sub_command() + _bytes = ' '.join(f'{byte:x}' for byte in bytes(self)) + + return f'{_id} {_info}\n{_bytes}' diff --git a/joycontrol/server.py b/joycontrol/server.py index 1b27b04..7d7a212 100644 --- a/joycontrol/server.py +++ b/joycontrol/server.py @@ -2,6 +2,7 @@ import asyncio import logging import socket +import dbus import pkg_resources from joycontrol import utils @@ -15,61 +16,116 @@ logger = logging.getLogger(__name__) async def _send_empty_input_reports(transport): report = InputReport() - - while True: + for i in range(10): await transport.write(report) await asyncio.sleep(1) -async def create_hid_server(protocol_factory, ctl_psm, itr_psm, capture_file=None): - ctl_sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) - itr_sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) - - # for some reason we need to restart bluetooth here, the Switch does not connect to the sockets if we don't... - logger.info('Restarting bluetooth service...') - await utils.run_system_command('systemctl restart bluetooth.service') - await asyncio.sleep(1) - - ctl_sock.setblocking(False) - itr_sock.setblocking(False) - - ctl_sock.bind((socket.BDADDR_ANY, ctl_psm)) - itr_sock.bind((socket.BDADDR_ANY, itr_psm)) - - ctl_sock.listen(1) - itr_sock.listen(1) - +async def create_hid_server(protocol_factory, ctl_psm=17, itr_psm=19, device_id=None, reconnect_bt_addr=None, + capture_file=None): + """ + :param protocol_factory: Factory function returning a ControllerProtocol instance + :param ctl_psm: hid control channel port + :param itr_psm: hid interrupt channel port + :param device_id: ID of the bluetooth adapter. + Integer matching the digit in the hci* notation (e.g. hci0, hci1, ...) or + Bluetooth mac address in string notation of the adapter (e.g. "FF:FF:FF:FF:FF:FF"). + If None, choose any device. + Note: Selection of adapters may currently not work if the bluez "input" plugin is enabled. + :param reconnect_bt_addr: The Bluetooth address of the console that was previously connected. Defaults to None. + If None, a new hid server will be started for the initial paring. + Otherwise, the function assumes an initial pairing with the console was already done + and reconnects to the provided Bluetooth address. + :param capture_file: opened file to log incoming and outgoing messages + :returns transport for input reports and protocol which handles incoming output reports + """ protocol = protocol_factory() - hid = HidDevice() - # setting bluetooth adapter name and class to the device we wish to emulate - await hid.set_name(protocol.controller.device_name()) - await hid.set_class() + if reconnect_bt_addr is None: + ctl_sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + itr_sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + ctl_sock.setblocking(False) + itr_sock.setblocking(False) + ctl_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + itr_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + try: + hid = HidDevice(device_id=device_id) - logger.info('Advertising the Bluetooth SDP record...') - hid.register_sdp_record(PROFILE_PATH) - hid.discoverable() + ctl_sock.bind((hid.address, ctl_psm)) + itr_sock.bind((hid.address, itr_psm)) + except OSError as err: + logger.warning(err) + # If the ports are already taken, this probably means that the bluez "input" plugin is enabled. + logger.warning('Fallback: Restarting bluetooth due to incompatibilities with the bluez "input" plugin. ' + 'Disable the plugin to avoid issues. See https://github.com/mart1nro/joycontrol/issues/8.') + # HACK: To circumvent incompatibilities with the bluetooth "input" plugin, we need to restart Bluetooth here. + # The Switch does not connect to the sockets if we don't. + # For more info see: https://github.com/mart1nro/joycontrol/issues/8 + logger.info('Restarting bluetooth service...') + await utils.run_system_command('systemctl restart bluetooth.service') + await asyncio.sleep(1) - loop = asyncio.get_event_loop() - client_ctl, ctl_address = await loop.sock_accept(ctl_sock) - logger.info(f'Accepted connection at psm {ctl_psm} from {ctl_address}') - client_itr, itr_address = await loop.sock_accept(itr_sock) - logger.info(f'Accepted connection at psm {itr_psm} from {itr_address}') - assert ctl_address[0] == itr_address[0] + hid = HidDevice(device_id=device_id) - # stop advertising - hid.discoverable(False) + ctl_sock.bind((socket.BDADDR_ANY, ctl_psm)) + itr_sock.bind((socket.BDADDR_ANY, itr_psm)) - transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, 50, capture_file=capture_file) + ctl_sock.listen(1) + itr_sock.listen(1) + + hid.powered(True) + hid.pairable(True) + + # setting bluetooth adapter name and class to the device we wish to emulate + await hid.set_name(protocol.controller.device_name()) + await hid.set_class() + + logger.info('Advertising the Bluetooth SDP record...') + try: + HidDevice.register_sdp_record(PROFILE_PATH) + except dbus.exceptions.DBusException as dbus_err: + # Already registered (If multiple controllers are being emulated and this method is called consecutive times) + logger.debug(dbus_err) + + # start advertising + hid.discoverable() + + logger.info('Waiting for Switch to connect... Please open the "Change Grip/Order" menu.') + + loop = asyncio.get_event_loop() + client_ctl, ctl_address = await loop.sock_accept(ctl_sock) + logger.info(f'Accepted connection at psm {ctl_psm} from {ctl_address}') + client_itr, itr_address = await loop.sock_accept(itr_sock) + logger.info(f'Accepted connection at psm {itr_psm} from {itr_address}') + assert ctl_address[0] == itr_address[0] + + # stop advertising + hid.discoverable(False) + hid.pairable(False) + + else: + # Reconnection to reconnect_bt_addr + client_ctl = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + client_itr = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + client_ctl.connect((reconnect_bt_addr, ctl_psm)) + client_itr.connect((reconnect_bt_addr, itr_psm)) + client_ctl.setblocking(False) + client_itr.setblocking(False) + + # create transport for the established connection and activate the HID protocol + transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, client_ctl, 50, capture_file=capture_file) protocol.connection_made(transport) - # send some empty input reports until the switch decides to reply + # HACK: send some empty input reports until the Switch decides to reply future = asyncio.ensure_future(_send_empty_input_reports(transport)) await protocol.wait_for_output_report() + """ future.cancel() try: await future except asyncio.CancelledError: pass + """ - return transport, protocol + return protocol.transport, protocol diff --git a/joycontrol/transport.py b/joycontrol/transport.py index d2e9c27..f2ff62d 100644 --- a/joycontrol/transport.py +++ b/joycontrol/transport.py @@ -4,55 +4,107 @@ import struct import time from typing import Any -from joycontrol.report import InputReport +from joycontrol import utils logger = logging.getLogger(__name__) +class NotConnectedError(ConnectionResetError): + pass + + class L2CAP_Transport(asyncio.Transport): - def __init__(self, loop, protocol, l2cap_socket, read_buffer_size, capture_file=None) -> None: + def __init__(self, loop, protocol, itr_sock, ctr_sock, read_buffer_size, capture_file=None) -> None: + super(L2CAP_Transport, self).__init__() + self._loop = loop self._protocol = protocol - self._sock = l2cap_socket + self._itr_sock = itr_sock + self._ctr_sock = ctr_sock + self._read_buffer_size = read_buffer_size self._extra_info = { - 'peername': self._sock.getpeername(), - 'sockname': self._sock.getsockname(), - 'socket': self._sock + 'peername': self._itr_sock.getpeername(), + 'sockname': self._itr_sock.getsockname(), + 'socket': self._itr_sock } - self._read_thread = asyncio.ensure_future(self._reader()) - - # create callback to check for exceptions - def callback(future): - try: - future.result() - except Exception as err: - logger.exception(err) - - self._read_thread.add_done_callback(callback) - self._is_closing = False self._is_reading = asyncio.Event() - self._is_reading.set() - - self._input_report_timer = 0x00 self._capture_file = capture_file + # start underlying reader + self._read_thread = None + self._is_reading.set() + self.start_reader() + async def _reader(self): while True: - await self._is_reading.wait() + try: + data = await self.read() + except NotConnectedError: + self._read_thread = None + break - data = await self.read() + await self._protocol.report_received(data, self._itr_sock.getpeername()) - #logger.debug(f'received "{list(data)}"') - await self._protocol.report_received(data, self._sock.getpeername()) + def start_reader(self): + """ + Starts the transport reader which calls the protocols report_received function for every incoming message + """ + if self._read_thread is not None: + raise ValueError('Reader is already running.') + + self._read_thread = asyncio.ensure_future(self._reader()) + + # Create callback in case the reader is failing + callback = utils.create_error_check_callback(ignore=asyncio.CancelledError) + self._read_thread.add_done_callback(callback) + + async def set_reader(self, reader: asyncio.Future): + """ + Cancel the currently running reader and register the new one. + A reader is a coroutine that calls this transports 'read' function. + The 'read' function calls can be paused by calling pause_reading of this transport. + :param reader: future reader + """ + if self._read_thread is not None: + # cancel currently running reader + if self._read_thread.cancel(): + try: + await self._read_thread + except asyncio.CancelledError: + pass + + # Create callback for debugging in case the reader is failing + err_callback = utils.create_error_check_callback(ignore=asyncio.CancelledError) + reader.add_done_callback(err_callback) + + self._read_thread = reader + + def get_reader(self): + return self._read_thread async def read(self): - data = await self._loop.sock_recv(self._sock, self._read_buffer_size) + """ + Read data from the underlying socket. This function waits, + if reading is paused using the pause_reading function. + + :returns bytes + """ + await self._is_reading.wait() + data = await self._loop.sock_recv(self._itr_sock, self._read_buffer_size) + + # logger.debug(f'received "{list(data)}"') + + if not data: + # disconnect happened + logger.error('No data received.') + self._protocol.connection_lost() + raise NotConnectedError('No data received.') if self._capture_file is not None: # write data to log file @@ -66,17 +118,17 @@ class L2CAP_Transport(asyncio.Transport): """ :returns True if the reader is running """ - return self._is_reading.is_set() + return self._reader is not None and self._is_reading.is_set() def pause_reading(self) -> None: """ - Pauses the reader + Pauses any 'read' function calls. """ self._is_reading.clear() def resume_reading(self) -> None: """ - Resumes the reader + Resumes all 'read' function calls. """ self._is_reading.set() @@ -95,11 +147,21 @@ class L2CAP_Transport(asyncio.Transport): size = struct.pack('i', len(_bytes)) self._capture_file.write(_time + size + _bytes) - #logger.debug(f'sending "{_bytes}"') - await self._loop.sock_sendall(self._sock, _bytes) + # logger.debug(f'sending "{_bytes}"') + + try: + await self._loop.sock_sendall(self._itr_sock, _bytes) + except OSError as err: + logger.error(err) + self._protocol.connection_lost() + raise NotConnectedError(err) + except ConnectionResetError as err: + logger.error(err) + self._protocol.connection_lost() + raise err def abort(self) -> None: - super().abort() + raise NotImplementedError def get_extra_info(self, name: Any, default=None) -> Any: return self._extra_info.get(name, default) @@ -109,16 +171,21 @@ class L2CAP_Transport(asyncio.Transport): async def close(self): """ - Stops socket reader and closes socket + Stops reader and closes underlying socket """ - self._is_closing = True - self._read_thread.cancel() - # wait for reader to cancel - try: - await self._read_thread - except asyncio.CancelledError: - pass - self._sock.close() + if not self._is_closing: + # was not already closed + self._is_closing = True + if self._read_thread.cancel(): + # wait for reader to cancel + try: + await self._read_thread + except asyncio.CancelledError: + pass + + # interrupt connection should be closed first + self._itr_sock.close() + self._ctr_sock.close() def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: self._protocol = protocol diff --git a/joycontrol/utils.py b/joycontrol/utils.py index 4aa33c8..4eadc31 100644 --- a/joycontrol/utils.py +++ b/joycontrol/utils.py @@ -1,9 +1,42 @@ import asyncio import logging +from contextlib import contextmanager + +import hid logger = logging.getLogger(__name__) +class AsyncHID(hid.Device): + def __init__(self, *args, loop=asyncio.get_event_loop(), **kwargs): + super().__init__(*args, **kwargs) + self._loop = loop + + self._write_lock = asyncio.Lock() + self._read_lock = asyncio.Lock() + + async def read(self, size, timeout=None): + async with self._read_lock: + return await self._loop.run_in_executor(None, hid.Device.read, self, size, timeout) + + async def write(self, data): + async with self._write_lock: + return await self._loop.run_in_executor(None, hid.Device.write, self, data) + + +@contextmanager +def get_output(path=None, open_flags='wb', default=None): + """ + Context manager that open the file a path was given, otherwise returns default value. + """ + if path is not None: + file = open(path, open_flags) + yield file + file.close() + else: + yield default + + def get_bit(value, n): return (value >> n & 1) != 0 @@ -12,6 +45,25 @@ def flip_bit(value, n): return value ^ (1 << n) +def create_error_check_callback(ignore=None): + """ + Creates callback causing errors of a finished future to be raised. + Useful for debugging futures that are never awaited. + :param ignore: Any number of errors to ignore. + :returns callback which can be added to a future with future.add_done_callback(...) + """ + def callback(future): + if ignore: + try: + future.result() + except ignore: + # ignore suppressed errors + pass + else: + future.result() + return callback + + async def run_system_command(cmd): proc = await asyncio.create_subprocess_shell( cmd, diff --git a/run_controller_cli.py b/run_controller_cli.py index acc2adf..c90e506 100644 --- a/run_controller_cli.py +++ b/run_controller_cli.py @@ -1,30 +1,257 @@ +#!/usr/bin/env python3 + import argparse import asyncio import logging import os -from contextlib import contextmanager -from joycontrol import logging_default as log +from aioconsole import ainput + +from joycontrol import logging_default as log, utils from joycontrol.command_line_interface import ControllerCLI from joycontrol.controller import Controller +from joycontrol.controller_state import ControllerState, button_push from joycontrol.memory import FlashMemory from joycontrol.protocol import controller_protocol_factory from joycontrol.server import create_hid_server logger = logging.getLogger(__name__) +"""Emulates Switch controller. Opens joycontrol.command_line_interface to send button commands and more. -async def _main(controller, capture_file=None, spi_flash=None): - factory = controller_protocol_factory(controller, spi_flash=spi_flash) - transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file) +While running the cli, call "help" for an explanation of available commands. - controller_state = protocol.get_controller_state() +Usage: + run_controller_cli.py [--device_id | -d ] + [--spi_flash ] + [--reconnect_bt_addr | -r ] + [--log | -l ] + [--nfc ] + run_controller_cli.py -h | --help - cli = ControllerCLI(controller_state) - await cli.run() +Arguments: + controller Choose which controller to emulate. Either "JOYCON_R", "JOYCON_L" or "PRO_CONTROLLER" - logger.info('Stopping communication...') - await transport.close() +Options: + -d --device_id ID of the bluetooth adapter. Integer matching the digit in the hci* notation + (e.g. hci0, hci1, ...) or Bluetooth mac address of the adapter in string + notation (e.g. "FF:FF:FF:FF:FF:FF"). + Note: Selection of adapters may not work if the bluez "input" plugin is + enabled. + + --spi_flash Memory dump of a real Switch controller. Required for joystick emulation. + Allows displaying of JoyCon colors. + Memory dumps can be created using the dump_spi_flash.py script. + + -r --reconnect_bt_addr Previously connected Switch console Bluetooth address in string + notation (e.g. "FF:FF:FF:FF:FF:FF") for reconnection. + Does not require the "Change Grip/Order" menu to be opened, + + -l --log Write hid communication (input reports and output reports) to a file. + + --nfc Sets the nfc data of the controller to a given nfc dump upon initial + connection. +""" + + +async def test_controller_buttons(controller_state: ControllerState): + """ + Example controller script. + Navigates to the "Test Controller Buttons" menu and presses all buttons. + """ + if controller_state.get_controller() != Controller.PRO_CONTROLLER: + raise ValueError('This script only works with the Pro Controller!') + + # waits until controller is fully connected + await controller_state.connect() + + await ainput(prompt='Make sure the Switch is in the Home menu and press to continue.') + + """ + # We assume we are in the "Change Grip/Order" menu of the switch + await button_push(controller_state, 'home') + + # wait for the animation + await asyncio.sleep(1) + """ + + # Goto settings + await button_push(controller_state, 'down', sec=1) + await button_push(controller_state, 'right', sec=2) + await asyncio.sleep(0.3) + await button_push(controller_state, 'left') + await asyncio.sleep(0.3) + await button_push(controller_state, 'a') + await asyncio.sleep(0.3) + + # go all the way down + await button_push(controller_state, 'down', sec=4) + await asyncio.sleep(0.3) + + # goto "Controllers and Sensors" menu + for _ in range(2): + await button_push(controller_state, 'up') + await asyncio.sleep(0.3) + await button_push(controller_state, 'right') + await asyncio.sleep(0.3) + + # go all the way down + await button_push(controller_state, 'down', sec=3) + await asyncio.sleep(0.3) + + # goto "Test Input Devices" menu + await button_push(controller_state, 'up') + await asyncio.sleep(0.3) + await button_push(controller_state, 'a') + await asyncio.sleep(0.3) + + # goto "Test Controller Buttons" menu + await button_push(controller_state, 'a') + await asyncio.sleep(0.3) + + # push all buttons except home and capture + button_list = controller_state.button_state.get_available_buttons() + if 'capture' in button_list: + button_list.remove('capture') + if 'home' in button_list: + button_list.remove('home') + + user_input = asyncio.ensure_future( + ainput(prompt='Pressing all buttons... Press to stop.') + ) + + # push all buttons consecutively until user input + while not user_input.done(): + for button in button_list: + await button_push(controller_state, button) + await asyncio.sleep(0.1) + + if user_input.done(): + break + + # await future to trigger exceptions in case something went wrong + await user_input + + # go back to home + await button_push(controller_state, 'home') + + +async def set_nfc(controller_state, file_path): + """ + Sets nfc content of the controller state to contents of the given file. + :param controller_state: Emulated controller state + :param file_path: Path to nfc dump file + """ + loop = asyncio.get_event_loop() + + with open(file_path, 'rb') as nfc_file: + content = await loop.run_in_executor(None, nfc_file.read) + controller_state.set_nfc(content) + + +async def mash_button(controller_state, button, interval): + # waits until controller is fully connected + await controller_state.connect() + + if button not in controller_state.button_state.get_available_buttons(): + raise ValueError(f'Button {button} does not exist on {controller_state.get_controller()}') + + user_input = asyncio.ensure_future( + ainput(prompt=f'Pressing the {button} button every {interval} seconds... Press to stop.') + ) + # push a button repeatedly until user input + while not user_input.done(): + await button_push(controller_state, button) + await asyncio.sleep(float(interval)) + + # await future to trigger exceptions in case something went wrong + await user_input + + +async def _main(args): + # parse the spi flash + if args.spi_flash: + with open(args.spi_flash, 'rb') as spi_flash_file: + spi_flash = FlashMemory(spi_flash_file.read()) + else: + # Create memory containing default controller stick calibration + spi_flash = FlashMemory() + + # Get controller name to emulate from arguments + controller = Controller.from_arg(args.controller) + + with utils.get_output(path=args.log, default=None) as capture_file: + factory = controller_protocol_factory(controller, spi_flash=spi_flash) + ctl_psm, itr_psm = 17, 19 + transport, protocol = await create_hid_server(factory, reconnect_bt_addr=args.reconnect_bt_addr, + ctl_psm=ctl_psm, + itr_psm=itr_psm, capture_file=capture_file, + device_id=args.device_id) + + controller_state = protocol.get_controller_state() + + # Create command line interface and add some extra commands + cli = ControllerCLI(controller_state) + + # Wrap the script so we can pass the controller state. The doc string will be printed when calling 'help' + async def _run_test_controller_buttons(): + """ + test_buttons - Navigates to the "Test Controller Buttons" menu and presses all buttons. + """ + await test_controller_buttons(controller_state) + + # add the script from above + cli.add_command('test_buttons', _run_test_controller_buttons) + + # Mash a button command + async def call_mash_button(*args): + """ + mash - Mash a specified button at a set interval + + Usage: + mash