diff --git a/CMakeLists.txt b/CMakeLists.txt index 83395293..d5c4564c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ # CMakeLists in this exact order for cmake to work correctly cmake_minimum_required(VERSION 3.16) -set(PROJECT_VER "0.9.1") +set(PROJECT_VER "0.9.2") include($ENV{IDF_PATH}/tools/cmake/project.cmake) project(xiaozhi) diff --git a/main/application.cc b/main/application.cc index 5b76d64c..19b4147c 100644 --- a/main/application.cc +++ b/main/application.cc @@ -123,20 +123,52 @@ void Application::ToggleChatState() { Schedule([this]() { if (chat_state_ == kChatStateIdle) { SetChatState(kChatStateConnecting); - if (protocol_->OpenAudioChannel()) { - opus_encoder_.ResetState(); - SetChatState(kChatStateListening); - } else { + if (!protocol_->OpenAudioChannel()) { + ESP_LOGE(TAG, "Failed to open audio channel"); SetChatState(kChatStateIdle); + return; } + + keep_listening_ = true; + protocol_->SendStartListening(kListeningModeAutoStop); + SetChatState(kChatStateListening); } else if (chat_state_ == kChatStateSpeaking) { - AbortSpeaking(); + AbortSpeaking(kAbortReasonNone); } else if (chat_state_ == kChatStateListening) { protocol_->CloseAudioChannel(); } }); } +void Application::StartListening() { + Schedule([this]() { + keep_listening_ = false; + if (chat_state_ == kChatStateIdle) { + if (!protocol_->IsAudioChannelOpened()) { + SetChatState(kChatStateConnecting); + if (!protocol_->OpenAudioChannel()) { + SetChatState(kChatStateIdle); + ESP_LOGE(TAG, "Failed to open audio channel"); + return; + } + } + protocol_->SendStartListening(kListeningModeManualStop); + SetChatState(kChatStateListening); + } else if (chat_state_ == kChatStateSpeaking) { + AbortSpeaking(kAbortReasonNone); + protocol_->SendStartListening(kListeningModeManualStop); + SetChatState(kChatStateListening); + } + }); +} + +void Application::StopListening() { + Schedule([this]() { + protocol_->SendStopListening(); + SetChatState(kChatStateIdle); + }); +} + void Application::Start() { auto& board = Board::GetInstance(); board.Initialize(); @@ -248,26 +280,31 @@ void Application::Start() { }); }); - wake_word_detect_.OnWakeWordDetected([this]() { - Schedule([this]() { + wake_word_detect_.OnWakeWordDetected([this](const std::string& wake_word) { + Schedule([this, &wake_word]() { if (chat_state_ == kChatStateIdle) { SetChatState(kChatStateConnecting); wake_word_detect_.EncodeWakeWordData(); - if (protocol_->OpenAudioChannel()) { - std::string opus; - // Encode and send the wake word data to the server - while (wake_word_detect_.GetWakeWordOpus(opus)) { - protocol_->SendAudio(opus); - } - opus_encoder_.ResetState(); - // Send a ready message to indicate the server that the wake word data is sent - SetChatState(kChatStateWakeWordDetected); - } else { + if (!protocol_->OpenAudioChannel()) { + ESP_LOGE(TAG, "Failed to open audio channel"); SetChatState(kChatStateIdle); + wake_word_detect_.StartDetection(); + return; } + + std::string opus; + // Encode and send the wake word data to the server + while (wake_word_detect_.GetWakeWordOpus(opus)) { + protocol_->SendAudio(opus); + } + // Set the chat state to wake word detected + protocol_->SendWakeWordDetected(wake_word); + ESP_LOGI(TAG, "Wake word detected: %s", wake_word.c_str()); + keep_listening_ = true; + SetChatState(kChatStateListening); } else if (chat_state_ == kChatStateSpeaking) { - AbortSpeaking(); + AbortSpeaking(kAbortReasonWakeWordDetected); } // Resume detection @@ -313,15 +350,23 @@ void Application::Start() { auto state = cJSON_GetObjectItem(root, "state"); if (strcmp(state->valuestring, "start") == 0) { Schedule([this]() { - skip_to_end_ = false; - SetChatState(kChatStateSpeaking); + if (chat_state_ == kChatStateIdle || chat_state_ == kChatStateListening) { + skip_to_end_ = false; + opus_decoder_ctl(opus_decoder_, OPUS_RESET_STATE); + SetChatState(kChatStateSpeaking); + } }); } else if (strcmp(state->valuestring, "stop") == 0) { Schedule([this]() { auto codec = Board::GetInstance().GetAudioCodec(); codec->WaitForOutputDone(); if (chat_state_ == kChatStateSpeaking) { - SetChatState(kChatStateListening); + if (keep_listening_) { + protocol_->SendStartListening(kListeningModeAutoStop); + SetChatState(kChatStateListening); + } else { + SetChatState(kChatStateIdle); + } } }); } else if (strcmp(state->valuestring, "sentence_start") == 0) { @@ -375,9 +420,9 @@ void Application::MainLoop() { } } -void Application::AbortSpeaking() { +void Application::AbortSpeaking(AbortReason reason) { ESP_LOGI(TAG, "Abort speaking"); - protocol_->SendAbort(); + protocol_->SendAbortSpeaking(reason); skip_to_end_ = true; auto codec = Board::GetInstance().GetAudioCodec(); @@ -391,7 +436,6 @@ void Application::SetChatState(ChatState state) { "connecting", "listening", "speaking", - "wake_word_detected", "upgrading", "invalid_state" }; @@ -399,12 +443,10 @@ void Application::SetChatState(ChatState state) { // No need to update the state return; } - chat_state_ = state; - ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]); auto display = Board::GetInstance().GetDisplay(); auto builtin_led = Board::GetInstance().GetBuiltinLed(); - switch (chat_state_) { + switch (state) { case kChatStateUnknown: case kChatStateIdle: builtin_led->TurnOff(); @@ -424,6 +466,7 @@ void Application::SetChatState(ChatState state) { builtin_led->TurnOn(); display->SetStatus("聆听中..."); display->SetEmotion("neutral"); + opus_encoder_.ResetState(); #ifdef CONFIG_USE_AFE_SR audio_processor_.Start(); #endif @@ -436,17 +479,17 @@ void Application::SetChatState(ChatState state) { audio_processor_.Stop(); #endif break; - case kChatStateWakeWordDetected: - builtin_led->SetBlue(); - builtin_led->TurnOn(); - break; case kChatStateUpgrading: builtin_led->SetGreen(); builtin_led->StartContinuousBlink(100); break; + default: + ESP_LOGE(TAG, "Invalid chat state: %d", chat_state_); + return; } - protocol_->SendState(state_str[chat_state_]); + chat_state_ = state; + ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]); } void Application::AudioEncodeTask() { diff --git a/main/application.h b/main/application.h index fe774eca..fc91f242 100644 --- a/main/application.h +++ b/main/application.h @@ -29,7 +29,6 @@ enum ChatState { kChatStateConnecting, kChatStateListening, kChatStateSpeaking, - kChatStateWakeWordDetected, kChatStateUpgrading }; @@ -41,17 +40,19 @@ public: static Application instance; return instance; } + // 删除拷贝构造函数和赋值运算符 + Application(const Application&) = delete; + Application& operator=(const Application&) = delete; void Start(); ChatState GetChatState() const { return chat_state_; } void Schedule(std::function callback); void SetChatState(ChatState state); void Alert(const std::string&& title, const std::string&& message); - void AbortSpeaking(); + void AbortSpeaking(AbortReason reason); void ToggleChatState(); - // 删除拷贝构造函数和赋值运算符 - Application(const Application&) = delete; - Application& operator=(const Application&) = delete; + void StartListening(); + void StopListening(); private: Application(); @@ -68,6 +69,7 @@ private: Protocol* protocol_ = nullptr; EventGroupHandle_t event_group_; volatile ChatState chat_state_ = kChatStateUnknown; + bool keep_listening_ = false; bool skip_to_end_ = false; // Audio encode / decode diff --git a/main/audio_codecs/no_audio_codec.cc b/main/audio_codecs/no_audio_codec.cc index ef12bf50..d92d7944 100644 --- a/main/audio_codecs/no_audio_codec.cc +++ b/main/audio_codecs/no_audio_codec.cc @@ -65,8 +65,6 @@ NoAudioCodec::NoAudioCodec(int input_sample_rate, int output_sample_rate, gpio_n }; ESP_ERROR_CHECK(i2s_channel_init_std_mode(tx_handle_, &std_cfg)); ESP_ERROR_CHECK(i2s_channel_init_std_mode(rx_handle_, &std_cfg)); - ESP_ERROR_CHECK(i2s_channel_enable(tx_handle_)); - ESP_ERROR_CHECK(i2s_channel_enable(rx_handle_)); ESP_LOGI(TAG, "Duplex channels created"); } diff --git a/main/boards/bread-compact-wifi/compact_wifi_board.cc b/main/boards/bread-compact-wifi/compact_wifi_board.cc index 67c31bb8..f5f19a0c 100644 --- a/main/boards/bread-compact-wifi/compact_wifi_board.cc +++ b/main/boards/bread-compact-wifi/compact_wifi_board.cc @@ -7,6 +7,7 @@ #include "led.h" #include "config.h" +#include #include #include @@ -38,7 +39,11 @@ private: void InitializeButtons() { boot_button_.OnClick([this]() { - Application::GetInstance().ToggleChatState(); + auto& app = Application::GetInstance(); + if (app.GetChatState() == kChatStateUnknown && !WifiStation::GetInstance().IsConnected()) { + ResetWifiConfiguration(); + } + app.ToggleChatState(); }); volume_up_button_.OnClick([this]() { diff --git a/main/boards/common/button.cc b/main/boards/common/button.cc index 15dc3b80..837bfd66 100644 --- a/main/boards/common/button.cc +++ b/main/boards/common/button.cc @@ -30,15 +30,28 @@ Button::~Button() { } } -void Button::OnPress(std::function callback) { +void Button::OnPressDown(std::function callback) { if (button_handle_ == nullptr) { return; } - on_press_ = callback; + on_press_down_ = callback; iot_button_register_cb(button_handle_, BUTTON_PRESS_DOWN, [](void* handle, void* usr_data) { Button* button = static_cast(usr_data); - if (button->on_press_) { - button->on_press_(); + if (button->on_press_down_) { + button->on_press_down_(); + } + }, this); +} + +void Button::OnPressUp(std::function callback) { + if (button_handle_ == nullptr) { + return; + } + on_press_up_ = callback; + iot_button_register_cb(button_handle_, BUTTON_PRESS_UP, [](void* handle, void* usr_data) { + Button* button = static_cast(usr_data); + if (button->on_press_up_) { + button->on_press_up_(); } }, this); } diff --git a/main/boards/common/button.h b/main/boards/common/button.h index 347df5c0..a43dc352 100644 --- a/main/boards/common/button.h +++ b/main/boards/common/button.h @@ -10,7 +10,8 @@ public: Button(gpio_num_t gpio_num); ~Button(); - void OnPress(std::function callback); + void OnPressDown(std::function callback); + void OnPressUp(std::function callback); void OnLongPress(std::function callback); void OnClick(std::function callback); void OnDoubleClick(std::function callback); @@ -19,7 +20,8 @@ private: button_handle_t button_handle_; - std::function on_press_; + std::function on_press_down_; + std::function on_press_up_; std::function on_long_press_; std::function on_click_; std::function on_double_click_; diff --git a/main/boards/common/wifi_board.cc b/main/boards/common/wifi_board.cc index 40e0458a..622ea405 100644 --- a/main/boards/common/wifi_board.cc +++ b/main/boards/common/wifi_board.cc @@ -2,6 +2,7 @@ #include "application.h" #include "system_info.h" #include "font_awesome_symbols.h" +#include "settings.h" #include #include @@ -149,3 +150,15 @@ void WifiBoard::SetPowerSaveMode(bool enabled) { auto& wifi_station = WifiStation::GetInstance(); wifi_station.SetPowerSaveMode(enabled); } + +void WifiBoard::ResetWifiConfiguration() { + // Reset the wifi station + { + Settings settings("wifi", true); + settings.EraseAll(); + } + GetDisplay()->ShowNotification("已重置 WiFi..."); + vTaskDelay(pdMS_TO_TICKS(1000)); + // Reboot the device + esp_restart(); +} diff --git a/main/boards/common/wifi_board.h b/main/boards/common/wifi_board.h index 185d971a..546039a2 100644 --- a/main/boards/common/wifi_board.h +++ b/main/boards/common/wifi_board.h @@ -19,6 +19,7 @@ public: virtual bool GetNetworkState(std::string& network_name, int& signal_quality, std::string& signal_quality_text) override; virtual const char* GetNetworkStateIcon() override; virtual void SetPowerSaveMode(bool enabled) override; + virtual void ResetWifiConfiguration(); }; #endif // WIFI_BOARD_H diff --git a/main/boards/kevin-box-2/axp2101.cc b/main/boards/kevin-box-2/axp2101.cc index daaeaf1d..b905f8b3 100644 --- a/main/boards/kevin-box-2/axp2101.cc +++ b/main/boards/kevin-box-2/axp2101.cc @@ -20,7 +20,7 @@ Axp2101::Axp2101(i2c_master_bus_handle_t i2c_bus, uint8_t addr) : I2cDevice(i2c_ WriteReg(0x64, 0x03); // CV charger voltage setting to 4.2V WriteReg(0x61, 0x05); // set Main battery precharge current to 125mA - WriteReg(0x62, 0x10); // set Main battery charger current to 1000mA ( 0x08-200mA, 0x09-300mA, 0x0A-400mA ) + WriteReg(0x62, 0x0A); // set Main battery charger current to 400mA ( 0x08-200mA, 0x09-300mA, 0x0A-400mA ) WriteReg(0x63, 0x15); // set Main battery term charge current to 125mA WriteReg(0x14, 0x00); // set minimum system voltage to 4.1V (default 4.7V), for poor USB cables diff --git a/main/boards/kevin-box-2/kevin_box_board.cc b/main/boards/kevin-box-2/kevin_box_board.cc index 5a0c8c06..26156ff5 100644 --- a/main/boards/kevin-box-2/kevin_box_board.cc +++ b/main/boards/kevin-box-2/kevin_box_board.cc @@ -118,8 +118,15 @@ private: } void InitializeButtons() { - boot_button_.OnClick([this]() { - Application::GetInstance().ToggleChatState(); + // 测试按住说话 + // boot_button_.OnClick([this]() { + // Application::GetInstance().ToggleChatState(); + // }); + boot_button_.OnPressDown([this]() { + Application::GetInstance().StartListening(); + }); + boot_button_.OnPressUp([this]() { + Application::GetInstance().StopListening(); }); volume_up_button_.OnClick([this]() { diff --git a/main/boards/lichuang-dev/lichuang_dev_board.cc b/main/boards/lichuang-dev/lichuang_dev_board.cc index cebc6f6e..3ef44c65 100644 --- a/main/boards/lichuang-dev/lichuang_dev_board.cc +++ b/main/boards/lichuang-dev/lichuang_dev_board.cc @@ -11,6 +11,7 @@ #include #include #include +#include #define TAG "LichuangDevBoard" @@ -71,7 +72,11 @@ private: void InitializeButtons() { boot_button_.OnClick([this]() { - Application::GetInstance().ToggleChatState(); + auto& app = Application::GetInstance(); + if (app.GetChatState() == kChatStateUnknown && !WifiStation::GetInstance().IsConnected()) { + ResetWifiConfiguration(); + } + app.ToggleChatState(); }); } diff --git a/main/display/ssd1306_display.cc b/main/display/ssd1306_display.cc index 00331bdd..9c9369b3 100644 --- a/main/display/ssd1306_display.cc +++ b/main/display/ssd1306_display.cc @@ -184,7 +184,6 @@ void Ssd1306Display::SetupUI_128x64() { status_label_ = lv_label_create(status_bar_); lv_obj_set_flex_grow(status_label_, 1); - lv_label_set_long_mode(status_label_, LV_LABEL_LONG_SCROLL_CIRCULAR); lv_label_set_text(status_label_, "正在初始化"); lv_obj_set_style_text_align(status_label_, LV_TEXT_ALIGN_CENTER, 0); @@ -255,11 +254,14 @@ void Ssd1306Display::SetupUI_128x32() { status_label_ = lv_label_create(side_bar_); lv_obj_set_flex_grow(status_label_, 1); + lv_obj_set_width(status_label_, width_ - 32); lv_label_set_long_mode(status_label_, LV_LABEL_LONG_SCROLL_CIRCULAR); lv_label_set_text(status_label_, "正在初始化"); notification_label_ = lv_label_create(side_bar_); lv_obj_set_flex_grow(notification_label_, 1); + lv_obj_set_width(notification_label_, width_ - 32); + lv_label_set_long_mode(notification_label_, LV_LABEL_LONG_SCROLL_CIRCULAR); lv_label_set_text(notification_label_, "通知"); lv_obj_add_flag(notification_label_, LV_OBJ_FLAG_HIDDEN); } diff --git a/main/display/st7789_display.cc b/main/display/st7789_display.cc index 32a6f722..20d0c387 100644 --- a/main/display/st7789_display.cc +++ b/main/display/st7789_display.cc @@ -198,7 +198,6 @@ void St7789Display::SetupUI() { status_label_ = lv_label_create(status_bar_); lv_obj_set_flex_grow(status_label_, 1); - lv_label_set_long_mode(status_label_, LV_LABEL_LONG_SCROLL_CIRCULAR); lv_label_set_text(status_label_, "正在初始化"); lv_obj_set_style_text_align(status_label_, LV_TEXT_ALIGN_CENTER, 0); diff --git a/main/idf_component.yml b/main/idf_component.yml index b22373c7..a3bcc214 100644 --- a/main/idf_component.yml +++ b/main/idf_component.yml @@ -1,6 +1,6 @@ ## IDF Component Manager Manifest File dependencies: - 78/esp-wifi-connect: "~1.4.0" + 78/esp-wifi-connect: "~1.4.1" 78/esp-opus-encoder: "~1.1.0" 78/esp-ml307: "~1.6.3" espressif/led_strip: "^2.4.1" diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index cbf4b580..ba6d9d06 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -72,9 +72,9 @@ bool MqttProtocol::StartMqttClient() { } else if (strcmp(type->valuestring, "goodbye") == 0) { auto session_id = cJSON_GetObjectItem(root, "session_id"); if (session_id == nullptr || session_id_ == session_id->valuestring) { - if (on_audio_channel_closed_ != nullptr) { - on_audio_channel_closed_(); - } + Application::GetInstance().Schedule([this]() { + CloseAudioChannel(); + }); } } else if (on_incoming_json_ != nullptr) { on_incoming_json_(root); @@ -129,23 +129,6 @@ void MqttProtocol::SendAudio(const std::string& data) { udp_->Send(encrypted); } -void MqttProtocol::SendState(const std::string& state) { - std::string message = "{"; - message += "\"session_id\":\"" + session_id_ + "\","; - message += "\"type\":\"state\","; - message += "\"state\":\"" + state + "\""; - message += "}"; - SendText(message); -} - -void MqttProtocol::SendAbort() { - std::string message = "{"; - message += "\"session_id\":\"" + session_id_ + "\","; - message += "\"type\":\"abort\""; - message += "}"; - SendText(message); -} - void MqttProtocol::CloseAudioChannel() { { std::lock_guard lock(channel_mutex_); diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index 6cc6be89..c6da3efe 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -26,9 +26,6 @@ public: ~MqttProtocol(); void SendAudio(const std::string& data) override; - void SendText(const std::string& text) override; - void SendState(const std::string& state) override; - void SendAbort() override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override; @@ -52,11 +49,12 @@ private: int udp_port_; uint32_t local_sequence_; uint32_t remote_sequence_; - std::string session_id_; bool StartMqttClient(); void ParseServerHello(const cJSON* root); std::string DecodeHexString(const std::string& hex_string); + + void SendText(const std::string& text) override; }; diff --git a/main/protocols/protocol.cc b/main/protocols/protocol.cc index d9906c2a..5c55141e 100644 --- a/main/protocols/protocol.cc +++ b/main/protocols/protocol.cc @@ -23,3 +23,37 @@ void Protocol::OnAudioChannelClosed(std::function callback) { void Protocol::OnNetworkError(std::function callback) { on_network_error_ = callback; } + +void Protocol::SendAbortSpeaking(AbortReason reason) { + std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\""; + if (reason == kAbortReasonWakeWordDetected) { + message += ",\"reason\":\"wake_word_detected\""; + } + message += "}"; + SendText(message); +} + +void Protocol::SendWakeWordDetected(const std::string& wake_word) { + std::string json = "{\"session_id\":\"" + session_id_ + + "\",\"type\":\"listen\",\"state\":\"detect\",\"text\":\"" + wake_word + "\"}"; + SendText(json); +} + +void Protocol::SendStartListening(ListeningMode mode) { + std::string message = "{\"session_id\":\"" + session_id_ + "\""; + message += ",\"type\":\"listen\",\"state\":\"start\""; + if (mode == kListeningModeAlwaysOn) { + message += ",\"mode\":\"realtime\""; + } else if (mode == kListeningModeAutoStop) { + message += ",\"mode\":\"auto\""; + } else { + message += ",\"mode\":\"manual\""; + } + message += "}"; + SendText(message); +} + +void Protocol::SendStopListening() { + std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"listen\",\"state\":\"stop\"}"; + SendText(message); +} diff --git a/main/protocols/protocol.h b/main/protocols/protocol.h index 6261b9ca..5b6216ab 100644 --- a/main/protocols/protocol.h +++ b/main/protocols/protocol.h @@ -12,6 +12,16 @@ struct BinaryProtocol3 { uint8_t payload[]; } __attribute__((packed)); +enum AbortReason { + kAbortReasonNone, + kAbortReasonWakeWordDetected +}; + +enum ListeningMode { + kListeningModeAutoStop, + kListeningModeManualStop, + kListeningModeAlwaysOn // 需要 AEC 支持 +}; class Protocol { public: @@ -27,13 +37,14 @@ public: void OnAudioChannelClosed(std::function callback); void OnNetworkError(std::function callback); - virtual void SendAudio(const std::string& data) = 0; - virtual void SendText(const std::string& text) = 0; - virtual void SendState(const std::string& state) = 0; - virtual void SendAbort() = 0; virtual bool OpenAudioChannel() = 0; virtual void CloseAudioChannel() = 0; virtual bool IsAudioChannelOpened() const = 0; + virtual void SendAudio(const std::string& data) = 0; + virtual void SendWakeWordDetected(const std::string& wake_word); + virtual void SendStartListening(ListeningMode mode); + virtual void SendStopListening(); + virtual void SendAbortSpeaking(AbortReason reason); protected: std::function on_incoming_json_; @@ -43,6 +54,9 @@ protected: std::function on_network_error_; int server_sample_rate_ = 16000; + std::string session_id_; + + virtual void SendText(const std::string& text) = 0; }; #endif // PROTOCOL_H diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 078185ad..500615be 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -39,21 +39,6 @@ void WebsocketProtocol::SendText(const std::string& text) { websocket_->Send(text); } -void WebsocketProtocol::SendState(const std::string& state) { - std::string message = "{"; - message += "\"type\":\"state\","; - message += "\"state\":\"" + state + "\""; - message += "}"; - SendText(message); -} - -void WebsocketProtocol::SendAbort() { - std::string message = "{"; - message += "\"type\":\"abort\""; - message += "}"; - SendText(message); -} - bool WebsocketProtocol::IsAudioChannelOpened() const { return websocket_ != nullptr; } diff --git a/main/protocols/websocket_protocol.h b/main/protocols/websocket_protocol.h index b4bd7670..f62b04f0 100644 --- a/main/protocols/websocket_protocol.h +++ b/main/protocols/websocket_protocol.h @@ -16,9 +16,6 @@ public: ~WebsocketProtocol(); void SendAudio(const std::string& data) override; - void SendText(const std::string& text) override; - void SendState(const std::string& state) override; - void SendAbort() override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override; @@ -28,6 +25,7 @@ private: WebSocket* websocket_ = nullptr; void ParseServerHello(const cJSON* root); + void SendText(const std::string& text) override; }; #endif diff --git a/main/settings.cc b/main/settings.cc index ec27ffb2..fb63f477 100644 --- a/main/settings.cc +++ b/main/settings.cc @@ -63,3 +63,19 @@ void Settings::SetInt(const std::string& key, int32_t value) { ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str()); } } + +void Settings::EraseKey(const std::string& key) { + if (read_write_) { + ESP_ERROR_CHECK(nvs_erase_key(nvs_handle_, key.c_str())); + } else { + ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str()); + } +} + +void Settings::EraseAll() { + if (read_write_) { + ESP_ERROR_CHECK(nvs_erase_all(nvs_handle_)); + } else { + ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str()); + } +} diff --git a/main/settings.h b/main/settings.h index 2273b5a8..0fe13885 100644 --- a/main/settings.h +++ b/main/settings.h @@ -13,6 +13,8 @@ public: void SetString(const std::string& key, const std::string& value); int32_t GetInt(const std::string& key, int32_t default_value = 0); void SetInt(const std::string& key, int32_t value); + void EraseKey(const std::string& key); + void EraseAll(); private: std::string ns_; diff --git a/main/wake_word_detect.cc b/main/wake_word_detect.cc index 622447fc..68b05f32 100644 --- a/main/wake_word_detect.cc +++ b/main/wake_word_detect.cc @@ -4,9 +4,9 @@ #include #include #include +#include #define DETECTION_RUNNING_EVENT 1 -#define WAKE_WORD_ENCODED_EVENT 2 static const char* TAG = "WakeWordDetect"; @@ -40,6 +40,13 @@ void WakeWordDetect::Initialize(int channels, bool reference) { ESP_LOGI(TAG, "Model %d: %s", i, models->model_name[i]); if (strstr(models->model_name[i], ESP_WN_PREFIX) != NULL) { wakenet_model_ = models->model_name[i]; + auto words = esp_srmodel_get_wake_words(models, wakenet_model_); + // split by ";" to get all wake words + std::stringstream ss(words); + std::string word; + while (std::getline(ss, word, ';')) { + wake_words_.push_back(word); + } } } @@ -84,7 +91,7 @@ void WakeWordDetect::Initialize(int channels, bool reference) { }, "audio_detection", 4096 * 2, this, 1, nullptr); } -void WakeWordDetect::OnWakeWordDetected(std::function callback) { +void WakeWordDetect::OnWakeWordDetected(std::function callback) { wake_word_detected_callback_ = callback; } @@ -144,11 +151,11 @@ void WakeWordDetect::AudioDetectionTask() { } if (res->wakeup_state == WAKENET_DETECTED) { - ESP_LOGI(TAG, "Wake word detected"); StopDetection(); + last_detected_wake_word_ = wake_words_[res->wake_word_index - 1]; if (wake_word_detected_callback_) { - wake_word_detected_callback_(); + wake_word_detected_callback_(last_detected_wake_word_); } } } @@ -165,7 +172,6 @@ void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) { } void WakeWordDetect::EncodeWakeWordData() { - xEventGroupClearBits(event_group_, WAKE_WORD_ENCODED_EVENT); wake_word_opus_.clear(); if (wake_word_encode_task_stack_ == nullptr) { wake_word_encode_task_stack_ = (StackType_t*)heap_caps_malloc(4096 * 8, MALLOC_CAP_SPIRAM); @@ -182,15 +188,18 @@ void WakeWordDetect::EncodeWakeWordData() { encoder->Encode(pcm, [this_](const uint8_t* opus, size_t opus_size) { std::lock_guard lock(this_->wake_word_mutex_); this_->wake_word_opus_.emplace_back(std::string(reinterpret_cast(opus), opus_size)); - this_->wake_word_cv_.notify_one(); + this_->wake_word_cv_.notify_all(); }); } this_->wake_word_pcm_.clear(); auto end_time = esp_timer_get_time(); ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000); - xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT); - this_->wake_word_cv_.notify_one(); + { + std::lock_guard lock(this_->wake_word_mutex_); + this_->wake_word_opus_.push_back(""); + this_->wake_word_cv_.notify_all(); + } delete encoder; vTaskDelete(NULL); }, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_); @@ -199,12 +208,9 @@ void WakeWordDetect::EncodeWakeWordData() { bool WakeWordDetect::GetWakeWordOpus(std::string& opus) { std::unique_lock lock(wake_word_mutex_); wake_word_cv_.wait(lock, [this]() { - return !wake_word_opus_.empty() || (xEventGroupGetBits(event_group_) & WAKE_WORD_ENCODED_EVENT); + return !wake_word_opus_.empty(); }); - if (wake_word_opus_.empty()) { - return false; - } opus.swap(wake_word_opus_.front()); wake_word_opus_.pop_front(); - return true; + return !opus.empty(); } diff --git a/main/wake_word_detect.h b/main/wake_word_detect.h index 7a472be9..892ea56b 100644 --- a/main/wake_word_detect.h +++ b/main/wake_word_detect.h @@ -23,24 +23,27 @@ public: void Initialize(int channels, bool reference); void Feed(std::vector& data); - void OnWakeWordDetected(std::function callback); + void OnWakeWordDetected(std::function callback); void OnVadStateChange(std::function callback); void StartDetection(); void StopDetection(); bool IsDetectionRunning(); void EncodeWakeWordData(); bool GetWakeWordOpus(std::string& opus); + const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; } private: esp_afe_sr_data_t* afe_detection_data_ = nullptr; char* wakenet_model_ = NULL; + std::vector wake_words_; std::vector input_buffer_; EventGroupHandle_t event_group_; - std::function wake_word_detected_callback_; + std::function wake_word_detected_callback_; std::function vad_state_change_callback_; bool is_speaking_ = false; int channels_; bool reference_; + std::string last_detected_wake_word_; TaskHandle_t wake_word_encode_task_ = nullptr; StaticTask_t wake_word_encode_task_buffer_;