update protocol to support manual response mode

This commit is contained in:
Terrence
2024-11-25 00:59:03 +08:00
parent aa806f676e
commit 472219d5bf
11 changed files with 166 additions and 100 deletions

View File

@@ -4,7 +4,7 @@
# CMakeLists in this exact order for cmake to work correctly # CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16) cmake_minimum_required(VERSION 3.16)
set(PROJECT_VER "0.9.1") set(PROJECT_VER "0.9.2")
include($ENV{IDF_PATH}/tools/cmake/project.cmake) include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(xiaozhi) project(xiaozhi)

View File

@@ -123,20 +123,52 @@ void Application::ToggleChatState() {
Schedule([this]() { Schedule([this]() {
if (chat_state_ == kChatStateIdle) { if (chat_state_ == kChatStateIdle) {
SetChatState(kChatStateConnecting); SetChatState(kChatStateConnecting);
if (protocol_->OpenAudioChannel()) { if (!protocol_->OpenAudioChannel()) {
opus_encoder_.ResetState(); ESP_LOGE(TAG, "Failed to open audio channel");
SetChatState(kChatStateListening);
} else {
SetChatState(kChatStateIdle); SetChatState(kChatStateIdle);
return;
} }
keep_listening_ = true;
protocol_->SendStartListening(kListeningModeAutoStop);
SetChatState(kChatStateListening);
} else if (chat_state_ == kChatStateSpeaking) { } else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking(); AbortSpeaking(kAbortReasonNone);
} else if (chat_state_ == kChatStateListening) { } else if (chat_state_ == kChatStateListening) {
protocol_->CloseAudioChannel(); protocol_->CloseAudioChannel();
} }
}); });
} }
void Application::StartListening() {
Schedule([this]() {
keep_listening_ = false;
if (chat_state_ == kChatStateIdle) {
if (!protocol_->IsAudioChannelOpened()) {
SetChatState(kChatStateConnecting);
if (!protocol_->OpenAudioChannel()) {
SetChatState(kChatStateIdle);
ESP_LOGE(TAG, "Failed to open audio channel");
return;
}
}
protocol_->SendStartListening(kListeningModeManualStop);
SetChatState(kChatStateListening);
} else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking(kAbortReasonNone);
protocol_->SendStartListening(kListeningModeManualStop);
SetChatState(kChatStateListening);
}
});
}
void Application::StopListening() {
Schedule([this]() {
protocol_->SendStopListening();
SetChatState(kChatStateIdle);
});
}
void Application::Start() { void Application::Start() {
auto& board = Board::GetInstance(); auto& board = Board::GetInstance();
board.Initialize(); board.Initialize();
@@ -248,26 +280,31 @@ void Application::Start() {
}); });
}); });
wake_word_detect_.OnWakeWordDetected([this]() { wake_word_detect_.OnWakeWordDetected([this](const std::string& wake_word) {
Schedule([this]() { Schedule([this, &wake_word]() {
if (chat_state_ == kChatStateIdle) { if (chat_state_ == kChatStateIdle) {
SetChatState(kChatStateConnecting); SetChatState(kChatStateConnecting);
wake_word_detect_.EncodeWakeWordData(); wake_word_detect_.EncodeWakeWordData();
if (protocol_->OpenAudioChannel()) { if (!protocol_->OpenAudioChannel()) {
std::string opus; ESP_LOGE(TAG, "Failed to open audio channel");
// Encode and send the wake word data to the server
while (wake_word_detect_.GetWakeWordOpus(opus)) {
protocol_->SendAudio(opus);
}
opus_encoder_.ResetState();
// Send a ready message to indicate the server that the wake word data is sent
SetChatState(kChatStateWakeWordDetected);
} else {
SetChatState(kChatStateIdle); SetChatState(kChatStateIdle);
wake_word_detect_.StartDetection();
return;
} }
std::string opus;
// Encode and send the wake word data to the server
while (wake_word_detect_.GetWakeWordOpus(opus)) {
protocol_->SendAudio(opus);
}
// Set the chat state to wake word detected
protocol_->SendWakeWordDetected(wake_word);
ESP_LOGI(TAG, "Wake word detected: %s", wake_word.c_str());
keep_listening_ = true;
SetChatState(kChatStateListening);
} else if (chat_state_ == kChatStateSpeaking) { } else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking(); AbortSpeaking(kAbortReasonWakeWordDetected);
} }
// Resume detection // Resume detection
@@ -313,15 +350,23 @@ void Application::Start() {
auto state = cJSON_GetObjectItem(root, "state"); auto state = cJSON_GetObjectItem(root, "state");
if (strcmp(state->valuestring, "start") == 0) { if (strcmp(state->valuestring, "start") == 0) {
Schedule([this]() { Schedule([this]() {
skip_to_end_ = false; if (chat_state_ == kChatStateIdle || chat_state_ == kChatStateListening) {
SetChatState(kChatStateSpeaking); skip_to_end_ = false;
opus_decoder_ctl(opus_decoder_, OPUS_RESET_STATE);
SetChatState(kChatStateSpeaking);
}
}); });
} else if (strcmp(state->valuestring, "stop") == 0) { } else if (strcmp(state->valuestring, "stop") == 0) {
Schedule([this]() { Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec(); auto codec = Board::GetInstance().GetAudioCodec();
codec->WaitForOutputDone(); codec->WaitForOutputDone();
if (chat_state_ == kChatStateSpeaking) { if (chat_state_ == kChatStateSpeaking) {
SetChatState(kChatStateListening); if (keep_listening_) {
protocol_->SendStartListening(kListeningModeAutoStop);
SetChatState(kChatStateListening);
} else {
SetChatState(kChatStateIdle);
}
} }
}); });
} else if (strcmp(state->valuestring, "sentence_start") == 0) { } else if (strcmp(state->valuestring, "sentence_start") == 0) {
@@ -375,9 +420,9 @@ void Application::MainLoop() {
} }
} }
void Application::AbortSpeaking() { void Application::AbortSpeaking(AbortReason reason) {
ESP_LOGI(TAG, "Abort speaking"); ESP_LOGI(TAG, "Abort speaking");
protocol_->SendAbort(); protocol_->SendAbortSpeaking(reason);
skip_to_end_ = true; skip_to_end_ = true;
auto codec = Board::GetInstance().GetAudioCodec(); auto codec = Board::GetInstance().GetAudioCodec();
@@ -391,7 +436,6 @@ void Application::SetChatState(ChatState state) {
"connecting", "connecting",
"listening", "listening",
"speaking", "speaking",
"wake_word_detected",
"upgrading", "upgrading",
"invalid_state" "invalid_state"
}; };
@@ -399,12 +443,10 @@ void Application::SetChatState(ChatState state) {
// No need to update the state // No need to update the state
return; return;
} }
chat_state_ = state;
ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]);
auto display = Board::GetInstance().GetDisplay(); auto display = Board::GetInstance().GetDisplay();
auto builtin_led = Board::GetInstance().GetBuiltinLed(); auto builtin_led = Board::GetInstance().GetBuiltinLed();
switch (chat_state_) { switch (state) {
case kChatStateUnknown: case kChatStateUnknown:
case kChatStateIdle: case kChatStateIdle:
builtin_led->TurnOff(); builtin_led->TurnOff();
@@ -424,6 +466,7 @@ void Application::SetChatState(ChatState state) {
builtin_led->TurnOn(); builtin_led->TurnOn();
display->SetStatus("聆听中..."); display->SetStatus("聆听中...");
display->SetEmotion("neutral"); display->SetEmotion("neutral");
opus_encoder_.ResetState();
#ifdef CONFIG_USE_AFE_SR #ifdef CONFIG_USE_AFE_SR
audio_processor_.Start(); audio_processor_.Start();
#endif #endif
@@ -436,17 +479,17 @@ void Application::SetChatState(ChatState state) {
audio_processor_.Stop(); audio_processor_.Stop();
#endif #endif
break; break;
case kChatStateWakeWordDetected:
builtin_led->SetBlue();
builtin_led->TurnOn();
break;
case kChatStateUpgrading: case kChatStateUpgrading:
builtin_led->SetGreen(); builtin_led->SetGreen();
builtin_led->StartContinuousBlink(100); builtin_led->StartContinuousBlink(100);
break; break;
default:
ESP_LOGE(TAG, "Invalid chat state: %d", chat_state_);
return;
} }
protocol_->SendState(state_str[chat_state_]); chat_state_ = state;
ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]);
} }
void Application::AudioEncodeTask() { void Application::AudioEncodeTask() {
@@ -474,7 +517,7 @@ void Application::AudioEncodeTask() {
audio_decode_queue_.pop_front(); audio_decode_queue_.pop_front();
lock.unlock(); lock.unlock();
if (skip_to_end_) { if (skip_to_end_ || chat_state_ != kChatStateSpeaking) {
continue; continue;
} }

View File

@@ -29,7 +29,6 @@ enum ChatState {
kChatStateConnecting, kChatStateConnecting,
kChatStateListening, kChatStateListening,
kChatStateSpeaking, kChatStateSpeaking,
kChatStateWakeWordDetected,
kChatStateUpgrading kChatStateUpgrading
}; };
@@ -41,17 +40,19 @@ public:
static Application instance; static Application instance;
return instance; return instance;
} }
// 删除拷贝构造函数和赋值运算符
Application(const Application&) = delete;
Application& operator=(const Application&) = delete;
void Start(); void Start();
ChatState GetChatState() const { return chat_state_; } ChatState GetChatState() const { return chat_state_; }
void Schedule(std::function<void()> callback); void Schedule(std::function<void()> callback);
void SetChatState(ChatState state); void SetChatState(ChatState state);
void Alert(const std::string&& title, const std::string&& message); void Alert(const std::string&& title, const std::string&& message);
void AbortSpeaking(); void AbortSpeaking(AbortReason reason);
void ToggleChatState(); void ToggleChatState();
// 删除拷贝构造函数和赋值运算符 void StartListening();
Application(const Application&) = delete; void StopListening();
Application& operator=(const Application&) = delete;
private: private:
Application(); Application();
@@ -68,6 +69,7 @@ private:
Protocol* protocol_ = nullptr; Protocol* protocol_ = nullptr;
EventGroupHandle_t event_group_; EventGroupHandle_t event_group_;
volatile ChatState chat_state_ = kChatStateUnknown; volatile ChatState chat_state_ = kChatStateUnknown;
bool keep_listening_ = false;
bool skip_to_end_ = false; bool skip_to_end_ = false;
// Audio encode / decode // Audio encode / decode

View File

@@ -72,9 +72,9 @@ bool MqttProtocol::StartMqttClient() {
} else if (strcmp(type->valuestring, "goodbye") == 0) { } else if (strcmp(type->valuestring, "goodbye") == 0) {
auto session_id = cJSON_GetObjectItem(root, "session_id"); auto session_id = cJSON_GetObjectItem(root, "session_id");
if (session_id == nullptr || session_id_ == session_id->valuestring) { if (session_id == nullptr || session_id_ == session_id->valuestring) {
if (on_audio_channel_closed_ != nullptr) { Application::GetInstance().Schedule([this]() {
on_audio_channel_closed_(); CloseAudioChannel();
} });
} }
} else if (on_incoming_json_ != nullptr) { } else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root); on_incoming_json_(root);
@@ -129,23 +129,6 @@ void MqttProtocol::SendAudio(const std::string& data) {
udp_->Send(encrypted); udp_->Send(encrypted);
} }
void MqttProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}
void MqttProtocol::SendAbort() {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}
void MqttProtocol::CloseAudioChannel() { void MqttProtocol::CloseAudioChannel() {
{ {
std::lock_guard<std::mutex> lock(channel_mutex_); std::lock_guard<std::mutex> lock(channel_mutex_);

View File

@@ -26,9 +26,6 @@ public:
~MqttProtocol(); ~MqttProtocol();
void SendAudio(const std::string& data) override; void SendAudio(const std::string& data) override;
void SendText(const std::string& text) override;
void SendState(const std::string& state) override;
void SendAbort() override;
bool OpenAudioChannel() override; bool OpenAudioChannel() override;
void CloseAudioChannel() override; void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override; bool IsAudioChannelOpened() const override;
@@ -52,11 +49,12 @@ private:
int udp_port_; int udp_port_;
uint32_t local_sequence_; uint32_t local_sequence_;
uint32_t remote_sequence_; uint32_t remote_sequence_;
std::string session_id_;
bool StartMqttClient(); bool StartMqttClient();
void ParseServerHello(const cJSON* root); void ParseServerHello(const cJSON* root);
std::string DecodeHexString(const std::string& hex_string); std::string DecodeHexString(const std::string& hex_string);
void SendText(const std::string& text) override;
}; };

View File

@@ -23,3 +23,37 @@ void Protocol::OnAudioChannelClosed(std::function<void()> callback) {
void Protocol::OnNetworkError(std::function<void(const std::string& message)> callback) { void Protocol::OnNetworkError(std::function<void(const std::string& message)> callback) {
on_network_error_ = callback; on_network_error_ = callback;
} }
void Protocol::SendAbortSpeaking(AbortReason reason) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\"";
if (reason == kAbortReasonWakeWordDetected) {
message += ",\"reason\":\"wake_word_detected\"";
}
message += "}";
SendText(message);
}
void Protocol::SendWakeWordDetected(const std::string& wake_word) {
std::string json = "{\"session_id\":\"" + session_id_ +
"\",\"type\":\"listen\",\"state\":\"detect\",\"text\":\"" + wake_word + "\"}";
SendText(json);
}
void Protocol::SendStartListening(ListeningMode mode) {
std::string message = "{\"session_id\":\"" + session_id_ + "\"";
message += ",\"type\":\"listen\",\"state\":\"start\"";
if (mode == kListeningModeAlwaysOn) {
message += ",\"mode\":\"realtime\"";
} else if (mode == kListeningModeAutoStop) {
message += ",\"mode\":\"auto\"";
} else {
message += ",\"mode\":\"manual\"";
}
message += "}";
SendText(message);
}
void Protocol::SendStopListening() {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"listen\",\"state\":\"stop\"}";
SendText(message);
}

View File

@@ -12,6 +12,16 @@ struct BinaryProtocol3 {
uint8_t payload[]; uint8_t payload[];
} __attribute__((packed)); } __attribute__((packed));
enum AbortReason {
kAbortReasonNone,
kAbortReasonWakeWordDetected
};
enum ListeningMode {
kListeningModeAutoStop,
kListeningModeManualStop,
kListeningModeAlwaysOn // 需要 AEC 支持
};
class Protocol { class Protocol {
public: public:
@@ -27,13 +37,14 @@ public:
void OnAudioChannelClosed(std::function<void()> callback); void OnAudioChannelClosed(std::function<void()> callback);
void OnNetworkError(std::function<void(const std::string& message)> callback); void OnNetworkError(std::function<void(const std::string& message)> callback);
virtual void SendAudio(const std::string& data) = 0;
virtual void SendText(const std::string& text) = 0;
virtual void SendState(const std::string& state) = 0;
virtual void SendAbort() = 0;
virtual bool OpenAudioChannel() = 0; virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0; virtual void CloseAudioChannel() = 0;
virtual bool IsAudioChannelOpened() const = 0; virtual bool IsAudioChannelOpened() const = 0;
virtual void SendAudio(const std::string& data) = 0;
virtual void SendWakeWordDetected(const std::string& wake_word);
virtual void SendStartListening(ListeningMode mode);
virtual void SendStopListening();
virtual void SendAbortSpeaking(AbortReason reason);
protected: protected:
std::function<void(const cJSON* root)> on_incoming_json_; std::function<void(const cJSON* root)> on_incoming_json_;
@@ -43,6 +54,9 @@ protected:
std::function<void(const std::string& message)> on_network_error_; std::function<void(const std::string& message)> on_network_error_;
int server_sample_rate_ = 16000; int server_sample_rate_ = 16000;
std::string session_id_;
virtual void SendText(const std::string& text) = 0;
}; };
#endif // PROTOCOL_H #endif // PROTOCOL_H

View File

@@ -39,21 +39,6 @@ void WebsocketProtocol::SendText(const std::string& text) {
websocket_->Send(text); websocket_->Send(text);
} }
void WebsocketProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}
void WebsocketProtocol::SendAbort() {
std::string message = "{";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}
bool WebsocketProtocol::IsAudioChannelOpened() const { bool WebsocketProtocol::IsAudioChannelOpened() const {
return websocket_ != nullptr; return websocket_ != nullptr;
} }

