diff --git a/scripts/dump_spi_flash.py b/scripts/dump_spi_flash.py index b92fd62..4eb24dc 100644 --- a/scripts/dump_spi_flash.py +++ b/scripts/dump_spi_flash.py @@ -6,8 +6,9 @@ from contextlib import suppress import hid -from joycontrol import logging_default as log +from joycontrol import logging_default as log, utils from joycontrol.report import OutputReport, InputReport, SubCommand +from joycontrol.utils import AsyncHID logger = logging.getLogger(__name__) @@ -17,23 +18,6 @@ PRODUCT_ID_JR = 8199 PRODUCT_ID_PC = 8201 -class AsyncHID(hid.Device): - def __init__(self, *args, loop=asyncio.get_event_loop(), **kwargs): - super().__init__(*args, **kwargs) - self._loop = loop - - self._write_lock = asyncio.Lock() - self._read_lock = asyncio.Lock() - - async def read(self, size, timeout=None): - async with self._read_lock: - return await self._loop.run_in_executor(None, hid.Device.read, self, size, timeout) - - async def write(self, data): - async with self._write_lock: - return await self._loop.run_in_executor(None, hid.Device.write, self, data) - - class DataReader: def __init__(self): self.pending_request = None @@ -160,13 +144,9 @@ async def _main(args, loop): logger.info(f'Found controller "{controller}".') - if args.output: - with open(args.output, 'wb') as output: - with AsyncHID(path=controller['path'], loop=loop) as hid_controller: - await dump_spi_flash(hid_controller, output_file=output) - else: + with utils.get_output(path=args.output, open_flags='wb', default=None) as output: with AsyncHID(path=controller['path'], loop=loop) as hid_controller: - await dump_spi_flash(hid_controller) + await dump_spi_flash(hid_controller, output_file=output) if __name__ == '__main__': @@ -193,5 +173,3 @@ if __name__ == '__main__': finally: loop.stop() loop.close() - -