fix multinet model for v2 (#1208)

This commit is contained in:
Xiaoxia
2025-09-16 19:05:28 +08:00
committed by GitHub
parent d188415949
commit d2e99bae34
3 changed files with 317 additions and 37 deletions

View File

@@ -1,11 +1,13 @@
#include "custom_wake_word.h"
#include "audio_service.h"
#include "system_info.h"
#include "assets.h"
#include <esp_log.h>
#include "esp_mn_iface.h"
#include "esp_mn_models.h"
#include "esp_mn_speech_commands.h"
#include <esp_mn_iface.h>
#include <esp_mn_models.h>
#include <esp_mn_speech_commands.h>
#include <cJSON.h>
#define TAG "CustomWakeWord"
@@ -34,13 +36,66 @@ CustomWakeWord::~CustomWakeWord() {
}
}
void CustomWakeWord::ParseWakenetModelConfig() {
// Read index.json
auto& assets = Assets::GetInstance();
void* ptr = nullptr;
size_t size = 0;
if (!assets.GetAssetData("index.json", ptr, size)) {
ESP_LOGE(TAG, "Failed to read index.json");
return;
}
cJSON* root = cJSON_ParseWithLength(static_cast<char*>(ptr), size);
if (root == nullptr) {
ESP_LOGE(TAG, "Failed to parse index.json");
return;
}
cJSON* multinet_model = cJSON_GetObjectItem(root, "multinet_model");
if (cJSON_IsObject(multinet_model)) {
cJSON* language = cJSON_GetObjectItem(multinet_model, "language");
cJSON* duration = cJSON_GetObjectItem(multinet_model, "duration");
cJSON* threshold = cJSON_GetObjectItem(multinet_model, "threshold");
cJSON* commands = cJSON_GetObjectItem(multinet_model, "commands");
if (cJSON_IsString(language)) {
language_ = language->valuestring;
}
if (cJSON_IsNumber(duration)) {
duration_ = duration->valueint;
}
if (cJSON_IsNumber(threshold)) {
threshold_ = threshold->valuedouble;
}
if (cJSON_IsArray(commands)) {
for (int i = 0; i < cJSON_GetArraySize(commands); i++) {
cJSON* command = cJSON_GetArrayItem(commands, i);
if (cJSON_IsObject(command)) {
cJSON* command_name = cJSON_GetObjectItem(command, "command");
cJSON* text = cJSON_GetObjectItem(command, "text");
cJSON* action = cJSON_GetObjectItem(command, "action");
if (cJSON_IsString(command_name) && cJSON_IsString(text) && cJSON_IsString(action)) {
commands_.push_back({command_name->valuestring, text->valuestring, action->valuestring});
ESP_LOGI(TAG, "Command: %s, Text: %s, Action: %s", command_name->valuestring, text->valuestring, action->valuestring);
}
}
}
}
}
cJSON_Delete(root);
}
bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list) {
codec_ = codec;
commands_.clear();
if (models_list == nullptr) {
language_ = "cn";
models_ = esp_srmodel_init("model");
threshold_ = CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f;
commands_.push_back({CONFIG_CUSTOM_WAKE_WORD, CONFIG_CUSTOM_WAKE_WORD_DISPLAY, "wake"});
} else {
models_ = models_list;
ParseWakenetModelConfig();
}
if (models_ == nullptr || models_->num == -1) {
@@ -49,19 +104,20 @@ bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list)
}
// 初始化 multinet (命令词识别)
mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, ESP_MN_CHINESE);
mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, language_.c_str());
if (mn_name_ == nullptr) {
ESP_LOGE(TAG, "Failed to initialize multinet, mn_name is nullptr");
ESP_LOGI(TAG, "Please refer to https://pcn7cs20v8cr.feishu.cn/wiki/CpQjwQsCJiQSWSkYEvrcxcbVnwh to add custom wake word");
return false;
}
ESP_LOGI(TAG, "multinet: %s", mn_name_);
multinet_ = esp_mn_handle_from_name(mn_name_);
multinet_model_data_ = multinet_->create(mn_name_, 3000); // 3 秒超时
multinet_->set_det_threshold(multinet_model_data_, CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f);
multinet_model_data_ = multinet_->create(mn_name_, duration_);
multinet_->set_det_threshold(multinet_model_data_, threshold_);
esp_mn_commands_clear();
esp_mn_commands_add(1, CONFIG_CUSTOM_WAKE_WORD);
for (int i = 0; i < commands_.size(); i++) {
esp_mn_commands_add(i + 1, commands_[i].command.c_str());
}
esp_mn_commands_update();
multinet_->print_active_speech_commands(multinet_model_data_);
@@ -104,16 +160,18 @@ void CustomWakeWord::Feed(const std::vector<int16_t>& data) {
return;
} else if (mn_state == ESP_MN_STATE_DETECTED) {
esp_mn_results_t *mn_result = multinet_->get_results(multinet_model_data_);
ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f",
mn_result->command_id[0], mn_result->string, mn_result->prob[0]);
if (mn_result->command_id[0] == 1) {
last_detected_wake_word_ = CONFIG_CUSTOM_WAKE_WORD_DISPLAY;
}
running_ = false;
if (wake_word_detected_callback_) {
wake_word_detected_callback_(last_detected_wake_word_);
for (int i = 0; i < mn_result->num && running_; i++) {
ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f",
mn_result->command_id[i], mn_result->string, mn_result->prob[i]);
auto& command = commands_[mn_result->command_id[i] - 1];
if (command.action == "wake") {
last_detected_wake_word_ = command.text;
running_ = false;
if (wake_word_detected_callback_) {
wake_word_detected_callback_(last_detected_wake_word_);
}
}
}
multinet_->clean(multinet_model_data_);
} else if (mn_state == ESP_MN_STATE_TIMEOUT) {

View File

@@ -33,11 +33,21 @@ public:
const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; }
private:
struct Command {
std::string command;
std::string text;
std::string action;
};
// multinet 相关成员变量
esp_mn_iface_t* multinet_ = nullptr;
model_iface_data_t* multinet_model_data_ = nullptr;
srmodel_list_t *models_ = nullptr;
char* mn_name_ = nullptr;
std::string language_ = "cn";
int duration_ = 3000;
float threshold_ = 0.2;
std::deque<Command> commands_;
std::function<void(const std::string& wake_word)> wake_word_detected_callback_;
AudioCodec* codec_ = nullptr;
@@ -53,6 +63,7 @@ private:
std::condition_variable wake_word_cv_;
void StoreWakeWordData(const std::vector<int16_t>& data);
void ParseWakenetModelConfig();
};
#endif