diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 29bebdd7..616fd05f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -37,6 +37,38 @@ jobs: sed -i 's/NO_ENERGY_SPEEDOMETER_USER/ENERGY_SPEEDOMETER_USER=\\"${{secrets.ENERGY_SPEEDOMETER_USER}}\\"/g' platformio.ini sed -i 's/NO_ENERGY_SPEEDOMETER_PASS/ENERGY_SPEEDOMETER_PASS=\\"${{secrets.ENERGY_SPEEDOMETER_PASS}}\\"/g' platformio.ini + - name: Write MQTT defaults + env: + MQTT_DEFAULT_HOST: ${{ secrets.MQTT_DEFAULT_HOST }} + MQTT_DEFAULT_PORT: ${{ secrets.MQTT_DEFAULT_PORT }} + MQTT_DEFAULT_USERNAME: ${{ secrets.MQTT_DEFAULT_USERNAME }} + MQTT_DEFAULT_PASSWORD: ${{ secrets.MQTT_DEFAULT_PASSWORD }} + MQTT_DEFAULT_CLIENT_ID: ${{ secrets.MQTT_DEFAULT_CLIENT_ID }} + MQTT_DEFAULT_PUBLISH_TOPIC: ${{ secrets.MQTT_DEFAULT_PUBLISH_TOPIC }} + MQTT_DEFAULT_SUBSCRIBE_TOPIC: ${{ secrets.MQTT_DEFAULT_SUBSCRIBE_TOPIC }} + MQTT_DEFAULT_PAYLOAD_FORMAT: ${{ secrets.MQTT_DEFAULT_PAYLOAD_FORMAT }} + MQTT_DEFAULT_SSL: ${{ secrets.MQTT_DEFAULT_SSL }} + MQTT_DEFAULT_STATE_UPDATE: ${{ secrets.MQTT_DEFAULT_STATE_UPDATE }} + MQTT_DEFAULT_STATE_UPDATE_INTERVAL: ${{ secrets.MQTT_DEFAULT_STATE_UPDATE_INTERVAL }} + MQTT_DEFAULT_TIMEOUT: ${{ secrets.MQTT_DEFAULT_TIMEOUT }} + MQTT_DEFAULT_KEEPALIVE: ${{ secrets.MQTT_DEFAULT_KEEPALIVE }} + run: | + { + printf 'MQTT_DEFAULT_HOST="%s"\n' "${MQTT_DEFAULT_HOST}" + printf 'MQTT_DEFAULT_PORT="%s"\n' "${MQTT_DEFAULT_PORT}" + printf 'MQTT_DEFAULT_USERNAME="%s"\n' "${MQTT_DEFAULT_USERNAME}" + printf 'MQTT_DEFAULT_PASSWORD="%s"\n' "${MQTT_DEFAULT_PASSWORD}" + printf 'MQTT_DEFAULT_CLIENT_ID="%s"\n' "${MQTT_DEFAULT_CLIENT_ID}" + printf 'MQTT_DEFAULT_PUBLISH_TOPIC="%s"\n' "${MQTT_DEFAULT_PUBLISH_TOPIC}" + printf 'MQTT_DEFAULT_SUBSCRIBE_TOPIC="%s"\n' "${MQTT_DEFAULT_SUBSCRIBE_TOPIC}" + printf 'MQTT_DEFAULT_PAYLOAD_FORMAT="%s"\n' "${MQTT_DEFAULT_PAYLOAD_FORMAT}" + printf 'MQTT_DEFAULT_SSL="%s"\n' "${MQTT_DEFAULT_SSL}" + printf 'MQTT_DEFAULT_STATE_UPDATE="%s"\n' "${MQTT_DEFAULT_STATE_UPDATE}" + printf 'MQTT_DEFAULT_STATE_UPDATE_INTERVAL="%s"\n' "${MQTT_DEFAULT_STATE_UPDATE_INTERVAL}" + printf 'MQTT_DEFAULT_TIMEOUT="%s"\n' "${MQTT_DEFAULT_TIMEOUT}" + printf 'MQTT_DEFAULT_KEEPALIVE="%s"\n' "${MQTT_DEFAULT_KEEPALIVE}" + } > .env + - name: Cache Python dependencies uses: actions/cache@v4 with: diff --git a/README.md b/README.md index 66bb876a..f4811dc7 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,47 @@ If you want devices to connect to a known MQTT broker immediately after flashing Any field you leave empty will fall back to the defaults in `lib/AmsConfiguration/include/MqttDefaults.h`, meaning the web UI will prompt for credentials during first-time setup. +### Shipping credentials with GitHub releases (without committing secrets) + +The OTA manifest generated by `scripts/package_firmware.py` now carries an +optional `mqtt` block. If the build machine provides values for +`MQTT_DEFAULT_*` (through environment variables or a `.env` file), the script +embeds those defaults alongside the firmware checksum. Devices that upgrade via +GitHub Pages will download the manifest, detect the `mqtt` section, and apply +the broker settings automatically—unless the installer has already customised +the device through the web UI. + +To keep secrets out of source control while still provisioning releases: + +1. Store your broker credentials as GitHub Action secrets (for example + `MQTT_DEFAULT_USERNAME`, `MQTT_DEFAULT_PASSWORD`, etc.). +2. In the release workflow, write a temporary `.env` file before invoking the + PlatformIO build: + + ```yaml + - name: Write MQTT defaults + run: | + cat <<'EOF' > .env + MQTT_DEFAULT_HOST=${{ secrets.MQTT_DEFAULT_HOST }} + MQTT_DEFAULT_PORT=${{ secrets.MQTT_DEFAULT_PORT }} + MQTT_DEFAULT_USERNAME=${{ secrets.MQTT_DEFAULT_USERNAME }} + MQTT_DEFAULT_PASSWORD=${{ secrets.MQTT_DEFAULT_PASSWORD }} + MQTT_DEFAULT_CLIENT_ID=${{ secrets.MQTT_DEFAULT_CLIENT_ID }} + MQTT_DEFAULT_PUBLISH_TOPIC=${{ secrets.MQTT_DEFAULT_PUBLISH_TOPIC }} + MQTT_DEFAULT_SUBSCRIBE_TOPIC=${{ secrets.MQTT_DEFAULT_SUBSCRIBE_TOPIC }} + EOF + ``` + +3. Build the firmware and run `scripts/package_firmware.py` as usual; the + generated `manifest.json` will include the broker defaults. +4. Upload `dist/` to GitHub Pages (the existing release workflow already covers + this), so devices retrieving the manifest can bootstrap the MQTT connection + immediately after flashing. + +Because the `.env` file is created on-the-fly inside CI and never committed, +your credentials remain private while every release published to GitHub ships +with working MQTT settings out of the box. + # How to wipe bricked board? diff --git a/lib/AmsFirmwareUpdater/include/AmsFirmwareUpdater.h b/lib/AmsFirmwareUpdater/include/AmsFirmwareUpdater.h index a88bcb37..364909a8 100644 --- a/lib/AmsFirmwareUpdater/include/AmsFirmwareUpdater.h +++ b/lib/AmsFirmwareUpdater/include/AmsFirmwareUpdater.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include "HwTools.h" #include "AmsData.h" #include "AmsConfiguration.h" @@ -44,9 +45,9 @@ class AmsFirmwareUpdater { public: #if defined(AMS_REMOTE_DEBUG) - AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState); + AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration); #else - AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState); + AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration); #endif bool relocateOrRepartitionIfNecessary(); void loop(); @@ -111,6 +112,7 @@ private: String downloadUrl; String md5; unsigned long fetchedAt = 0; + bool mqttApplied = false; } manifestInfo; bool loadManifest(bool force = false); @@ -120,16 +122,19 @@ private: bool fetchNextVersion(); bool fetchVersionDetails(); bool fetchFirmwareChunk(HTTPClient& http); - bool writeBufferToFlash(); + bool writeBufferToFlash(size_t length); bool verifyChecksum(); bool activateNewFirmware(); bool writeUpdateStatus(); bool isFlashReadyForNextUpdateVersion(uint32_t size); + bool applyManifestMqttDefaults(JsonVariantConst mqttSection); uint8_t* buf = NULL; uint16_t bufPos = 0; int lastHttpStatus = 0; + AmsConfiguration* configuration; + #if defined(ESP32) bool readPartition(uint8_t num, const esp_partition_info_t* info); bool writePartition(uint8_t num, const esp_partition_info_t* info); diff --git a/lib/AmsFirmwareUpdater/src/AmsFirmwareUpdater.cpp b/lib/AmsFirmwareUpdater/src/AmsFirmwareUpdater.cpp index 086f3b97..e58057be 100644 --- a/lib/AmsFirmwareUpdater/src/AmsFirmwareUpdater.cpp +++ b/lib/AmsFirmwareUpdater/src/AmsFirmwareUpdater.cpp @@ -3,6 +3,7 @@ #include "FirmwareVersion.h" #include "UpgradeDefaults.h" #include +#include #if defined(ESP32) #include "esp_ota_ops.h" @@ -16,13 +17,14 @@ #endif #if defined(AMS_REMOTE_DEBUG) -AmsFirmwareUpdater::AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState) { +AmsFirmwareUpdater::AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration) { #else -AmsFirmwareUpdater::AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState) { +AmsFirmwareUpdater::AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration) { #endif this->debugger = debugger; this->hw = hw; this->meterState = meterState; + this->configuration = configuration; memset(nextVersion, 0, sizeof(nextVersion)); firmwareVariant = 0; autoUpgrade = false; @@ -169,7 +171,7 @@ void AmsFirmwareUpdater::loop() { debugger->printf_P(PSTR("read buffer took %lums (%lu bytes, %d left)\n"), end-start, bytes, client->available()); if(bytes > 0) { start = millis(); - if(!writeBufferToFlash()) { + if(!writeBufferToFlash(bytes)) { http.end(); return; } @@ -483,6 +485,13 @@ bool AmsFirmwareUpdater::loadManifest(bool force) { manifestInfo.md5 = checksum; manifestInfo.loaded = true; manifestInfo.fetchedAt = millis(); + manifestInfo.mqttApplied = false; + + JsonVariantConst mqttSection = doc["mqtt"]; + if(!mqttSection.isNull()) { + applyManifestMqttDefaults(mqttSection); + } + manifestInfo.mqttApplied = true; success = true; } } else { @@ -504,6 +513,158 @@ bool AmsFirmwareUpdater::loadManifest(bool force) { } #endif +bool AmsFirmwareUpdater::applyManifestMqttDefaults(JsonVariantConst mqttSection) { + if(configuration == NULL || mqttSection.isNull()) { + return false; + } + + SystemConfig sys; + configuration->getSystemConfig(sys); + if(sys.userConfigured) { +#if defined(AMS_REMOTE_DEBUG) + if (debugger->isActive(RemoteDebug::DEBUG)) +#endif + debugger->println(F("Skipping manifest MQTT defaults: user configuration in place")); + return false; + } + + MqttConfig mqtt; + configuration->getMqttConfig(mqtt); + bool changed = false; + + JsonVariantConst hostVariant = mqttSection["host"]; + bool hostProvided = false; + if(hostVariant.is()) { + const char* rawHost = hostVariant.as(); + hostProvided = rawHost != NULL && rawHost[0] != '\0'; + } + + auto updateString = [&](const char* key, char* dest, size_t len) { + JsonVariantConst value = mqttSection[key]; + if(value.isNull() || !value.is()) { + return; + } + const char* raw = value.as(); + if(raw == NULL || raw[0] == '\0') { + return; + } + if(strncmp(dest, raw, len) != 0) { + size_t copyLen = strlen(raw); + if(copyLen >= len) { + copyLen = len - 1; + } + memset(dest, 0, len); + memcpy(dest, raw, copyLen); + changed = true; + } + }; + + auto updateUint16 = [&](const char* key, uint16_t& field) { + JsonVariantConst value = mqttSection[key]; + if(value.isNull()) { + return; + } + long parsed = 0; + if(value.is() || value.is() || value.is() || value.is()) { + parsed = value.as(); + } else if(value.is()) { + parsed = static_cast(value.as()); + } else { + return; + } + if(parsed < 0) { + return; + } + if(parsed > 0xFFFF) { + parsed = 0xFFFF; + } + uint16_t converted = static_cast(parsed); + if(field != converted) { + field = converted; + changed = true; + } + }; + + auto updateUint8 = [&](const char* key, uint8_t& field) { + JsonVariantConst value = mqttSection[key]; + if(value.isNull()) { + return; + } + long parsed = 0; + if(value.is() || value.is() || value.is() || value.is()) { + parsed = value.as(); + } else if(value.is()) { + parsed = static_cast(value.as()); + } else { + return; + } + if(parsed < 0) { + return; + } + if(parsed > 0xFF) { + parsed = 0xFF; + } + uint8_t converted = static_cast(parsed); + if(field != converted) { + field = converted; + changed = true; + } + }; + + auto updateBool = [&](const char* key, bool& field) { + JsonVariantConst value = mqttSection[key]; + if(value.isNull()) { + return; + } + bool parsed; + if(value.is()) { + parsed = value.as(); + } else if(value.is() || value.is() || value.is() || value.is()) { + parsed = value.as() != 0; + } else { + return; + } + if(field != parsed) { + field = parsed; + changed = true; + } + }; + + updateString("host", mqtt.host, sizeof(mqtt.host)); + updateUint16("port", mqtt.port); + updateString("client_id", mqtt.clientId, sizeof(mqtt.clientId)); + updateString("publish_topic", mqtt.publishTopic, sizeof(mqtt.publishTopic)); + updateString("subscribe_topic", mqtt.subscribeTopic, sizeof(mqtt.subscribeTopic)); + updateString("username", mqtt.username, sizeof(mqtt.username)); + updateString("password", mqtt.password, sizeof(mqtt.password)); + updateUint8("payload_format", mqtt.payloadFormat); + updateBool("ssl", mqtt.ssl); + updateBool("state_update", mqtt.stateUpdate); + updateUint16("state_update_interval", mqtt.stateUpdateInterval); + updateUint16("timeout", mqtt.timeout); + updateUint8("keepalive", mqtt.keepalive); + + bool sysChanged = false; + if(hostProvided && !sys.vendorConfigured) { + sys.vendorConfigured = true; + sysChanged = true; + } + + if(changed) { + configuration->setMqttConfig(mqtt); +#if defined(AMS_REMOTE_DEBUG) + if (debugger->isActive(RemoteDebug::INFO)) +#endif + debugger->println(F("Applied MQTT defaults from manifest")); + } + + if(sysChanged) { + configuration->setSystemConfig(sys); + } + + return changed || sysChanged; +} + bool AmsFirmwareUpdater::writeUpdateStatus() { if(updateStatus.block_position - lastSaveBlocksWritten > 32) { updateStatusChanged = true; @@ -538,7 +699,7 @@ bool AmsFirmwareUpdater::addFirmwareUploadChunk(uint8_t* buf, size_t length) { for(size_t i = 0; i < length; i++) { this->buf[bufPos++] = buf[i]; if(bufPos == UPDATE_BUF_SIZE) { - if(!writeBufferToFlash()) { + if(!writeBufferToFlash(UPDATE_BUF_SIZE)) { #if defined(AMS_REMOTE_DEBUG) if (debugger->isActive(RemoteDebug::ERROR)) #endif @@ -559,7 +720,8 @@ bool AmsFirmwareUpdater::completeFirmwareUpload(uint32_t size) { debugger->printf_P(PSTR("Firmware write complete\n")); if(bufPos > 0) { - writeBufferToFlash(); + writeBufferToFlash(bufPos); + memset(this->buf, 0, UPDATE_BUF_SIZE); bufPos = 0; } if(md5.equals(F("unknown"))) { @@ -613,7 +775,15 @@ bool AmsFirmwareUpdater::isFlashReadyForNextUpdateVersion(uint32_t size) { return true; } -bool AmsFirmwareUpdater::writeBufferToFlash() { +bool AmsFirmwareUpdater::writeBufferToFlash(size_t length) { + if(length == 0) { + return true; + } + + if(length > UPDATE_BUF_SIZE) { + length = UPDATE_BUF_SIZE; + } + uint32_t offset = updateStatus.block_position * UPDATE_BUF_SIZE; const esp_partition_t* partition = esp_ota_get_next_update_partition(NULL); esp_err_t eraseErr = esp_partition_erase_range(partition, offset, UPDATE_BUF_SIZE); @@ -625,7 +795,7 @@ bool AmsFirmwareUpdater::writeBufferToFlash() { updateStatus.errorCode = AMS_UPDATE_ERR_ERASE; return false; } - esp_err_t writeErr = esp_partition_write(partition, offset, buf, UPDATE_BUF_SIZE); + esp_err_t writeErr = esp_partition_write(partition, offset, buf, length); if(writeErr != ESP_OK) { #if defined(AMS_REMOTE_DEBUG) if (debugger->isActive(RemoteDebug::ERROR)) @@ -1341,7 +1511,21 @@ bool AmsFirmwareUpdater::isFlashReadyForNextUpdateVersion(uint32_t size) { return true; } -bool AmsFirmwareUpdater::writeBufferToFlash() { +bool AmsFirmwareUpdater::writeBufferToFlash(size_t length) { + if(length == 0) { + return true; + } + + if(length > UPDATE_BUF_SIZE) { + length = UPDATE_BUF_SIZE; + } + + // ESP8266 flash writes must be 4-byte aligned + size_t paddedLength = (length + 3) & ~((size_t)3); + if(paddedLength > UPDATE_BUF_SIZE) { + paddedLength = UPDATE_BUF_SIZE; + } + #if defined(AMS_REMOTE_DEBUG) if (debugger->isActive(RemoteDebug::INFO)) #endif @@ -1372,11 +1556,11 @@ bool AmsFirmwareUpdater::writeBufferToFlash() { #endif debugger->printf_P(PSTR("flashWrite(%lu)\n"), sector); yield(); - if(!ESP.flashWrite(currentAddress, buf, UPDATE_BUF_SIZE)) { + if(!ESP.flashWrite(currentAddress, buf, paddedLength)) { #if defined(AMS_REMOTE_DEBUG) if (debugger->isActive(RemoteDebug::ERROR)) #endif - debugger->printf_P(PSTR("flashWrite(%lu, buf, %lu) failed\n"), currentAddress, UPDATE_BUF_SIZE); + debugger->printf_P(PSTR("flashWrite(%lu, buf, %lu) failed\n"), currentAddress, paddedLength); updateStatus.errorCode = AMS_UPDATE_ERR_WRITE; return false; } diff --git a/scripts/package_firmware.py b/scripts/package_firmware.py index 4c058755..90f656fa 100644 --- a/scripts/package_firmware.py +++ b/scripts/package_firmware.py @@ -25,7 +25,7 @@ import json import os from pathlib import Path from datetime import datetime, timezone -from typing import Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple DEFAULT_CHANNEL = "stable" DEFAULT_OUTPUT = Path("dist") @@ -108,6 +108,82 @@ def compute_md5(path: Path) -> str: return hash_md5.hexdigest() +MQTT_FIELD_SPECS: Tuple[Tuple[str, str, type], ...] = ( + ("host", "MQTT_DEFAULT_HOST", str), + ("port", "MQTT_DEFAULT_PORT", int), + ("username", "MQTT_DEFAULT_USERNAME", str), + ("password", "MQTT_DEFAULT_PASSWORD", str), + ("client_id", "MQTT_DEFAULT_CLIENT_ID", str), + ("publish_topic", "MQTT_DEFAULT_PUBLISH_TOPIC", str), + ("subscribe_topic", "MQTT_DEFAULT_SUBSCRIBE_TOPIC", str), + ("payload_format", "MQTT_DEFAULT_PAYLOAD_FORMAT", int), + ("ssl", "MQTT_DEFAULT_SSL", bool), + ("state_update", "MQTT_DEFAULT_STATE_UPDATE", bool), + ("state_update_interval", "MQTT_DEFAULT_STATE_UPDATE_INTERVAL", int), + ("timeout", "MQTT_DEFAULT_TIMEOUT", int), + ("keepalive", "MQTT_DEFAULT_KEEPALIVE", int), +) + + +def _parse_env_file(path: Path) -> Dict[str, str]: + data: Dict[str, str] = {} + if not path.exists(): + return data + with path.open("r", encoding="utf-8") as handle: + for raw_line in handle: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + key, sep, value = line.partition("=") + if not sep: + continue + key = key.strip() + value = value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}: + value = value[1:-1] + data[key] = value + return data + + +def _to_bool(value: str) -> bool: + truthy = {"1", "true", "yes", "on"} + falsy = {"0", "false", "no", "off"} + lowered = value.lower() + if lowered in truthy: + return True + if lowered in falsy: + return False + raise ValueError(f"Invalid boolean literal: {value}") + + +def load_mqtt_defaults(project_dir: Path) -> Dict[str, Any]: + env_path = project_dir / ".env" + file_values = _parse_env_file(env_path) + manifest_values: Dict[str, Any] = {} + + for manifest_key, env_key, value_type in MQTT_FIELD_SPECS: + raw_value = os.getenv(env_key) + if raw_value is None: + raw_value = file_values.get(env_key) + if raw_value is None or raw_value == "": + continue + + try: + if value_type is bool: + converted: Any = _to_bool(raw_value) + elif value_type is int: + converted = int(raw_value, 0) + else: + converted = raw_value + except ValueError as exc: + print(f"WARN: Skipping MQTT field {env_key}: {exc}") + continue + + manifest_values[manifest_key] = converted + + return manifest_values + + def package_environment( env: str, chip: str, @@ -116,6 +192,7 @@ def package_environment( version: str, output_dir: Path, published_at: str, + mqtt_defaults: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, str]]: firmware_path = build_dir / env / "firmware.bin" if not firmware_path.exists(): @@ -144,6 +221,8 @@ def package_environment( "published_at": published_at, "env": env, } + if mqtt_defaults: + manifest["mqtt"] = mqtt_defaults manifest_path.write_text(json.dumps(manifest, indent=2)) # Optional: bundle flashing zip if produced by scripts/mkzip.sh @@ -249,6 +328,11 @@ def main() -> None: else datetime.now(timezone.utc).isoformat(timespec="seconds") ) + project_dir = Path.cwd() + mqtt_defaults = load_mqtt_defaults(project_dir) + if mqtt_defaults: + print("Including MQTT defaults in manifest for release packaging") + summaries = [] for env in envs: chip = ENV_TO_CHIP[env] @@ -260,6 +344,7 @@ def main() -> None: version=version, output_dir=args.output, published_at=published_at, + mqtt_defaults=mqtt_defaults, ) if summary: summaries.append(summary) diff --git a/src/AmsToMqttBridge.cpp b/src/AmsToMqttBridge.cpp index 20345faf..e2282ffe 100644 --- a/src/AmsToMqttBridge.cpp +++ b/src/AmsToMqttBridge.cpp @@ -181,7 +181,7 @@ bool ntpEnabled = false; bool mdnsEnabled = false; -AmsFirmwareUpdater updater(&Debug, &hw, &meterState); +AmsFirmwareUpdater updater(&Debug, &hw, &meterState, &config); AmsDataStorage ds(&Debug); #if defined(_CLOUDCONNECTOR_H)