lowobservable.oec/oec/controller.py

385 lines
13 KiB
Python

"""
oec.controller
~~~~~~~~~~~~~~
"""
from enum import Enum
import time
import logging
import selectors
from concurrent import futures
from itertools import groupby
from coax import InterfaceFeature, Poll, PollAck, KeystrokePollResponse, \
ReceiveTimeout, ReceiveError, ProtocolError
from coax.multiplexer import PORT_MAP_3299
from .device import address_commands, format_address, UnsupportedDeviceError
from .keyboard import Key
from .session import SessionDisconnectedError
class SessionState(Enum):
"""Session state."""
STARTING = 1
ACTIVE = 2
TERMINATING = 3
class Controller:
"""The controller."""
def __init__(self, interface, create_device, create_session):
self.logger = logging.getLogger(__name__)
self.interface = interface
self.running = False
self.create_device = create_device
self.create_session = create_session
self.devices = { }
self.detatched_device_poll_queue = []
self.sessions = { }
self.session_selector = None
self.session_executor = None
# Target time between POLL commands in seconds when a device is attached or
# no device is attached.
self.attached_poll_period = 1 / 15
self.detatched_poll_period = 1 / 2
# Maximum number of POLL commands to execute, per attached device, per run
# loop iteration. If all attached devices respond with TT/AR the run loop
# iteration will exit without reaching this maximum depth.
#
# This is an effort to improve the keystroke responsiveness.
self.poll_depth = 3
self.last_attached_poll_time = None
self.last_detatched_poll_time = None
def run(self):
"""Run the controller."""
self.running = True
self.session_selector = selectors.DefaultSelector()
self.session_executor = futures.ThreadPoolExecutor()
self.logger.info('Controller started')
while self.running:
self._run_loop()
self.session_executor.shutdown(wait=True)
self.session_executor = None
for session in [session for (state, session) in self.sessions.values() if state == SessionState.ACTIVE]:
self._terminate_session(session, blocking=True)
self.session_selector.close()
self.session_selector = None
self.sessions.clear()
self.devices.clear()
self.detatched_device_poll_queue.clear()
self.logger.info('Controller stopped')
def stop(self):
"""Stop the controller."""
self.running = False
def _run_loop(self):
poll_delay = self._calculate_poll_delay()
# If POLLing is delayed, handle the host output, otherwise just sleep.
start_time = time.perf_counter()
if poll_delay > 0:
self._update_sessions(poll_delay)
poll_delay -= (time.perf_counter() - start_time)
if poll_delay > 0:
time.sleep(poll_delay)
# POLL devices.
self._poll_attached_devices()
self._poll_next_detatched_device()
def _update_sessions(self, duration):
start_time = time.perf_counter()
# Start any missing sessions.
for device_address in self.devices.keys() - self.sessions.keys():
self._start_session(self.devices[device_address])
sessions = { state: [(device_address, session) for (device_address, (_, session)) in group] for (state, group) in groupby(self.sessions.items(), lambda item: item[1][0]) }
# Handle started sessions.
started_sessions = []
for (device_address, future) in sessions.get(SessionState.STARTING, []):
if future.done():
session = future.result()
self.sessions[device_address] = (SessionState.ACTIVE, session)
self.session_selector.register(session, selectors.EVENT_READ)
started_sessions.append(session)
self.logger.info(f'Session started for device @ {format_address(self.interface, device_address)}')
# Handle terminated sessions.
for (device_address, future) in sessions.get(SessionState.TERMINATING, []):
if future.done():
del self.sessions[device_address]
self.logger.info(f'Session terminated for device @ {format_address(self.interface, device_address)}')
# Update the duration based on the time taken handling futures.
duration -= (time.perf_counter() - start_time)
# Update active sessions.
updated_sessions = set()
is_first_iteration = True
while duration > 0:
start_time = time.perf_counter()
sessions = set(self._select_sessions(duration))
# Handle host output from started sessions immediately as the telnet client
# buffer may contain commands that were buffered during negotiation. If we do
# not handle them here, we will have to wait for further commands to trigger
# the read select event.
#
# This ensures that messages such as "connection rejected, no available device"
# are shown on the terminal.
if is_first_iteration:
sessions.update(started_sessions)
if not sessions:
break
for session in sessions:
try:
if session.handle_host():
updated_sessions.add(session)
except SessionDisconnectedError:
updated_sessions.discard(session)
self._handle_session_disconnected(session)
duration -= (time.perf_counter() - start_time)
is_first_iteration = False
for session in updated_sessions:
session.render()
def _select_sessions(self, duration):
# The Windows selector will raise an error if there are no handles registered while
# other selectors may block for the provided duration.
if not self.session_selector.get_map():
return []
selected = self.session_selector.select(duration)
return [key.fileobj for (key, _) in selected]
def _start_session(self, device):
device_address = device.device_address
self.logger.info(f'Starting session for device @ {format_address(self.interface, device_address)}')
def start_session():
session = self.create_session(device)
session.start()
return session
future = self.session_executor.submit(start_session)
self.sessions[device_address] = (SessionState.STARTING, future)
def _terminate_session(self, session, blocking=False):
device_address = session.terminal.device_address
self.logger.info(f'Terminating session for device @ {format_address(self.interface, device_address)}')
self.session_selector.unregister(session)
def terminate_session():
session.terminate()
if blocking:
terminate_session()
del self.sessions[device_address]
else:
future = self.session_executor.submit(terminate_session)
self.sessions[device_address] = (SessionState.TERMINATING, future)
def _handle_session_disconnected(self, session):
self.logger.info('Session disconnected')
self._terminate_session(session)
def _poll_attached_devices(self):
self.last_attached_poll_time = time.perf_counter()
for _ in range(self.poll_depth):
devices = self.devices.values()
if not devices:
break
poll_commands = [address_commands(device.device_address, Poll(device.get_poll_action())) for device in devices]
poll_responses = list(zip(devices, self.interface.execute(poll_commands, receive_timeout_is_error=False)))
# Handle POLL responses.
handleable_poll_responses = [pair for pair in poll_responses if pair[1] is not None and not isinstance(pair[1], ReceiveTimeout)]
if handleable_poll_responses:
poll_ack_commands = [address_commands(device.device_address, PollAck()) for (device, _) in handleable_poll_responses]
self.interface.execute(poll_ack_commands)
for (device, poll_response) in handleable_poll_responses:
self._handle_poll_response(device, poll_response)
# Handle lost devices.
for (device, poll_response) in poll_responses:
if isinstance(poll_response, ReceiveTimeout):
self._handle_device_lost(device)
if not handleable_poll_responses:
break
def _poll_next_detatched_device(self):
if self.last_detatched_poll_time is not None and (time.perf_counter() - self.last_detatched_poll_time) < self.detatched_poll_period:
return
self.last_detatched_poll_time = time.perf_counter()
if not self.detatched_device_poll_queue:
self.detatched_device_poll_queue = list(self._get_detatched_device_addresses())
try:
device_address = self.detatched_device_poll_queue.pop(0)
except IndexError:
return
try:
poll_response = self.interface.execute(address_commands(device_address, Poll()))
except ReceiveTimeout:
return
except ReceiveError as error:
self.logger.warning(f'POLL detatched device @ {format_address(self.interface, device_address)} receive error: {error}')
return
except ProtocolError as error:
self.logger.warning(f'POLL detatched device @ {format_address(self.interface, device_address)} protocol error: {error}')
return
if poll_response:
self.interface.execute(address_commands(device_address, PollAck()))
self._handle_device_found(device_address, poll_response)
def _handle_device_found(self, device_address, poll_response):
self.logger.info(f'Found device @ {format_address(self.interface, device_address)}')
try:
device = self.create_device(self.interface, device_address, poll_response)
except UnsupportedDeviceError as error:
self.logger.error(f'Unsupported device @ {format_address(self.interface, device_address)}: {error}')
return
device.setup()
self.devices[device_address] = device
self.logger.info(f'Attached device @ {format_address(self.interface, device_address)}')
def _handle_device_lost(self, device):
device_address = device.device_address
self.logger.info(f'Lost device @ {format_address(self.interface, device_address)}')
if device_address in self.sessions:
(session_state, session) = self.sessions[device_address]
if session_state == SessionState.ACTIVE:
self._terminate_session(session)
del self.devices[device_address]
self.logger.info(f'Detached device @ {format_address(self.interface, device_address)}')
def _handle_poll_response(self, device, poll_response):
if isinstance(poll_response, KeystrokePollResponse):
self._handle_keystroke_poll_response(device, poll_response)
def _handle_keystroke_poll_response(self, terminal, poll_response):
device_address = terminal.device_address
scan_code = poll_response.scan_code
(key, modifiers, modifiers_changed) = terminal.keyboard.get_key(scan_code)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug((f'Keystroke detected: Scan Code = {scan_code}, '
f'Key = {key}, Modifiers = {modifiers}'))
# Update the status line if modifiers have changed.
if modifiers_changed:
terminal.display.status_line.write_keyboard_modifiers(modifiers)
if not key:
return
if key == Key.CURSOR_BLINK:
terminal.display.toggle_cursor_blink()
elif key == Key.ALT_CURSOR:
terminal.display.toggle_cursor_reverse()
elif key == Key.CLICKER:
terminal.keyboard.toggle_clicker()
elif device_address in self.sessions:
(session_state, session) = self.sessions[device_address]
if session_state == SessionState.ACTIVE:
session.handle_key(key, modifiers, scan_code)
session.render()
def _calculate_poll_delay(self):
if self.last_attached_poll_time is None:
return 0
return max((self.last_attached_poll_time + self.attached_poll_period) - time.perf_counter(), 0)
def _get_detatched_device_addresses(self):
attached_addresses = set(self.devices.keys())
# The 3299 is transparent, but if there is at least one device attached to a 3299
# port then we can assume there is a 3299 attached and if there is one device
# direct attached then we can assume there is not a 3299 attached.
is_3299_attached = any(attached_addresses.difference([None]))
is_3299_not_attached = (None in attached_addresses)
if is_3299_not_attached or InterfaceFeature.PROTOCOL_3299 not in self.interface.features:
addresses = [None]
elif is_3299_attached:
addresses = PORT_MAP_3299
else:
addresses = [None, *PORT_MAP_3299]
return filter(lambda address: address not in attached_addresses, addresses)