forked from xiaozhi/xiaozhi-esp32
fix multinet model for v2 (#1208)
This commit is contained in:
@@ -156,21 +156,44 @@ def copy_directory(src, dst):
|
||||
return False
|
||||
|
||||
|
||||
def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
|
||||
"""Process wakenet_model parameter"""
|
||||
if not wakenet_model_dir:
|
||||
def process_sr_models(wakenet_model_dir, multinet_model_dirs, build_dir, assets_dir):
|
||||
"""Process SR models (wakenet and multinet) and generate srmodels.bin"""
|
||||
if not wakenet_model_dir and not multinet_model_dirs:
|
||||
return None
|
||||
|
||||
# Copy input directory to build directory
|
||||
wakenet_build_dir = os.path.join(build_dir, "wakenet_model")
|
||||
if os.path.exists(wakenet_build_dir):
|
||||
shutil.rmtree(wakenet_build_dir)
|
||||
copy_directory(wakenet_model_dir, os.path.join(wakenet_build_dir, os.path.basename(wakenet_model_dir)))
|
||||
# Create SR models build directory
|
||||
sr_models_build_dir = os.path.join(build_dir, "srmodels")
|
||||
if os.path.exists(sr_models_build_dir):
|
||||
shutil.rmtree(sr_models_build_dir)
|
||||
os.makedirs(sr_models_build_dir)
|
||||
|
||||
models_processed = 0
|
||||
|
||||
# Copy wakenet model if available
|
||||
if wakenet_model_dir:
|
||||
wakenet_name = os.path.basename(wakenet_model_dir)
|
||||
wakenet_dst = os.path.join(sr_models_build_dir, wakenet_name)
|
||||
if copy_directory(wakenet_model_dir, wakenet_dst):
|
||||
models_processed += 1
|
||||
print(f"Added wakenet model: {wakenet_name}")
|
||||
|
||||
# Copy multinet models if available
|
||||
if multinet_model_dirs:
|
||||
for multinet_model_dir in multinet_model_dirs:
|
||||
multinet_name = os.path.basename(multinet_model_dir)
|
||||
multinet_dst = os.path.join(sr_models_build_dir, multinet_name)
|
||||
if copy_directory(multinet_model_dir, multinet_dst):
|
||||
models_processed += 1
|
||||
print(f"Added multinet model: {multinet_name}")
|
||||
|
||||
if models_processed == 0:
|
||||
print("Warning: No SR models were successfully processed")
|
||||
return None
|
||||
|
||||
# Use pack_models function to generate srmodels.bin
|
||||
srmodels_output = os.path.join(wakenet_build_dir, "srmodels.bin")
|
||||
srmodels_output = os.path.join(sr_models_build_dir, "srmodels.bin")
|
||||
try:
|
||||
pack_models(wakenet_build_dir, "srmodels.bin")
|
||||
pack_models(sr_models_build_dir, "srmodels.bin")
|
||||
print(f"Generated: {srmodels_output}")
|
||||
# Copy srmodels.bin to assets directory
|
||||
copy_file(srmodels_output, os.path.join(assets_dir, "srmodels.bin"))
|
||||
@@ -180,6 +203,11 @@ def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
|
||||
return None
|
||||
|
||||
|
||||
def process_wakenet_model(wakenet_model_dir, build_dir, assets_dir):
|
||||
"""Process wakenet_model parameter (legacy compatibility function)"""
|
||||
return process_sr_models(wakenet_model_dir, None, build_dir, assets_dir)
|
||||
|
||||
|
||||
def process_text_font(text_font_file, assets_dir):
|
||||
"""Process text_font parameter"""
|
||||
if not text_font_file:
|
||||
@@ -250,7 +278,7 @@ def process_extra_files(extra_files_dir, assets_dir):
|
||||
return extra_files_list
|
||||
|
||||
|
||||
def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files=None):
|
||||
def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files=None, multinet_model_info=None):
|
||||
"""Generate index.json file"""
|
||||
index_data = {
|
||||
"version": 1
|
||||
@@ -268,6 +296,9 @@ def generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra
|
||||
if extra_files:
|
||||
index_data["extra_files"] = extra_files
|
||||
|
||||
if multinet_model_info:
|
||||
index_data["multinet_model"] = multinet_model_info
|
||||
|
||||
# Write index.json
|
||||
index_path = os.path.join(assets_dir, "index.json")
|
||||
with open(index_path, 'w', encoding='utf-8') as f:
|
||||
@@ -434,6 +465,132 @@ def read_wakenet_from_sdkconfig(sdkconfig_path):
|
||||
return models[0] if models else None
|
||||
|
||||
|
||||
def read_multinet_from_sdkconfig(sdkconfig_path):
|
||||
"""
|
||||
Read multinet models from sdkconfig (based on movemodel.py logic)
|
||||
Returns a list of multinet model names
|
||||
"""
|
||||
if not os.path.exists(sdkconfig_path):
|
||||
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
|
||||
return []
|
||||
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
models_string = ''
|
||||
for label in f:
|
||||
label = label.strip("\n")
|
||||
if 'CONFIG_SR_MN' in label and label[0] != '#':
|
||||
models_string += label
|
||||
|
||||
models = []
|
||||
if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in models_string:
|
||||
models.append('mn3_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in models_string:
|
||||
models.append('mn4q8_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in models_string:
|
||||
models.append('mn4_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in models_string:
|
||||
models.append('mn5q8_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET6_QUANT" in models_string:
|
||||
models.append('mn6_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET6_AC_QUANT" in models_string:
|
||||
models.append('mn6_cn_ac')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET7_QUANT" in models_string:
|
||||
models.append('mn7_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET7_AC_QUANT" in models_string:
|
||||
models.append('mn7_cn_ac')
|
||||
|
||||
if "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in models_string:
|
||||
models.append('mn5q8_en')
|
||||
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in models_string:
|
||||
models.append('mn5_en')
|
||||
elif "CONFIG_SR_MN_EN_MULTINET6_QUANT" in models_string:
|
||||
models.append('mn6_en')
|
||||
elif "CONFIG_SR_MN_EN_MULTINET7_QUANT" in models_string:
|
||||
models.append('mn7_en')
|
||||
|
||||
if "MULTINET6" in models_string or "MULTINET7" in models_string:
|
||||
models.append('fst')
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def read_custom_wake_word_from_sdkconfig(sdkconfig_path):
|
||||
"""
|
||||
Read custom wake word configuration from sdkconfig
|
||||
Returns a dict with custom wake word info or None if not configured
|
||||
"""
|
||||
if not os.path.exists(sdkconfig_path):
|
||||
print(f"Warning: sdkconfig file not found: {sdkconfig_path}")
|
||||
return None
|
||||
|
||||
config_values = {}
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip("\n")
|
||||
if line.startswith('#') or '=' not in line:
|
||||
continue
|
||||
|
||||
# Check for custom wake word configuration
|
||||
if 'CONFIG_USE_CUSTOM_WAKE_WORD=y' in line:
|
||||
config_values['use_custom_wake_word'] = True
|
||||
elif 'CONFIG_CUSTOM_WAKE_WORD=' in line and not line.startswith('#'):
|
||||
# Extract string value (remove quotes)
|
||||
value = line.split('=', 1)[1].strip('"')
|
||||
config_values['wake_word'] = value
|
||||
elif 'CONFIG_CUSTOM_WAKE_WORD_DISPLAY=' in line and not line.startswith('#'):
|
||||
# Extract string value (remove quotes)
|
||||
value = line.split('=', 1)[1].strip('"')
|
||||
config_values['display'] = value
|
||||
elif 'CONFIG_CUSTOM_WAKE_WORD_THRESHOLD=' in line and not line.startswith('#'):
|
||||
# Extract numeric value
|
||||
value = line.split('=', 1)[1]
|
||||
try:
|
||||
config_values['threshold'] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
config_values['threshold'] = float(value)
|
||||
except ValueError:
|
||||
print(f"Warning: Invalid threshold value: {value}")
|
||||
config_values['threshold'] = 20 # default (will be converted to 0.2)
|
||||
|
||||
# Return config only if custom wake word is enabled and required fields are present
|
||||
if (config_values.get('use_custom_wake_word', False) and
|
||||
'wake_word' in config_values and
|
||||
'display' in config_values and
|
||||
'threshold' in config_values):
|
||||
return {
|
||||
'wake_word': config_values['wake_word'],
|
||||
'display': config_values['display'],
|
||||
'threshold': config_values['threshold'] / 100.0 # Convert to decimal (20 -> 0.2)
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_language_from_multinet_models(multinet_models):
|
||||
"""
|
||||
Determine language from multinet model names
|
||||
Returns 'cn', 'en', or None
|
||||
"""
|
||||
if not multinet_models:
|
||||
return None
|
||||
|
||||
# Check for Chinese models
|
||||
cn_indicators = ['_cn', 'cn_']
|
||||
en_indicators = ['_en', 'en_']
|
||||
|
||||
has_cn = any(any(indicator in model for indicator in cn_indicators) for model in multinet_models)
|
||||
has_en = any(any(indicator in model for indicator in en_indicators) for model in multinet_models)
|
||||
|
||||
# If both or neither, default to cn
|
||||
if has_cn and not has_en:
|
||||
return 'cn'
|
||||
elif has_en and not has_cn:
|
||||
return 'en'
|
||||
else:
|
||||
return 'cn' # Default to Chinese
|
||||
|
||||
|
||||
def get_wakenet_model_path(model_name, esp_sr_model_path):
|
||||
"""
|
||||
Get the full path to the wakenet model directory
|
||||
@@ -449,6 +606,25 @@ def get_wakenet_model_path(model_name, esp_sr_model_path):
|
||||
return None
|
||||
|
||||
|
||||
def get_multinet_model_paths(model_names, esp_sr_model_path):
|
||||
"""
|
||||
Get the full paths to the multinet model directories
|
||||
Returns a list of valid model paths
|
||||
"""
|
||||
if not model_names:
|
||||
return []
|
||||
|
||||
valid_paths = []
|
||||
for model_name in model_names:
|
||||
multinet_model_path = os.path.join(esp_sr_model_path, 'multinet_model', model_name)
|
||||
if os.path.exists(multinet_model_path):
|
||||
valid_paths.append(multinet_model_path)
|
||||
else:
|
||||
print(f"Warning: Multinet model directory not found: {multinet_model_path}")
|
||||
|
||||
return valid_paths
|
||||
|
||||
|
||||
def get_text_font_path(builtin_text_font, xiaozhi_fonts_path):
|
||||
"""
|
||||
Get the text font path if needed
|
||||
@@ -485,7 +661,7 @@ def get_emoji_collection_path(default_emoji_collection, xiaozhi_fonts_path):
|
||||
return None
|
||||
|
||||
|
||||
def build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection_path, extra_files_path, output_path):
|
||||
def build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path, extra_files_path, output_path, multinet_model_info=None):
|
||||
"""
|
||||
Build assets using integrated functions (no external dependencies)
|
||||
"""
|
||||
@@ -503,13 +679,13 @@ def build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection
|
||||
print("Starting to build assets...")
|
||||
|
||||
# Process each component
|
||||
srmodels = process_wakenet_model(wakenet_model_path, temp_build_dir, assets_dir) if wakenet_model_path else None
|
||||
srmodels = process_sr_models(wakenet_model_path, multinet_model_paths, temp_build_dir, assets_dir) if (wakenet_model_path or multinet_model_paths) else None
|
||||
text_font = process_text_font(text_font_path, assets_dir) if text_font_path else None
|
||||
emoji_collection = process_emoji_collection(emoji_collection_path, assets_dir) if emoji_collection_path else None
|
||||
extra_files = process_extra_files(extra_files_path, assets_dir) if extra_files_path else None
|
||||
|
||||
# Generate index.json
|
||||
generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files)
|
||||
generate_index_json(assets_dir, srmodels, text_font, emoji_collection, extra_files, multinet_model_info)
|
||||
|
||||
# Generate config.json for packing
|
||||
config_path = generate_config_json(temp_build_dir, assets_dir)
|
||||
@@ -578,9 +754,19 @@ def main():
|
||||
print(f" emoji_collection: {args.emoji_collection}")
|
||||
print(f" output: {args.output}")
|
||||
|
||||
# Read wakenet model from sdkconfig
|
||||
# Read SR models from sdkconfig
|
||||
wakenet_model_name = read_wakenet_from_sdkconfig(args.sdkconfig)
|
||||
multinet_model_names = read_multinet_from_sdkconfig(args.sdkconfig)
|
||||
|
||||
# Get model paths
|
||||
wakenet_model_path = get_wakenet_model_path(wakenet_model_name, args.esp_sr_model_path)
|
||||
multinet_model_paths = get_multinet_model_paths(multinet_model_names, args.esp_sr_model_path)
|
||||
|
||||
# Print model information
|
||||
if wakenet_model_name:
|
||||
print(f" wakenet model: {wakenet_model_name}")
|
||||
if multinet_model_names:
|
||||
print(f" multinet models: {', '.join(multinet_model_names)}")
|
||||
|
||||
# Get text font path if needed
|
||||
text_font_path = get_text_font_path(args.builtin_text_font, args.xiaozhi_fonts_path)
|
||||
@@ -591,9 +777,34 @@ def main():
|
||||
# Get extra files path if provided
|
||||
extra_files_path = args.extra_files
|
||||
|
||||
# Read custom wake word configuration
|
||||
custom_wake_word_config = read_custom_wake_word_from_sdkconfig(args.sdkconfig)
|
||||
multinet_model_info = None
|
||||
|
||||
if custom_wake_word_config and multinet_model_names:
|
||||
# Determine language from multinet models
|
||||
language = get_language_from_multinet_models(multinet_model_names)
|
||||
|
||||
# Build multinet_model info structure
|
||||
multinet_model_info = {
|
||||
"language": language,
|
||||
"duration": 3000, # Default duration in ms
|
||||
"threshold": custom_wake_word_config['threshold'],
|
||||
"commands": [
|
||||
{
|
||||
"command": custom_wake_word_config['wake_word'],
|
||||
"text": custom_wake_word_config['display'],
|
||||
"action": "wake"
|
||||
}
|
||||
]
|
||||
}
|
||||
print(f" custom wake word: {custom_wake_word_config['wake_word']} ({custom_wake_word_config['display']})")
|
||||
print(f" wake word language: {language}")
|
||||
print(f" wake word threshold: {custom_wake_word_config['threshold']}")
|
||||
|
||||
# Check if we have anything to build
|
||||
if not wakenet_model_path and not text_font_path and not emoji_collection_path and not extra_files_path:
|
||||
print("Warning: No assets to build (no wakenet, text font, emoji collection, or extra files)")
|
||||
if not wakenet_model_path and not multinet_model_paths and not text_font_path and not emoji_collection_path and not extra_files_path and not multinet_model_info:
|
||||
print("Warning: No assets to build (no SR models, text font, emoji collection, extra files, or custom wake word)")
|
||||
# Create an empty assets.bin file
|
||||
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
||||
with open(args.output, 'wb') as f:
|
||||
@@ -602,8 +813,8 @@ def main():
|
||||
return
|
||||
|
||||
# Build the assets
|
||||
success = build_assets_integrated(wakenet_model_path, text_font_path, emoji_collection_path,
|
||||
extra_files_path, args.output)
|
||||
success = build_assets_integrated(wakenet_model_path, multinet_model_paths, text_font_path, emoji_collection_path,
|
||||
extra_files_path, args.output, multinet_model_info)
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
Reference in New Issue
Block a user