fix multiple wakenet words and custom wake word (#1226)

* fix multiple wakenet words and custom wake word

* fix idf_component.yml
This commit is contained in:
Xiaoxia
2025-09-22 10:49:08 +08:00
committed by GitHub
parent 96e39bea1b
commit d3e7fee828
5 changed files with 217 additions and 146 deletions

View File

@@ -17,8 +17,6 @@ import shutil
import sys
import json
import struct
import math
from pathlib import Path
from datetime import datetime
@@ -156,9 +154,9 @@ def copy_directory(src, dst):
return False
def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_dir):
def process_sr_models(wakenet_model_dirs, multinet_model_dirs, build_dir, assets_dir):
"""Process SR models (wakenet and multinet) and generate srmodels.bin"""
if not wakenet_model_dir and not multinet_model_dirs:
if not wakenet_model_dirs and not multinet_model_dirs:
return None
# Create SR models build directory
@@ -169,13 +167,14 @@ def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_
models_processed = 0
# Copy wakenet model if available
if wakenet_model_dir:
wakenet_name = os.path.basename(wakenet_model_dir)
wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name)
if copy_directory(wakenet_model_dir, wakenet_dst):
models_processed += 1
print(f"Added wakenet model: {wakenet_name}")
# Copy wakenet models if available
if wakenet_model_dirs:
for wakenet_model_dir in wakenet_model_dirs:
wakenet_name = os.path.basename(wakenet_model_dir)
wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name)
if copy_directory(wakenet_model_dir, wakenet_dst):
models_processed += 1
print(f"Added wakenet model: {wakenet_name}")
# Copy multinet models if available
if multinet_model_dirs:
@@ -203,11 +202,6 @@ def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_
return None
def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
"""Process wakenet_model parameter (legacy compatibility function)"""
return process_sr_models(wakenet_model_dir, None, build_dir, assets_dir)
def process_text_font(text_font_file, assets_dir):
"""Process text_font parameter"""
if not text_font_file:
@@ -440,12 +434,12 @@ def pack_assets_simple(target_path, include_path, out_file, assets_path, max_nam
def read_wakenet_from_sdkconfig(sdkconfig_path):
"""
Read wakenet model from sdkconfig (based on movemodel.py logic)
Returns the wakenet model name or None if no wakenet is configured
Read wakenet models from sdkconfig (based on movemodel.py logic)
Returns a list of wakenet model names
"""
if not os.path.exists(sdkconfig_path):
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
return None
return []
models = []
with io.open(sdkconfig_path, "r") as f:
@@ -461,8 +455,7 @@ def read_wakenet_from_sdkconfig(sdkconfig_path):
model_name = label.split("_SR_WN_")[-1].lower()
models.append(model_name)
# Return the first model found, or None if no models
return models[0] if models else None
return models
def read_multinet_from_sdkconfig(sdkconfig_path):
@@ -514,6 +507,46 @@ def read_multinet_from_sdkconfig(sdkconfig_path):
return models
def read_wake_word_type_from_sdkconfig(sdkconfig_path):
"""
Read wake word type configuration from sdkconfig
Returns a dict with wake word type info
"""
if not os.path.exists(sdkconfig_path):
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
return {
'use_esp_wake_word': False,
'use_afe_wake_word': False,
'use_custom_wake_word': False,
'wake_word_disabled': True
}
config_values = {
'use_esp_wake_word': False,
'use_afe_wake_word': False,
'use_custom_wake_word': False,
'wake_word_disabled': False
}
with io.open(sdkconfig_path, "r") as f:
for line in f:
line = line.strip("\n")
if line.startswith('#'):
continue
# Check for wake word type configuration
if 'CONFIG_USE_ESP_WAKE_WORD=y' in line:
config_values['use_esp_wake_word'] = True
elif 'CONFIG_USE_AFE_WAKE_WORD=y' in line:
config_values['use_afe_wake_word'] = True
elif 'CONFIG_USE_CUSTOM_WAKE_WORD=y' in line:
config_values['use_custom_wake_word'] = True
elif 'CONFIG_WAKE_WORD_DISABLED=y' in line:
config_values['wake_word_disabled'] = True
return config_values
def read_custom_wake_word_from_sdkconfig(sdkconfig_path):
"""
Read custom wake word configuration from sdkconfig
@@ -591,19 +624,23 @@ def get_language_from_multinet_models(multinet_models):
return 'cn' # Default to Chinese
def get_wakenet_model_path(model_name, esp_sr_model_path):
def get_wakenet_model_paths(model_names, esp_sr_model_path):
"""
Get the full path to the wakenet model directory
Get the full paths to the wakenet model directories
Returns a list of valid model paths
"""
if not model_name:
return None
if not model_names:
return []
wakenet_model_path = os.path.join(esp_sr_model_path, 'wakenet_model', model_name)
if os.path.exists(wakenet_model_path):
return wakenet_model_path
else:
print(f"Warning: Wakenet model directory not found: {wakenet_model_path}")
return None
valid_paths = []
for model_name in model_names:
wakenet_model_path = os.path.join(esp_sr_model_path, 'wakenet_model', model_name)
if os.path.exists(wakenet_model_path):
valid_paths.append(wakenet_model_path)
else:
print(f"Warning: Wakenet model directory not found: {wakenet_model_path}")
return valid_paths
def get_multinet_model_paths(model_names, esp_sr_model_path):
@@ -661,7 +698,7 @@ def get_emoji_collection_path(default_emoji_collection, xiaozhi_fonts_path):
return None
def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None):
def build_assets_integrated(wakenet_model_paths, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None):
"""
Build assets using integrated functions (no external dependencies)
"""
@@ -679,7 +716,7 @@ def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_
print("Starting to build assets...")
# Process each component
srmodels = process_sr_models(wakenet_model_path, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_path or multinet_model_paths) else None
srmodels = process_sr_models(wakenet_model_paths, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_paths or multinet_model_paths) else None
text_font = process_text_font(text_font_path, assets_dir) if text_font_path else None
emoji_collection = process_emoji_collection(emoji_collection_path, assets_dir) if emoji_collection_path else None
extra_files = process_extra_files(extra_files_path, assets_dir) if extra_files_path else None
@@ -734,19 +771,17 @@ def main():
args = parser.parse_args()
# Get script directory (not needed anymore but keep for future use)
script_dir = os.path.dirname(os.path.abspath(__file__))
# Set default paths if not provided
if not args.esp_sr_model_path:
# Default ESP-SR model path relative to project root
project_root = os.path.dirname(os.path.dirname(script_dir))
args.esp_sr_model_path = os.path.join(project_root, "managed_components", "espressif__esp-sr", "model")
if not args.xiaozhi_fonts_path:
# Default xiaozhi-fonts path relative to project root
project_root = os.path.dirname(os.path.dirname(script_dir))
args.xiaozhi_fonts_path = os.path.join(project_root, "managed_components", "78__xiaozhi-fonts")
if not args.esp_sr_model_path or not args.xiaozhi_fonts_path:
# Calculate project root from script location
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
if not args.esp_sr_model_path:
args.esp_sr_model_path = os.path.join(project_root, "managed_components", "espressif__esp-sr", "model")
if not args.xiaozhi_fonts_path:
args.xiaozhi_fonts_path = os.path.join(project_root, "components", "xiaozhi-fonts")
print("Building default assets...")
print(f" sdkconfig: {args.sdkconfig}")
@@ -754,19 +789,40 @@ def main():
print(f" emoji_collection: {args.emoji_collection}")
print(f" output: {args.output}")
# Read wake word type configuration from sdkconfig
wake_word_config = read_wake_word_type_from_sdkconfig(args.sdkconfig)
# Read SR models from sdkconfig
wakenet_model_name = read_wakenet_from_sdkconfig(args.sdkconfig)
wakenet_model_names = read_wakenet_from_sdkconfig(args.sdkconfig)
multinet_model_names = read_multinet_from_sdkconfig(args.sdkconfig)
# Get model paths
wakenet_model_path = get_wakenet_model_path(wakenet_model_name, args.esp_sr_model_path)
multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path)
# Apply wake word logic to decide which models to package
wakenet_model_paths = []
multinet_model_paths = []
# Print model information
if wakenet_model_name:
print(f" wakenet model: {wakenet_model_name}")
if multinet_model_names:
print(f" multinet models: {', '.join(multinet_model_names)}")
# 1. Only package wakenet models if USE_ESP_WAKE_WORD=y or USE_AFE_WAKE_WORD=y
if wake_word_config['use_esp_wake_word'] or wake_word_config['use_afe_wake_word']:
wakenet_model_paths = get_wakenet_model_paths(wakenet_model_names, args.esp_sr_model_path)
elif wakenet_model_names:
print(f" Note: Found wakenet models {wakenet_model_names} but wake word type is not ESP/AFE, skipping")
# 2. Error check: if USE_CUSTOM_WAKE_WORD=y but no multinet models selected, report error
if wake_word_config['use_custom_wake_word'] and not multinet_model_names:
print("Error: USE_CUSTOM_WAKE_WORD is enabled but no multinet models are selected in sdkconfig")
print("Please select appropriate CONFIG_SR_MN_* options in menuconfig, or disable USE_CUSTOM_WAKE_WORD")
sys.exit(1)
# 3. Only package multinet models if USE_CUSTOM_WAKE_WORD=y
if wake_word_config['use_custom_wake_word']:
multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path)
elif multinet_model_names:
print(f" Note: Found multinet models {multinet_model_names} but USE_CUSTOM_WAKE_WORD is disabled, skipping")
# Print model information (only for models that will actually be packaged)
if wakenet_model_paths:
print(f" wakenet models: {', '.join(wakenet_model_names)} (will be packaged)")
if multinet_model_paths:
print(f" multinet models: {', '.join(multinet_model_names)} (will be packaged)")
# Get text font path if needed
text_font_path = get_text_font_path(args.builtin_text_font, args.xiaozhi_fonts_path)
@@ -781,7 +837,7 @@ def main():
custom_wake_word_config = read_custom_wake_word_from_sdkconfig(args.sdkconfig)
multinet_model_info = None
if custom_wake_word_config and multinet_model_names:
if custom_wake_word_config and multinet_model_paths:
# Determine language from multinet models
language = get_language_from_multinet_models(multinet_model_names)
@@ -803,7 +859,7 @@ def main():
print(f" wake word threshold: {custom_wake_word_config['threshold']}")
# Check if we have anything to build
if not wakenet_model_path and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info:
if not wakenet_model_paths and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info:
print("Warning: No assets to build (no SR models, text font, emoji collection, extra files, or custom wake word)")
# Create an empty assets.bin file
os.makedirs(os.path.dirname(args.output), exist_ok=True)
@@ -813,7 +869,7 @@ def main():
return
# Build the assets
success = build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path,
success = build_assets_integrated(wakenet_model_paths, multinet_model_paths, text_font_path, emoji_collection_path,
extra_files_path, args.output, multinet_model_info)
if not success: