fix multinet model for v2 (#1208)

This commit is contained in:
Xiaoxia
2025-09-16 19:05:28 +08:00
committed by GitHub
parent d188415949
commit d2e99bae34
3 changed files with 317 additions and 37 deletions

View File

@@ -1,11 +1,13 @@
#include "custom_wake_word.h"
#include "audio_service.h"
#include "system_info.h"
#include "assets.h"
#include <esp_log.h>
#include "esp_mn_iface.h"
#include "esp_mn_models.h"
#include "esp_mn_speech_commands.h"
#include <esp_mn_iface.h>
#include <esp_mn_models.h>
#include <esp_mn_speech_commands.h>
#include <cJSON.h>
#define TAG "CustomWakeWord"
@@ -34,13 +36,66 @@ CustomWakeWord::~CustomWakeWord() {
}
}
void CustomWakeWord::ParseWakenetModelConfig() {
// Read index.json
auto& assets = Assets::GetInstance();
void* ptr = nullptr;
size_t size = 0;
if (!assets.GetAssetData("index.json", ptr, size)) {
ESP_LOGE(TAG, "Failed to read index.json");
return;
}
cJSON* root = cJSON_ParseWithLength(static_cast<char*>(ptr), size);
if (root == nullptr) {
ESP_LOGE(TAG, "Failed to parse index.json");
return;
}
cJSON* multinet_model = cJSON_GetObjectItem(root, "multinet_model");
if (cJSON_IsObject(multinet_model)) {
cJSON* language = cJSON_GetObjectItem(multinet_model, "language");
cJSON* duration = cJSON_GetObjectItem(multinet_model, "duration");
cJSON* threshold = cJSON_GetObjectItem(multinet_model, "threshold");
cJSON* commands = cJSON_GetObjectItem(multinet_model, "commands");
if (cJSON_IsString(language)) {
language_ = language->valuestring;
}
if (cJSON_IsNumber(duration)) {
duration_ = duration->valueint;
}
if (cJSON_IsNumber(threshold)) {
threshold_ = threshold->valuedouble;
}
if (cJSON_IsArray(commands)) {
for (int i = 0; i < cJSON_GetArraySize(commands); i++) {
cJSON* command = cJSON_GetArrayItem(commands, i);
if (cJSON_IsObject(command)) {
cJSON* command_name = cJSON_GetObjectItem(command, "command");
cJSON* text = cJSON_GetObjectItem(command, "text");
cJSON* action = cJSON_GetObjectItem(command, "action");
if (cJSON_IsString(command_name) && cJSON_IsString(text) && cJSON_IsString(action)) {
commands_.push_back({command_name->valuestring, text->valuestring, action->valuestring});
ESP_LOGI(TAG, "Command: %s, Text: %s, Action: %s", command_name->valuestring, text->valuestring, action->valuestring);
}
}
}
}
}
cJSON_Delete(root);
}
bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list) {
codec_ = codec;
commands_.clear();
if (models_list == nullptr) {
language_ = "cn";
models_ = esp_srmodel_init("model");
threshold_ = CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f;
commands_.push_back({CONFIG_CUSTOM_WAKE_WORD, CONFIG_CUSTOM_WAKE_WORD_DISPLAY, "wake"});
} else {
models_ = models_list;
ParseWakenetModelConfig();
}
if (models_ == nullptr || models_->num == -1) {
@@ -49,19 +104,20 @@ bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list)
}
// 初始化 multinet (命令词识别)
mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, ESP_MN_CHINESE);
mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, language_.c_str());
if (mn_name_ == nullptr) {
ESP_LOGE(TAG, "Failed to initialize multinet, mn_name is nullptr");
ESP_LOGI(TAG, "Please refer to https://pcn7cs20v8cr.feishu.cn/wiki/CpQjwQsCJiQSWSkYEvrcxcbVnwh to add custom wake word");
return false;
}
ESP_LOGI(TAG, "multinet: %s", mn_name_);
multinet_ = esp_mn_handle_from_name(mn_name_);
multinet_model_data_ = multinet_->create(mn_name_, 3000); // 3 秒超时
multinet_->set_det_threshold(multinet_model_data_, CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f);
multinet_model_data_ = multinet_->create(mn_name_, duration_);
multinet_->set_det_threshold(multinet_model_data_, threshold_);
esp_mn_commands_clear();
esp_mn_commands_add(1, CONFIG_CUSTOM_WAKE_WORD);
for (int i = 0; i < commands_.size(); i++) {
esp_mn_commands_add(i + 1, commands_[i].command.c_str());
}
esp_mn_commands_update();
multinet_->print_active_speech_commands(multinet_model_data_);
@@ -104,17 +160,19 @@ void CustomWakeWord::Feed(const std::vector<int16_t>& data) {
return;
} else if (mn_state == ESP_MN_STATE_DETECTED) {
esp_mn_results_t *mn_result = multinet_->get_results(multinet_model_data_);
for (int i = 0; i < mn_result->num && running_; i++) {
ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f",
mn_result->command_id[0], mn_result->string, mn_result->prob[0]);
if (mn_result->command_id[0] == 1) {
last_detected_wake_word_ = CONFIG_CUSTOM_WAKE_WORD_DISPLAY;
}
mn_result->command_id[i], mn_result->string, mn_result->prob[i]);
auto& command = commands_[mn_result->command_id[i] - 1];
if (command.action == "wake") {
last_detected_wake_word_ = command.text;
running_ = false;
if (wake_word_detected_callback_) {
wake_word_detected_callback_(last_detected_wake_word_);
}
}
}
multinet_->clean(multinet_model_data_);
} else if (mn_state == ESP_MN_STATE_TIMEOUT) {
ESP_LOGD(TAG, "Command word detection timeout, cleaning state");

View File

@@ -33,11 +33,21 @@ public:
const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; }
private:
struct Command {
std::string command;
std::string text;
std::string action;
};
// multinet 相关成员变量
esp_mn_iface_t* multinet_ = nullptr;
model_iface_data_t* multinet_model_data_ = nullptr;
srmodel_list_t *models_ = nullptr;
char* mn_name_ = nullptr;
std::string language_ = "cn";
int duration_ = 3000;
float threshold_ = 0.2;
std::deque<Command> commands_;
std::function<void(const std::string& wake_word)> wake_word_detected_callback_;
AudioCodec* codec_ = nullptr;
@@ -53,6 +63,7 @@ private:
std::condition_variable wake_word_cv_;
void StoreWakeWordData(const std::vector<int16_t>& data);
void ParseWakenetModelConfig();
};
#endif

