From 19eb051726e31efbf3a7a2e991dcf7f640be6964 Mon Sep 17 00:00:00 2001 From: Robert Martin Date: Fri, 31 Jan 2020 20:09:17 +0900 Subject: [PATCH] Implemented goto "Test Controller Buttons" menu --- button_state.py | 75 ++++++++++++++++++++++++++++++++ buttons.py | 76 -------------------------------- controller_state.py | 38 ++++++++++++++++ protocol.py | 75 ++++++++++++++------------------ report.py | 36 ++++++++++++--- run_and_pair_switch.py | 99 +++++++++++++++++++++++++++++------------- transport.py | 5 ++- utils.py | 8 ++++ 8 files changed, 256 insertions(+), 156 deletions(-) create mode 100644 button_state.py delete mode 100644 buttons.py create mode 100644 controller_state.py diff --git a/button_state.py b/button_state.py new file mode 100644 index 0000000..37b14fd --- /dev/null +++ b/button_state.py @@ -0,0 +1,75 @@ +import utils + + +class ButtonState: + """ + Utility class to set buttons in the input report + https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_notes.md + Byte 0 1 2 3 4 5 6 7 + 1 Y X B A SR SL R ZR + 2 Minus Plus R Stick L Stick Home Capture + 3 Down Up Right Left SR SL L ZL + """ + def __init__(self): + # 3 bytes + self._byte_1 = 0 + self._byte_2 = 0 + self._byte_3 = 0 + + # generating methods for each button + def button_method_factory(byte, bit): + def flip(): + setattr(self, byte, utils.flip_bit(getattr(self, byte), bit)) + + def getter(): + return utils.get_bit(getattr(self, byte), bit) + return flip, getter + + # byte 1 + self.y, self.y_is_set = button_method_factory('_byte_1', 0) + self.x, self.x_is_set = button_method_factory('_byte_1', 1) + self.b, self.b_is_set = button_method_factory('_byte_1', 2) + self.a, self.a_is_set = button_method_factory('_byte_1', 3) + self.right_sr, self.right_sr_is_set = button_method_factory('_byte_1', 4) + self.right_sl, self.right_sl_is_set = button_method_factory('_byte_1', 5) + self.r, self.r_is_set = button_method_factory('_byte_1', 6) + self.zr, self.zr_is_set = button_method_factory('_byte_1', 7) + + # byte 2 + self.minus, self.minus_is_set = button_method_factory('_byte_2', 0) + self.plus, self.plus_is_set = button_method_factory('_byte_2', 1) + self.r_stick, self.r_stick_is_set = button_method_factory('_byte_2', 2) + self.l_stick, self.l_stick_is_set = button_method_factory('_byte_2', 3) + self.home, self.home_is_set = button_method_factory('_byte_2', 4) + self.capture, self.capture_is_set = button_method_factory('_byte_2', 5) + + # byte 3 + self.down, self.down_is_set = button_method_factory('_byte_3', 0) + self.up, self.up_is_set = button_method_factory('_byte_3', 1) + self.right, self.right_is_set = button_method_factory('_byte_3', 2) + self.left, self.left_is_set = button_method_factory('_byte_3', 3) + self.left_sr, self.left_sr_is_set = button_method_factory('_byte_3', 4) + self.left_sl, self.left_sl_is_set = button_method_factory('_byte_3', 5) + self.l, self.l_is_set = button_method_factory('_byte_3', 6) + self.zl, self.zl_is_set = button_method_factory('_byte_3', 7) + + """ + Example for generated methods: home button (byte_2, 4) + + def home(self): + self.byte_2 = flip_bit(self.byte_2, 4) + + def home_is_set(self): + return get_bit(self.byte_2, 4) + """ + + def __iter__(self): + """ + @returns iterator of the button bytes + """ + yield self._byte_1 + yield self._byte_2 + yield self._byte_3 + + def clear(self): + self._byte_1 = self._byte_2 = self._byte_3 = 0 \ No newline at end of file diff --git a/buttons.py b/buttons.py deleted file mode 100644 index b7e7267..0000000 --- a/buttons.py +++ /dev/null @@ -1,76 +0,0 @@ - -def get_bit(value, n): - return (value >> n & 1) != 0 - - -def flip_bit(value, n): - return value ^ (1 << n) - - -class Buttons: - """ - Utility class to set buttons in the input report - https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_notes.md - Byte 0 1 2 3 4 5 6 7 - 1 Y X B A SR SL R ZR - 2 Minus Plus R Stick L Stick Home Capture - 3 Down Up Right Left SR SL L ZL - """ - def __init__(self): - # 3 bytes - self.byte_1 = 0 - self.byte_2 = 0 - self.byte_3 = 0 - - # generating methods for each button - def button_method_factory(byte, bit): - def flip(): - setattr(self, byte, flip_bit(getattr(self, byte), bit)) - - def getter(): - return get_bit(getattr(self, byte), bit) - return flip, getter - - # byte 1 - self.y, self.y_is_set = button_method_factory('byte_1', 0) - self.x, self.x_is_set = button_method_factory('byte_1', 1) - self.b, self.b_is_set = button_method_factory('byte_1', 2) - self.a, self.a_is_set = button_method_factory('byte_1', 3) - self.right_sr, self.right_sr_is_set = button_method_factory('byte_1', 4) - self.right_sl, self.right_sl_is_set = button_method_factory('byte_1', 5) - self.r, self.r_is_set = button_method_factory('byte_1', 6) - self.zr, self.zr_is_set = button_method_factory('byte_1', 7) - - # byte 2 - self.minus, self.minus_is_set = button_method_factory('byte_2', 0) - self.plus, self.plus_is_set = button_method_factory('byte_2', 1) - self.r_stick, self.r_stick_is_set = button_method_factory('byte_2', 2) - self.l_stick, self.l_stick_is_set = button_method_factory('byte_2', 3) - self.home, self.home_is_set = button_method_factory('byte_2', 4) - self.capture, self.capture_is_set = button_method_factory('byte_2', 5) - - # byte 3 - self.down, self.down_is_set = button_method_factory('byte_3', 0) - self.up, self.up_is_set = button_method_factory('byte_3', 1) - self.right, self.right_is_set = button_method_factory('byte_3', 2) - self.left, self.left_is_set = button_method_factory('byte_3', 3) - self.left_sr, self.left_sr_is_set = button_method_factory('byte_3', 4) - self.left_sl, self.left_sl_is_set = button_method_factory('byte_3', 5) - self.l, self.l_is_set = button_method_factory('byte_3', 6) - self.zl, self.zl_is_set = button_method_factory('byte_3', 7) - - """ - Example for generated methods: home button (byte_2, 4) - - def home(self): - self.byte_2 = flip_bit(self.byte_2, 4) - - def home_is_set(self): - return get_bit(self.byte_2, 4) - """ - - def to_list(self): - return [self.byte_1, self.byte_2, self.byte_3] - - def clear(self): - self.byte_1 = self.byte_2 = self.byte_3 = 0 diff --git a/controller_state.py b/controller_state.py new file mode 100644 index 0000000..fc759b7 --- /dev/null +++ b/controller_state.py @@ -0,0 +1,38 @@ +import asyncio + +from button_state import ButtonState +from protocol import ControllerProtocol + + +class ControllerState: + def __init__(self, transport: asyncio.Transport, protocol: ControllerProtocol): + super().__init__() + self.transport = transport + + self.protocol = protocol + + async def send(self): + await self.protocol.button_input_report.write(self.transport) + + async def connect(self): + """ + Waits until the switch is paired with the controller and accepts button commands + """ + # TODO HACK: Hard to say for now. + await self.protocol.wait_for_output_report() + # The switch sends data to our device, it shouldn't take long until the connection is fully established. + await asyncio.sleep(5) + + def set_button_state(self, button_state: ButtonState): + """ + Sets the button status bytes in the input report + """ + self.protocol.button_input_report.set_button_status(button_state) + + def set_stick_state(self): + """ + TODO + """ + raise NotImplementedError() + + diff --git a/protocol.py b/protocol.py index c0d09ad..06748e0 100644 --- a/protocol.py +++ b/protocol.py @@ -21,6 +21,11 @@ class ControllerProtocol(BaseProtocol): self.transport = None + # This must always be an 0x21 input report to be compatible with button events + self.button_input_report = InputReport() + self.button_input_report.set_input_report_id(0x21) + self.button_input_report.set_misc() + self._data_received = asyncio.Event() async def wait_for_output_report(self): @@ -49,7 +54,9 @@ class ControllerProtocol(BaseProtocol): # classify sub command sc_byte, sub_command = report.get_sub_command() logging.info(f'received output report - {sub_command}') - if sub_command == SubCommand.REQUEST_DEVICE_INFO: + if sub_command is None: + logger.warning(f'Received output report does not contain a sub command') + elif sub_command == SubCommand.REQUEST_DEVICE_INFO: await self._command_request_device_info(report) elif sub_command == SubCommand.SET_SHIPMENT_STATE: @@ -65,67 +72,51 @@ class ControllerProtocol(BaseProtocol): await self._command_trigger_buttons_elapsed_time(report) elif sub_command == SubCommand.NOT_IMPLEMENTED: - logger.error(f'Sub command 0x{sc_byte:02x} not implemented - ignoring') + logger.warning(f'Sub command 0x{sc_byte:02x} not implemented - ignoring') async def _command_request_device_info(self, output_report): address = self.transport.get_extra_info('sockname') assert address is not None bd_address = list(map(lambda x: int(x, 16), address[0].split(':'))) - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - input_report.set_ack(0x82) - #input_report.set_button_status() - #input_report.set_left_analog_stick() - #input_report.set_right_analog_stick() - #input_report.set_vibrator_input() - input_report.sub_0x02_device_info(bd_address, controller=self.controller) + self.button_input_report.set_misc() + self.button_input_report.set_ack(0x82) + self.button_input_report.sub_0x02_device_info(bd_address, controller=self.controller) - asyncio.ensure_future(self.transport.write(input_report)) + await self.button_input_report.write(self.transport) async def _command_set_shipment_state(self, output_report): - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - input_report.set_ack(0x80) - input_report.sub_0x08_shipment() + self.button_input_report.set_misc() + self.button_input_report.set_ack(0x80) + self.button_input_report.sub_0x08_shipment() - asyncio.ensure_future(self.transport.write(input_report)) + await self.button_input_report.write(self.transport) async def _command_spi_flash_read(self, output_report): - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - input_report.set_ack(0x90) - input_report.sub_0x10_spi_flash_read(output_report) + self.button_input_report.set_misc() + self.button_input_report.set_ack(0x90) + self.button_input_report.sub_0x10_spi_flash_read(output_report) - asyncio.ensure_future(self.transport.write(input_report)) + await self.button_input_report.write(self.transport) async def _command_set_input_report_mode(self, output_report): - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - input_report.set_ack(0x80) - input_report.sub_0x03_set_input_report_mode() + self.button_input_report.set_misc() + self.button_input_report.set_ack(0x80) + self.button_input_report.sub_0x03_set_input_report_mode() - asyncio.ensure_future(self.transport.write(input_report)) + await self.button_input_report.write(self.transport) async def _command_trigger_buttons_elapsed_time(self, output_report): - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - input_report.set_ack(0x83) - input_report.sub_0x04_trigger_buttons_elapsed_time() + self.button_input_report.set_misc() + self.button_input_report.set_ack(0x83) + self.button_input_report.sub_0x04_trigger_buttons_elapsed_time() - asyncio.ensure_future(self.transport.write(input_report)) + await self.button_input_report.write(self.transport) async def _enable_6axis_sensor(self, output_report): - input_report = InputReport() - input_report.set_input_report_id(0x21) - input_report.set_misc() - input_report.set_ack(0x80) + self.button_input_report.set_misc() + self.button_input_report.set_ack(0x80) - input_report.reply_to_subcommand_id(0x40) + self.button_input_report.reply_to_subcommand_id(0x40) - asyncio.ensure_future(self.transport.write(input_report)) + await self.button_input_report.write(self.transport) diff --git a/report.py b/report.py index 2ca619f..21fd3a4 100644 --- a/report.py +++ b/report.py @@ -1,5 +1,7 @@ +import asyncio from enum import Enum +from button_state import ButtonState from controller import Controller @@ -9,10 +11,19 @@ class InputReport: https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_notes.md """ def __init__(self): - self.data = [0x00] * 50 + self.data = [0x00] * 51 # all input reports are prepended with 0xA1 self.data[0] = 0xA1 + self.subcommand_is_set = False + + self.is_writing = None + + def clear_sub_command(self): + for i in range(14, 51): + self.data[i] = 0x00 + self.subcommand_is_set = False + def set_input_report_id(self, _id): """ :param _id: e.g. 0x21 Standard input reports used for sub command replies @@ -30,11 +41,11 @@ class InputReport: # battery level + connection info self.data[3] = 0x8E - def set_button_status(self): + def set_button_status(self, button_status: ButtonState): """ - TODO + Sets the button status bytes """ - self.data[4:7] = [0x84, 0x00, 0x12] + self.data[4:7] = iter(button_status) def set_left_analog_stick(self): """ @@ -74,8 +85,7 @@ class InputReport: elif len(mac) != 6: raise ValueError('Bluetooth mac address must consist of 6 bytes!') - # reply to sub command ID - self.data[15] = 0x02 + self.reply_to_subcommand_id(0x02) # sub command reply data offset = 16 @@ -87,6 +97,7 @@ class InputReport: self.data[offset + 11] = 0x01 def reply_to_subcommand_id(self, id_): + self.subcommand_is_set = True self.data[15] = id_ def sub_0x08_shipment(self): @@ -106,8 +117,17 @@ class InputReport: blub = [0x00, 0xCC, 0x00, 0xEE, 0x00, 0xFF] self.data[16:22] = blub + async def write(self, transport): + if self.is_writing is None: + self.is_writing = asyncio.ensure_future(transport.write(self)) + await self.is_writing + self.is_writing = None + def __bytes__(self): - return bytes(self.data) + if self.subcommand_is_set: + return bytes(self.data) + else: + return bytes(self.data[:15]) class SubCommand(Enum): @@ -127,6 +147,8 @@ class OutputReport: self.data = data def get_sub_command(self): + if len(self.data) < 12: + return None, None try: return self.data[11], SubCommand(self.data[11]) except ValueError: diff --git a/run_and_pair_switch.py b/run_and_pair_switch.py index 5d62269..a7b933b 100644 --- a/run_and_pair_switch.py +++ b/run_and_pair_switch.py @@ -5,7 +5,7 @@ import socket import logging_default as log import utils -from buttons import Buttons +from controller_state import ButtonState, ControllerState from device import HidDevice from protocol import controller_protocol_factory, Controller from report import InputReport @@ -64,37 +64,80 @@ async def send_empty_input_reports(transport): await asyncio.sleep(1) -async def mash_buttons(transport): +async def button_push(controller_state, button, sec=0.1): + button_state = ButtonState() + + # push button + getattr(button_state, button)() + + # send report + controller_state.set_button_state(button_state) + await controller_state.send() + await asyncio.sleep(sec) + + # release button + getattr(button_state, button)() + + # send report + controller_state.set_button_state(button_state) + await controller_state.send() + + +async def test_controller_buttons(controller_state: ControllerState): + """ + Goes to the "Test Controller Buttons" menu and presses all buttons + """ + await controller_state.connect() + + # We assume we are in the "Change Grip/Order" menu of the switch + await button_push(controller_state, 'home') + + # wait for the animation + await asyncio.sleep(1) + + # Goto settings + await button_push(controller_state, 'down') + await asyncio.sleep(0.3) + for _ in range(4): + await button_push(controller_state, 'right') + await asyncio.sleep(0.3) + await button_push(controller_state, 'a') + await asyncio.sleep(0.3) + + # go all the way down + await button_push(controller_state, 'down', sec=3) + await asyncio.sleep(0.3) + + # goto "Controllers and Sensors" menu + for _ in range(2): + await button_push(controller_state, 'up') + await asyncio.sleep(0.3) + await button_push(controller_state, 'right') + await asyncio.sleep(0.3) + + # go all the way down + await button_push(controller_state, 'down', sec=3) + await asyncio.sleep(0.3) + + # goto "Test Input Devices" menu + await button_push(controller_state, 'up') + await asyncio.sleep(0.3) + await button_push(controller_state, 'a') + await asyncio.sleep(0.3) + + # goto "Test Controller Buttons" menu + await button_push(controller_state, 'a') + await asyncio.sleep(0.3) + + # push all buttons button_list = ['y', 'x', 'b', 'a', 'r', 'zr', 'minus', 'plus', 'r_stick', 'l_stick', 'down', 'up', 'right', 'left', 'l', 'zl'] - - report = InputReport() - report.set_input_report_id(0x21) - report.set_misc() - - buttons = Buttons() - for i in range(10): for button in button_list: - logger.info(f'Pressing Button {button}...') - - # push button - getattr(buttons, button)() - - # send report - report.data[4:7] = buttons.to_list() - await transport.write(report) + await button_push(controller_state, button) await asyncio.sleep(0.1) - # release button - getattr(buttons, button)() - - # send report - report.data[4:7] = buttons.to_list() - await transport.write(report) - await asyncio.sleep(0.3) - async def main(): transport, protocol = await create_hid_server(controller_protocol_factory(Controller.PRO_CONTROLLER), 17, 19) @@ -108,12 +151,8 @@ async def main(): except asyncio.CancelledError: pass - await asyncio.sleep(20) + await test_controller_buttons(ControllerState(transport, protocol)) - await mash_buttons(transport) - - # stop communication after some time - await asyncio.sleep(60) logger.info('Stopping communication...') await transport.close() diff --git a/transport.py b/transport.py index 577b86b..18b8176 100644 --- a/transport.py +++ b/transport.py @@ -33,7 +33,7 @@ class L2CAP_Transport(asyncio.Transport): await self._is_reading.wait() data = await self._loop.sock_recv(self._sock, self._read_buffer_size) - logger.debug(f'received "{data}"') + logger.debug(f'received "{list(map(hex, list(data)))}"') await self._protocol.report_received(data, self._sock.getpeername()) def is_reading(self) -> bool: @@ -62,6 +62,9 @@ class L2CAP_Transport(asyncio.Transport): data.set_timer(self._input_report_timer) self._input_report_timer = (self._input_report_timer + 1) % 256 _bytes = bytes(data) + + if data.subcommand_is_set: + data.clear_sub_command() else: raise ValueError('data must be bytes or InputReport') diff --git a/utils.py b/utils.py index d304fb3..a0045f3 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,14 @@ import re logger = logging.getLogger(__name__) +def get_bit(value, n): + return (value >> n & 1) != 0 + + +def flip_bit(value, n): + return value ^ (1 << n) + + async def run_system_command(cmd): proc = await asyncio.create_subprocess_shell( cmd,