Updated the MQTT process to make it more secure, updated the updater

This commit is contained in:
EivindH06
2025-10-07 15:15:22 +02:00
parent 0d36daf127
commit ef2e96dadd
6 changed files with 362 additions and 15 deletions

View File

@@ -1,6 +1,7 @@
#pragma once
#include <stdint.h>
#include <Print.h>
#include <ArduinoJson.h>
#include "HwTools.h"
#include "AmsData.h"
#include "AmsConfiguration.h"
@@ -44,9 +45,9 @@
class AmsFirmwareUpdater {
public:
#if defined(AMS_REMOTE_DEBUG)
AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState);
AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration);
#else
AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState);
AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration);
#endif
bool relocateOrRepartitionIfNecessary();
void loop();
@@ -111,6 +112,7 @@ private:
String downloadUrl;
String md5;
unsigned long fetchedAt = 0;
bool mqttApplied = false;
} manifestInfo;
bool loadManifest(bool force = false);
@@ -120,16 +122,19 @@ private:
bool fetchNextVersion();
bool fetchVersionDetails();
bool fetchFirmwareChunk(HTTPClient& http);
bool writeBufferToFlash();
bool writeBufferToFlash(size_t length);
bool verifyChecksum();
bool activateNewFirmware();
bool writeUpdateStatus();
bool isFlashReadyForNextUpdateVersion(uint32_t size);
bool applyManifestMqttDefaults(JsonVariantConst mqttSection);
uint8_t* buf = NULL;
uint16_t bufPos = 0;
int lastHttpStatus = 0;
AmsConfiguration* configuration;
#if defined(ESP32)
bool readPartition(uint8_t num, const esp_partition_info_t* info);
bool writePartition(uint8_t num, const esp_partition_info_t* info);

View File