View File

@@ -16,9 +16,6 @@ public:
~WebsocketProtocol(); ~WebsocketProtocol();
void SendAudio(const std::string& data) override; void SendAudio(const std::string& data) override;
void SendText(const std::string& text) override;
void SendState(const std::string& state) override;
void SendAbort() override;
bool OpenAudioChannel() override; bool OpenAudioChannel() override;
void CloseAudioChannel() override; void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override; bool IsAudioChannelOpened() const override;
@@ -28,6 +25,7 @@ private:
WebSocket* websocket_ = nullptr; WebSocket* websocket_ = nullptr;
void ParseServerHello(const cJSON* root); void ParseServerHello(const cJSON* root);
void SendText(const std::string& text) override;
}; };
#endif #endif

View File

@@ -4,9 +4,9 @@
#include <esp_log.h> #include <esp_log.h>
#include <model_path.h> #include <model_path.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <sstream>
#define DETECTION_RUNNING_EVENT 1 #define DETECTION_RUNNING_EVENT 1
#define WAKE_WORD_ENCODED_EVENT 2
static const char* TAG = "WakeWordDetect"; static const char* TAG = "WakeWordDetect";
@@ -40,6 +40,13 @@ void WakeWordDetect::Initialize(int channels, bool reference) {
ESP_LOGI(TAG, "Model %d: %s", i, models->model_name[i]); ESP_LOGI(TAG, "Model %d: %s", i, models->model_name[i]);
if (strstr(models->model_name[i], ESP_WN_PREFIX) != NULL) { if (strstr(models->model_name[i], ESP_WN_PREFIX) != NULL) {
wakenet_model_ = models->model_name[i]; wakenet_model_ = models->model_name[i];
auto words = esp_srmodel_get_wake_words(models, wakenet_model_);
// split by ";" to get all wake words
std::stringstream ss(words);
std::string word;
while (std::getline(ss, word, ';')) {
wake_words_.push_back(word);
}
} }
} }
@@ -84,7 +91,7 @@ void WakeWordDetect::Initialize(int channels, bool reference) {
}, "audio_detection", 4096 * 2, this, 1, nullptr); }, "audio_detection", 4096 * 2, this, 1, nullptr);
} }
void WakeWordDetect::OnWakeWordDetected(std::function<void()> callback) { void WakeWordDetect::OnWakeWordDetected(std::function<void(const std::string& wake_word)> callback) {
wake_word_detected_callback_ = callback; wake_word_detected_callback_ = callback;
} }
@@ -144,11 +151,11 @@ void WakeWordDetect::AudioDetectionTask() {
} }
if (res->wakeup_state == WAKENET_DETECTED) { if (res->wakeup_state == WAKENET_DETECTED) {
ESP_LOGI(TAG, "Wake word detected");
StopDetection(); StopDetection();
last_detected_wake_word_ = wake_words_[res->wake_word_index - 1];
if (wake_word_detected_callback_) { if (wake_word_detected_callback_) {
wake_word_detected_callback_(); wake_word_detected_callback_(last_detected_wake_word_);
} }
} }
} }
@@ -165,7 +172,6 @@ void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) {
} }
void WakeWordDetect::EncodeWakeWordData() { void WakeWordDetect::EncodeWakeWordData() {
xEventGroupClearBits(event_group_, WAKE_WORD_ENCODED_EVENT);
wake_word_opus_.clear(); wake_word_opus_.clear();
if (wake_word_encode_task_stack_ == nullptr) { if (wake_word_encode_task_stack_ == nullptr) {
wake_word_encode_task_stack_ = (StackType_t*)heap_caps_malloc(4096 * 8, MALLOC_CAP_SPIRAM); wake_word_encode_task_stack_ = (StackType_t*)heap_caps_malloc(4096 * 8, MALLOC_CAP_SPIRAM);
@@ -182,15 +188,18 @@ void WakeWordDetect::EncodeWakeWordData() {
encoder->Encode(pcm, [this_](const uint8_t* opus, size_t opus_size) { encoder->Encode(pcm, [this_](const uint8_t* opus, size_t opus_size) {
std::lock_guard<std::mutex> lock(this_->wake_word_mutex_); std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
this_->wake_word_opus_.emplace_back(std::string(reinterpret_cast<const char*>(opus), opus_size)); this_->wake_word_opus_.emplace_back(std::string(reinterpret_cast<const char*>(opus), opus_size));
this_->wake_word_cv_.notify_one(); this_->wake_word_cv_.notify_all();
}); });
} }
this_->wake_word_pcm_.clear(); this_->wake_word_pcm_.clear();
auto end_time = esp_timer_get_time(); auto end_time = esp_timer_get_time();
ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000); ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT); {
this_->wake_word_cv_.notify_one(); std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
this_->wake_word_opus_.push_back("");
this_->wake_word_cv_.notify_all();
}
delete encoder; delete encoder;
vTaskDelete(NULL); vTaskDelete(NULL);
}, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_); }, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_);
@@ -199,12 +208,9 @@ void WakeWordDetect::EncodeWakeWordData() {
bool WakeWordDetect::GetWakeWordOpus(std::string& opus) { bool WakeWordDetect::GetWakeWordOpus(std::string& opus) {
std::unique_lock<std::mutex> lock(wake_word_mutex_); std::unique_lock<std::mutex> lock(wake_word_mutex_);
wake_word_cv_.wait(lock, [this]() { wake_word_cv_.wait(lock, [this]() {
return !wake_word_opus_.empty() || (xEventGroupGetBits(event_group_) & WAKE_WORD_ENCODED_EVENT); return !wake_word_opus_.empty();
}); });
if (wake_word_opus_.empty()) {
return false;
}
opus.swap(wake_word_opus_.front()); opus.swap(wake_word_opus_.front());
wake_word_opus_.pop_front(); wake_word_opus_.pop_front();
return true; return !opus.empty();
} }

View File

@@ -23,24 +23,27 @@ public:
void Initialize(int channels, bool reference); void Initialize(int channels, bool reference);
void Feed(std::vector<int16_t>& data); void Feed(std::vector<int16_t>& data);
void OnWakeWordDetected(std::function<void()> callback); void OnWakeWordDetected(std::function<void(const std::string& wake_word)> callback);
void OnVadStateChange(std::function<void(bool speaking)> callback); void OnVadStateChange(std::function<void(bool speaking)> callback);
void StartDetection(); void StartDetection();
void StopDetection(); void StopDetection();
bool IsDetectionRunning(); bool IsDetectionRunning();
void EncodeWakeWordData(); void EncodeWakeWordData();
bool GetWakeWordOpus(std::string& opus); bool GetWakeWordOpus(std::string& opus);
const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; }
private: private:
esp_afe_sr_data_t* afe_detection_data_ = nullptr; esp_afe_sr_data_t* afe_detection_data_ = nullptr;
char* wakenet_model_ = NULL; char* wakenet_model_ = NULL;
std::vector<std::string> wake_words_;
std::vector<int16_t> input_buffer_; std::vector<int16_t> input_buffer_;
EventGroupHandle_t event_group_; EventGroupHandle_t event_group_;
std::function<void()> wake_word_detected_callback_; std::function<void(const std::string& wake_word)> wake_word_detected_callback_;
std::function<void(bool speaking)> vad_state_change_callback_; std::function<void(bool speaking)> vad_state_change_callback_;
bool is_speaking_ = false; bool is_speaking_ = false;
int channels_; int channels_;
bool reference_; bool reference_;
std::string last_detected_wake_word_;
TaskHandle_t wake_word_encode_task_ = nullptr; TaskHandle_t wake_word_encode_task_ = nullptr;
StaticTask_t wake_word_encode_task_buffer_; StaticTask_t wake_word_encode_task_buffer_;