forked from xiaozhi/xiaozhi-esp32
fix multiple wakenet words and custom wake word (#1226)
* fix multiple wakenet words and custom wake word * fix idf_component.yml
This commit is contained in:
@@ -17,8 +17,6 @@ import shutil
|
||||
import sys
|
||||
import json
|
||||
import struct
|
||||
import math
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -156,9 +154,9 @@ def copy_directory(src, dst):
|
||||
return False
|
||||
|
||||
|
||||
def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_dir):
|
||||
def process_sr_models(wakenet_model_dirs, multinet_model_dirs, build_dir, assets_dir):
|
||||
"""Process SR models (wakenet and multinet) and generate srmodels.bin"""
|
||||
if not wakenet_model_dir and not multinet_model_dirs:
|
||||
if not wakenet_model_dirs and not multinet_model_dirs:
|
||||
return None
|
||||
|
||||
# Create SR models build directory
|
||||
@@ -169,13 +167,14 @@ def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_
|
||||
|
||||
models_processed = 0
|
||||
|
||||
# Copy wakenet model if available
|
||||
if wakenet_model_dir:
|
||||
wakenet_name = os.path.basename(wakenet_model_dir)
|
||||
wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name)
|
||||
if copy_directory(wakenet_model_dir, wakenet_dst):
|
||||
models_processed += 1
|
||||
print(f"Added wakenet model: {wakenet_name}")
|
||||
# Copy wakenet models if available
|
||||
if wakenet_model_dirs:
|
||||
for wakenet_model_dir in wakenet_model_dirs:
|
||||
wakenet_name = os.path.basename(wakenet_model_dir)
|
||||
wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name)
|
||||
if copy_directory(wakenet_model_dir, wakenet_dst):
|
||||
models_processed += 1
|
||||
print(f"Added wakenet model: {wakenet_name}")
|
||||
|
||||
# Copy multinet models if available
|
||||
if multinet_model_dirs:
|
||||
@@ -203,11 +202,6 @@ def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_
|
||||
return None
|
||||
|
||||
|
||||
def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
|
||||
"""Process wakenet_model parameter (legacy compatibility function)"""
|
||||
return process_sr_models(wakenet_model_dir, None, build_dir, assets_dir)
|
||||
|
||||
|
||||
def process_text_font(text_font_file, assets_dir):
|
||||
"""Process text_font parameter"""
|
||||
if not text_font_file:
|
||||
@@ -440,12 +434,12 @@ def pack_assets_simple(target_path, include_path, out_file, assets_path, max_nam
|
||||
|
||||
def read_wakenet_from_sdkconfig(sdkconfig_path):
|
||||
"""
|
||||
Read wakenet model from sdkconfig (based on movemodel.py logic)
|
||||
Returns the wakenet model name or None if no wakenet is configured
|
||||
Read wakenet models from sdkconfig (based on movemodel.py logic)
|
||||
Returns a list of wakenet model names
|
||||
"""
|
||||
if not os.path.exists(sdkconfig_path):
|
||||
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
|
||||
return None
|
||||
return []
|
||||
|
||||
models = []
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
@@ -461,8 +455,7 @@ def read_wakenet_from_sdkconfig(sdkconfig_path):
|
||||
model_name = label.split("_SR_WN_")[-1].lower()
|
||||
models.append(model_name)
|
||||
|
||||
# Return the first model found, or None if no models
|
||||
return models[0] if models else None
|
||||
return models
|
||||
|
||||
|
||||
def read_multinet_from_sdkconfig(sdkconfig_path):
|
||||
@@ -514,6 +507,46 @@ def read_multinet_from_sdkconfig(sdkconfig_path):
|
||||
return models
|
||||
|
||||
|
||||
def read_wake_word_type_from_sdkconfig(sdkconfig_path):
|
||||
"""
|
||||
Read wake word type configuration from sdkconfig
|
||||
Returns a dict with wake word type info
|
||||
"""
|
||||
if not os.path.exists(sdkconfig_path):
|
||||
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
|
||||
return {
|
||||
'use_esp_wake_word': False,
|
||||
'use_afe_wake_word': False,
|
||||
'use_custom_wake_word': False,
|
||||
'wake_word_disabled': True
|
||||
}
|
||||
|
||||
config_values = {
|
||||
'use_esp_wake_word': False,
|
||||
'use_afe_wake_word': False,
|
||||
'use_custom_wake_word': False,
|
||||
'wake_word_disabled': False
|
||||
}
|
||||
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip("\n")
|
||||
if line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Check for wake word type configuration
|
||||
if 'CONFIG_USE_ESP_WAKE_WORD=y' in line:
|
||||
config_values['use_esp_wake_word'] = True
|
||||
elif 'CONFIG_USE_AFE_WAKE_WORD=y' in line:
|
||||
config_values['use_afe_wake_word'] = True
|
||||
elif 'CONFIG_USE_CUSTOM_WAKE_WORD=y' in line:
|
||||
config_values['use_custom_wake_word'] = True
|
||||
elif 'CONFIG_WAKE_WORD_DISABLED=y' in line:
|
||||
config_values['wake_word_disabled'] = True
|
||||
|
||||
return config_values
|
||||
|
||||
|
||||
def read_custom_wake_word_from_sdkconfig(sdkconfig_path):
|
||||
"""
|
||||
Read custom wake word configuration from sdkconfig
|
||||
@@ -591,19 +624,23 @@ def get_language_from_multinet_models(multinet_models):
|
||||
return 'cn' # Default to Chinese
|
||||
|
||||
|
||||
def get_wakenet_model_path(model_name, esp_sr_model_path):
|
||||
def get_wakenet_model_paths(model_names, esp_sr_model_path):
|
||||
"""
|
||||
Get the full path to the wakenet model directory
|
||||
Get the full paths to the wakenet model directories
|
||||
Returns a list of valid model paths
|
||||
"""
|
||||
if not model_name:
|
||||
return None
|
||||
if not model_names:
|
||||
return []
|
||||
|
||||
wakenet_model_path = os.path.join(esp_sr_model_path, 'wakenet_model', model_name)
|
||||
if os.path.exists(wakenet_model_path):
|
||||
return wakenet_model_path
|
||||
else:
|
||||
print(f"Warning: Wakenet model directory not found: {wakenet_model_path}")
|
||||
return None
|
||||
valid_paths = []
|
||||
for model_name in model_names:
|
||||
wakenet_model_path = os.path.join(esp_sr_model_path, 'wakenet_model', model_name)
|
||||
if os.path.exists(wakenet_model_path):
|
||||
valid_paths.append(wakenet_model_path)
|
||||
else:
|
||||
print(f"Warning: Wakenet model directory not found: {wakenet_model_path}")
|
||||
|
||||
return valid_paths
|
||||
|
||||
|
||||
def get_multinet_model_paths(model_names, esp_sr_model_path):
|
||||
@@ -661,7 +698,7 @@ def get_emoji_collection_path(default_emoji_collection, xiaozhi_fonts_path):
|
||||
return None
|
||||
|
||||
|
||||
def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None):
|
||||
def build_assets_integrated(wakenet_model_paths, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None):
|
||||
"""
|
||||
Build assets using integrated functions (no external dependencies)
|
||||
"""
|
||||
@@ -679,7 +716,7 @@ def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_
|
||||
print("Starting to build assets...")
|
||||
|
||||
# Process each component
|
||||
srmodels = process_sr_models(wakenet_model_path, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_path or multinet_model_paths) else None
|
||||
srmodels = process_sr_models(wakenet_model_paths, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_paths or multinet_model_paths) else None
|
||||
text_font = process_text_font(text_font_path, assets_dir) if text_font_path else None
|
||||
emoji_collection = process_emoji_collection(emoji_collection_path, assets_dir) if emoji_collection_path else None
|
||||
extra_files = process_extra_files(extra_files_path, assets_dir) if extra_files_path else None
|
||||
@@ -734,19 +771,17 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get script directory (not needed anymore but keep for future use)
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Set default paths if not provided
|
||||
if not args.esp_sr_model_path:
|
||||
# Default ESP-SR model path relative to project root
|
||||
project_root = os.path.dirname(os.path.dirname(script_dir))
|
||||
args.esp_sr_model_path = os.path.join(project_root, "managed_components", "espressif__esp-sr", "model")
|
||||
|
||||
if not args.xiaozhi_fonts_path:
|
||||
# Default xiaozhi-fonts path relative to project root
|
||||
project_root = os.path.dirname(os.path.dirname(script_dir))
|
||||
args.xiaozhi_fonts_path = os.path.join(project_root, "managed_components", "78__xiaozhi-fonts")
|
||||
if not args.esp_sr_model_path or not args.xiaozhi_fonts_path:
|
||||
# Calculate project root from script location
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(script_dir)
|
||||
|
||||
if not args.esp_sr_model_path:
|
||||
args.esp_sr_model_path = os.path.join(project_root, "managed_components", "espressif__esp-sr", "model")
|
||||
|
||||
if not args.xiaozhi_fonts_path:
|
||||
args.xiaozhi_fonts_path = os.path.join(project_root, "components", "xiaozhi-fonts")
|
||||
|
||||
print("Building default assets...")
|
||||
print(f" sdkconfig: {args.sdkconfig}")
|
||||
@@ -754,19 +789,40 @@ def main():
|
||||
print(f" emoji_collection: {args.emoji_collection}")
|
||||
print(f" output: {args.output}")
|
||||
|
||||
# Read wake word type configuration from sdkconfig
|
||||
wake_word_config = read_wake_word_type_from_sdkconfig(args.sdkconfig)
|
||||
|
||||
# Read SR models from sdkconfig
|
||||
wakenet_model_name = read_wakenet_from_sdkconfig(args.sdkconfig)
|
||||
wakenet_model_names = read_wakenet_from_sdkconfig(args.sdkconfig)
|
||||
multinet_model_names = read_multinet_from_sdkconfig(args.sdkconfig)
|
||||
|
||||
# Get model paths
|
||||
wakenet_model_path = get_wakenet_model_path(wakenet_model_name, args.esp_sr_model_path)
|
||||
multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path)
|
||||
# Apply wake word logic to decide which models to package
|
||||
wakenet_model_paths = []
|
||||
multinet_model_paths = []
|
||||
|
||||
# Print model information
|
||||
if wakenet_model_name:
|
||||
print(f" wakenet model: {wakenet_model_name}")
|
||||
if multinet_model_names:
|
||||
print(f" multinet models: {', '.join(multinet_model_names)}")
|
||||
# 1. Only package wakenet models if USE_ESP_WAKE_WORD=y or USE_AFE_WAKE_WORD=y
|
||||
if wake_word_config['use_esp_wake_word'] or wake_word_config['use_afe_wake_word']:
|
||||
wakenet_model_paths = get_wakenet_model_paths(wakenet_model_names, args.esp_sr_model_path)
|
||||
elif wakenet_model_names:
|
||||
print(f" Note: Found wakenet models {wakenet_model_names} but wake word type is not ESP/AFE, skipping")
|
||||
|
||||
# 2. Error check: if USE_CUSTOM_WAKE_WORD=y but no multinet models selected, report error
|
||||
if wake_word_config['use_custom_wake_word'] and not multinet_model_names:
|
||||
print("Error: USE_CUSTOM_WAKE_WORD is enabled but no multinet models are selected in sdkconfig")
|
||||
print("Please select appropriate CONFIG_SR_MN_* options in menuconfig, or disable USE_CUSTOM_WAKE_WORD")
|
||||
sys.exit(1)
|
||||
|
||||
# 3. Only package multinet models if USE_CUSTOM_WAKE_WORD=y
|
||||
if wake_word_config['use_custom_wake_word']:
|
||||
multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path)
|
||||
elif multinet_model_names:
|
||||
print(f" Note: Found multinet models {multinet_model_names} but USE_CUSTOM_WAKE_WORD is disabled, skipping")
|
||||
|
||||
# Print model information (only for models that will actually be packaged)
|
||||
if wakenet_model_paths:
|
||||
print(f" wakenet models: {', '.join(wakenet_model_names)} (will be packaged)")
|
||||
if multinet_model_paths:
|
||||
print(f" multinet models: {', '.join(multinet_model_names)} (will be packaged)")
|
||||
|
||||
# Get text font path if needed
|
||||
text_font_path = get_text_font_path(args.builtin_text_font, args.xiaozhi_fonts_path)
|
||||
@@ -781,7 +837,7 @@ def main():
|
||||
custom_wake_word_config = read_custom_wake_word_from_sdkconfig(args.sdkconfig)
|
||||
multinet_model_info = None
|
||||
|
||||
if custom_wake_word_config and multinet_model_names:
|
||||
if custom_wake_word_config and multinet_model_paths:
|
||||
# Determine language from multinet models
|
||||
language = get_language_from_multinet_models(multinet_model_names)
|
||||
|
||||
@@ -803,7 +859,7 @@ def main():
|
||||
print(f" wake word threshold: {custom_wake_word_config['threshold']}")
|
||||
|
||||
# Check if we have anything to build
|
||||
if not wakenet_model_path and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info:
|
||||
if not wakenet_model_paths and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info:
|
||||
print("Warning: No assets to build (no SR models, text font, emoji collection, extra files, or custom wake word)")
|
||||
# Create an empty assets.bin file
|
||||
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
||||
@@ -813,7 +869,7 @@ def main():
|
||||
return
|
||||
|
||||
# Build the assets
|
||||
success = build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path,
|
||||
success = build_assets_integrated(wakenet_model_paths, multinet_model_paths, text_font_path, emoji_collection_path,
|
||||
extra_files_path, args.output, multinet_model_info)
|
||||
|
||||
if not success:
|
||||
|
||||
Reference in New Issue
Block a user