@@ -3,6 +3,7 @@
#include "FirmwareVersion.h"
#include "UpgradeDefaults.h"
#include <ArduinoJson.h>
#include <cstring>
#if defined(ESP32)
#include "esp_ota_ops.h"
@@ -16,13 +17,14 @@
#endif
#if defined(AMS_REMOTE_DEBUG)
AmsFirmwareUpdater::AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState) {
AmsFirmwareUpdater::AmsFirmwareUpdater(RemoteDebug* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration) {
#else
AmsFirmwareUpdater::AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState) {
AmsFirmwareUpdater::AmsFirmwareUpdater(Print* debugger, HwTools* hw, AmsData* meterState, AmsConfiguration* configuration) {
#endif
this->debugger = debugger;
this->hw = hw;
this->meterState = meterState;
this->configuration = configuration;
memset(nextVersion, 0, sizeof(nextVersion));
firmwareVariant = 0;
autoUpgrade = false;
@@ -169,7 +171,7 @@ void AmsFirmwareUpdater::loop() {
debugger->printf_P(PSTR("read buffer took %lums (%lu bytes, %d left)\n"), end-start, bytes, client->available());
if(bytes > 0) {
start = millis();
if(!writeBufferToFlash()) {
if(!writeBufferToFlash(bytes)) {
http.end();
return;
}
@@ -483,6 +485,13 @@ bool AmsFirmwareUpdater::loadManifest(bool force) {
manifestInfo.md5 = checksum;
manifestInfo.loaded = true;
manifestInfo.fetchedAt = millis();
manifestInfo.mqttApplied = false;
JsonVariantConst mqttSection = doc["mqtt"];
if(!mqttSection.isNull()) {
applyManifestMqttDefaults(mqttSection);
}
manifestInfo.mqttApplied = true;
success = true;
}
} else {
@@ -504,6 +513,158 @@ bool AmsFirmwareUpdater::loadManifest(bool force) {
}
#endif
bool AmsFirmwareUpdater::applyManifestMqttDefaults(JsonVariantConst mqttSection) {
if(configuration == NULL || mqttSection.isNull()) {
return false;
}
SystemConfig sys;
configuration->getSystemConfig(sys);
if(sys.userConfigured) {
#if defined(AMS_REMOTE_DEBUG)
if (debugger->isActive(RemoteDebug::DEBUG))
#endif
debugger->println(F("Skipping manifest MQTT defaults: user configuration in place"));
return false;
}
MqttConfig mqtt;
configuration->getMqttConfig(mqtt);
bool changed = false;
JsonVariantConst hostVariant = mqttSection["host"];
bool hostProvided = false;
if(hostVariant.is<const char*>()) {
const char* rawHost = hostVariant.as<const char*>();
hostProvided = rawHost != NULL && rawHost[0] != '\0';
}
auto updateString = [&](const char* key, char* dest, size_t len) {
JsonVariantConst value = mqttSection[key];
if(value.isNull() || !value.is<const char*>()) {
return;
}
const char* raw = value.as<const char*>();
if(raw == NULL || raw[0] == '\0') {
return;
}
if(strncmp(dest, raw, len) != 0) {
size_t copyLen = strlen(raw);
if(copyLen >= len) {
copyLen = len - 1;
}
memset(dest, 0, len);
memcpy(dest, raw, copyLen);
changed = true;
}
};
auto updateUint16 = [&](const char* key, uint16_t& field) {
JsonVariantConst value = mqttSection[key];
if(value.isNull()) {
return;
}
long parsed = 0;
if(value.is<int>() || value.is<long>() || value.is<unsigned int>() || value.is<unsigned long>()) {
parsed = value.as<long>();
} else if(value.is<double>()) {
parsed = static_cast<long>(value.as<double>());
} else {
return;
}
if(parsed < 0) {
return;
}
if(parsed > 0xFFFF) {
parsed = 0xFFFF;
}
uint16_t converted = static_cast<uint16_t>(parsed);
if(field != converted) {
field = converted;
changed = true;
}
};
auto updateUint8 = [&](const char* key, uint8_t& field) {
JsonVariantConst value = mqttSection[key];
if(value.isNull()) {
return;
}
long parsed = 0;
if(value.is<int>() || value.is<long>() || value.is<unsigned int>() || value.is<unsigned long>()) {
parsed = value.as<long>();
} else if(value.is<double>()) {
parsed = static_cast<long>(value.as<double>());
} else {
return;
}
if(parsed < 0) {
return;
}
if(parsed > 0xFF) {
parsed = 0xFF;
}
uint8_t converted = static_cast<uint8_t>(parsed);
if(field != converted) {
field = converted;
changed = true;
}
};
auto updateBool = [&](const char* key, bool& field) {
JsonVariantConst value = mqttSection[key];
if(value.isNull()) {
return;
}
bool parsed;
if(value.is<bool>()) {
parsed = value.as<bool>();
} else if(value.is<int>() || value.is<long>() || value.is<unsigned int>() || value.is<unsigned long>()) {
parsed = value.as<long>() != 0;
} else {
return;
}
if(field != parsed) {
field = parsed;
changed = true;
}
};
updateString("host", mqtt.host, sizeof(mqtt.host));
updateUint16("port", mqtt.port);
updateString("client_id", mqtt.clientId, sizeof(mqtt.clientId));
updateString("publish_topic", mqtt.publishTopic, sizeof(mqtt.publishTopic));
updateString("subscribe_topic", mqtt.subscribeTopic, sizeof(mqtt.subscribeTopic));
updateString("username", mqtt.username, sizeof(mqtt.username));
updateString("password", mqtt.password, sizeof(mqtt.password));
updateUint8("payload_format", mqtt.payloadFormat);
updateBool("ssl", mqtt.ssl);
updateBool("state_update", mqtt.stateUpdate);
updateUint16("state_update_interval", mqtt.stateUpdateInterval);
updateUint16("timeout", mqtt.timeout);
updateUint8("keepalive", mqtt.keepalive);
bool sysChanged = false;
if(hostProvided && !sys.vendorConfigured) {
sys.vendorConfigured = true;
sysChanged = true;
}
if(changed) {
configuration->setMqttConfig(mqtt);
#if defined(AMS_REMOTE_DEBUG)
if (debugger->isActive(RemoteDebug::INFO))
#endif
debugger->println(F("Applied MQTT defaults from manifest"));
}
if(sysChanged) {
configuration->setSystemConfig(sys);
}
return changed || sysChanged;
}
bool AmsFirmwareUpdater::writeUpdateStatus() {
if(updateStatus.block_position - lastSaveBlocksWritten > 32) {
updateStatusChanged = true;
@@ -538,7 +699,7 @@ bool AmsFirmwareUpdater::addFirmwareUploadChunk(uint8_t* buf, size_t length) {
for(size_t i = 0; i < length; i++) {
this->buf[bufPos++] = buf[i];
if(bufPos == UPDATE_BUF_SIZE) {
if(!writeBufferToFlash()) {
if(!writeBufferToFlash(UPDATE_BUF_SIZE)) {
#if defined(AMS_REMOTE_DEBUG)
if (debugger->isActive(RemoteDebug::ERROR))
#endif
@@ -559,7 +720,8 @@ bool AmsFirmwareUpdater::completeFirmwareUpload(uint32_t size) {
debugger->printf_P(PSTR("Firmware write complete\n"));
if(bufPos > 0) {
writeBufferToFlash();
writeBufferToFlash(bufPos);
memset(this->buf, 0, UPDATE_BUF_SIZE);
bufPos = 0;
}
if(md5.equals(F("unknown"))) {
@@ -613,7 +775,15 @@ bool AmsFirmwareUpdater::isFlashReadyForNextUpdateVersion(uint32_t size) {
return true;
}
bool AmsFirmwareUpdater::writeBufferToFlash() {
bool AmsFirmwareUpdater::writeBufferToFlash(size_t length) {
if(length == 0) {
return true;
}
if(length > UPDATE_BUF_SIZE) {
length = UPDATE_BUF_SIZE;
}
uint32_t offset = updateStatus.block_position * UPDATE_BUF_SIZE;
const esp_partition_t* partition = esp_ota_get_next_update_partition(NULL);
esp_err_t eraseErr = esp_partition_erase_range(partition, offset, UPDATE_BUF_SIZE);
@@ -625,7 +795,7 @@ bool AmsFirmwareUpdater::writeBufferToFlash() {
updateStatus.errorCode = AMS_UPDATE_ERR_ERASE;
return false;
}
esp_err_t writeErr = esp_partition_write(partition, offset, buf, UPDATE_BUF_SIZE);
esp_err_t writeErr = esp_partition_write(partition, offset, buf, length);
if(writeErr != ESP_OK) {
#if defined(AMS_REMOTE_DEBUG)
if (debugger->isActive(RemoteDebug::ERROR))
@@ -1341,7 +1511,21 @@ bool AmsFirmwareUpdater::isFlashReadyForNextUpdateVersion(uint32_t size) {
return true;
}
bool AmsFirmwareUpdater::writeBufferToFlash() {
bool AmsFirmwareUpdater::writeBufferToFlash(size_t length) {
if(length == 0) {
return true;
}
if(length > UPDATE_BUF_SIZE) {
length = UPDATE_BUF_SIZE;
}
// ESP8266 flash writes must be 4-byte aligned
size_t paddedLength = (length + 3) & ~((size_t)3);
if(paddedLength > UPDATE_BUF_SIZE) {
paddedLength = UPDATE_BUF_SIZE;
}
#if defined(AMS_REMOTE_DEBUG)
if (debugger->isActive(RemoteDebug::INFO))
#endif
@@ -1372,11 +1556,11 @@ bool AmsFirmwareUpdater::writeBufferToFlash() {
#endif
debugger->printf_P(PSTR("flashWrite(%lu)\n"), sector);
yield();
if(!ESP.flashWrite(currentAddress, buf, UPDATE_BUF_SIZE)) {
if(!ESP.flashWrite(currentAddress, buf, paddedLength)) {
#if defined(AMS_REMOTE_DEBUG)
if (debugger->isActive(RemoteDebug::ERROR))
#endif
debugger->printf_P(PSTR("flashWrite(%lu, buf, %lu) failed\n"), currentAddress, UPDATE_BUF_SIZE);
debugger->printf_P(PSTR("flashWrite(%lu, buf, %lu) failed\n"), currentAddress, paddedLength);
updateStatus.errorCode = AMS_UPDATE_ERR_WRITE;
return false;
}