diff --git a/main/audio/wake_words/custom_wake_word.cc b/main/audio/wake_words/custom_wake_word.cc index 7c7f3fb3..56ac1a03 100644 --- a/main/audio/wake_words/custom_wake_word.cc +++ b/main/audio/wake_words/custom_wake_word.cc @@ -1,11 +1,13 @@ #include "custom_wake_word.h" #include "audio_service.h" #include "system_info.h" +#include "assets.h" #include -#include "esp_mn_iface.h" -#include "esp_mn_models.h" -#include "esp_mn_speech_commands.h" +#include +#include +#include +#include #define TAG "CustomWakeWord" @@ -34,13 +36,66 @@ CustomWakeWord::~CustomWakeWord() { } } +void CustomWakeWord::ParseWakenetModelConfig() { + // Read index.json + auto& assets = Assets::GetInstance(); + void* ptr = nullptr; + size_t size = 0; + if (!assets.GetAssetData("index.json", ptr, size)) { + ESP_LOGE(TAG, "Failed to read index.json"); + return; + } + cJSON* root = cJSON_ParseWithLength(static_cast(ptr), size); + if (root == nullptr) { + ESP_LOGE(TAG, "Failed to parse index.json"); + return; + } + cJSON* multinet_model = cJSON_GetObjectItem(root, "multinet_model"); + if (cJSON_IsObject(multinet_model)) { + cJSON* language = cJSON_GetObjectItem(multinet_model, "language"); + cJSON* duration = cJSON_GetObjectItem(multinet_model, "duration"); + cJSON* threshold = cJSON_GetObjectItem(multinet_model, "threshold"); + cJSON* commands = cJSON_GetObjectItem(multinet_model, "commands"); + if (cJSON_IsString(language)) { + language_ = language->valuestring; + } + if (cJSON_IsNumber(duration)) { + duration_ = duration->valueint; + } + if (cJSON_IsNumber(threshold)) { + threshold_ = threshold->valuedouble; + } + if (cJSON_IsArray(commands)) { + for (int i = 0; i < cJSON_GetArraySize(commands); i++) { + cJSON* command = cJSON_GetArrayItem(commands, i); + if (cJSON_IsObject(command)) { + cJSON* command_name = cJSON_GetObjectItem(command, "command"); + cJSON* text = cJSON_GetObjectItem(command, "text"); + cJSON* action = cJSON_GetObjectItem(command, "action"); + if (cJSON_IsString(command_name) && cJSON_IsString(text) && cJSON_IsString(action)) { + commands_.push_back({command_name->valuestring, text->valuestring, action->valuestring}); + ESP_LOGI(TAG, "Command: %s, Text: %s, Action: %s", command_name->valuestring, text->valuestring, action->valuestring); + } + } + } + } + } + cJSON_Delete(root); +} + + bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list) { codec_ = codec; + commands_.clear(); if (models_list == nullptr) { + language_ = "cn"; models_ = esp_srmodel_init("model"); + threshold_ = CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f; + commands_.push_back({CONFIG_CUSTOM_WAKE_WORD, CONFIG_CUSTOM_WAKE_WORD_DISPLAY, "wake"}); } else { models_ = models_list; + ParseWakenetModelConfig(); } if (models_ == nullptr || models_->num == -1) { @@ -49,19 +104,20 @@ bool CustomWakeWord::Initialize(AudioCodec* codec, srmodel_list_t* models_list) } // 初始化 multinet (命令词识别) - mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, ESP_MN_CHINESE); + mn_name_ = esp_srmodel_filter(models_, ESP_MN_PREFIX, language_.c_str()); if (mn_name_ == nullptr) { ESP_LOGE(TAG, "Failed to initialize multinet, mn_name is nullptr"); ESP_LOGI(TAG, "Please refer to https://pcn7cs20v8cr.feishu.cn/wiki/CpQjwQsCJiQSWSkYEvrcxcbVnwh to add custom wake word"); return false; } - ESP_LOGI(TAG, "multinet: %s", mn_name_); multinet_ = esp_mn_handle_from_name(mn_name_); - multinet_model_data_ = multinet_->create(mn_name_, 3000); // 3 秒超时 - multinet_->set_det_threshold(multinet_model_data_, CONFIG_CUSTOM_WAKE_WORD_THRESHOLD / 100.0f); + multinet_model_data_ = multinet_->create(mn_name_, duration_); + multinet_->set_det_threshold(multinet_model_data_, threshold_); esp_mn_commands_clear(); - esp_mn_commands_add(1, CONFIG_CUSTOM_WAKE_WORD); + for (int i = 0; i < commands_.size(); i++) { + esp_mn_commands_add(i + 1, commands_[i].command.c_str()); + } esp_mn_commands_update(); multinet_->print_active_speech_commands(multinet_model_data_); @@ -104,16 +160,18 @@ void CustomWakeWord::Feed(const std::vector& data) { return; } else if (mn_state == ESP_MN_STATE_DETECTED) { esp_mn_results_t *mn_result = multinet_->get_results(multinet_model_data_); - ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f", - mn_result->command_id[0], mn_result->string, mn_result->prob[0]); - - if (mn_result->command_id[0] == 1) { - last_detected_wake_word_ = CONFIG_CUSTOM_WAKE_WORD_DISPLAY; - } - running_ = false; - - if (wake_word_detected_callback_) { - wake_word_detected_callback_(last_detected_wake_word_); + for (int i = 0; i < mn_result->num && running_; i++) { + ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f", + mn_result->command_id[i], mn_result->string, mn_result->prob[i]); + auto& command = commands_[mn_result->command_id[i] - 1]; + if (command.action == "wake") { + last_detected_wake_word_ = command.text; + running_ = false; + + if (wake_word_detected_callback_) { + wake_word_detected_callback_(last_detected_wake_word_); + } + } } multinet_->clean(multinet_model_data_); } else if (mn_state == ESP_MN_STATE_TIMEOUT) { diff --git a/main/audio/wake_words/custom_wake_word.h b/main/audio/wake_words/custom_wake_word.h index e7157470..d4e6d8c3 100644 --- a/main/audio/wake_words/custom_wake_word.h +++ b/main/audio/wake_words/custom_wake_word.h @@ -33,11 +33,21 @@ public: const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; } private: + struct Command { + std::string command; + std::string text; + std::string action; + }; + // multinet 相关成员变量 esp_mn_iface_t* multinet_ = nullptr; model_iface_data_t* multinet_model_data_ = nullptr; srmodel_list_t *models_ = nullptr; char* mn_name_ = nullptr; + std::string language_ = "cn"; + int duration_ = 3000; + float threshold_ = 0.2; + std::deque commands_; std::function wake_word_detected_callback_; AudioCodec* codec_ = nullptr; @@ -53,6 +63,7 @@ private: std::condition_variable wake_word_cv_; void StoreWakeWordData(const std::vector& data); + void ParseWakenetModelConfig(); }; #endif diff --git a/scripts/build_default_assets.py b/scripts/build_default_assets.py index fa1c0832..36c7c107 100755 --- a/scripts/build_default_assets.py +++ b/scripts/build_default_assets.py @@ -156,21 +156,44 @@ def copy_directory(src, dst): return False -def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir): - """Process wakenet_model parameter""" - if not wakenet_model_dir: +def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_dir): + """Process SR models (wakenet and multinet) and generate srmodels.bin""" + if not wakenet_model_dir and not multinet_model_dirs: return None - # Copy input directory to build directory - wakenet_build_dir = os.path.join(build_dir, "wakenet_model") - if os.path.exists(wakenet_build_dir): - shutil.rmtree(wakenet_build_dir) - copy_directory(wakenet_model_dir, os.path.join(wakenet_build_dir, os.path.basename(wakenet_model_dir))) + # Create SR models build directory + sr_models_build_dir = os.path.join(build_dir, "srmodels") + if os.path.exists(sr_models_build_dir): + shutil.rmtree(sr_models_build_dir) + os.makedirs(sr_models_build_dir) + + models_processed = 0 + + # Copy wakenet model if available + if wakenet_model_dir: + wakenet_name = os.path.basename(wakenet_model_dir) + wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name) + if copy_directory(wakenet_model_dir, wakenet_dst): + models_processed += 1 + print(f"Added wakenet model: {wakenet_name}") + + # Copy multinet models if available + if multinet_model_dirs: + for multinet_model_dir in multinet_model_dirs: + multinet_name = os.path.basename(multinet_model_dir) + multinet_dst = os.path.join(sr_models_build_dir, multinet_name) + if copy_directory(multinet_model_dir, multinet_dst): + models_processed += 1 + print(f"Added multinet model: {multinet_name}") + + if models_processed == 0: + print("Warning: No SR models were successfully processed") + return None # Use pack_models function to generate srmodels.bin - srmodels_output = os.path.join(wakenet_build_dir, "srmodels.bin") + srmodels_output = os.path.join(sr_models_build_dir, "srmodels.bin") try: - pack_models(wakenet_build_dir, "srmodels.bin") + pack_models(sr_models_build_dir, "srmodels.bin") print(f"Generated: {srmodels_output}") # Copy srmodels.bin to assets directory copy_file(srmodels_output, os.path.join(assets_dir, "srmodels.bin")) @@ -180,6 +203,11 @@ def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir): return None +def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir): + """Process wakenet_model parameter (legacy compatibility function)""" + return process_sr_models(wakenet_model_dir, None, build_dir, assets_dir) + + def process_text_font(text_font_file, assets_dir): """Process text_font parameter""" if not text_font_file: @@ -250,7 +278,7 @@ def process_extra_files(extra_files_dir, assets_dir): return extra_files_list -def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files=None): +def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files=None, multinet_model_info=None): """Generate index.json file""" index_data = { "version": 1 @@ -268,6 +296,9 @@ def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra if extra_files: index_data["extra_files"] = extra_files + if multinet_model_info: + index_data["multinet_model"] = multinet_model_info + # Write index.json index_path = os.path.join(assets_dir, "index.json") with open(index_path, 'w', encoding='utf-8') as f: @@ -434,6 +465,132 @@ def read_wakenet_from_sdkconfig(sdkconfig_path): return models[0] if models else None +def read_multinet_from_sdkconfig(sdkconfig_path): + """ + Read multinet models from sdkconfig (based on movemodel.py logic) + Returns a list of multinet model names + """ + if not os.path.exists(sdkconfig_path): + print(f"Warning: sdkconfig file not found: {sdkconfig_path}") + return [] + + with io.open(sdkconfig_path, "r") as f: + models_string = '' + for label in f: + label = label.strip("\n") + if 'CONFIG_SR_MN' in label and label[0] != '#': + models_string += label + + models = [] + if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in models_string: + models.append('mn3_cn') + elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in models_string: + models.append('mn4q8_cn') + elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in models_string: + models.append('mn4_cn') + elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in models_string: + models.append('mn5q8_cn') + elif "CONFIG_SR_MN_CN_MULTINET6_QUANT" in models_string: + models.append('mn6_cn') + elif "CONFIG_SR_MN_CN_MULTINET6_AC_QUANT" in models_string: + models.append('mn6_cn_ac') + elif "CONFIG_SR_MN_CN_MULTINET7_QUANT" in models_string: + models.append('mn7_cn') + elif "CONFIG_SR_MN_CN_MULTINET7_AC_QUANT" in models_string: + models.append('mn7_cn_ac') + + if "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in models_string: + models.append('mn5q8_en') + elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in models_string: + models.append('mn5_en') + elif "CONFIG_SR_MN_EN_MULTINET6_QUANT" in models_string: + models.append('mn6_en') + elif "CONFIG_SR_MN_EN_MULTINET7_QUANT" in models_string: + models.append('mn7_en') + + if "MULTINET6" in models_string or "MULTINET7" in models_string: + models.append('fst') + + return models + + +def read_custom_wake_word_from_sdkconfig(sdkconfig_path): + """ + Read custom wake word configuration from sdkconfig + Returns a dict with custom wake word info or None if not configured + """ + if not os.path.exists(sdkconfig_path): + print(f"Warning: sdkconfig file not found: {sdkconfig_path}") + return None + + config_values = {} + with io.open(sdkconfig_path, "r") as f: + for line in f: + line = line.strip("\n") + if line.startswith('#') or '=' not in line: + continue + + # Check for custom wake word configuration + if 'CONFIG_USE_CUSTOM_WAKE_WORD=y' in line: + config_values['use_custom_wake_word'] = True + elif 'CONFIG_CUSTOM_WAKE_WORD=' in line and not line.startswith('#'): + # Extract string value (remove quotes) + value = line.split('=', 1)[1].strip('"') + config_values['wake_word'] = value + elif 'CONFIG_CUSTOM_WAKE_WORD_DISPLAY=' in line and not line.startswith('#'): + # Extract string value (remove quotes) + value = line.split('=', 1)[1].strip('"') + config_values['display'] = value + elif 'CONFIG_CUSTOM_WAKE_WORD_THRESHOLD=' in line and not line.startswith('#'): + # Extract numeric value + value = line.split('=', 1)[1] + try: + config_values['threshold'] = int(value) + except ValueError: + try: + config_values['threshold'] = float(value) + except ValueError: + print(f"Warning: Invalid threshold value: {value}") + config_values['threshold'] = 20 # default (will be converted to 0.2) + + # Return config only if custom wake word is enabled and required fields are present + if (config_values.get('use_custom_wake_word', False) and + 'wake_word' in config_values and + 'display' in config_values and + 'threshold' in config_values): + return { + 'wake_word': config_values['wake_word'], + 'display': config_values['display'], + 'threshold': config_values['threshold'] / 100.0 # Convert to decimal (20 -> 0.2) + } + + return None + + +def get_language_from_multinet_models(multinet_models): + """ + Determine language from multinet model names + Returns 'cn', 'en', or None + """ + if not multinet_models: + return None + + # Check for Chinese models + cn_indicators = ['_cn', 'cn_'] + en_indicators = ['_en', 'en_'] + + has_cn = any(any(indicator in model for indicator in cn_indicators) for model in multinet_models) + has_en = any(any(indicator in model for indicator in en_indicators) for model in multinet_models) + + # If both or neither, default to cn + if has_cn and not has_en: + return 'cn' + elif has_en and not has_cn: + return 'en' + else: + return 'cn' # Default to Chinese + + def get_wakenet_model_path(model_name, esp_sr_model_path): """ Get the full path to the wakenet model directory @@ -449,6 +606,25 @@ def get_wakenet_model_path(model_name, esp_sr_model_path): return None +def get_multinet_model_paths(model_names, esp_sr_model_path): + """ + Get the full paths to the multinet model directories + Returns a list of valid model paths + """ + if not model_names: + return [] + + valid_paths = [] + for model_name in model_names: + multinet_model_path = os.path.join(esp_sr_model_path, 'multinet_model', model_name) + if os.path.exists(multinet_model_path): + valid_paths.append(multinet_model_path) + else: + print(f"Warning: Multinet model directory not found: {multinet_model_path}") + + return valid_paths + + def get_text_font_path(builtin_text_font, xiaozhi_fonts_path): """ Get the text font path if needed @@ -485,7 +661,7 @@ def get_emoji_collection_path(default_emoji_collection, xiaozhi_fonts_path): return None -def build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection_path, extra_files_path, output_path): +def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None): """ Build assets using integrated functions (no external dependencies) """ @@ -503,13 +679,13 @@ def build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection print("Starting to build assets...") # Process each component - srmodels = process_wakenet_model(wakenet_model_path, temp_build_dir, assets_dir) if wakenet_model_path else None + srmodels = process_sr_models(wakenet_model_path, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_path or multinet_model_paths) else None text_font = process_text_font(text_font_path, assets_dir) if text_font_path else None emoji_collection = process_emoji_collection(emoji_collection_path, assets_dir) if emoji_collection_path else None extra_files = process_extra_files(extra_files_path, assets_dir) if extra_files_path else None # Generate index.json - generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files) + generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files, multinet_model_info) # Generate config.json for packing config_path = generate_config_json(temp_build_dir, assets_dir) @@ -578,9 +754,19 @@ def main(): print(f" emoji_collection: {args.emoji_collection}") print(f" output: {args.output}") - # Read wakenet model from sdkconfig + # Read SR models from sdkconfig wakenet_model_name = read_wakenet_from_sdkconfig(args.sdkconfig) + multinet_model_names = read_multinet_from_sdkconfig(args.sdkconfig) + + # Get model paths wakenet_model_path = get_wakenet_model_path(wakenet_model_name, args.esp_sr_model_path) + multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path) + + # Print model information + if wakenet_model_name: + print(f" wakenet model: {wakenet_model_name}") + if multinet_model_names: + print(f" multinet models: {', '.join(multinet_model_names)}") # Get text font path if needed text_font_path = get_text_font_path(args.builtin_text_font, args.xiaozhi_fonts_path) @@ -591,9 +777,34 @@ def main(): # Get extra files path if provided extra_files_path = args.extra_files + # Read custom wake word configuration + custom_wake_word_config = read_custom_wake_word_from_sdkconfig(args.sdkconfig) + multinet_model_info = None + + if custom_wake_word_config and multinet_model_names: + # Determine language from multinet models + language = get_language_from_multinet_models(multinet_model_names) + + # Build multinet_model info structure + multinet_model_info = { + "language": language, + "duration": 3000, # Default duration in ms + "threshold": custom_wake_word_config['threshold'], + "commands": [ + { + "command": custom_wake_word_config['wake_word'], + "text": custom_wake_word_config['display'], + "action": "wake" + } + ] + } + print(f" custom wake word: {custom_wake_word_config['wake_word']} ({custom_wake_word_config['display']})") + print(f" wake word language: {language}") + print(f" wake word threshold: {custom_wake_word_config['threshold']}") + # Check if we have anything to build - if not wakenet_model_path and not text_font_path and not emoji_collection_path and not extra_files_path: - print("Warning: No assets to build (no wakenet, text font, emoji collection, or extra files)") + if not wakenet_model_path and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info: + print("Warning: No assets to build (no SR models, text font, emoji collection, extra files, or custom wake word)") # Create an empty assets.bin file os.makedirs(os.path.dirname(args.output), exist_ok=True) with open(args.output, 'wb') as f: @@ -602,8 +813,8 @@ def main(): return # Build the assets - success = build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection_path, - extra_files_path, args.output) + success = build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path, + extra_files_path, args.output, multinet_model_info) if not success: sys.exit(1)