mirror of
https://github.com/lowobservable/oec.git
synced 2026-01-11 23:53:04 +00:00
385 lines
13 KiB
Python
385 lines
13 KiB
Python
"""
|
|
oec.controller
|
|
~~~~~~~~~~~~~~
|
|
"""
|
|
|
|
from enum import Enum
|
|
import time
|
|
import logging
|
|
import selectors
|
|
from concurrent import futures
|
|
from itertools import groupby
|
|
from coax import InterfaceFeature, Poll, PollAck, KeystrokePollResponse, \
|
|
ReceiveTimeout, ReceiveError, ProtocolError
|
|
from coax.multiplexer import PORT_MAP_3299
|
|
|
|
from .device import address_commands, format_address, UnsupportedDeviceError
|
|
from .keyboard import Key
|
|
from .session import SessionDisconnectedError
|
|
|
|
class SessionState(Enum):
|
|
"""Session state."""
|
|
|
|
STARTING = 1
|
|
ACTIVE = 2
|
|
TERMINATING = 3
|
|
|
|
class Controller:
|
|
"""The controller."""
|
|
|
|
def __init__(self, interface, create_device, create_session):
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
self.interface = interface
|
|
self.running = False
|
|
|
|
self.create_device = create_device
|
|
self.create_session = create_session
|
|
|
|
self.devices = { }
|
|
self.detatched_device_poll_queue = []
|
|
|
|
self.sessions = { }
|
|
self.session_selector = None
|
|
self.session_executor = None
|
|
|
|
# Target time between POLL commands in seconds when a device is attached or
|
|
# no device is attached.
|
|
self.attached_poll_period = 1 / 15
|
|
self.detatched_poll_period = 1 / 2
|
|
|
|
# Maximum number of POLL commands to execute, per attached device, per run
|
|
# loop iteration. If all attached devices respond with TT/AR the run loop
|
|
# iteration will exit without reaching this maximum depth.
|
|
#
|
|
# This is an effort to improve the keystroke responsiveness.
|
|
self.poll_depth = 3
|
|
|
|
self.last_attached_poll_time = None
|
|
self.last_detatched_poll_time = None
|
|
|
|
def run(self):
|
|
"""Run the controller."""
|
|
self.running = True
|
|
|
|
self.session_selector = selectors.DefaultSelector()
|
|
self.session_executor = futures.ThreadPoolExecutor()
|
|
|
|
self.logger.info('Controller started')
|
|
|
|
while self.running:
|
|
self._run_loop()
|
|
|
|
self.session_executor.shutdown(wait=True)
|
|
|
|
self.session_executor = None
|
|
|
|
for session in [session for (state, session) in self.sessions.values() if state == SessionState.ACTIVE]:
|
|
self._terminate_session(session, blocking=True)
|
|
|
|
self.session_selector.close()
|
|
|
|
self.session_selector = None
|
|
|
|
self.sessions.clear()
|
|
|
|
self.devices.clear()
|
|
self.detatched_device_poll_queue.clear()
|
|
|
|
self.logger.info('Controller stopped')
|
|
|
|
def stop(self):
|
|
"""Stop the controller."""
|
|
self.running = False
|
|
|
|
def _run_loop(self):
|
|
poll_delay = self._calculate_poll_delay()
|
|
|
|
# If POLLing is delayed, handle the host output, otherwise just sleep.
|
|
start_time = time.perf_counter()
|
|
|
|
if poll_delay > 0:
|
|
self._update_sessions(poll_delay)
|
|
|
|
poll_delay -= (time.perf_counter() - start_time)
|
|
|
|
if poll_delay > 0:
|
|
time.sleep(poll_delay)
|
|
|
|
# POLL devices.
|
|
self._poll_attached_devices()
|
|
self._poll_next_detatched_device()
|
|
|
|
def _update_sessions(self, duration):
|
|
start_time = time.perf_counter()
|
|
|
|
# Start any missing sessions.
|
|
for device_address in self.devices.keys() - self.sessions.keys():
|
|
self._start_session(self.devices[device_address])
|
|
|
|
sessions = { state: [(device_address, session) for (device_address, (_, session)) in group] for (state, group) in groupby(self.sessions.items(), lambda item: item[1][0]) }
|
|
|
|
# Handle started sessions.
|
|
started_sessions = []
|
|
|
|
for (device_address, future) in sessions.get(SessionState.STARTING, []):
|
|
if future.done():
|
|
session = future.result()
|
|
|
|
self.sessions[device_address] = (SessionState.ACTIVE, session)
|
|
|
|
self.session_selector.register(session, selectors.EVENT_READ)
|
|
|
|
started_sessions.append(session)
|
|
|
|
self.logger.info(f'Session started for device @ {format_address(self.interface, device_address)}')
|
|
|
|
# Handle terminated sessions.
|
|
for (device_address, future) in sessions.get(SessionState.TERMINATING, []):
|
|
if future.done():
|
|
del self.sessions[device_address]
|
|
|
|
self.logger.info(f'Session terminated for device @ {format_address(self.interface, device_address)}')
|
|
|
|
# Update the duration based on the time taken handling futures.
|
|
duration -= (time.perf_counter() - start_time)
|
|
|
|
# Update active sessions.
|
|
updated_sessions = set()
|
|
|
|
is_first_iteration = True
|
|
|
|
while duration > 0:
|
|
start_time = time.perf_counter()
|
|
|
|
sessions = set(self._select_sessions(duration))
|
|
|
|
# Handle host output from started sessions immediately as the telnet client
|
|
# buffer may contain commands that were buffered during negotiation. If we do
|
|
# not handle them here, we will have to wait for further commands to trigger
|
|
# the read select event.
|
|
#
|
|
# This ensures that messages such as "connection rejected, no available device"
|
|
# are shown on the terminal.
|
|
if is_first_iteration:
|
|
sessions.update(started_sessions)
|
|
|
|
if not sessions:
|
|
break
|
|
|
|
for session in sessions:
|
|
try:
|
|
if session.handle_host():
|
|
updated_sessions.add(session)
|
|
except SessionDisconnectedError:
|
|
updated_sessions.discard(session)
|
|
|
|
self._handle_session_disconnected(session)
|
|
|
|
duration -= (time.perf_counter() - start_time)
|
|
is_first_iteration = False
|
|
|
|
for session in updated_sessions:
|
|
session.render()
|
|
|
|
def _select_sessions(self, duration):
|
|
# The Windows selector will raise an error if there are no handles registered while
|
|
# other selectors may block for the provided duration.
|
|
if not self.session_selector.get_map():
|
|
return []
|
|
|
|
selected = self.session_selector.select(duration)
|
|
|
|
return [key.fileobj for (key, _) in selected]
|
|
|
|
def _start_session(self, device):
|
|
device_address = device.device_address
|
|
|
|
self.logger.info(f'Starting session for device @ {format_address(self.interface, device_address)}')
|
|
|
|
def start_session():
|
|
session = self.create_session(device)
|
|
|
|
session.start()
|
|
|
|
return session
|
|
|
|
future = self.session_executor.submit(start_session)
|
|
|
|
self.sessions[device_address] = (SessionState.STARTING, future)
|
|
|
|
def _terminate_session(self, session, blocking=False):
|
|
device_address = session.terminal.device_address
|
|
|
|
self.logger.info(f'Terminating session for device @ {format_address(self.interface, device_address)}')
|
|
|
|
self.session_selector.unregister(session)
|
|
|
|
def terminate_session():
|
|
session.terminate()
|
|
|
|
if blocking:
|
|
terminate_session()
|
|
|
|
del self.sessions[device_address]
|
|
else:
|
|
future = self.session_executor.submit(terminate_session)
|
|
|
|
self.sessions[device_address] = (SessionState.TERMINATING, future)
|
|
|
|
def _handle_session_disconnected(self, session):
|
|
self.logger.info('Session disconnected')
|
|
|
|
self._terminate_session(session)
|
|
|
|
def _poll_attached_devices(self):
|
|
self.last_attached_poll_time = time.perf_counter()
|
|
|
|
for _ in range(self.poll_depth):
|
|
devices = self.devices.values()
|
|
|
|
if not devices:
|
|
break
|
|
|
|
poll_commands = [address_commands(device.device_address, Poll(device.get_poll_action())) for device in devices]
|
|
|
|
poll_responses = list(zip(devices, self.interface.execute(poll_commands, receive_timeout_is_error=False)))
|
|
|
|
# Handle POLL responses.
|
|
handleable_poll_responses = [pair for pair in poll_responses if pair[1] is not None and not isinstance(pair[1], ReceiveTimeout)]
|
|
|
|
if handleable_poll_responses:
|
|
poll_ack_commands = [address_commands(device.device_address, PollAck()) for (device, _) in handleable_poll_responses]
|
|
|
|
self.interface.execute(poll_ack_commands)
|
|
|
|
for (device, poll_response) in handleable_poll_responses:
|
|
self._handle_poll_response(device, poll_response)
|
|
|
|
# Handle lost devices.
|
|
for (device, poll_response) in poll_responses:
|
|
if isinstance(poll_response, ReceiveTimeout):
|
|
self._handle_device_lost(device)
|
|
|
|
if not handleable_poll_responses:
|
|
break
|
|
|
|
def _poll_next_detatched_device(self):
|
|
if self.last_detatched_poll_time is not None and (time.perf_counter() - self.last_detatched_poll_time) < self.detatched_poll_period:
|
|
return
|
|
|
|
self.last_detatched_poll_time = time.perf_counter()
|
|
|
|
if not self.detatched_device_poll_queue:
|
|
self.detatched_device_poll_queue = list(self._get_detatched_device_addresses())
|
|
|
|
try:
|
|
device_address = self.detatched_device_poll_queue.pop(0)
|
|
except IndexError:
|
|
return
|
|
|
|
try:
|
|
poll_response = self.interface.execute(address_commands(device_address, Poll()))
|
|
except ReceiveTimeout:
|
|
return
|
|
except ReceiveError as error:
|
|
self.logger.warning(f'POLL detatched device @ {format_address(self.interface, device_address)} receive error: {error}')
|
|
return
|
|
except ProtocolError as error:
|
|
self.logger.warning(f'POLL detatched device @ {format_address(self.interface, device_address)} protocol error: {error}')
|
|
return
|
|
|
|
if poll_response:
|
|
self.interface.execute(address_commands(device_address, PollAck()))
|
|
|
|
self._handle_device_found(device_address, poll_response)
|
|
|
|
def _handle_device_found(self, device_address, poll_response):
|
|
self.logger.info(f'Found device @ {format_address(self.interface, device_address)}')
|
|
|
|
try:
|
|
device = self.create_device(self.interface, device_address, poll_response)
|
|
except UnsupportedDeviceError as error:
|
|
self.logger.error(f'Unsupported device @ {format_address(self.interface, device_address)}: {error}')
|
|
return
|
|
|
|
device.setup()
|
|
|
|
self.devices[device_address] = device
|
|
|
|
self.logger.info(f'Attached device @ {format_address(self.interface, device_address)}')
|
|
|
|
def _handle_device_lost(self, device):
|
|
device_address = device.device_address
|
|
|
|
self.logger.info(f'Lost device @ {format_address(self.interface, device_address)}')
|
|
|
|
if device_address in self.sessions:
|
|
(session_state, session) = self.sessions[device_address]
|
|
|
|
if session_state == SessionState.ACTIVE:
|
|
self._terminate_session(session)
|
|
|
|
del self.devices[device_address]
|
|
|
|
self.logger.info(f'Detached device @ {format_address(self.interface, device_address)}')
|
|
|
|
def _handle_poll_response(self, device, poll_response):
|
|
if isinstance(poll_response, KeystrokePollResponse):
|
|
self._handle_keystroke_poll_response(device, poll_response)
|
|
|
|
def _handle_keystroke_poll_response(self, terminal, poll_response):
|
|
device_address = terminal.device_address
|
|
scan_code = poll_response.scan_code
|
|
|
|
(key, modifiers, modifiers_changed) = terminal.keyboard.get_key(scan_code)
|
|
|
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
self.logger.debug((f'Keystroke detected: Scan Code = {scan_code}, '
|
|
f'Key = {key}, Modifiers = {modifiers}'))
|
|
|
|
# Update the status line if modifiers have changed.
|
|
if modifiers_changed:
|
|
terminal.display.status_line.write_keyboard_modifiers(modifiers)
|
|
|
|
if not key:
|
|
return
|
|
|
|
if key == Key.CURSOR_BLINK:
|
|
terminal.display.toggle_cursor_blink()
|
|
elif key == Key.ALT_CURSOR:
|
|
terminal.display.toggle_cursor_reverse()
|
|
elif key == Key.CLICKER:
|
|
terminal.keyboard.toggle_clicker()
|
|
elif device_address in self.sessions:
|
|
(session_state, session) = self.sessions[device_address]
|
|
|
|
if session_state == SessionState.ACTIVE:
|
|
session.handle_key(key, modifiers, scan_code)
|
|
|
|
session.render()
|
|
|
|
def _calculate_poll_delay(self):
|
|
if self.last_attached_poll_time is None:
|
|
return 0
|
|
|
|
return max((self.last_attached_poll_time + self.attached_poll_period) - time.perf_counter(), 0)
|
|
|
|
def _get_detatched_device_addresses(self):
|
|
attached_addresses = set(self.devices.keys())
|
|
|
|
# The 3299 is transparent, but if there is at least one device attached to a 3299
|
|
# port then we can assume there is a 3299 attached and if there is one device
|
|
# direct attached then we can assume there is not a 3299 attached.
|
|
is_3299_attached = any(attached_addresses.difference([None]))
|
|
is_3299_not_attached = (None in attached_addresses)
|
|
|
|
if is_3299_not_attached or InterfaceFeature.PROTOCOL_3299 not in self.interface.features:
|
|
addresses = [None]
|
|
elif is_3299_attached:
|
|
addresses = PORT_MAP_3299
|
|
else:
|
|
addresses = [None, *PORT_MAP_3299]
|
|
|
|
return filter(lambda address: address not in attached_addresses, addresses)
|