From d5d8b34b2b74800d260f5a7a7f2d6db35847f538 Mon Sep 17 00:00:00 2001 From: Terrence Date: Mon, 14 Apr 2025 15:44:06 +0800 Subject: [PATCH] Add activation version 2 --- main/application.cc | 76 ++++++++----- main/audio_processing/audio_processor.cc | 6 + main/audio_processing/wake_word_detect.cc | 6 + main/display/display.cc | 1 + main/ota.cc | 131 +++++++++++++++++++--- main/ota.h | 12 ++ 6 files changed, 194 insertions(+), 38 deletions(-) diff --git a/main/application.cc b/main/application.cc index 3147c1b0..6d4292c6 100644 --- a/main/application.cc +++ b/main/application.cc @@ -65,20 +65,31 @@ Application::~Application() { void Application::CheckNewVersion() { const int MAX_RETRY = 10; int retry_count = 0; + int retry_delay = 10; // 初始重试延迟为10秒 while (true) { + SetDeviceState(kDeviceStateActivating); auto display = Board::GetInstance().GetDisplay(); + display->SetStatus(Lang::Strings::CHECKING_NEW_VERSION); + if (!ota_.CheckVersion()) { retry_count++; if (retry_count >= MAX_RETRY) { ESP_LOGE(TAG, "Too many retries, exit version check"); return; } - ESP_LOGW(TAG, "Check new version failed, retry in %d seconds (%d/%d)", 60, retry_count, MAX_RETRY); - vTaskDelay(pdMS_TO_TICKS(60000)); + ESP_LOGW(TAG, "Check new version failed, retry in %d seconds (%d/%d)", retry_delay, retry_count, MAX_RETRY); + for (int i = 0; i < retry_delay; i++) { + vTaskDelay(pdMS_TO_TICKS(1000)); + if (device_state_ == kDeviceStateIdle) { + break; + } + } + retry_delay *= 2; // 每次重试后延迟时间翻倍 continue; } retry_count = 0; + retry_delay = 10; // 重置重试延迟时间 if (ota_.HasNewVersion()) { Alert(Lang::Strings::OTA_UPGRADE, Lang::Strings::UPGRADING, "happy", Lang::Sounds::P3_UPGRADE); @@ -125,24 +136,34 @@ void Application::CheckNewVersion() { // No new version, mark the current version as valid ota_.MarkCurrentVersionValid(); - if (ota_.HasActivationCode()) { - // Activation code is valid - SetDeviceState(kDeviceStateActivating); - ShowActivationCode(); - - // Check again in 60 seconds or until the device is idle - for (int i = 0; i < 60; ++i) { - if (device_state_ == kDeviceStateIdle) { - break; - } - vTaskDelay(pdMS_TO_TICKS(1000)); - } - continue; + if (!ota_.HasActivationCode() && !ota_.HasActivationChallenge()) { + xEventGroupSetBits(event_group_, CHECK_NEW_VERSION_DONE_EVENT); + // Exit the loop if done checking new version + break; } - xEventGroupSetBits(event_group_, CHECK_NEW_VERSION_DONE_EVENT); - // Exit the loop if done checking new version - break; + display->SetStatus(Lang::Strings::ACTIVATION); + // Activation code is shown to the user and waiting for the user to input + if (ota_.HasActivationCode()) { + ShowActivationCode(); + } + + // This will block the loop until the activation is done or timeout + for (int i = 0; i < 10; ++i) { + ESP_LOGI(TAG, "Activating... %d/%d", i + 1, 10); + esp_err_t err = ota_.Activate(); + if (err == ESP_OK) { + xEventGroupSetBits(event_group_, CHECK_NEW_VERSION_DONE_EVENT); + break; + } else if (err == ESP_ERR_TIMEOUT) { + vTaskDelay(pdMS_TO_TICKS(3000)); + } else { + vTaskDelay(pdMS_TO_TICKS(10000)); + } + if (device_state_ == kDeviceStateIdle) { + break; + } + } } } @@ -347,7 +368,6 @@ void Application::Start() { board.StartNetwork(); // Check for new firmware version or get the MQTT broker address - display->SetStatus(Lang::Strings::CHECKING_NEW_VERSION); CheckNewVersion(); // Initialize the protocol @@ -648,17 +668,23 @@ void Application::OnAudioInput() { #if CONFIG_USE_WAKE_WORD_DETECT if (wake_word_detect_.IsDetectionRunning()) { std::vector data; - ReadAudio(data, 16000, wake_word_detect_.GetFeedSize()); - wake_word_detect_.Feed(data); - return; + int samples = wake_word_detect_.GetFeedSize(); + if (samples > 0) { + ReadAudio(data, 16000, samples); + wake_word_detect_.Feed(data); + return; + } } #endif #if CONFIG_USE_AUDIO_PROCESSOR if (audio_processor_.IsRunning()) { std::vector data; - ReadAudio(data, 16000, audio_processor_.GetFeedSize()); - audio_processor_.Feed(data); - return; + int samples = audio_processor_.GetFeedSize(); + if (samples > 0) { + ReadAudio(data, 16000, samples); + audio_processor_.Feed(data); + return; + } } #else if (device_state_ == kDeviceStateListening) { diff --git a/main/audio_processing/audio_processor.cc b/main/audio_processing/audio_processor.cc index b726e1c1..9bab939b 100644 --- a/main/audio_processing/audio_processor.cc +++ b/main/audio_processing/audio_processor.cc @@ -65,10 +65,16 @@ AudioProcessor::~AudioProcessor() { } size_t AudioProcessor::GetFeedSize() { + if (afe_data_ == nullptr) { + return 0; + } return afe_iface_->get_feed_chunksize(afe_data_) * codec_->input_channels(); } void AudioProcessor::Feed(const std::vector& data) { + if (afe_data_ == nullptr) { + return; + } afe_iface_->feed(afe_data_, data.data()); } diff --git a/main/audio_processing/wake_word_detect.cc b/main/audio_processing/wake_word_detect.cc index 42ed1024..f623eb37 100644 --- a/main/audio_processing/wake_word_detect.cc +++ b/main/audio_processing/wake_word_detect.cc @@ -93,10 +93,16 @@ bool WakeWordDetect::IsDetectionRunning() { } void WakeWordDetect::Feed(const std::vector& data) { + if (afe_data_ == nullptr) { + return; + } afe_iface_->feed(afe_data_, data.data()); } size_t WakeWordDetect::GetFeedSize() { + if (afe_data_ == nullptr) { + return 0; + } return afe_iface_->get_feed_chunksize(afe_data_) * codec_->input_channels(); } diff --git a/main/display/display.cc b/main/display/display.cc index 633b67a2..71b0466a 100644 --- a/main/display/display.cc +++ b/main/display/display.cc @@ -178,6 +178,7 @@ void Display::Update() { kDeviceStateStarting, kDeviceStateWifiConfiguring, kDeviceStateListening, + kDeviceStateActivating, }; if (std::find(allowed_states.begin(), allowed_states.end(), device_state) != allowed_states.end()) { icon = board.GetNetworkStateIcon(); diff --git a/main/ota.cc b/main/ota.cc index 8544ff64..6527a86b 100644 --- a/main/ota.cc +++ b/main/ota.cc @@ -1,6 +1,5 @@ #include "ota.h" #include "system_info.h" -#include "board.h" #include "settings.h" #include "assets/lang_config.h" @@ -9,6 +8,9 @@ #include #include #include +#include +#include +#include #include #include @@ -20,6 +22,17 @@ Ota::Ota() { SetCheckVersionUrl(CONFIG_OTA_VERSION_URL); + + // Read Serial Number from efuse user_data + uint8_t serial_number[33] = {0}; + if (esp_efuse_read_field_blob(ESP_EFUSE_USER_DATA, serial_number, 32 * 8) == ESP_OK) { + if (serial_number[0] == 0) { + has_serial_number_ = false; + } else { + serial_number_ = std::string(reinterpret_cast(serial_number), 32); + has_serial_number_ = true; + } + } } Ota::~Ota() { @@ -33,6 +46,25 @@ void Ota::SetHeader(const std::string& key, const std::string& value) { headers_[key] = value; } +Http* Ota::SetupHttp() { + auto& board = Board::GetInstance(); + auto app_desc = esp_app_get_description(); + + auto http = board.CreateHttp(); + for (const auto& header : headers_) { + http->SetHeader(header.first, header.second); + } + + http->SetHeader("Activation-Version", has_serial_number_ ? "2" : "1"); + http->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str()); + http->SetHeader("Client-Id", board.GetUuid()); + http->SetHeader("User-Agent", std::string(BOARD_NAME "/") + app_desc->version); + http->SetHeader("Accept-Language", Lang::CODE); + http->SetHeader("Content-Type", "application/json"); + + return http; +} + bool Ota::CheckVersion() { auto& board = Board::GetInstance(); auto app_desc = esp_app_get_description(); @@ -46,17 +78,7 @@ bool Ota::CheckVersion() { return false; } - auto http = board.CreateHttp(); - for (const auto& header : headers_) { - http->SetHeader(header.first, header.second); - } - - http->SetHeader("Ota-Version", "2"); - http->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str()); - http->SetHeader("Client-Id", board.GetUuid()); - http->SetHeader("User-Agent", std::string(BOARD_NAME "/") + app_desc->version); - http->SetHeader("Accept-Language", Lang::CODE); - http->SetHeader("Content-Type", "application/json"); + auto http = SetupHttp(); std::string data = board.GetJson(); std::string method = data.length() > 0 ? "POST" : "GET"; @@ -81,6 +103,7 @@ bool Ota::CheckVersion() { } has_activation_code_ = false; + has_activation_challenge_ = false; cJSON *activation = cJSON_GetObjectItem(root, "activation"); if (activation != NULL) { cJSON* message = cJSON_GetObjectItem(activation, "message"); @@ -90,8 +113,17 @@ bool Ota::CheckVersion() { cJSON* code = cJSON_GetObjectItem(activation, "code"); if (code != NULL) { activation_code_ = code->valuestring; + has_activation_code_ = true; + } + cJSON* challenge = cJSON_GetObjectItem(activation, "challenge"); + if (challenge != NULL) { + activation_challenge_ = challenge->valuestring; + has_activation_challenge_ = true; + } + cJSON* timeout_ms = cJSON_GetObjectItem(activation, "timeout_ms"); + if (timeout_ms != NULL) { + activation_timeout_ms_ = timeout_ms->valueint; } - has_activation_code_ = true; } has_mqtt_config_ = false; @@ -327,3 +359,76 @@ bool Ota::IsNewVersionAvailable(const std::string& currentVersion, const std::st return newer.size() > current.size(); } + +std::string Ota::GetActivationPayload() { + if (!has_serial_number_) { + ESP_LOGI(TAG, "No serial number found"); + return "{}"; + } + + uint8_t hmac_result[32]; // SHA-256 输出为32字节 + + // 使用Key0计算HMAC + esp_err_t ret = esp_hmac_calculate(HMAC_KEY0, (uint8_t*)activation_challenge_.data(), activation_challenge_.size(), hmac_result); + if (ret != ESP_OK) { + ESP_LOGE(TAG, "HMAC calculation failed: %s", esp_err_to_name(ret)); + return "{}"; + } + + std::string hmac_hex; + for (size_t i = 0; i < sizeof(hmac_result); i++) { + char buffer[3]; + sprintf(buffer, "%02x", hmac_result[i]); + hmac_hex += buffer; + } + + cJSON *payload = cJSON_CreateObject(); + cJSON_AddStringToObject(payload, "algorithm", "hmac-sha256"); + cJSON_AddStringToObject(payload, "serial_number", serial_number_.c_str()); + cJSON_AddStringToObject(payload, "challenge", activation_challenge_.c_str()); + cJSON_AddStringToObject(payload, "hmac", hmac_hex.c_str()); + std::string json = cJSON_Print(payload); + cJSON_Delete(payload); + + ESP_LOGI(TAG, "Activation payload: %s", json.c_str()); + return json; +} + +esp_err_t Ota::Activate() { + if (!has_activation_challenge_) { + ESP_LOGW(TAG, "No activation challenge found"); + return ESP_FAIL; + } + + std::string url = check_version_url_; + if (url.back() != '/') { + url += "/activate"; + } else { + url += "activate"; + } + + auto http = SetupHttp(); + + std::string data = GetActivationPayload(); + if (!http->Open("POST", url, data)) { + ESP_LOGE(TAG, "Failed to open HTTP connection"); + delete http; + return ESP_FAIL; + } + + auto status_code = http->GetStatusCode(); + data = http->GetBody(); + http->Close(); + delete http; + + if (status_code == 202) { + return ESP_ERR_TIMEOUT; + } + if (status_code != 200) { + ESP_LOGE(TAG, "Failed to activate, code: %d, body: %s", status_code, data.c_str()); + return ESP_FAIL; + } + + ESP_LOGI(TAG, "Activation successful"); + return ESP_OK; +} diff --git a/main/ota.h b/main/ota.h index acf21d98..d7aedd1d 100644 --- a/main/ota.h +++ b/main/ota.h @@ -5,6 +5,9 @@ #include #include +#include +#include "board.h" + class Ota { public: Ota(); @@ -13,6 +16,8 @@ public: void SetCheckVersionUrl(std::string check_version_url); void SetHeader(const std::string& key, const std::string& value); bool CheckVersion(); + esp_err_t Activate(); + bool HasActivationChallenge() { return has_activation_challenge_; } bool HasNewVersion() { return has_new_version_; } bool HasMqttConfig() { return has_mqtt_config_; } bool HasActivationCode() { return has_activation_code_; } @@ -33,15 +38,22 @@ private: bool has_mqtt_config_ = false; bool has_server_time_ = false; bool has_activation_code_ = false; + bool has_serial_number_ = false; + bool has_activation_challenge_ = false; std::string current_version_; std::string firmware_version_; std::string firmware_url_; + std::string activation_challenge_; + std::string serial_number_; + int activation_timeout_ms_ = 30000; std::map headers_; void Upgrade(const std::string& firmware_url); std::function upgrade_callback_; std::vector ParseVersion(const std::string& version); bool IsNewVersionAvailable(const std::string& currentVersion, const std::string& newVersion); + std::string GetActivationPayload(); + Http* SetupHttp(); }; #endif // _OTA_H