diff --git a/dump_spi_flash.py b/dump_spi_flash.py new file mode 100644 index 0000000..74601ef --- /dev/null +++ b/dump_spi_flash.py @@ -0,0 +1,196 @@ +import argparse +import asyncio +import logging +import os +from contextlib import suppress + +import hid + +from joycontrol import logging_default as log +from joycontrol.report import OutputReport, InputReport, SubCommand + +logger = logging.getLogger(__name__) + + +VENDOR_ID = 1406 +PRODUCT_ID_JL = 8198 +PRODUCT_ID_JR = 8199 + + +class AsyncHID(hid.Device): + def __init__(self, *args, loop=asyncio.get_event_loop(), **kwargs): + super().__init__(*args, **kwargs) + self._loop = loop + + self._write_lock = asyncio.Lock() + self._read_lock = asyncio.Lock() + + async def read(self, size, timeout=None): + async with self._read_lock: + return await self._loop.run_in_executor(None, hid.Device.read, self, size, timeout) + + async def write(self, data): + async with self._write_lock: + return await self._loop.run_in_executor(None, hid.Device.write, self, data) + + +class DataReader: + def __init__(self): + self.pending_request = None + self.timer = 0 + self._stop_reading = False + + def close(self): + self._stop_reading = True + + async def send_spi_read_request(self, hid_device, offset, size): + report = OutputReport() + report.sub_0x10_spi_flash_read(offset, size) + + # event shall be set if data received + reply_event = asyncio.Event() + self.pending_request = (offset, size, reply_event) + + # send spi flash read request + while True: + report.set_timer(self.timer) + self.timer += 1 + + # remove 0xA2 output report padding byte since it's not needed for communication over hid library + data = report.data[1:] + await hid_device.write(bytes(data)) + + # wait for data received, send again if time out occurs (1 sec) + try: + await asyncio.wait_for(reply_event.wait(), 1) + self.pending_request = None + break + except asyncio.TimeoutError: + continue + + async def receive_data(self, hid_device, output_file=None): + while True: + data = await hid_device.read(size=255, timeout=3) + if self._stop_reading: + break + elif not data: + continue + + # add byte for input report + data = b'\xa1' + data + + input_report = InputReport(list(data)) + + # check if input report is spi flash read reply + if input_report.get_input_report_id() != 0x21: + continue + try: + sub_command_id = input_report.get_reply_to_subcommand_id() + if sub_command_id != SubCommand.SPI_FLASH_READ: + continue + except NotImplementedError: + continue + + assert input_report.get_ack() == 0x90 + + reply = input_report.get_sub_command_reply_data() + + # parse offset + offset = 0 + digit = 1 + for i in range(4): + offset += reply[i] * digit + digit *= 0x100 + + size = reply[4] + + # parse spi flash data + assert len(reply) >= 5+size + spi_data = reply[5:5+size] + + # check if received data is currently requested + if self.pending_request is None or self.pending_request[0] != offset or self.pending_request[1] != size: + continue + + # notify spi request sender that the data is received + self.pending_request[2].set() + + logger.info(f'received offset {offset}, size {size} - {spi_data}') + + # write data to file + if output_file is not None: + output_file.write(bytes(spi_data)) + + +async def dumb_spi_flash(hid_device, output_file=None): + SPI_FLASH_SIZE = 0x80000 + + spi_flash_reader = DataReader() + reader = asyncio.ensure_future(spi_flash_reader.receive_data(hid_device, output_file=output_file)) + + try: + # read data in 0x1D chunks + for i in range(SPI_FLASH_SIZE // 0x1D): + await spi_flash_reader.send_spi_read_request(hid_device, i * 0x1D, 0x1D) + + remainder = SPI_FLASH_SIZE % 0x1D + if remainder: + await spi_flash_reader.send_spi_read_request(hid_device, SPI_FLASH_SIZE - 1 - remainder, remainder) + except asyncio.CancelledError: + pass + finally: + spi_flash_reader.close() + # wait for reader to close + await reader + + +async def _main(args, loop): + logger.info('Waiting for HID devices... Please connect JoyCon over bluetooth.') + + controller = None + while controller is None: + for device in hid.enumerate(0, 0): + # looking for devices matching Nintendo's vendor id and JoyCon product id + if device['vendor_id'] == VENDOR_ID and device['product_id'] in (PRODUCT_ID_JL, PRODUCT_ID_JR): + controller = device + break + else: + await asyncio.sleep(2) + + logger.info(f'Found controller "{controller}".') + + if args.output: + with open(args.output, 'wb') as output: + with AsyncHID(path=controller['path'], loop=loop) as hid_controller: + await dumb_spi_flash(hid_controller, output_file=output) + else: + with AsyncHID(path=controller['path'], loop=loop) as hid_controller: + await dumb_spi_flash(hid_controller) + + +if __name__ == '__main__': + # check if root + if not os.geteuid() == 0: + raise PermissionError('Script must be run as root!') + + parser = argparse.ArgumentParser() + parser.add_argument('-o', '--output') + args = parser.parse_args() + + # setup logging + log.configure() + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future(_main(args, loop)) + + try: + loop.run_until_complete(task) + except KeyboardInterrupt: + task.cancel() + with suppress(asyncio.CancelledError): + loop.run_until_complete(task) + finally: + loop.stop() + loop.close() + + diff --git a/joycontrol/protocol.py b/joycontrol/protocol.py index 259f092..19442cb 100644 --- a/joycontrol/protocol.py +++ b/joycontrol/protocol.py @@ -10,15 +10,19 @@ from joycontrol.report import OutputReport, SubCommand, InputReport, OutputRepor logger = logging.getLogger(__name__) -def controller_protocol_factory(controller: Controller): +def controller_protocol_factory(controller: Controller, spi_flash=None): def create_controller_protocol(): - return ControllerProtocol(controller) + return ControllerProtocol(controller, spi_flash=spi_flash) return create_controller_protocol class ControllerProtocol(BaseProtocol): - def __init__(self, controller: Controller): + def __init__(self, controller: Controller, spi_flash=None): self.controller = controller + if spi_flash is not None: + self.spi_flash = list(spi_flash) + else: + self.spi_flash = None self.transport = None @@ -34,7 +38,7 @@ class ControllerProtocol(BaseProtocol): self.sig_wait_player_lights = asyncio.Event() async def write(self, input_report: InputReport): - # set button and TODO: stick date + # set button and TODO: stick data if self._controller_state.button_state is not None: input_report.set_button_status(self._controller_state.button_state) self._controller_state.sig_is_send.set() @@ -106,33 +110,37 @@ class ControllerProtocol(BaseProtocol): raise ValueError('Received output report does not contain a sub command') logging.info(f'received output report - Sub command {sub_command}') + + sub_command_data = report.get_sub_command_data() + assert sub_command_data is not None + # answer to sub command if sub_command == SubCommand.REQUEST_DEVICE_INFO: - await self._command_request_device_info(report) + await self._command_request_device_info(sub_command_data) elif sub_command == SubCommand.SET_SHIPMENT_STATE: - await self._command_set_shipment_state(report) + await self._command_set_shipment_state(sub_command_data) elif sub_command == SubCommand.SPI_FLASH_READ: - await self._command_spi_flash_read(report) + await self._command_spi_flash_read(sub_command_data) elif sub_command == SubCommand.SET_INPUT_REPORT_MODE: - await self._command_set_input_report_mode(report) + await self._command_set_input_report_mode(sub_command_data) elif sub_command == SubCommand.TRIGGER_BUTTONS_ELAPSED_TIME: - await self._command_trigger_buttons_elapsed_time(report) + await self._command_trigger_buttons_elapsed_time(sub_command_data) elif sub_command == SubCommand.ENABLE_6AXIS_SENSOR: - await self._command_enable_6axis_sensor(report) + await self._command_enable_6axis_sensor(sub_command_data) elif sub_command == SubCommand.ENABLE_VIBRATION: - await self._command_enable_vibration(report) + await self._command_enable_vibration(sub_command_data) elif sub_command == SubCommand.SET_NFC_IR_MCU_CONFIG: - await self._command_set_nfc_ir_mcu_config(report) + await self._command_set_nfc_ir_mcu_config(sub_command_data) elif sub_command == SubCommand.SET_PLAYER_LIGHTS: - await self._command_set_player_lights(report) + await self._command_set_player_lights(sub_command_data) else: logger.warning(f'Sub command 0x{sub_command.value:02x} not implemented - ignoring') @@ -141,7 +149,7 @@ class ControllerProtocol(BaseProtocol): else: logger.warning(f'Output report {output_report_id} not implemented - ignoring') - async def _command_request_device_info(self, output_report): + async def _command_request_device_info(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() @@ -155,7 +163,7 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - async def _command_set_shipment_state(self, output_report): + async def _command_set_shipment_state(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() @@ -165,18 +173,33 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - async def _command_spi_flash_read(self, output_report): + async def _command_spi_flash_read(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() input_report.set_ack(0x90) - input_report.sub_0x10_spi_flash_read(output_report) + + # parse offset + offset = 0 + digit = 1 + for i in range(4): + offset += sub_command_data[i] * digit + digit *= 0x100 + + size = sub_command_data[4] + + if self.spi_flash is not None: + spi_flash_data = self.spi_flash[offset: offset+size] + input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data) + else: + spi_flash_data = size * [0x00] + input_report.sub_0x10_spi_flash_read(offset, size, spi_flash_data) await self.write(input_report) - async def _command_set_input_report_mode(self, output_report): - if output_report.data[12] == 0x30: + async def _command_set_input_report_mode(self, sub_command_data): + if sub_command_data[0] == 0x30: logger.info('Setting input report mode to 0x30...') # start sending 0x30 input reports assert self._0x30_input_report_sender is None @@ -191,9 +214,9 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) else: - logger.error(f'input report mode {output_report.data[12]} not implemented - ignoring request') + logger.error(f'input report mode {sub_command_data[0]} not implemented - ignoring request') - async def _command_trigger_buttons_elapsed_time(self, output_report): + async def _command_trigger_buttons_elapsed_time(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() @@ -203,7 +226,7 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - async def _command_enable_6axis_sensor(self, output_report): + async def _command_enable_6axis_sensor(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() @@ -213,7 +236,7 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - async def _command_enable_vibration(self, output_report): + async def _command_enable_vibration(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() @@ -223,7 +246,7 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - async def _command_set_nfc_ir_mcu_config(self, output_report): + async def _command_set_nfc_ir_mcu_config(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() @@ -236,7 +259,7 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - async def _command_set_player_lights(self, output_report): + async def _command_set_player_lights(self, sub_command_data): input_report = InputReport() input_report.set_input_report_id(0x21) input_report.set_misc() diff --git a/joycontrol/report.py b/joycontrol/report.py index 121fba9..2cfaffa 100644 --- a/joycontrol/report.py +++ b/joycontrol/report.py @@ -26,6 +26,12 @@ class InputReport: for i in range(14, 51): self.data[i] = 0x00 + def get_sub_command_reply_data(self): + if len(self.data) < 50: + raise ValueError('Not enough data') + + return self.data[16:51] + def set_input_report_id(self, _id): """ :param _id: e.g. 0x21 Standard input reports used for sub command replies @@ -39,7 +45,7 @@ class InputReport: def set_timer(self, timer): """ - Input report timer (0x00-0xFF), usually set by the transport + Input report timer [0x00-0xFF], usually set by the transport """ self.data[2] = timer % 256 @@ -78,6 +84,9 @@ class InputReport: """ self.data[14] = ack + def get_ack(self): + return self.data[14] + def set_6axis_data(self): """ Set accelerator and gyro of 0x30 input reports @@ -89,6 +98,14 @@ class InputReport: def reply_to_subcommand_id(self, id_): self.data[15] = id_ + def get_reply_to_subcommand_id(self): + if len(self.data) < 16: + return None + try: + return SubCommand(self.data[15]) + except ValueError: + raise NotImplementedError(f'Sub command id {hex(self.data[11])} not implemented') + def sub_0x02_device_info(self, mac, fm_version=(0x04, 0x00), controller=Controller.JOYCON_L): """ Sub command 0x02 request device info response. @@ -113,9 +130,21 @@ class InputReport: self.data[offset + 10] = 0x01 self.data[offset + 11] = 0x01 - def sub_0x10_spi_flash_read(self, output_report): + def sub_0x10_spi_flash_read(self, offset, size, data): + if len(data) != size: + raise ValueError(f'Length of data {len(data)} does not match size {size}') + if size > 0x1D: + raise ValueError(f'Size can not exceed {0x1D}') + self.reply_to_subcommand_id(0x10) - self.data[16:18] = output_report.data[12:14] + + # write offset to data + for i in range(16, 16 + 4): + self.data[i] = offset % 0x100 + offset = offset // 0x100 + + self.data[20] = size + self.data[21:21+len(data)] = data def sub_0x04_trigger_buttons_elapsed_time(self): self.reply_to_subcommand_id(0x04) @@ -149,7 +178,11 @@ class OutputReportID(Enum): class OutputReport: - def __init__(self, data): + def __init__(self, data=None): + if data is None: + data = 50 * [0x00] + data[0] = 0xA2 + if data[0] != 0xA2: raise ValueError('Output reports must start with 0xA2') self.data = data @@ -160,9 +193,21 @@ class OutputReport: except ValueError: raise NotImplementedError(f'Output report id {hex(self.data[1])} not implemented') + def set_output_report_id(self, _id): + if isinstance(_id, OutputReportID): + self.data[1] = _id.value + else: + self.data[1] = _id + def get_timer(self): return OutputReportID(self.data[2]) + def set_timer(self, timer): + """ + Output report timer [0x0 - 0xF] + """ + self.data[2] = timer % 0x10 + def get_rumble_data(self): return self.data[3:11] @@ -174,5 +219,37 @@ class OutputReport: except ValueError: raise NotImplementedError(f'Sub command id {hex(self.data[11])} not implemented') + def get_sub_command_data(self): + if len(self.data) < 13: + return None + return self.data[12:] + + def set_sub_command(self, _id): + if isinstance(_id, SubCommand): + self.data[11] = _id.value + else: + self.data[11] = _id + + def sub_0x10_spi_flash_read(self, offset, size): + """ + Creates output report data with spi flash read sub command. + :param offset: start byte of the spi flash to read in [0x00, 0x80000) + :param size: size of data to be read in [0x00, 0x1D] + """ + if size > 0x1D: + raise ValueError(f'Size read can not exceed {0x1D}') + if offset+size > 0x80000: + raise ValueError(f'Given address range exceeds max address {0x80000-1}') + + self.set_output_report_id(OutputReportID.SUB_COMMAND) + self.set_sub_command(SubCommand.SPI_FLASH_READ) + + # write offset to data + for i in range(12, 12+4): + self.data[i] = offset % 0x100 + offset = offset // 0x100 + + self.data[16] = size + def __bytes__(self): return bytes(self.data) diff --git a/joycontrol/transport.py b/joycontrol/transport.py index 8b7eeb5..d7e18db 100644 --- a/joycontrol/transport.py +++ b/joycontrol/transport.py @@ -59,12 +59,6 @@ class L2CAP_Transport(asyncio.Transport): def set_read_buffer_size(self, size): self._read_buffer_size = size - def set_write_buffer_limits(self, high: int = ..., low: int = ...) -> None: - super().set_write_buffer_limits(high, low) - - def get_write_buffer_size(self) -> int: - return super().get_write_buffer_size() - async def write(self, data: Any) -> None: if isinstance(data, bytes): _bytes = data diff --git a/joycontrol/utils.py b/joycontrol/utils.py index a0045f3..4aa33c8 100644 --- a/joycontrol/utils.py +++ b/joycontrol/utils.py @@ -1,6 +1,5 @@ import asyncio import logging -import re logger = logging.getLogger(__name__) diff --git a/run_test_controller_buttons.py b/run_test_controller_buttons.py index 93edd50..95c4410 100644 --- a/run_test_controller_buttons.py +++ b/run_test_controller_buttons.py @@ -1,3 +1,4 @@ +import argparse import asyncio import logging import os @@ -66,8 +67,14 @@ async def test_controller_buttons(controller_state: ControllerState): await asyncio.sleep(0.1) -async def main(): - transport, protocol = await create_hid_server(controller_protocol_factory(Controller.PRO_CONTROLLER), 17, 19) +async def _main(args): + spi_flash = None + if args.spi_flash: + with open(args.spi_flash, 'rb') as spi_flash_file: + spi_flash = spi_flash_file.read() + + factory = controller_protocol_factory(Controller.PRO_CONTROLLER, spi_flash=spi_flash) + transport, protocol = await create_hid_server(factory, 17, 19) await test_controller_buttons(protocol.get_controller_state()) @@ -80,8 +87,12 @@ if __name__ == '__main__': if not os.geteuid() == 0: raise PermissionError('Script must be run as root!') + parser = argparse.ArgumentParser() + parser.add_argument('--spi_flash') + args = parser.parse_args() + # setup logging log.configure() loop = asyncio.get_event_loop() - loop.run_until_complete(main()) + loop.run_until_complete(_main(args))