diff --git a/pycoax/coax/__init__.py b/pycoax/coax/__init__.py index 31433d0..b144d46 100644 --- a/pycoax/coax/__init__.py +++ b/pycoax/coax/__init__.py @@ -2,6 +2,8 @@ from .__about__ import __version__ from .interface1 import Interface1 +from .serial_interface import SerialInterface + from .protocol import ( PollAction, PollResponse, diff --git a/pycoax/coax/interface.py b/pycoax/coax/interface.py new file mode 100644 index 0000000..77ad984 --- /dev/null +++ b/pycoax/coax/interface.py @@ -0,0 +1,19 @@ +""" +coax.interface +~~~~~~~~~~~~~~ +""" + +class Interface: + def reset(self): + raise NotImplementedError + + def transmit(self, words, repeat_count=None, repeat_offset=1): + raise NotImplementedError + + def receive(self, length=None, timeout=None): + raise NotImplementedError + + def transmit_receive(self, transmit_words, transmit_repeat_count=None, + transmit_repeat_offset=1, receive_length=None, + receive_timeout=None): + raise NotImplementedError diff --git a/pycoax/coax/serial_interface.py b/pycoax/coax/serial_interface.py new file mode 100644 index 0000000..0688bba --- /dev/null +++ b/pycoax/coax/serial_interface.py @@ -0,0 +1,176 @@ +""" +coax.serial_interface +~~~~~~~~~~~~~~~~~~~~~ +""" + +import struct +from sliplib import SlipWrapper, ProtocolError + +from .interface import Interface +from .exceptions import InterfaceError, InterfaceTimeout, ReceiveError, ReceiveTimeout + +class SerialInterface(Interface): + def __init__(self, serial): + if serial is None: + raise ValueError('Serial port is required') + + self.serial = serial + + self.slip_serial = SlipSerial(self.serial) + + def reset(self): + original_serial_timeout = self.serial.timeout + + self.serial.reset_input_buffer() + + self._write_message(bytes([0x01])) + + self.serial.timeout = 5 + + try: + message = self._read_message() + finally: + self.serial.timeout = original_serial_timeout + + if message[0] != 0x01: + raise _convert_error(message) + + if len(message) != 4: + raise InterfaceError('Invalid reset response') + + (major, minor, patch) = struct.unpack('BBB', message[1:]) + + return '{}.{}.{}'.format(major, minor, patch) + + def transmit(self, words, repeat_count=None, repeat_offset=1): + message = bytes([0x02]) + + message += _pack_transmit_header(len(words), repeat_count, repeat_offset) + message += _pack_transmit_data(words) + + self._write_message(message) + + message = self._read_message() + + if message[0] != 0x01: + raise _convert_error(message) + + def receive(self, length=None, timeout=None): + timeout_milliseconds = self._calculate_timeout_milliseconds(timeout) + + message = bytes([0x04]) + + message += _pack_receive_header(length, timeout_milliseconds) + + self._write_message(message) + + message = self._read_message() + + if message[0] != 0x01: + raise _convert_error(message) + + return _unpack_receive_response(message[1:]) + + def transmit_receive(self, transmit_words, transmit_repeat_count=None, + transmit_repeat_offset=1, receive_length=None, + receive_timeout=None): + timeout_milliseconds = self._calculate_timeout_milliseconds(receive_timeout) + + message = bytes([0x06]) + + message += _pack_transmit_header(len(transmit_words), transmit_repeat_count, + transmit_repeat_offset) + message += _pack_transmit_data(transmit_words) + message += _pack_receive_header(receive_length, timeout_milliseconds) + + self._write_message(message) + + message = self._read_message() + + if message[0] != 0x01: + raise _convert_error(message) + + return _unpack_receive_response(message[1:]) + + def _calculate_timeout_milliseconds(self, timeout): + milliseconds = 0 + + if timeout: + if self.serial.timeout and timeout > self.serial.timeout: + raise ValueError('Timeout cannot be greater than serial timeout') + + milliseconds = int(timeout * 1000) + + return milliseconds + + def _read_message(self): + try: + message = self.slip_serial.recv_msg() + except ProtocolError: + raise InterfaceError('SLIP protocol error') + + if len(message) < 4: + raise InterfaceError('Invalid response message') + + (length,) = struct.unpack(">H", message[:2]) + + if length != len(message) - 4: + raise InterfaceError('Response message length mismatch') + + if length < 1: + raise InterfaceError('Empty response message') + + return message[2:-2] + + def _write_message(self, message): + self.slip_serial.send_msg(struct.pack(">H", len(message)) + message + + struct.pack(">H", 0)) + +def _pack_transmit_header(length, repeat_count, repeat_offset): + repeat = ((repeat_offset << 15) | repeat_count) if repeat_count else 0 + + return struct.pack(">HH", length, repeat) + +def _pack_transmit_data(words): + bytes_ = bytearray() + + for word in words: + bytes_ += struct.pack(">H", word) + + return bytes_ + +def _pack_receive_header(length, timeout_milliseconds): + return struct.pack(">HH", length or 0, timeout_milliseconds) + +def _unpack_receive_response(message): + pass + +def _convert_error(message): + # TODO + + return InterfaceError('Unknown error') + +class SlipSerial(SlipWrapper): + """sliplib wrapper for pySerial.""" + + def send_bytes(self, packet): + """Sends a packet over the serial port.""" + self.stream.write(packet) + self.stream.flush() + + def recv_bytes(self): + """Receive data from the serial port.""" + if self.stream.closed: + return b'' + + count = self.stream.in_waiting + + if count: + return self.stream.read(count) + + byte = self.stream.read(1) + + if byte == b'': + raise InterfaceTimeout() + + return byte diff --git a/pycoax/tests/test_serial_interface.py b/pycoax/tests/test_serial_interface.py new file mode 100644 index 0000000..56ded69 --- /dev/null +++ b/pycoax/tests/test_serial_interface.py @@ -0,0 +1,124 @@ +import unittest +from unittest.mock import Mock +import sliplib + +import context + +from coax import SerialInterface, InterfaceError, ReceiveError, ReceiveTimeout + +class SerialInterfaceResetTestCase(unittest.TestCase): + def setUp(self): + self.serial = Mock() + + self.serial.timeout = None + + self.interface = SerialInterface(self.serial) + + self.interface._write_message = Mock() + self.interface._read_message = Mock(return_value=bytes.fromhex('01 01 02 03')) + + def test_message_is_sent(self): + # Act + self.interface.reset() + + # Assert + self.interface._write_message.assert_called_with(bytes.fromhex('01')) + + def test_version_is_formatted_correctly(self): + self.assertEqual(self.interface.reset(), '1.2.3') + + def test_timeout_is_restored_after_reset(self): + # Arrange + self.serial.timeout = 123 + + # Act + self.interface.reset() + + # Assert + self.assertEqual(self.serial.timeout, 123) + + def test_invalid_message_length_is_handled_correctly(self): + # Arrange + self.interface._read_message = Mock(return_value=bytes.fromhex('01 01')) + + # Act and assert + with self.assertRaisesRegex(InterfaceError, 'Invalid reset response'): + self.interface.reset() + + def test_error_is_handled_correctly(self): + # Arrange + self.interface._read_message = Mock(return_value=bytes.fromhex('02 01')) + + # Act and assert + with self.assertRaisesRegex(InterfaceError, 'Invalid request message'): + self.interface.reset() + +# TODO... + +class SerialInterfaceReadMessageTestCase(unittest.TestCase): + def setUp(self): + self.serial = Mock() + + self.interface = SerialInterface(self.serial) + + self.interface.slip_serial = Mock() + + def test(self): + # Arrange + self.interface.slip_serial.recv_msg = Mock(return_value=bytes.fromhex('00 04 01 02 03 04 00 00')) + + # Act + message = self.interface._read_message() + + # Assert + self.assertEqual(message, bytes.fromhex('01 02 03 04')) + + def test_protocol_error_is_handled_correctly(self): + # Arrange + self.interface.slip_serial.recv_msg = Mock(side_effect=sliplib.ProtocolError) + + # Act and assert + with self.assertRaisesRegex(InterfaceError, 'SLIP protocol error'): + self.interface._read_message() + + def test_invalid_message_length_is_handled_correctly(self): + # Arrange + self.interface.slip_serial.recv_msg = Mock(return_value=bytes.fromhex('00')) + + # Act and assert + with self.assertRaisesRegex(InterfaceError, 'Invalid response message'): + self.interface._read_message() + + def test_message_length_mismatch_is_handled_correctly(self): + # Arrange + self.interface.slip_serial.recv_msg = Mock(return_value=bytes.fromhex('00 05 01 02 03 04 00 00')) + + # Act and assert + with self.assertRaisesRegex(InterfaceError, 'Response message length mismatch'): + self.interface._read_message() + + def test_empty_message_is_handled_correctly(self): + # Arrange + self.interface.slip_serial.recv_msg = Mock(return_value=bytes.fromhex('00 00 00 00')) + + # Act and assert + with self.assertRaisesRegex(InterfaceError, 'Empty response message'): + self.interface._read_message() + +class SerialInterfaceWriteMessageTestCase(unittest.TestCase): + def setUp(self): + self.serial = Mock() + + self.interface = SerialInterface(self.serial) + + self.interface.slip_serial = Mock() + + def test(self): + # Act + self.interface._write_message(bytes.fromhex('01 02 03 04')) + + # Assert + self.interface.slip_serial.send_msg.assert_called_with(bytes.fromhex('00 04 01 02 03 04 00 00')) + +if __name__ == '__main__': + unittest.main()