forked from xiaozhi/xiaozhi-esp32
fix multinet model for v2 (#1208)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
#include "custom_wake_word.h"
|
||||
#include "audio_service.h"
|
||||
#include "system_info.h"
|
||||
#include "assets.h"
|
||||
|
||||
#include <esp_log.h>
|
||||
#include "esp_mn_iface.h"
|
||||
#include "esp_mn_models.h"
|
||||
#include "esp_mn_speech_commands.h"
|
||||
#include <esp_mn_iface.h>
|
||||
#include <esp_mn_models.h>
|
||||
#include <esp_mn_speech_commands.h>
|
||||
#include <cJSON.h>
|
||||
|
||||
|
||||
#define TAG "CustomWakeWord"
|
||||
@@ -34,13 +36,66 @@ CustomWakeWord::~CustomWakeWord() {
|
||||
}
|
||||
}
|
||||
|
||||
void CustomWakeWord::ParseWakenetModelConfig() {
|
||||
// Read index.json
|
||||
auto& assets = Assets::GetInstance();
|
||||
void* ptr = nullptr;
|
||||
size_t size = 0;
|
||||
if (!assets.GetAssetData("index.json", ptr, size)) {
|
||||
ESP_LOGE(TAG, "Failed to read index.json");
|
||||
return;
|
||||
}
|
||||
cJSON* root = cJSON_ParseWithLength(static_cast<char*>(ptr), size);
|
||||
if (root == nullptr) {
|
||||
ESP_LOGE(TAG, "Failed to parse index.json");
|
||||
return;
|
||||
}
|
||||
cJSON* multinet_model = cJSON_GetObjectItem(root, "multinet_model");
|
||||
if (cJSON_IsObject(multinet_model)) {
|
||||
cJSON* language = cJSON_GetObjectItem(multinet_model, "language");
|
||||
cJSON* duration = cJSON_GetObjectItem(multinet_model, "duration");
|
||||
cJSON* threshold = cJSON_GetObjectItem(multinet_model, "threshold");
|
||||
cJSON* commands = cJSON_GetObjectItem(multinet_model, "commands");
|
||||
if (cJSON_IsString(language)) {
|
||||
language_ = language->valuestring;
|
||||
}
|
||||
if (cJSON_IsNumber(duration)) {
|
||||
duration_ = duration->valueint;
|
||||
}
|
||||
if (cJSON_IsNumber(threshold)) {
|
||||
threshold_ = threshold->valuedouble;
|
||||
}
|
||||
if (cJSON_IsArray(commands)) {
|
||||
for (int i = 0; i < cJSON_GetArraySize(commands); i++) {
|
||||
cJSON* command = cJSON_GetArrayItem(commands, i);
|
||||
if (cJSON_IsObject(command)) {
|
||||
cJSON* command_name = cJSON_GetObjectItem(command, "command");
|
||||
cJSON* text = cJSON_GetObjectItem(command, "text");
|
||||
cJSON* action = cJSON_GetObjectItem(command, "action");
|
||||
if (cJSON_IsString(command_name) && cJSON_IsString(text) && cJSON_IsString(action)) {
|
||||
commands_.push_back({command_name->valuestring, text->valuestring, action->valuestring});
|
||||
ESP_LOGI(TAG, "Command: %s, Text: %s, Action: %s", command_name->valuestring, text->valuestring, action->valuestring);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
cJSON_Delete(root);
|
||||
}
|
||||
|
||||
|
||||
bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list) {
|
||||
codec_ = codec;
|
||||
commands_.clear();
|
||||
|
||||
if (models_list == nullptr) {
|
||||
language_ = "cn";
|
||||
models_ = esp_srmodel_init("model");
|
||||
threshold_ = CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f;
|
||||
commands_.push_back({CONFIG_CUSTOM_WAKE_WORD, CONFIG_CUSTOM_WAKE_WORD_DISPLAY, "wake"});
|
||||
} else {
|
||||
models_ = models_list;
|
||||
ParseWakenetModelConfig();
|
||||
}
|
||||
|
||||
if (models_ == nullptr || models_->num == -1) {
|
||||
@@ -49,19 +104,20 @@ bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list)
|
||||
}
|
||||
|
||||
// 初始化 multinet (命令词识别)
|
||||
mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, ESP_MN_CHINESE);
|
||||
mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, language_.c_str());
|
||||
if (mn_name_ == nullptr) {
|
||||
ESP_LOGE(TAG, "Failed to initialize multinet, mn_name is nullptr");
|
||||
ESP_LOGI(TAG, "Please refer to https://pcn7cs20v8cr.feishu.cn/wiki/CpQjwQsCJiQSWSkYEvrcxcbVnwh to add custom wake word");
|
||||
return false;
|
||||
}
|
||||
|
||||
ESP_LOGI(TAG, "multinet: %s", mn_name_);
|
||||
multinet_ = esp_mn_handle_from_name(mn_name_);
|
||||
multinet_model_data_ = multinet_->create(mn_name_, 3000); // 3 秒超时
|
||||
multinet_->set_det_threshold(multinet_model_data_, CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f);
|
||||
multinet_model_data_ = multinet_->create(mn_name_, duration_);
|
||||
multinet_->set_det_threshold(multinet_model_data_, threshold_);
|
||||
esp_mn_commands_clear();
|
||||
esp_mn_commands_add(1, CONFIG_CUSTOM_WAKE_WORD);
|
||||
for (int i = 0; i < commands_.size(); i++) {
|
||||
esp_mn_commands_add(i + 1, commands_[i].command.c_str());
|
||||
}
|
||||
esp_mn_commands_update();
|
||||
|
||||
multinet_->print_active_speech_commands(multinet_model_data_);
|
||||
@@ -104,16 +160,18 @@ void CustomWakeWord::Feed(const std::vector<int16_t>& data) {
|
||||
return;
|
||||
} else if (mn_state == ESP_MN_STATE_DETECTED) {
|
||||
esp_mn_results_t *mn_result = multinet_->get_results(multinet_model_data_);
|
||||
ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f",
|
||||
mn_result->command_id[0], mn_result->string, mn_result->prob[0]);
|
||||
|
||||
if (mn_result->command_id[0] == 1) {
|
||||
last_detected_wake_word_ = CONFIG_CUSTOM_WAKE_WORD_DISPLAY;
|
||||
}
|
||||
running_ = false;
|
||||
|
||||
if (wake_word_detected_callback_) {
|
||||
wake_word_detected_callback_(last_detected_wake_word_);
|
||||
for (int i = 0; i < mn_result->num && running_; i++) {
|
||||
ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f",
|
||||
mn_result->command_id[i], mn_result->string, mn_result->prob[i]);
|
||||
auto& command = commands_[mn_result->command_id[i] - 1];
|
||||
if (command.action == "wake") {
|
||||
last_detected_wake_word_ = command.text;
|
||||
running_ = false;
|
||||
|
||||
if (wake_word_detected_callback_) {
|
||||
wake_word_detected_callback_(last_detected_wake_word_);
|
||||
}
|
||||
}
|
||||
}
|
||||
multinet_->clean(multinet_model_data_);
|
||||
} else if (mn_state == ESP_MN_STATE_TIMEOUT) {
|
||||
|
||||
@@ -33,11 +33,21 @@ public:
|
||||
const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; }
|
||||
|
||||
private:
|
||||
struct Command {
|
||||
std::string command;
|
||||
std::string text;
|
||||
std::string action;
|
||||
};
|
||||
|
||||
// multinet 相关成员变量
|
||||
esp_mn_iface_t* multinet_ = nullptr;
|
||||
model_iface_data_t* multinet_model_data_ = nullptr;
|
||||
srmodel_list_t *models_ = nullptr;
|
||||
char* mn_name_ = nullptr;
|
||||
std::string language_ = "cn";
|
||||
int duration_ = 3000;
|
||||
float threshold_ = 0.2;
|
||||
std::deque<Command> commands_;
|
||||
|
||||
std::function<void(const std::string& wake_word)> wake_word_detected_callback_;
|
||||
AudioCodec* codec_ = nullptr;
|
||||
@@ -53,6 +63,7 @@ private:
|
||||
std::condition_variable wake_word_cv_;
|
||||
|
||||
void StoreWakeWordData(const std::vector<int16_t>& data);
|
||||
void ParseWakenetModelConfig();
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user