View File

@@ -156,21 +156,44 @@ def copy_directory(src, dst):
return False
def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
"""Process wakenet_model parameter"""
if not wakenet_model_dir:
def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_dir):
"""Process SR models (wakenet and multinet) and generate srmodels.bin"""
if not wakenet_model_dir and not multinet_model_dirs:
return None
# Copy input directory to build directory
wakenet_build_dir = os.path.join(build_dir, "wakenet_model")
if os.path.exists(wakenet_build_dir):
shutil.rmtree(wakenet_build_dir)
copy_directory(wakenet_model_dir, os.path.join(wakenet_build_dir, os.path.basename(wakenet_model_dir)))
# Create SR models build directory
sr_models_build_dir = os.path.join(build_dir, "srmodels")
if os.path.exists(sr_models_build_dir):
shutil.rmtree(sr_models_build_dir)
os.makedirs(sr_models_build_dir)
models_processed = 0
# Copy wakenet model if available
if wakenet_model_dir:
wakenet_name = os.path.basename(wakenet_model_dir)
wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name)
if copy_directory(wakenet_model_dir, wakenet_dst):
models_processed += 1
print(f"Added wakenet model: {wakenet_name}")
# Copy multinet models if available
if multinet_model_dirs:
for multinet_model_dir in multinet_model_dirs:
multinet_name = os.path.basename(multinet_model_dir)
multinet_dst = os.path.join(sr_models_build_dir, multinet_name)
if copy_directory(multinet_model_dir, multinet_dst):
models_processed += 1
print(f"Added multinet model: {multinet_name}")
if models_processed == 0:
print("Warning: No SR models were successfully processed")
return None
# Use pack_models function to generate srmodels.bin
srmodels_output = os.path.join(wakenet_build_dir, "srmodels.bin")
srmodels_output = os.path.join(sr_models_build_dir, "srmodels.bin")
try:
pack_models(wakenet_build_dir, "srmodels.bin")
pack_models(sr_models_build_dir, "srmodels.bin")
print(f"Generated: {srmodels_output}")
# Copy srmodels.bin to assets directory
copy_file(srmodels_output, os.path.join(assets_dir, "srmodels.bin"))
@@ -180,6 +203,11 @@ def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
return None
def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
"""Process wakenet_model parameter (legacy compatibility function)"""
return process_sr_models(wakenet_model_dir, None, build_dir, assets_dir)
def process_text_font(text_font_file, assets_dir):
"""Process text_font parameter"""
if not text_font_file:
@@ -250,7 +278,7 @@ def process_extra_files(extra_files_dir, assets_dir):
return extra_files_list
def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files=None):
def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files=None, multinet_model_info=None):
"""Generate index.json file"""
index_data = {
"version": 1
@@ -268,6 +296,9 @@ def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra
if extra_files:
index_data["extra_files"] = extra_files
if multinet_model_info:
index_data["multinet_model"] = multinet_model_info
# Write index.json
index_path = os.path.join(assets_dir, "index.json")
with open(index_path, 'w', encoding='utf-8') as f:
@@ -434,6 +465,132 @@ def read_wakenet_from_sdkconfig(sdkconfig_path):
return models[0] if models else None
def read_multinet_from_sdkconfig(sdkconfig_path):
"""
Read multinet models from sdkconfig (based on movemodel.py logic)
Returns a list of multinet model names
"""
if not os.path.exists(sdkconfig_path):
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
return []
with io.open(sdkconfig_path, "r") as f:
models_string = ''
for label in f:
label = label.strip("\n")
if 'CONFIG_SR_MN' in label and label[0] != '#':
models_string += label
models = []
if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in models_string:
models.append('mn3_cn')
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in models_string:
models.append('mn4q8_cn')
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in models_string:
models.append('mn4_cn')
elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in models_string:
models.append('mn5q8_cn')
elif "CONFIG_SR_MN_CN_MULTINET6_QUANT" in models_string:
models.append('mn6_cn')
elif "CONFIG_SR_MN_CN_MULTINET6_AC_QUANT" in models_string:
models.append('mn6_cn_ac')
elif "CONFIG_SR_MN_CN_MULTINET7_QUANT" in models_string:
models.append('mn7_cn')
elif "CONFIG_SR_MN_CN_MULTINET7_AC_QUANT" in models_string:
models.append('mn7_cn_ac')
if "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in models_string:
models.append('mn5q8_en')
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in models_string:
models.append('mn5_en')
elif "CONFIG_SR_MN_EN_MULTINET6_QUANT" in models_string:
models.append('mn6_en')
elif "CONFIG_SR_MN_EN_MULTINET7_QUANT" in models_string:
models.append('mn7_en')
if "MULTINET6" in models_string or "MULTINET7" in models_string:
models.append('fst')
return models
def read_custom_wake_word_from_sdkconfig(sdkconfig_path):
"""
Read custom wake word configuration from sdkconfig
Returns a dict with custom wake word info or None if not configured
"""
if not os.path.exists(sdkconfig_path):
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
return None
config_values = {}
with io.open(sdkconfig_path, "r") as f:
for line in f:
line = line.strip("\n")
if line.startswith('#') or '=' not in line:
continue
# Check for custom wake word configuration
if 'CONFIG_USE_CUSTOM_WAKE_WORD=y' in line:
config_values['use_custom_wake_word'] = True
elif 'CONFIG_CUSTOM_WAKE_WORD=' in line and not line.startswith('#'):
# Extract string value (remove quotes)
value = line.split('=', 1)[1].strip('"')
config_values['wake_word'] = value
elif 'CONFIG_CUSTOM_WAKE_WORD_DISPLAY=' in line and not line.startswith('#'):
# Extract string value (remove quotes)
value = line.split('=', 1)[1].strip('"')
config_values['display'] = value
elif 'CONFIG_CUSTOM_WAKE_WORD_THRESHOLD=' in line and not line.startswith('#'):
# Extract numeric value
value = line.split('=', 1)[1]
try:
config_values['threshold'] = int(value)
except ValueError:
try:
config_values['threshold'] = float(value)
except ValueError:
print(f"Warning: Invalid threshold value: {value}")
config_values['threshold'] = 20 # default (will be converted to 0.2)
# Return config only if custom wake word is enabled and required fields are present
if (config_values.get('use_custom_wake_word', False) and
'wake_word' in config_values and
'display' in config_values and
'threshold' in config_values):
return {
'wake_word': config_values['wake_word'],
'display': config_values['display'],
'threshold': config_values['threshold'] / 100.0 # Convert to decimal (20 -> 0.2)
}
return None
def get_language_from_multinet_models(multinet_models):
"""
Determine language from multinet model names
Returns 'cn', 'en', or None
"""
if not multinet_models:
return None
# Check for Chinese models
cn_indicators = ['_cn', 'cn_']
en_indicators = ['_en', 'en_']
has_cn = any(any(indicator in model for indicator in cn_indicators) for model in multinet_models)
has_en = any(any(indicator in model for indicator in en_indicators) for model in multinet_models)
# If both or neither, default to cn
if has_cn and not has_en:
return 'cn'
elif has_en and not has_cn:
return 'en'
else:
return 'cn' # Default to Chinese
def get_wakenet_model_path(model_name, esp_sr_model_path):
"""
Get the full path to the wakenet model directory
@@ -449,6 +606,25 @@ def get_wakenet_model_path(model_name, esp_sr_model_path):
return None
def get_multinet_model_paths(model_names, esp_sr_model_path):
"""
Get the full paths to the multinet model directories
Returns a list of valid model paths
"""
if not model_names:
return []
valid_paths = []
for model_name in model_names:
multinet_model_path = os.path.join(esp_sr_model_path, 'multinet_model', model_name)
if os.path.exists(multinet_model_path):
valid_paths.append(multinet_model_path)
else:
print(f"Warning: Multinet model directory not found: {multinet_model_path}")
return valid_paths
def get_text_font_path(builtin_text_font, xiaozhi_fonts_path):
"""
Get the text font path if needed
@@ -485,7 +661,7 @@ def get_emoji_collection_path(default_emoji_collection, xiaozhi_fonts_path):
return None
def build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection_path, extra_files_path, output_path):
def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None):
"""
Build assets using integrated functions (no external dependencies)
"""
@@ -503,13 +679,13 @@ def build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection
print("Starting to build assets...")
# Process each component
srmodels = process_wakenet_model(wakenet_model_path, temp_build_dir, assets_dir) if wakenet_model_path else None
srmodels = process_sr_models(wakenet_model_path, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_path or multinet_model_paths) else None
text_font = process_text_font(text_font_path, assets_dir) if text_font_path else None
emoji_collection = process_emoji_collection(emoji_collection_path, assets_dir) if emoji_collection_path else None
extra_files = process_extra_files(extra_files_path, assets_dir) if extra_files_path else None
# Generate index.json
generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files)
generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files, multinet_model_info)
# Generate config.json for packing
config_path = generate_config_json(temp_build_dir, assets_dir)
@@ -578,9 +754,19 @@ def main():
print(f" emoji_collection: {args.emoji_collection}")
print(f" output: {args.output}")
# Read wakenet model from sdkconfig
# Read SR models from sdkconfig
wakenet_model_name = read_wakenet_from_sdkconfig(args.sdkconfig)
multinet_model_names = read_multinet_from_sdkconfig(args.sdkconfig)
# Get model paths
wakenet_model_path = get_wakenet_model_path(wakenet_model_name, args.esp_sr_model_path)
multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path)
# Print model information
if wakenet_model_name:
print(f" wakenet model: {wakenet_model_name}")
if multinet_model_names:
print(f" multinet models: {', '.join(multinet_model_names)}")
# Get text font path if needed
text_font_path = get_text_font_path(args.builtin_text_font, args.xiaozhi_fonts_path)
@@ -591,9 +777,34 @@ def main():
# Get extra files path if provided
extra_files_path = args.extra_files
# Read custom wake word configuration
custom_wake_word_config = read_custom_wake_word_from_sdkconfig(args.sdkconfig)
multinet_model_info = None
if custom_wake_word_config and multinet_model_names:
# Determine language from multinet models
language = get_language_from_multinet_models(multinet_model_names)
# Build multinet_model info structure
multinet_model_info = {
"language": language,
"duration": 3000, # Default duration in ms
"threshold": custom_wake_word_config['threshold'],
"commands": [
{
"command": custom_wake_word_config['wake_word'],
"text": custom_wake_word_config['display'],
"action": "wake"
}
]
}
print(f" custom wake word: {custom_wake_word_config['wake_word']} ({custom_wake_word_config['display']})")
print(f" wake word language: {language}")
print(f" wake word threshold: {custom_wake_word_config['threshold']}")
# Check if we have anything to build
if not wakenet_model_path and not text_font_path and not emoji_collection_path and not extra_files_path:
print("Warning: No assets to build (no wakenet, text font, emoji collection, or extra files)")
if not wakenet_model_path and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info:
print("Warning: No assets to build (no SR models, text font, emoji collection, extra files, or custom wake word)")
# Create an empty assets.bin file
os.makedirs(os.path.dirname(args.output), exist_ok=True)
with open(args.output, 'wb') as f:
@@ -602,8 +813,8 @@ def main():
return
# Build the assets
success = build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection_path,
extra_files_path, args.output)
success = build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path,
extra_files_path, args.output, multinet_model_info)
if not success:
sys.exit(1)