From 42b06f286283956ed2d91b8b4a3c6f37627e6096 Mon Sep 17 00:00:00 2001 From: Robert Martin Date: Tue, 11 Feb 2020 22:16:06 +0900 Subject: [PATCH] implemented controller joy sticks --- dump_spi_flash.py | 2 +- joycontrol/controller_state.py | 302 ++++++++++++++++++++++++++++----- joycontrol/memory.py | 46 +++++ joycontrol/protocol.py | 32 ++-- joycontrol/report.py | 11 ++ run_controller_cli.py | 190 +++++++++++++++++++++ run_test_controller_buttons.py | 70 +++++--- 7 files changed, 578 insertions(+), 75 deletions(-) create mode 100644 joycontrol/memory.py create mode 100644 run_controller_cli.py diff --git a/dump_spi_flash.py b/dump_spi_flash.py index 74601ef..aced347 100644 --- a/dump_spi_flash.py +++ b/dump_spi_flash.py @@ -174,7 +174,7 @@ if __name__ == '__main__': raise PermissionError('Script must be run as root!') parser = argparse.ArgumentParser() - parser.add_argument('-o', '--output') + parser.add_argument('output') args = parser.parse_args() # setup logging diff --git a/joycontrol/controller_state.py b/joycontrol/controller_state.py index bf41ec6..6e2adf7 100644 --- a/joycontrol/controller_state.py +++ b/joycontrol/controller_state.py @@ -1,17 +1,49 @@ import asyncio from joycontrol import utils +from joycontrol.controller import Controller +from joycontrol.memory import FlashMemory class ControllerState: - def __init__(self, protocol): + def __init__(self, protocol, controller: Controller, spi_flash: FlashMemory = None): self._protocol = protocol + self._controller = controller - self.button_state = None - self.stick_state = None + self._spi_flash = spi_flash + + self.button_state = ButtonState(controller) + + # create left stick state + self.l_stick_state = self.r_stick_state = None + if controller in (Controller.PRO_CONTROLLER, Controller.JOYCON_L): + # load calibration data from memory + calibration = None + if spi_flash is not None: + calibration_data = spi_flash.get_user_l_stick_calibration() + if calibration_data is None: + calibration_data = spi_flash.get_factory_l_stick_calibration() + calibration = LeftStickCalibration.from_bytes(calibration_data) + + self.l_stick_state = StickState(calibration=calibration) + + # create right stick state + if controller in (Controller.PRO_CONTROLLER, Controller.JOYCON_R): + # load calibration data from memory + calibration = None + if spi_flash is not None: + calibration_data = spi_flash.get_user_r_stick_calibration() + if calibration_data is None: + calibration_data = spi_flash.get_factory_r_stick_calibration() + calibration = RightStickCalibration.from_bytes(calibration_data) + + self.r_stick_state = StickState(calibration=calibration) self.sig_is_send = asyncio.Event() + def get_flash_memory(self): + return self._spi_flash + async def send(self): self.sig_is_send.clear() await self.sig_is_send.wait() @@ -20,7 +52,7 @@ class ControllerState: """ Waits until the switch is paired with the controller and accepts button commands """ - await self._protocol.sig_wait_player_lights.wait() + await self._protocol.sig_set_player_lights.wait() class ButtonState: @@ -31,8 +63,19 @@ class ButtonState: 1 Y X B A SR SL R ZR 2 Minus Plus R Stick L Stick Home Capture 3 Down Up Right Left SR SL L ZL + + Example for generated methods: home button (byte_2, 4) + + def home(self, pushed=True): + if pushed != utils.get_bit(self.byte_2, 4): + self.byte_2 = utils.flip_bit(self.byte_2, 4) + + def home_is_set(self): + return get_bit(self.byte_2, 4) """ - def __init__(self): + def __init__(self, controller: Controller): + self.controller = controller + # 3 bytes self._byte_1 = 0 self._byte_2 = 0 @@ -40,54 +83,84 @@ class ButtonState: # generating methods for each button def button_method_factory(byte, bit): - def flip(): - setattr(self, byte, utils.flip_bit(getattr(self, byte), bit)) + def setter(pushed=True): + _byte = getattr(self, byte) + + if pushed != utils.get_bit(_byte, bit): + setattr(self, byte, utils.flip_bit(_byte, bit)) def getter(): return utils.get_bit(getattr(self, byte), bit) - return flip, getter + return setter, getter + + if self.controller == Controller.PRO_CONTROLLER: + self._available_buttons = {'y', 'x', 'b', 'a', 'r', 'zr', + 'minus', 'plus', 'r_stick', 'l_stick', 'home', 'capture', + 'down', 'up', 'right', 'left', 'l', 'zl'} + elif self.controller == Controller.JOYCON_R: + self._available_buttons = {'y', 'x', 'b', 'a', 'sr', 'sl', 'r', 'zr', + 'plus', 'r_stick', 'home'} + elif self.controller == Controller.JOYCON_L: + self._available_buttons = {'plus', 'l_stick', 'capture', + 'down', 'up', 'right', 'left', 'sr', 'sl', 'l', 'zl'} # byte 1 - self.y, self.y_is_set = button_method_factory('_byte_1', 0) - self.x, self.x_is_set = button_method_factory('_byte_1', 1) - self.b, self.b_is_set = button_method_factory('_byte_1', 2) - self.a, self.a_is_set = button_method_factory('_byte_1', 3) - self.right_sr, self.right_sr_is_set = button_method_factory('_byte_1', 4) - self.right_sl, self.right_sl_is_set = button_method_factory('_byte_1', 5) - self.r, self.r_is_set = button_method_factory('_byte_1', 6) - self.zr, self.zr_is_set = button_method_factory('_byte_1', 7) + if self.controller == Controller.PRO_CONTROLLER or self.controller == Controller.JOYCON_R: + self.y, self.y_is_set = button_method_factory('_byte_1', 0) + self.x, self.x_is_set = button_method_factory('_byte_1', 1) + self.b, self.b_is_set = button_method_factory('_byte_1', 2) + self.a, self.a_is_set = button_method_factory('_byte_1', 3) + + if self.controller == Controller.JOYCON_R: + self.sr, self.sr_is_set = button_method_factory('_byte_1', 4) + self.sl, self.sl_is_set = button_method_factory('_byte_1', 5) + + self.r, self.r_is_set = button_method_factory('_byte_1', 6) + self.zr, self.zr_is_set = button_method_factory('_byte_1', 7) # byte 2 self.minus, self.minus_is_set = button_method_factory('_byte_2', 0) self.plus, self.plus_is_set = button_method_factory('_byte_2', 1) self.r_stick, self.r_stick_is_set = button_method_factory('_byte_2', 2) self.l_stick, self.l_stick_is_set = button_method_factory('_byte_2', 3) - self.home, self.home_is_set = button_method_factory('_byte_2', 4) - self.capture, self.capture_is_set = button_method_factory('_byte_2', 5) + if self.controller == Controller.JOYCON_R or self.controller == Controller.PRO_CONTROLLER: + self.home, self.home_is_set = button_method_factory('_byte_2', 4) + if self.controller == Controller.JOYCON_L or self.controller == Controller.PRO_CONTROLLER: + self.capture, self.capture_is_set = button_method_factory('_byte_2', 5) # byte 3 - self.down, self.down_is_set = button_method_factory('_byte_3', 0) - self.up, self.up_is_set = button_method_factory('_byte_3', 1) - self.right, self.right_is_set = button_method_factory('_byte_3', 2) - self.left, self.left_is_set = button_method_factory('_byte_3', 3) - self.left_sr, self.left_sr_is_set = button_method_factory('_byte_3', 4) - self.left_sl, self.left_sl_is_set = button_method_factory('_byte_3', 5) - self.l, self.l_is_set = button_method_factory('_byte_3', 6) - self.zl, self.zl_is_set = button_method_factory('_byte_3', 7) + if self.controller == Controller.PRO_CONTROLLER or self.controller == Controller.JOYCON_L: + self.down, self.down_is_set = button_method_factory('_byte_3', 0) + self.up, self.up_is_set = button_method_factory('_byte_3', 1) + self.right, self.right_is_set = button_method_factory('_byte_3', 2) + self.left, self.left_is_set = button_method_factory('_byte_3', 3) - """ - Example for generated methods: home button (byte_2, 4) + if self.controller == Controller.JOYCON_L: + self.sr, self.sr_is_set = button_method_factory('_byte_3', 4) + self.sl, self.sl_is_set = button_method_factory('_byte_3', 5) - def home(self): - self.byte_2 = flip_bit(self.byte_2, 4) + self.l, self.l_is_set = button_method_factory('_byte_3', 6) + self.zl, self.zl_is_set = button_method_factory('_byte_3', 7) - def home_is_set(self): - return get_bit(self.byte_2, 4) - """ + def set_button(self, button, pushed=True): + if button not in self._available_buttons: + raise ValueError(f'Given button "{button}" is not available to {self.controller.device_name()}.') + getattr(self, button)(pushed=pushed) + + def get_button(self, button): + if button not in self._available_buttons: + raise ValueError(f'Given button "{button}" is not available to {self.controller.device_name()}.') + return getattr(self, f'{button}_is_set')() + + def get_available_buttons(self): + """ + :returns set of valid buttons + """ + return set(self._available_buttons) def __iter__(self): """ - @returns iterator of the button bytes + @returns iterator over the button bytes """ yield self._byte_1 yield self._byte_2 @@ -97,25 +170,166 @@ class ButtonState: self._byte_1 = self._byte_2 = self._byte_3 = 0 -async def button_push(controller_state, button, sec=0.1): - button_state = ButtonState() +async def button_push(controller_state, *buttons, sec=0.1): + if not buttons: + raise ValueError('No Buttons were given.') - # push button - getattr(button_state, button)() + button_state = controller_state.button_state + + for button in buttons: + # push button + button_state.set_button(button) # send report - controller_state.button_state = button_state await controller_state.send() await asyncio.sleep(sec) - # release button - getattr(button_state, button)() + for button in buttons: + # release button + button_state.set_button(button, pushed=False) # send report - controller_state.button_state = button_state await controller_state.send() +class _StickCalibration: + def __init__(self, h_center, v_center, h_max_above_center, v_max_above_center, h_max_below_center, v_max_below_center): + self.h_center = h_center + self.v_center = v_center + + self.h_max_above_center = h_max_above_center + self.v_max_above_center = v_max_above_center + self.h_max_below_center = h_max_below_center + self.v_max_below_center = v_max_below_center + + def __str__(self): + return f'h_center:{self.h_center} v_center:{self.v_center} h_max_above_center:{self.h_max_above_center} ' \ + f'v_max_above_center:{self.v_max_above_center} h_max_below_center:{self.h_max_below_center} ' \ + f'v_max_below_center:{self.v_max_below_center}' + + +class LeftStickCalibration(_StickCalibration): + @staticmethod + def from_bytes(_9bytes): + h_max_above_center = (_9bytes[1] << 8) & 0xF00 | _9bytes[0] + v_max_above_center = (_9bytes[2] << 4) | (_9bytes[1] >> 4) + h_center = (_9bytes[4] << 8) & 0xF00 | _9bytes[3] + v_center = (_9bytes[5] << 4) | (_9bytes[4] >> 4) + h_max_below_center = (_9bytes[7] << 8) & 0xF00 | _9bytes[6] + v_max_below_center = (_9bytes[8] << 4) | (_9bytes[7] >> 4) + + return _StickCalibration(h_center, v_center, h_max_above_center, v_max_above_center, + h_max_below_center, v_max_below_center) + + +class RightStickCalibration(_StickCalibration): + @staticmethod + def from_bytes(_9bytes): + h_center = (_9bytes[1] << 8) & 0xF00 | _9bytes[0] + v_center = (_9bytes[2] << 4) | (_9bytes[1] >> 4) + h_max_below_center = (_9bytes[4] << 8) & 0xF00 | _9bytes[3] + v_max_below_center = (_9bytes[5] << 4) | (_9bytes[4] >> 4) + h_max_above_center = (_9bytes[7] << 8) & 0xF00 | _9bytes[6] + v_max_above_center = (_9bytes[8] << 4) | (_9bytes[7] >> 4) + + return _StickCalibration(h_center, v_center, h_max_above_center, v_max_above_center, + h_max_below_center, v_max_below_center) + + class StickState: - def __init__(self): - raise NotImplementedError() + def __init__(self, h=0, v=0, calibration: _StickCalibration = None): + for val in (h, v): + if not 0 <= val < 0x1000: + raise ValueError(f'Stick values must be in [0,{0x1000})') + + self._h_stick = h + self._v_stick = v + + self._calibration = calibration + + def set_h(self, value): + if not 0 <= value < 0x1000: + raise ValueError(f'Stick values must be in [0,{0x1000})') + self._h_stick = value + + def get_h(self): + return self._h_stick + + def set_v(self, value): + if not 0 <= value < 0x1000: + raise ValueError(f'Stick values must be in [0,{0x1000})') + self._v_stick = value + + def get_v(self): + return self._v_stick + + def set_center(self): + """ + Sets stick to center position using the calibration data. + """ + if self._calibration is None: + raise ValueError('No calibration data available.') + self._h_stick = self._calibration.h_center + self._v_stick = self._calibration.v_center + + def is_center(self, radius=0): + return self._calibration.h_center - radius <= self._h_stick <= self._calibration.h_center + radius and \ + self._calibration.v_center - radius <= self._v_stick <= self._calibration.v_center + radius + + def set_up(self): + """ + Sets stick to up position using the calibration data. + """ + if self._calibration is None: + raise ValueError('No calibration data available.') + self._h_stick = self._calibration.h_center + self._v_stick = self._calibration.v_center + self._calibration.v_max_above_center + + def set_down(self): + """ + Sets stick to down position using the calibration data. + """ + if self._calibration is None: + raise ValueError('No calibration data available.') + self._h_stick = self._calibration.h_center + self._v_stick = self._calibration.v_center - self._calibration.v_max_below_center + + def set_left(self): + """ + Sets stick to left position using the calibration data. + """ + if self._calibration is None: + raise ValueError('No calibration data available.') + self._h_stick = self._calibration.h_center - self._calibration.h_max_below_center + self._v_stick = self._calibration.v_center + + def set_right(self): + """ + Sets stick to right position using the calibration data. + """ + if self._calibration is None: + raise ValueError('No calibration data available.') + self._h_stick = self._calibration.h_center + self._calibration.h_max_above_center + self._v_stick = self._calibration.v_center + + def set_calibration(self, calibration): + self._calibration = calibration + + def get_calibration(self): + if self._calibration is None: + raise ValueError('No calibration data available.') + return self._calibration + + @staticmethod + def from_bytes(_3bytes): + stick_h = _3bytes[0] | ((_3bytes[1] & 0xF) << 8) + stick_v = (_3bytes[1] >> 4) | (_3bytes[2] << 4) + + return StickState(h=stick_h, v=stick_v) + + def __bytes__(self): + byte_1 = 0xFF & self._h_stick + byte_2 = (self._h_stick >> 8) | ((0xF & self._v_stick) << 4) + byte_3 = self._v_stick >> 4 + assert all(0 <= byte <= 0xFF for byte in (byte_1, byte_2, byte_3)) + return bytes((byte_1, byte_2, byte_3)) diff --git a/joycontrol/memory.py b/joycontrol/memory.py new file mode 100644 index 0000000..272cb50 --- /dev/null +++ b/joycontrol/memory.py @@ -0,0 +1,46 @@ + +class FlashMemory: + def __init__(self, spi_flash_memory_data=None, size=0x80000): + if spi_flash_memory_data is None: + self.data = size * [0x00] + else: + if len(spi_flash_memory_data) != size: + raise ValueError(f'Given data size {len(spi_flash_memory_data)} does not match size {size}.') + if isinstance(spi_flash_memory_data, bytes): + spi_flash_memory_data = list(spi_flash_memory_data) + self.data = spi_flash_memory_data + + def __getitem__(self, item): + return self.data[item] + + def get_factory_l_stick_calibration(self): + """ + :returns 9 left stick factory calibration bytes + """ + return self.data[0x603D:0x6046] + + def get_factory_r_stick_calibration(self): + """ + :returns 9 right stick factory calibration bytes + """ + return self.data[0x6046:0x604F] + + def get_user_l_stick_calibration(self): + """ + :returns 9 left stick user calibration bytes if the data is available, otherwise None + """ + # check if calibration data is available: + if self.data[0x8010] == 0xB2 and self.data[0x8011] == 0xA1: + return self.data[0x8012:0x801B] + else: + return None + + def get_user_r_stick_calibration(self): + """ + :returns 9 right stick user calibration bytes if the data is available, otherwise None + """ + # check if calibration data is available: + if self.data[0x801B] == 0xB2 and self.data[0x801C] == 0xA1: + return self.data[0x801D:0x8026] + else: + return None diff --git a/joycontrol/protocol.py b/joycontrol/protocol.py index 19442cb..fb74795 100644 --- a/joycontrol/protocol.py +++ b/joycontrol/protocol.py @@ -5,42 +5,52 @@ from typing import Optional, Union, Tuple, Text from joycontrol.controller import Controller from joycontrol.controller_state import ControllerState +from joycontrol.memory import FlashMemory from joycontrol.report import OutputReport, SubCommand, InputReport, OutputReportID logger = logging.getLogger(__name__) def controller_protocol_factory(controller: Controller, spi_flash=None): + if isinstance(spi_flash, bytes): + spi_flash = FlashMemory(spi_flash_memory_data=spi_flash) + def create_controller_protocol(): return ControllerProtocol(controller, spi_flash=spi_flash) return create_controller_protocol class ControllerProtocol(BaseProtocol): - def __init__(self, controller: Controller, spi_flash=None): + def __init__(self, controller: Controller, spi_flash: FlashMemory = None): self.controller = controller - if spi_flash is not None: - self.spi_flash = list(spi_flash) - else: - self.spi_flash = None + self.spi_flash = spi_flash self.transport = None self._data_received = asyncio.Event() - self._controller_state = ControllerState(self) + self._controller_state = ControllerState(self, controller, spi_flash=spi_flash) self._pending_write = None self._pending_input_report = None self._0x30_input_report_sender = None - self.sig_wait_player_lights = asyncio.Event() + self.sig_set_player_lights = asyncio.Event() async def write(self, input_report: InputReport): - # set button and TODO: stick data - if self._controller_state.button_state is not None: - input_report.set_button_status(self._controller_state.button_state) + # set button and stick data + input_report.set_button_status(self._controller_state.button_state) + if self._controller_state.l_stick_state is None: + l_stick = [0x00, 0x00, 0x00] + else: + l_stick = self._controller_state.l_stick_state + if self._controller_state.r_stick_state is None: + r_stick = [0x00, 0x00, 0x00] + else: + r_stick = self._controller_state.r_stick_state + input_report.set_stick_status(l_stick, r_stick) + self._controller_state.sig_is_send.set() await self.transport.write(input_report) @@ -269,4 +279,4 @@ class ControllerProtocol(BaseProtocol): await self.write(input_report) - self.sig_wait_player_lights.set() + self.sig_set_player_lights.set() diff --git a/joycontrol/report.py b/joycontrol/report.py index 2cfaffa..da41e27 100644 --- a/joycontrol/report.py +++ b/joycontrol/report.py @@ -26,6 +26,10 @@ class InputReport: for i in range(14, 51): self.data[i] = 0x00 + def get_stick_data(self): + # TODO: Not every input report has stick data + return self.data[7:13] + def get_sub_command_reply_data(self): if len(self.data) < 50: raise ValueError('Not enough data') @@ -59,6 +63,13 @@ class InputReport: """ self.data[4:7] = iter(button_status) + def set_stick_status(self, left_stick, right_stick): + """ + Sets the joystick status bytes + """ + self.data[7:10] = bytes(left_stick) + self.data[10:13] = bytes(right_stick) + def set_left_analog_stick(self): """ TODO diff --git a/run_controller_cli.py b/run_controller_cli.py new file mode 100644 index 0000000..509800d --- /dev/null +++ b/run_controller_cli.py @@ -0,0 +1,190 @@ +import argparse +import asyncio +import inspect +import logging +import os +from contextlib import contextmanager + +from aioconsole import ainput +from joycontrol import logging_default as log +from joycontrol.controller import Controller +from joycontrol.controller_state import button_push, ControllerState +from joycontrol.memory import FlashMemory +from joycontrol.protocol import controller_protocol_factory +from joycontrol.server import create_hid_server + + +logger = logging.getLogger(__name__) + + +class ControllerCLI: + def __init__(self, controller_state: ControllerState): + self.controller_state = controller_state + self.commands = {} + + async def cmd_help(self): + print('Buttons can be used as commands: ', ', '.join(self.controller_state.button_state.get_available_buttons())) + + for name, fun in inspect.getmembers(self): + if name.startswith('cmd_') and fun.__doc__: + print(fun.__doc__) + + print('Commands can be chained using "&&"') + print('Type "exit" to close.') + + @staticmethod + def _set_stick(stick, direction, value): + if direction == 'center': + stick.set_center() + elif direction == 'up': + stick.set_up() + elif direction == 'down': + stick.set_down() + elif direction == 'left': + stick.set_left() + elif direction == 'right': + stick.set_right() + elif direction in ('h', 'horizontal'): + if value is None: + raise ValueError(f'Missing value') + try: + val = int(value) + except ValueError: + raise ValueError(f'Unexpected stick value "{value}"') + stick.set_h(val) + elif direction in ('v', 'vertical'): + if value is None: + raise ValueError(f'Missing value') + try: + val = int(value) + except ValueError: + raise ValueError(f'Unexpected stick value "{value}"') + stick.set_v(val) + else: + raise ValueError(f'Unexpected argument "{direction}"') + + return f'{stick.__class__.__name__} was set to ({stick.get_h()}, {stick.get_v()}).' + + async def cmd_stick(self, side, direction, value=None): + """ + stick - Command to set stick positions. + :param side: 'l', 'left' for left control stick; 'r', 'right' for right control stick + :param direction: 'center', 'up', 'down', 'left', 'right'; + 'h', 'horizontal' or 'v', 'vertical' to set the value directly to the "value" argument + :param value: horizontal or vertical value + """ + if side in ('l', 'left'): + stick = self.controller_state.l_stick_state + return ControllerCLI._set_stick(stick, direction, value) + elif side in ('r', 'right'): + stick = self.controller_state.r_stick_state + return ControllerCLI._set_stick(stick, direction, value) + else: + raise ValueError('Value of side must be "l", "left" or "r", "right"') + + def add_command(self, name, command): + if name in self.commands: + raise ValueError(f'Command {name} already registered.') + self.commands[name] = command + + async def run(self): + while True: + user_input = await ainput(prompt='cmd >> ') + if not user_input: + continue + + buttons_to_push = [] + + for command in user_input.split('&&'): + cmd, *args = command.split() + + if cmd == 'exit': + return + + available_buttons = self.controller_state.button_state.get_available_buttons() + + if hasattr(self, f'cmd_{cmd}'): + try: + result = await getattr(self, f'cmd_{cmd}')(*args) + if result: + print(result) + except Exception as e: + print(e) + elif cmd in self.commands: + try: + result = await self.commands[cmd](*args) + if result: + print(result) + except Exception as e: + print(e) + elif cmd in available_buttons: + buttons_to_push.append(cmd) + else: + print('command', cmd, 'not found, call help for help.') + + if buttons_to_push: + await button_push(self.controller_state, *buttons_to_push) + else: + await self.controller_state.send() + + +async def _main(controller, capture_file=None, spi_flash=None): + factory = controller_protocol_factory(controller, spi_flash=spi_flash) + transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file) + + controller_state = protocol.get_controller_state() + + cli = ControllerCLI(controller_state) + await cli.run() + + logger.info('Stopping communication...') + await transport.close() + + +if __name__ == '__main__': + # check if root + if not os.geteuid() == 0: + raise PermissionError('Script must be run as root!') + + # setup logging + log.configure(console_level=logging.ERROR) + + parser = argparse.ArgumentParser() + parser.add_argument('controller', help='JOYCON_R, JOYCON_L or PRO_CONTROLLER') + parser.add_argument('-l', '--log') + parser.add_argument('--spi_flash') + args = parser.parse_args() + + if args.controller == 'JOYCON_R': + controller = Controller.JOYCON_R + elif args.controller == 'JOYCON_L': + controller = Controller.JOYCON_L + elif args.controller == 'PRO_CONTROLLER': + controller = Controller.PRO_CONTROLLER + else: + raise ValueError(f'Unknown controller "{args.controller}".') + + spi_flash = None + if args.spi_flash: + with open(args.spi_flash, 'rb') as spi_flash_file: + spi_flash = FlashMemory(spi_flash_file.read()) + + # creates file if arg is given + @contextmanager + def get_output(path=None): + """ + Opens file if path is given + """ + if path is not None: + file = open(path, 'wb') + yield file + file.close() + else: + yield None + + with get_output(args.log) as capture_file: + loop = asyncio.get_event_loop() + loop.run_until_complete(_main(controller, capture_file=capture_file, spi_flash=spi_flash)) + + + diff --git a/run_test_controller_buttons.py b/run_test_controller_buttons.py index 95c4410..4483744 100644 --- a/run_test_controller_buttons.py +++ b/run_test_controller_buttons.py @@ -2,6 +2,7 @@ import argparse import asyncio import logging import os +from contextlib import contextmanager from joycontrol import logging_default as log from joycontrol.controller_state import ControllerState, button_push @@ -33,7 +34,7 @@ async def test_controller_buttons(controller_state: ControllerState): await asyncio.sleep(0.3) # go all the way down - await button_push(controller_state, 'down', sec=3) + await button_push(controller_state, 'down', sec=4) await asyncio.sleep(0.3) # goto "Controllers and Sensors" menu @@ -57,24 +58,22 @@ async def test_controller_buttons(controller_state: ControllerState): await button_push(controller_state, 'a') await asyncio.sleep(0.3) - # push all buttons - button_list = ['y', 'x', 'b', 'a', 'r', 'zr', - 'minus', 'plus', 'r_stick', 'l_stick', - 'down', 'up', 'right', 'left', 'l', 'zl'] + # push all buttons except home and capture + button_list = controller_state.button_state.get_available_buttons() + if 'capture' in button_list: + button_list.remove('capture') + if 'home' in button_list: + button_list.remove('home') + for i in range(10): for button in button_list: await button_push(controller_state, button) await asyncio.sleep(0.1) -async def _main(args): - spi_flash = None - if args.spi_flash: - with open(args.spi_flash, 'rb') as spi_flash_file: - spi_flash = spi_flash_file.read() - - factory = controller_protocol_factory(Controller.PRO_CONTROLLER, spi_flash=spi_flash) - transport, protocol = await create_hid_server(factory, 17, 19) +async def _main(controller, capture_file=None, spi_flash=None): + factory = controller_protocol_factory(controller, spi_flash=spi_flash) + transport, protocol = await create_hid_server(factory, 17, 19, capture_file=capture_file) await test_controller_buttons(protocol.get_controller_state()) @@ -87,12 +86,45 @@ if __name__ == '__main__': if not os.geteuid() == 0: raise PermissionError('Script must be run as root!') - parser = argparse.ArgumentParser() - parser.add_argument('--spi_flash') - args = parser.parse_args() - # setup logging log.configure() - loop = asyncio.get_event_loop() - loop.run_until_complete(_main(args)) + parser = argparse.ArgumentParser() + #parser.add_argument('controller', help='JOYCON_R, JOYCON_L or PRO_CONTROLLER') + parser.add_argument('-l', '--log') + parser.add_argument('--spi_flash') + args = parser.parse_args() + + """ + if args.controller == 'JOYCON_R': + controller = Controller.JOYCON_R + elif args.controller == 'JOYCON_L': + controller = Controller.JOYCON_L + elif args.controller == 'PRO_CONTROLLER': + controller = Controller.PRO_CONTROLLER + else: + raise ValueError(f'Unknown controller "{args.controller}".') + """ + controller = Controller.PRO_CONTROLLER + + spi_flash = None + if args.spi_flash: + with open(args.spi_flash, 'rb') as spi_flash_file: + spi_flash = spi_flash_file.read() + + # creates file if arg is given + @contextmanager + def get_output(path=None): + """ + Opens file if path is given + """ + if path is not None: + file = open(path, 'wb') + yield file + file.close() + else: + yield None + + with get_output(args.log) as capture_file: + loop = asyncio.get_event_loop() + loop.run_until_complete(_main(controller, capture_file=capture_file, spi_flash=spi_flash))