diff --git a/pycoax/coax/protocol.py b/pycoax/coax/protocol.py index d3c0e0d..d2eb03b 100644 --- a/pycoax/coax/protocol.py +++ b/pycoax/coax/protocol.py @@ -4,6 +4,7 @@ coax.protocol """ from enum import Enum +from more_itertools import chunked from .exceptions import ProtocolError from .parity import odd_parity @@ -438,20 +439,42 @@ def _execute_read_command(interface, command_word, response_length=1, return unpack_data_words(response) if unpack else response -def _execute_write_command(interface, command_word, data=None, **kwargs): +def _execute_write_command(interface, command_word, data=None, + jumbo_write_strategy=None, **kwargs): """Execute a standard write command.""" - data_words = [] - transmit_repeat_count = None + length = 1 if isinstance(data, tuple): - data_words = pack_data_words(data[0]) - transmit_repeat_count = data[1] + length += len(data[0]) * data[1] elif data is not None: - data_words = pack_data_words(data) + length += len(data) - response = interface.transmit_receive([command_word, *data_words], - transmit_repeat_count, - receive_length=1, **kwargs) + max_length = 1024 + + if jumbo_write_strategy == 'split' and length > max_length: + if isinstance(data, tuple): + data_words = pack_data_words(data[0]) * data[1] + else: + data_words = pack_data_words(data) + + for words in chunked([command_word, *data_words], max_length): + _execute_write(interface, words, None, **kwargs) + else: + data_words = [] + transmit_repeat_count = None + + if isinstance(data, tuple): + data_words = pack_data_words(data[0]) + transmit_repeat_count = data[1] + elif data is not None: + data_words = pack_data_words(data) + + _execute_write(interface, [command_word, *data_words], + transmit_repeat_count, **kwargs) + +def _execute_write(interface, words, transmit_repeat_count, **kwargs): + response = interface.transmit_receive(words, transmit_repeat_count, + receive_length=1, **kwargs) if len(response) != 1: raise ProtocolError(f'Expected 1 word response: {response}') diff --git a/pycoax/requirements.txt b/pycoax/requirements.txt index 926a4ab..3b0bf4d 100644 --- a/pycoax/requirements.txt +++ b/pycoax/requirements.txt @@ -1,2 +1,3 @@ +more-itertools==8.7.0 pyserial==3.5 sliplib==0.6.2 diff --git a/pycoax/setup.py b/pycoax/setup.py index 0d304bf..eb1e0d7 100644 --- a/pycoax/setup.py +++ b/pycoax/setup.py @@ -21,7 +21,7 @@ setup( author='Andrew Kay', author_email='projects@ajk.me', packages=['coax'], - install_requires=['pyserial==3.5', 'sliplib==0.6.2'], + install_requires=['more-itertools==8.7.0', 'pyserial==3.5', 'sliplib==0.6.2'], long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', classifiers=[ diff --git a/pycoax/tests/test_protocol.py b/pycoax/tests/test_protocol.py index 112f981..8eb2789 100644 --- a/pycoax/tests/test_protocol.py +++ b/pycoax/tests/test_protocol.py @@ -198,13 +198,37 @@ class ExecuteWriteCommandTestCase(unittest.TestCase): self.interface = Mock() def test(self): + for jumbo_write_strategy in [None, 'split']: + with self.subTest(jumbo_write_strategy=jumbo_write_strategy): + # Arrange + command_word = pack_command_word(Command.WRITE_DATA) + + self.interface.transmit_receive = Mock(return_value=[0b0000000000]) + + # Act + _execute_write_command(self.interface, command_word, bytes.fromhex('de ad be ef'), jumbo_write_strategy=jumbo_write_strategy) + + # Assert + self.interface.transmit_receive.assert_called_once_with([0x0031, 0x037a, 0x02b4, 0x02fa, 0x03bc], None, receive_length=1) + + def test_jumbo_write_split_strategy(self): # Arrange command_word = pack_command_word(Command.WRITE_DATA) self.interface.transmit_receive = Mock(return_value=[0b0000000000]) - # Act and assert - _execute_write_command(self.interface, command_word, bytes.fromhex('de ad be ef')) + data = (bytes.fromhex('01') * 1023) + (bytes.fromhex('02') * 32) + + # Act + _execute_write_command(self.interface, command_word, data, jumbo_write_strategy='split') + + # Assert + self.assertEqual(self.interface.transmit_receive.call_count, 2) + + call_args_list = self.interface.transmit_receive.call_args_list + + self.assertEqual(len(call_args_list[0][0][0]), 1024) + self.assertEqual(len(call_args_list[1][0][0]), 32) def test_unexpected_response_length(self): # Arrange