diff --git a/device.py b/device.py index e69de29..2eeabda 100644 --- a/device.py +++ b/device.py @@ -0,0 +1,54 @@ +import logging +import uuid + +import dbus + +import utils + +logger = logging.getLogger(__name__) + + +class HidDevice: + _HID_UUID = '00001124-0000-1000-8000-00805f9b34fb' + _HID_PATH = '/bluez/switch/hid' + + PRO_CONTROLLER = 'Pro Controller' + JOYCON_R = 'Joy-Con (R)' + JOYCON_L = 'Joy-Con (L)' + + def __init__(self): + self._uuid = str(uuid.uuid4()) + + # Setting up dbus to advertise the service record + bus = dbus.SystemBus() + obj = bus.get_object('org.bluez', '/org/bluez/hci0') + self.adapter = dbus.Interface(obj, 'org.bluez.Adapter1') + self.properties = dbus.Interface(self.adapter, 'org.freedesktop.DBus.Properties') + + def discoverable(self, boolean=True): + #self.properties.Set(self.adapter.dbus_interface, 'Powered', True) + self.properties.Set(self.adapter.dbus_interface, 'Discoverable', boolean) + + async def set_class(self, cls=0x002508): + """ + :param cls: default 0x002508 (Gamepad/joystick device class) + """ + logger.info(f'setting device class to {cls}...') + await utils.run_system_command(f'hciconfig hci0 class {cls}') + + async def set_name(self, name: str): + logger.info(f'setting device name to {name}...') + await utils.run_system_command(f'hciconfig hci0 name "{name}"') + + def register_sdp_record(self, record_path): + with open(record_path) as record: + opts = { + 'ServiceRecord': record.read(), + 'Role': 'server', + 'Service': self._HID_UUID, + 'RequireAuthentication': False, + 'RequireAuthorization': False + } + bus = dbus.SystemBus() + manager = dbus.Interface(bus.get_object("org.bluez", "/org/bluez"), "org.bluez.ProfileManager1") + manager.RegisterProfile(self._HID_PATH, self._uuid, opts) \ No newline at end of file diff --git a/protocol.py b/protocol.py index 14df895..590602a 100644 --- a/protocol.py +++ b/protocol.py @@ -1,15 +1,9 @@ import asyncio import enum import logging -import socket from asyncio import BaseTransport, BaseProtocol from typing import Optional, Union, Tuple, Text -import logging_default as log -import utils -from device import HidDevice -from transport import L2CAP_Transport - logger = logging.getLogger(__name__) @@ -42,6 +36,12 @@ class ControllerProtocol(BaseProtocol): def __init__(self, controller: Controller): self.transport = None + self._data_received = asyncio.Event() + + async def wait_for_output_report(self): + self._data_received.clear() + await self._data_received.wait() + def connection_made(self, transport: BaseTransport) -> None: logger.debug('Connection established.') self.transport = transport @@ -50,7 +50,7 @@ class ControllerProtocol(BaseProtocol): raise NotImplementedError() async def report_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: - raise NotImplementedError() + self._data_received.set() def error_received(self, exc: Exception) -> None: raise NotImplementedError() diff --git a/report.py b/report.py index e69de29..d17ede7 100644 --- a/report.py +++ b/report.py @@ -0,0 +1,71 @@ +from enum import Enum, auto + + +class InputReport: + def __init__(self): + self.data = [0x00] * 50 + # all input reports are prepended with 0xA1 + self.data[0] = 0xA1 + + def set(self, input_report_id, timer=0x00): + self.data[1] = input_report_id + self.data[2] = timer % 256 + # battery level + connection info + self.data[3] = 0x8E + + # Todo: Button status, analog stick data, vibrator input + + # ACK byte for subcmd reply + self.data[14] = 0x82 + + # Reply-to subcommand ID + self.data[14] = 0x02 + + def sub_0x2_device_info(self, mac, fm_version=(0x03, 0x48), controller=0x01): + """ + Sub command 0x02 request device info response. + + :param mac: Controller MAC address in Big Endian (6 Bytes) + :param fm_version: TODO + :param controller: 1=Left Joy-Con, 2=Right Joy-Con, 3=Pro Controller + """ + if len(fm_version) != 2: + raise ValueError('Firmware version must consist of 2 bytes!') + elif len(mac) != 6: + raise ValueError('Bluetooth mac address must consist of 6 bytes!') + + # reply to sub command ID + self.data[14] = 0x02 + + # sub command reply data + offset = 15 + self.data[offset: offset + 1] = fm_version + self.data[offset + 2] = controller + self.data[offset + 3] = 0x02 + self.data[offset + 4: offset + 9] = mac + self.data[offset + 10] = 0x01 + self.data[offset + 11] = 0x01 + + def __bytes__(self): + return bytes(self.data) + + +class SubCommand(Enum): + REQUEST_DEVICE_INFO = auto() + NOT_IMPLEMENTED = auto() + + +class OutputReport: + def __init__(self, data): + if data[0] != 0xA2: + raise ValueError('Output reports must start with 0xA2') + self.data = data + + def sub_command(self): + if self.data[11] == 0x02: + return SubCommand.REQUEST_DEVICE_INFO + else: + return None + + def __bytes__(self): + return bytes(self.data) diff --git a/run_and_pair_switch.py b/run_and_pair_switch.py index e69de29..6bc3969 100644 --- a/run_and_pair_switch.py +++ b/run_and_pair_switch.py @@ -0,0 +1,84 @@ +import asyncio +import logging +import socket + +import logging_default as log +import utils +from device import HidDevice +from protocol import controller_protocol_factory, Controller +from report import InputReport +from transport import L2CAP_Transport + +logger = logging.getLogger(__name__) + + +async def create_hid_server(protocol_factory, ctl_psm, itr_psm): + ctl_sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + itr_sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + + # for some reason we need to restart bluetooth here, the Switch does not connect to the sockets if we don't... + logger.info('Restarting bluetooth service...') + await utils.run_system_command('systemctl restart bluetooth.service') + await asyncio.sleep(1) + + ctl_sock.setblocking(False) + itr_sock.setblocking(False) + + ctl_sock.bind((socket.BDADDR_ANY, ctl_psm)) + itr_sock.bind((socket.BDADDR_ANY, itr_psm)) + + ctl_sock.listen(1) + itr_sock.listen(1) + + hid = HidDevice() + # setting bluetooth adapter name and class to the device we wish to emulate + await hid.set_name(HidDevice.JOYCON_L) + await hid.set_class() + + logger.info('Advertising the Bluetooth SDP record...') + hid.register_sdp_record('profile/sdp_record_hid_pro.xml') + hid.discoverable() + + loop = asyncio.get_event_loop() + client_ctl, address = await loop.sock_accept(ctl_sock) + logger.info(f'Accepted connection at psm {ctl_psm} from {address}') + client_itr, address = await loop.sock_accept(itr_sock) + logger.info(f'Accepted connection at psm {itr_psm} from {address}') + + protocol = protocol_factory() + transport = L2CAP_Transport(asyncio.get_event_loop(), protocol, client_itr, address, 50) + protocol.connection_made(transport) + + return transport, protocol + + +async def send_empty_input_reports(transport): + report = InputReport() + + while True: + await transport.write(bytes(report)) + await asyncio.sleep(1) + + +async def main(): + transport, protocol = await create_hid_server(controller_protocol_factory(Controller.JOYCON_L), 17, 19) + + future = asyncio.ensure_future(send_empty_input_reports(transport)) + + await protocol.wait_for_output_report() + + future.cancel() + try: + await future + except asyncio.CancelledError: + pass + + await transport.close() + + +if __name__ == '__main__': + # setup logging + log.configure() + + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) \ No newline at end of file diff --git a/transport.py b/transport.py index e69de29..e19217d 100644 --- a/transport.py +++ b/transport.py @@ -0,0 +1,84 @@ +import asyncio +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class L2CAP_Transport(asyncio.Transport): + def __init__(self, loop, protocol, l2cap_socket, client_addr, read_buffer_size) -> None: + self._loop = loop + self._protocol = protocol + + self._sock = l2cap_socket + self._client_addr = client_addr + self._read_buffer_size = read_buffer_size + + self._read_thread = asyncio.ensure_future(self._read()) + + self._is_closing = False + self._is_reading = asyncio.Event() + self._is_reading.set() + + async def _read(self): + try: + while True: + + await self._is_reading.wait() + + data = await self._loop.sock_recv(self._sock, self._read_buffer_size) + logger.debug(f'received "{data}') + await self._protocol.report_received(data, self._client_addr) + except asyncio.CancelledError: + # reading has been stopped + pass + + def is_reading(self) -> bool: + return self._is_reading.is_set() + + def pause_reading(self) -> None: + self._is_reading.clear() + + def resume_reading(self) -> None: + self._is_reading.set() + + def set_read_buffer_size(self, size): + self._read_buffer_size = size + + def set_write_buffer_limits(self, high: int = ..., low: int = ...) -> None: + super().set_write_buffer_limits(high, low) + + def get_write_buffer_size(self) -> int: + return super().get_write_buffer_size() + + async def write(self, data: Any) -> None: + logger.debug(f'sending "{data}"') + await self._loop.sock_sendall(self._sock, data) + + def abort(self) -> None: + super().abort() + + def get_extra_info(self, name: Any, default: Any = ...) -> Any: + return super().get_extra_info(name, default) + + def is_closing(self) -> bool: + return self._is_closing + + async def close(self): + """ + Stops socket reader and closes socket + """ + self._is_closing = True + self._read_thread.cancel() + # wait for reader to cancel + try: + await self._read_thread + except asyncio.CancelledError: + pass + self._sock.close() + + def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: + self._protocol = protocol + + def get_protocol(self) -> asyncio.BaseProtocol: + return self._protocol \ No newline at end of file diff --git a/utils.py b/utils.py index e69de29..4c951f9 100644 --- a/utils.py +++ b/utils.py @@ -0,0 +1,21 @@ +import asyncio +import logging + +logger = logging.getLogger(__name__) + + +async def run_system_command(cmd): + proc = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + + stdout, stderr = await proc.communicate() + + logger.debug(f'[{cmd!r} exited with {proc.returncode}]') + if stdout: + logger.debug(f'[stdout]\n{stdout.decode()}') + if stderr: + logger.debug(f'[stderr]\n{stderr.decode()}') + + return proc.returncode