Implemented goto "Test Controller Buttons" menu

This commit is contained in:
Robert Martin
2020-01-31 20:09:17 +09:00
parent 6db94676f8
commit 19eb051726
8 changed files with 256 additions and 156 deletions
+75
View File
@@ -0,0 +1,75 @@
import utils
class ButtonState:
"""
Utility class to set buttons in the input report
https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_notes.md
Byte 0 1 2 3 4 5 6 7
1 Y X B A SR SL R ZR
2 Minus Plus R Stick L Stick Home Capture
3 Down Up Right Left SR SL L ZL
"""
def __init__(self):
# 3 bytes
self._byte_1 = 0
self._byte_2 = 0
self._byte_3 = 0
# generating methods for each button
def button_method_factory(byte, bit):
def flip():
setattr(self, byte, utils.flip_bit(getattr(self, byte), bit))
def getter():
return utils.get_bit(getattr(self, byte), bit)
return flip, getter
# byte 1
self.y, self.y_is_set = button_method_factory('_byte_1', 0)
self.x, self.x_is_set = button_method_factory('_byte_1', 1)
self.b, self.b_is_set = button_method_factory('_byte_1', 2)
self.a, self.a_is_set = button_method_factory('_byte_1', 3)
self.right_sr, self.right_sr_is_set = button_method_factory('_byte_1', 4)
self.right_sl, self.right_sl_is_set = button_method_factory('_byte_1', 5)
self.r, self.r_is_set = button_method_factory('_byte_1', 6)
self.zr, self.zr_is_set = button_method_factory('_byte_1', 7)
# byte 2
self.minus, self.minus_is_set = button_method_factory('_byte_2', 0)
self.plus, self.plus_is_set = button_method_factory('_byte_2', 1)
self.r_stick, self.r_stick_is_set = button_method_factory('_byte_2', 2)
self.l_stick, self.l_stick_is_set = button_method_factory('_byte_2', 3)
self.home, self.home_is_set = button_method_factory('_byte_2', 4)
self.capture, self.capture_is_set = button_method_factory('_byte_2', 5)
# byte 3
self.down, self.down_is_set = button_method_factory('_byte_3', 0)
self.up, self.up_is_set = button_method_factory('_byte_3', 1)
self.right, self.right_is_set = button_method_factory('_byte_3', 2)
self.left, self.left_is_set = button_method_factory('_byte_3', 3)
self.left_sr, self.left_sr_is_set = button_method_factory('_byte_3', 4)
self.left_sl, self.left_sl_is_set = button_method_factory('_byte_3', 5)
self.l, self.l_is_set = button_method_factory('_byte_3', 6)
self.zl, self.zl_is_set = button_method_factory('_byte_3', 7)
"""
Example for generated methods: home button (byte_2, 4)
def home(self):
self.byte_2 = flip_bit(self.byte_2, 4)
def home_is_set(self):
return get_bit(self.byte_2, 4)
"""
def __iter__(self):
"""
@returns iterator of the button bytes
"""
yield self._byte_1
yield self._byte_2
yield self._byte_3
def clear(self):
self._byte_1 = self._byte_2 = self._byte_3 = 0
-76
View File
@@ -1,76 +0,0 @@
def get_bit(value, n):
return (value >> n & 1) != 0
def flip_bit(value, n):
return value ^ (1 << n)
class Buttons:
"""
Utility class to set buttons in the input report
https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_notes.md
Byte 0 1 2 3 4 5 6 7
1 Y X B A SR SL R ZR
2 Minus Plus R Stick L Stick Home Capture
3 Down Up Right Left SR SL L ZL
"""
def __init__(self):
# 3 bytes
self.byte_1 = 0
self.byte_2 = 0
self.byte_3 = 0
# generating methods for each button
def button_method_factory(byte, bit):
def flip():
setattr(self, byte, flip_bit(getattr(self, byte), bit))
def getter():
return get_bit(getattr(self, byte), bit)
return flip, getter
# byte 1
self.y, self.y_is_set = button_method_factory('byte_1', 0)
self.x, self.x_is_set = button_method_factory('byte_1', 1)
self.b, self.b_is_set = button_method_factory('byte_1', 2)
self.a, self.a_is_set = button_method_factory('byte_1', 3)
self.right_sr, self.right_sr_is_set = button_method_factory('byte_1', 4)
self.right_sl, self.right_sl_is_set = button_method_factory('byte_1', 5)
self.r, self.r_is_set = button_method_factory('byte_1', 6)
self.zr, self.zr_is_set = button_method_factory('byte_1', 7)
# byte 2
self.minus, self.minus_is_set = button_method_factory('byte_2', 0)
self.plus, self.plus_is_set = button_method_factory('byte_2', 1)
self.r_stick, self.r_stick_is_set = button_method_factory('byte_2', 2)
self.l_stick, self.l_stick_is_set = button_method_factory('byte_2', 3)
self.home, self.home_is_set = button_method_factory('byte_2', 4)
self.capture, self.capture_is_set = button_method_factory('byte_2', 5)
# byte 3
self.down, self.down_is_set = button_method_factory('byte_3', 0)
self.up, self.up_is_set = button_method_factory('byte_3', 1)
self.right, self.right_is_set = button_method_factory('byte_3', 2)
self.left, self.left_is_set = button_method_factory('byte_3', 3)
self.left_sr, self.left_sr_is_set = button_method_factory('byte_3', 4)
self.left_sl, self.left_sl_is_set = button_method_factory('byte_3', 5)
self.l, self.l_is_set = button_method_factory('byte_3', 6)
self.zl, self.zl_is_set = button_method_factory('byte_3', 7)
"""
Example for generated methods: home button (byte_2, 4)
def home(self):
self.byte_2 = flip_bit(self.byte_2, 4)
def home_is_set(self):
return get_bit(self.byte_2, 4)
"""
def to_list(self):
return [self.byte_1, self.byte_2, self.byte_3]
def clear(self):
self.byte_1 = self.byte_2 = self.byte_3 = 0
+38
View File
@@ -0,0 +1,38 @@
import asyncio
from button_state import ButtonState
from protocol import ControllerProtocol
class ControllerState:
def __init__(self, transport: asyncio.Transport, protocol: ControllerProtocol):
super().__init__()
self.transport = transport
self.protocol = protocol
async def send(self):
await self.protocol.button_input_report.write(self.transport)
async def connect(self):
"""
Waits until the switch is paired with the controller and accepts button commands
"""
# TODO HACK: Hard to say for now.
await self.protocol.wait_for_output_report()
# The switch sends data to our device, it shouldn't take long until the connection is fully established.
await asyncio.sleep(5)
def set_button_state(self, button_state: ButtonState):
"""
Sets the button status bytes in the input report
"""
self.protocol.button_input_report.set_button_status(button_state)
def set_stick_state(self):
"""
TODO
"""
raise NotImplementedError()
+33 -42
View File
@@ -21,6 +21,11 @@ class ControllerProtocol(BaseProtocol):
self.transport = None
# This must always be an 0x21 input report to be compatible with button events
self.button_input_report = InputReport()
self.button_input_report.set_input_report_id(0x21)
self.button_input_report.set_misc()
self._data_received = asyncio.Event()
async def wait_for_output_report(self):
@@ -49,7 +54,9 @@ class ControllerProtocol(BaseProtocol):
# classify sub command
sc_byte, sub_command = report.get_sub_command()
logging.info(f'received output report - {sub_command}')
if sub_command == SubCommand.REQUEST_DEVICE_INFO:
if sub_command is None:
logger.warning(f'Received output report does not contain a sub command')
elif sub_command == SubCommand.REQUEST_DEVICE_INFO:
await self._command_request_device_info(report)
elif sub_command == SubCommand.SET_SHIPMENT_STATE:
@@ -65,67 +72,51 @@ class ControllerProtocol(BaseProtocol):
await self._command_trigger_buttons_elapsed_time(report)
elif sub_command == SubCommand.NOT_IMPLEMENTED:
logger.error(f'Sub command 0x{sc_byte:02x} not implemented - ignoring')
logger.warning(f'Sub command 0x{sc_byte:02x} not implemented - ignoring')
async def _command_request_device_info(self, output_report):
address = self.transport.get_extra_info('sockname')
assert address is not None
bd_address = list(map(lambda x: int(x, 16), address[0].split(':')))
input_report = InputReport()
input_report.set_input_report_id(0x21)
input_report.set_misc()
input_report.set_ack(0x82)
#input_report.set_button_status()
#input_report.set_left_analog_stick()
#input_report.set_right_analog_stick()
#input_report.set_vibrator_input()
input_report.sub_0x02_device_info(bd_address, controller=self.controller)
self.button_input_report.set_misc()
self.button_input_report.set_ack(0x82)
self.button_input_report.sub_0x02_device_info(bd_address, controller=self.controller)
asyncio.ensure_future(self.transport.write(input_report))
await self.button_input_report.write(self.transport)
async def _command_set_shipment_state(self, output_report):
input_report = InputReport()
input_report.set_input_report_id(0x21)
input_report.set_misc()
input_report.set_ack(0x80)
input_report.sub_0x08_shipment()
self.button_input_report.set_misc()
self.button_input_report.set_ack(0x80)
self.button_input_report.sub_0x08_shipment()
asyncio.ensure_future(self.transport.write(input_report))
await self.button_input_report.write(self.transport)
async def _command_spi_flash_read(self, output_report):
input_report = InputReport()
input_report.set_input_report_id(0x21)
input_report.set_misc()
input_report.set_ack(0x90)
input_report.sub_0x10_spi_flash_read(output_report)
self.button_input_report.set_misc()
self.button_input_report.set_ack(0x90)
self.button_input_report.sub_0x10_spi_flash_read(output_report)
asyncio.ensure_future(self.transport.write(input_report))
await self.button_input_report.write(self.transport)
async def _command_set_input_report_mode(self, output_report):
input_report = InputReport()
input_report.set_input_report_id(0x21)
input_report.set_misc()
input_report.set_ack(0x80)
input_report.sub_0x03_set_input_report_mode()
self.button_input_report.set_misc()
self.button_input_report.set_ack(0x80)
self.button_input_report.sub_0x03_set_input_report_mode()
asyncio.ensure_future(self.transport.write(input_report))
await self.button_input_report.write(self.transport)
async def _command_trigger_buttons_elapsed_time(self, output_report):
input_report = InputReport()
input_report.set_input_report_id(0x21)
input_report.set_misc()
input_report.set_ack(0x83)
input_report.sub_0x04_trigger_buttons_elapsed_time()
self.button_input_report.set_misc()
self.button_input_report.set_ack(0x83)
self.button_input_report.sub_0x04_trigger_buttons_elapsed_time()
asyncio.ensure_future(self.transport.write(input_report))
await self.button_input_report.write(self.transport)
async def _enable_6axis_sensor(self, output_report):
input_report = InputReport()
input_report.set_input_report_id(0x21)
input_report.set_misc()
input_report.set_ack(0x80)
self.button_input_report.set_misc()
self.button_input_report.set_ack(0x80)
input_report.reply_to_subcommand_id(0x40)
self.button_input_report.reply_to_subcommand_id(0x40)
asyncio.ensure_future(self.transport.write(input_report))
await self.button_input_report.write(self.transport)
+28 -6
View File
@@ -1,5 +1,7 @@
import asyncio
from enum import Enum
from button_state import ButtonState
from controller import Controller
@@ -9,10 +11,19 @@ class InputReport:
https://github.com/dekuNukem/Nintendo_Switch_Reverse_Engineering/blob/master/bluetooth_hid_notes.md
"""
def __init__(self):
self.data = [0x00] * 50
self.data = [0x00] * 51
# all input reports are prepended with 0xA1
self.data[0] = 0xA1
self.subcommand_is_set = False
self.is_writing = None
def clear_sub_command(self):
for i in range(14, 51):
self.data[i] = 0x00
self.subcommand_is_set = False
def set_input_report_id(self, _id):
"""
:param _id: e.g. 0x21 Standard input reports used for sub command replies
@@ -30,11 +41,11 @@ class InputReport:
# battery level + connection info
self.data[3] = 0x8E
def set_button_status(self):
def set_button_status(self, button_status: ButtonState):
"""
TODO
Sets the button status bytes
"""
self.data[4:7] = [0x84, 0x00, 0x12]
self.data[4:7] = iter(button_status)
def set_left_analog_stick(self):
"""
@@ -74,8 +85,7 @@ class InputReport:
elif len(mac) != 6:
raise ValueError('Bluetooth mac address must consist of 6 bytes!')
# reply to sub command ID
self.data[15] = 0x02
self.reply_to_subcommand_id(0x02)
# sub command reply data
offset = 16
@@ -87,6 +97,7 @@ class InputReport:
self.data[offset + 11] = 0x01
def reply_to_subcommand_id(self, id_):
self.subcommand_is_set = True
self.data[15] = id_
def sub_0x08_shipment(self):
@@ -106,8 +117,17 @@ class InputReport:
blub = [0x00, 0xCC, 0x00, 0xEE, 0x00, 0xFF]
self.data[16:22] = blub
async def write(self, transport):
if self.is_writing is None:
self.is_writing = asyncio.ensure_future(transport.write(self))
await self.is_writing
self.is_writing = None
def __bytes__(self):
if self.subcommand_is_set:
return bytes(self.data)
else:
return bytes(self.data[:15])
class SubCommand(Enum):
@@ -127,6 +147,8 @@ class OutputReport:
self.data = data
def get_sub_command(self):
if len(self.data) < 12:
return None, None
try:
return self.data[11], SubCommand(self.data[11])
except ValueError:
+69 -30
View File
@@ -5,7 +5,7 @@ import socket
import logging_default as log
import utils
from buttons import Buttons
from controller_state import ButtonState, ControllerState
from device import HidDevice
from protocol import controller_protocol_factory, Controller
from report import InputReport
@@ -64,37 +64,80 @@ async def send_empty_input_reports(transport):
await asyncio.sleep(1)
async def mash_buttons(transport):
async def button_push(controller_state, button, sec=0.1):
button_state = ButtonState()
# push button
getattr(button_state, button)()
# send report
controller_state.set_button_state(button_state)
await controller_state.send()
await asyncio.sleep(sec)
# release button
getattr(button_state, button)()
# send report
controller_state.set_button_state(button_state)
await controller_state.send()
async def test_controller_buttons(controller_state: ControllerState):
"""
Goes to the "Test Controller Buttons" menu and presses all buttons
"""
await controller_state.connect()
# We assume we are in the "Change Grip/Order" menu of the switch
await button_push(controller_state, 'home')
# wait for the animation
await asyncio.sleep(1)
# Goto settings
await button_push(controller_state, 'down')
await asyncio.sleep(0.3)
for _ in range(4):
await button_push(controller_state, 'right')
await asyncio.sleep(0.3)
await button_push(controller_state, 'a')
await asyncio.sleep(0.3)
# go all the way down
await button_push(controller_state, 'down', sec=3)
await asyncio.sleep(0.3)
# goto "Controllers and Sensors" menu
for _ in range(2):
await button_push(controller_state, 'up')
await asyncio.sleep(0.3)
await button_push(controller_state, 'right')
await asyncio.sleep(0.3)
# go all the way down
await button_push(controller_state, 'down', sec=3)
await asyncio.sleep(0.3)
# goto "Test Input Devices" menu
await button_push(controller_state, 'up')
await asyncio.sleep(0.3)
await button_push(controller_state, 'a')
await asyncio.sleep(0.3)
# goto "Test Controller Buttons" menu
await button_push(controller_state, 'a')
await asyncio.sleep(0.3)
# push all buttons
button_list = ['y', 'x', 'b', 'a', 'r', 'zr',
'minus', 'plus', 'r_stick', 'l_stick',
'down', 'up', 'right', 'left', 'l', 'zl']
report = InputReport()
report.set_input_report_id(0x21)
report.set_misc()
buttons = Buttons()
for i in range(10):
for button in button_list:
logger.info(f'Pressing Button {button}...')
# push button
getattr(buttons, button)()
# send report
report.data[4:7] = buttons.to_list()
await transport.write(report)
await button_push(controller_state, button)
await asyncio.sleep(0.1)
# release button
getattr(buttons, button)()
# send report
report.data[4:7] = buttons.to_list()
await transport.write(report)
await asyncio.sleep(0.3)
async def main():
transport, protocol = await create_hid_server(controller_protocol_factory(Controller.PRO_CONTROLLER), 17, 19)
@@ -108,12 +151,8 @@ async def main():
except asyncio.CancelledError:
pass
await asyncio.sleep(20)
await test_controller_buttons(ControllerState(transport, protocol))
await mash_buttons(transport)
# stop communication after some time
await asyncio.sleep(60)
logger.info('Stopping communication...')
await transport.close()
+4 -1
View File
@@ -33,7 +33,7 @@ class L2CAP_Transport(asyncio.Transport):
await self._is_reading.wait()
data = await self._loop.sock_recv(self._sock, self._read_buffer_size)
logger.debug(f'received "{data}"')
logger.debug(f'received "{list(map(hex, list(data)))}"')
await self._protocol.report_received(data, self._sock.getpeername())
def is_reading(self) -> bool:
@@ -62,6 +62,9 @@ class L2CAP_Transport(asyncio.Transport):
data.set_timer(self._input_report_timer)
self._input_report_timer = (self._input_report_timer + 1) % 256
_bytes = bytes(data)
if data.subcommand_is_set:
data.clear_sub_command()
else:
raise ValueError('data must be bytes or InputReport')
+8
View File
@@ -5,6 +5,14 @@ import re
logger = logging.getLogger(__name__)
def get_bit(value, n):
return (value >> n & 1) != 0
def flip_bit(value, n):
return value ^ (1 << n)
async def run_system_command(cmd):
proc = await asyncio.create_subprocess_shell(
cmd,