diff --git a/CMakeLists.txt b/CMakeLists.txt index 56341858..12cbc4f2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ # CMakeLists in this exact order for cmake to work correctly cmake_minimum_required(VERSION 3.16) -set(PROJECT_VER "0.9.5") +set(PROJECT_VER "0.9.6") include($ENV{IDF_PATH}/tools/cmake/project.cmake) project(xiaozhi) diff --git a/main/application.cc b/main/application.cc index ef339242..40d93c46 100644 --- a/main/application.cc +++ b/main/application.cc @@ -21,6 +21,15 @@ extern const char p3_err_pin_end[] asm("_binary_err_pin_p3_end"); extern const char p3_err_wificonfig_start[] asm("_binary_err_wificonfig_p3_start"); extern const char p3_err_wificonfig_end[] asm("_binary_err_wificonfig_p3_end"); +static const char* const STATE_STRINGS[] = { + "unknown", + "idle", + "connecting", + "listening", + "speaking", + "upgrading", + "invalid_state" +}; Application::Application() : background_task_(4096 * 8) { event_group_ = xEventGroupCreate(); @@ -30,13 +39,6 @@ Application::Application() : background_task_(4096 * 8) { } Application::~Application() { - if (protocol_ != nullptr) { - delete protocol_; - } - if (opus_decoder_ != nullptr) { - opus_decoder_destroy(opus_decoder_); - } - vEventGroupDelete(event_group_); } @@ -83,7 +85,7 @@ void Application::CheckNewVersion() { } } -void Application::Alert(const std::string&& title, const std::string&& message) { +void Application::Alert(const std::string& title, const std::string& message) { ESP_LOGW(TAG, "Alert: %s, %s", title.c_str(), message.c_str()); auto display = Board::GetInstance().GetDisplay(); display->ShowNotification(message); @@ -105,7 +107,7 @@ void Application::PlayLocalFile(const char* data, size_t size) { p += sizeof(BinaryProtocol3); auto payload_size = ntohs(p3->payload_size); - std::string opus; + std::vector opus; opus.resize(payload_size); memcpy(opus.data(), p3->payload, payload_size); p += payload_size; @@ -117,10 +119,15 @@ void Application::PlayLocalFile(const char* data, size_t size) { void Application::ToggleChatState() { Schedule([this]() { + if (!protocol_) { + ESP_LOGE(TAG, "Protocol not initialized"); + return; + } + if (chat_state_ == kChatStateIdle) { SetChatState(kChatStateConnecting); if (!protocol_->OpenAudioChannel()) { - ESP_LOGE(TAG, "Failed to open audio channel"); + Alert("Error", "Failed to open audio channel"); SetChatState(kChatStateIdle); return; } @@ -138,13 +145,18 @@ void Application::ToggleChatState() { void Application::StartListening() { Schedule([this]() { + if (!protocol_) { + ESP_LOGE(TAG, "Protocol not initialized"); + return; + } + keep_listening_ = false; if (chat_state_ == kChatStateIdle) { if (!protocol_->IsAudioChannelOpened()) { SetChatState(kChatStateConnecting); if (!protocol_->OpenAudioChannel()) { SetChatState(kChatStateIdle); - ESP_LOGE(TAG, "Failed to open audio channel"); + Alert("Error", "Failed to open audio channel"); return; } } @@ -183,8 +195,8 @@ void Application::Start() { /* Setup the audio codec */ auto codec = board.GetAudioCodec(); opus_decode_sample_rate_ = codec->output_sample_rate(); - opus_decoder_ = opus_decoder_create(opus_decode_sample_rate_, 1, NULL); - opus_encoder_.Configure(16000, 1, OPUS_FRAME_DURATION_MS); + opus_decoder_ = std::make_unique(opus_decode_sample_rate_, 1); + opus_encoder_ = std::make_unique(16000, 1, OPUS_FRAME_DURATION_MS); if (codec->input_sample_rate() != 16000) { input_resampler_.Configure(codec->input_sample_rate(), 16000); reference_resampler_.Configure(codec->input_sample_rate(), 16000); @@ -221,9 +233,9 @@ void Application::Start() { #if CONFIG_IDF_TARGET_ESP32S3 audio_processor_.Initialize(codec->input_channels(), codec->input_reference()); audio_processor_.OnOutput([this](std::vector&& data) { - background_task_.Schedule([this, data = std::move(data)]() { - opus_encoder_.Encode(data, [this](const uint8_t* opus, size_t opus_size) { - Schedule([this, opus = std::string(reinterpret_cast(opus), opus_size)]() { + background_task_.Schedule([this, data = std::move(data)]() mutable { + opus_encoder_->Encode(std::move(data), [this](std::vector&& opus) { + Schedule([this, opus = std::move(opus)]() { protocol_->SendAudio(opus); }); }); @@ -258,7 +270,7 @@ void Application::Start() { return; } - std::string opus; + std::vector opus; // Encode and send the wake word data to the server while (wake_word_detect_.GetWakeWordOpus(opus)) { protocol_->SendAudio(opus); @@ -282,14 +294,14 @@ void Application::Start() { // Initialize the protocol display->SetStatus("初始化协议"); #ifdef CONFIG_CONNECTION_TYPE_WEBSOCKET - protocol_ = new WebsocketProtocol(); + protocol_ = std::make_unique(); #else - protocol_ = new MqttProtocol(); + protocol_ = std::make_unique(); #endif protocol_->OnNetworkError([this](const std::string& message) { Alert("Error", std::move(message)); }); - protocol_->OnIncomingAudio([this](const std::string& data) { + protocol_->OnIncomingAudio([this](std::vector&& data) { std::lock_guard lock(mutex_); if (chat_state_ == kChatStateSpeaking) { audio_decode_queue_.emplace_back(std::move(data)); @@ -363,9 +375,8 @@ void Application::Start() { } void Application::Schedule(std::function callback) { - mutex_.lock(); - main_tasks_.push_back(callback); - mutex_.unlock(); + std::lock_guard lock(mutex_); + main_tasks_.push_back(std::move(callback)); xEventGroupSetBits(event_group_, SCHEDULE_EVENT); } @@ -397,7 +408,7 @@ void Application::MainLoop() { void Application::ResetDecoder() { std::lock_guard lock(mutex_); - opus_decoder_ctl(opus_decoder_, OPUS_RESET_STATE); + opus_decoder_->ResetState(); audio_decode_queue_.clear(); last_output_time_ = std::chrono::steady_clock::now(); Board::GetInstance().GetAudioCodec()->EnableOutput(true); @@ -430,24 +441,21 @@ void Application::OutputAudio() { audio_decode_queue_.pop_front(); lock.unlock(); - background_task_.Schedule([this, codec, opus = std::move(opus)]() { + background_task_.Schedule([this, codec, opus = std::move(opus)]() mutable { if (aborted_) { return; } - int frame_size = opus_decode_sample_rate_ * OPUS_FRAME_DURATION_MS / 1000; - std::vector pcm(frame_size); - int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0); - if (ret < 0) { - ESP_LOGE(TAG, "Failed to decode audio, error code: %d", ret); + std::vector pcm; + if (!opus_decoder_->Decode(std::move(opus), pcm)) { return; } // Resample if the sample rate is different if (opus_decode_sample_rate_ != codec->output_sample_rate()) { - int target_size = output_resampler_.GetOutputSamples(frame_size); + int target_size = output_resampler_.GetOutputSamples(pcm.size()); std::vector resampled(target_size); - output_resampler_.Process(pcm.data(), frame_size, resampled.data()); + output_resampler_.Process(pcm.data(), pcm.size(), resampled.data()); pcm = std::move(resampled); } @@ -495,9 +503,9 @@ void Application::InputAudio() { } #else if (chat_state_ == kChatStateListening) { - background_task_.Schedule([this, data = std::move(data)]() { - opus_encoder_.Encode(data, [this](const uint8_t* opus, size_t opus_size) { - Schedule([this, opus = std::string(reinterpret_cast(opus), opus_size)]() { + background_task_.Schedule([this, data = std::move(data)]() mutable { + opus_encoder_->Encode(std::move(data), [this](std::vector&& opus) { + Schedule([this, opus = std::move(opus)]() { protocol_->SendAudio(opus); }); }); @@ -513,22 +521,12 @@ void Application::AbortSpeaking(AbortReason reason) { } void Application::SetChatState(ChatState state) { - const char* state_str[] = { - "unknown", - "idle", - "connecting", - "listening", - "speaking", - "upgrading", - "invalid_state" - }; if (chat_state_ == state) { - // No need to update the state return; } - + chat_state_ = state; - ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]); + ESP_LOGI(TAG, "STATE: %s", STATE_STRINGS[chat_state_]); // The state is changed, wait for all background tasks to finish background_task_.WaitForCompletion(); @@ -555,7 +553,7 @@ void Application::SetChatState(ChatState state) { display->SetStatus("聆听中..."); display->SetEmotion("neutral"); ResetDecoder(); - opus_encoder_.ResetState(); + opus_encoder_->ResetState(); #if CONFIG_IDF_TARGET_ESP32S3 audio_processor_.Start(); #endif @@ -584,9 +582,8 @@ void Application::SetDecodeSampleRate(int sample_rate) { return; } - opus_decoder_destroy(opus_decoder_); opus_decode_sample_rate_ = sample_rate; - opus_decoder_ = opus_decoder_create(opus_decode_sample_rate_, 1, NULL); + opus_decoder_ = std::make_unique(opus_decode_sample_rate_, 1); auto codec = Board::GetInstance().GetAudioCodec(); if (opus_decode_sample_rate_ != codec->output_sample_rate()) { diff --git a/main/application.h b/main/application.h index ff7923b4..889d635d 100644 --- a/main/application.h +++ b/main/application.h @@ -10,6 +10,7 @@ #include #include "opus_encoder.h" +#include "opus_decoder.h" #include "opus_resampler.h" #include "protocol.h" @@ -52,7 +53,7 @@ public: ChatState GetChatState() const { return chat_state_; } void Schedule(std::function callback); void SetChatState(ChatState state); - void Alert(const std::string&& title, const std::string&& message); + void Alert(const std::string& title, const std::string& message); void AbortSpeaking(AbortReason reason); void ToggleChatState(); void StartListening(); @@ -69,7 +70,7 @@ private: Ota ota_; std::mutex mutex_; std::list> main_tasks_; - Protocol* protocol_ = nullptr; + std::unique_ptr protocol_; EventGroupHandle_t event_group_; volatile ChatState chat_state_ = kChatStateUnknown; bool keep_listening_ = false; @@ -78,10 +79,10 @@ private: // Audio encode / decode BackgroundTask background_task_; std::chrono::steady_clock::time_point last_output_time_; - std::list audio_decode_queue_; + std::list> audio_decode_queue_; - OpusEncoder opus_encoder_; - OpusDecoder* opus_decoder_ = nullptr; + std::unique_ptr opus_encoder_; + std::unique_ptr opus_decoder_; int opus_decode_sample_rate_ = -1; OpusResampler input_resampler_; diff --git a/main/audio_processing/audio_processor.cc b/main/audio_processing/audio_processor.cc index 803474ca..bb73c766 100644 --- a/main/audio_processing/audio_processor.cc +++ b/main/audio_processing/audio_processor.cc @@ -63,7 +63,7 @@ AudioProcessor::~AudioProcessor() { vEventGroupDelete(event_group_); } -void AudioProcessor::Input(std::vector& data) { +void AudioProcessor::Input(const std::vector& data) { input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); auto chunk_size = esp_afe_vc_v1.get_feed_chunksize(afe_communication_data_) * channels_; diff --git a/main/audio_processing/audio_processor.h b/main/audio_processing/audio_processor.h index d095e0cc..3c8fd90b 100644 --- a/main/audio_processing/audio_processor.h +++ b/main/audio_processing/audio_processor.h @@ -16,7 +16,7 @@ public: ~AudioProcessor(); void Initialize(int channels, bool reference); - void Input(std::vector& data); + void Input(const std::vector& data); void Start(); void Stop(); bool IsRunning(); diff --git a/main/audio_processing/wake_word_detect.cc b/main/audio_processing/wake_word_detect.cc index 68b05f32..307e9dd4 100644 --- a/main/audio_processing/wake_word_detect.cc +++ b/main/audio_processing/wake_word_detect.cc @@ -111,7 +111,7 @@ bool WakeWordDetect::IsDetectionRunning() { return xEventGroupGetBits(event_group_) & DETECTION_RUNNING_EVENT; } -void WakeWordDetect::Feed(std::vector& data) { +void WakeWordDetect::Feed(const std::vector& data) { input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); auto chunk_size = esp_afe_sr_v1.get_feed_chunksize(afe_detection_data_) * channels_; @@ -163,8 +163,7 @@ void WakeWordDetect::AudioDetectionTask() { void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) { // store audio data to wake_word_pcm_ - std::vector pcm(data, data + samples); - wake_word_pcm_.emplace_back(std::move(pcm)); + wake_word_pcm_.emplace_back(std::vector(data, data + samples)); // keep about 2 seconds of data, detect duration is 32ms (sample_rate == 16000, chunksize == 512) while (wake_word_pcm_.size() > 2000 / 32) { wake_word_pcm_.pop_front(); @@ -178,34 +177,33 @@ void WakeWordDetect::EncodeWakeWordData() { } wake_word_encode_task_ = xTaskCreateStatic([](void* arg) { auto this_ = (WakeWordDetect*)arg; - auto start_time = esp_timer_get_time(); - // encode detect packets - OpusEncoder* encoder = new OpusEncoder(); - encoder->Configure(16000, 1, 60); - encoder->SetComplexity(0); - - for (auto& pcm: this_->wake_word_pcm_) { - encoder->Encode(pcm, [this_](const uint8_t* opus, size_t opus_size) { - std::lock_guard lock(this_->wake_word_mutex_); - this_->wake_word_opus_.emplace_back(std::string(reinterpret_cast(opus), opus_size)); - this_->wake_word_cv_.notify_all(); - }); - } - this_->wake_word_pcm_.clear(); - - auto end_time = esp_timer_get_time(); - ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000); { + auto start_time = esp_timer_get_time(); + auto encoder = std::make_unique(16000, 1, OPUS_FRAME_DURATION_MS); + encoder->SetComplexity(0); // 0 is the fastest + + for (auto& pcm: this_->wake_word_pcm_) { + encoder->Encode(std::move(pcm), [this_](std::vector&& opus) { + std::lock_guard lock(this_->wake_word_mutex_); + this_->wake_word_opus_.emplace_back(std::move(opus)); + this_->wake_word_cv_.notify_all(); + }); + } + this_->wake_word_pcm_.clear(); + + auto end_time = esp_timer_get_time(); + ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", + this_->wake_word_opus_.size(), (end_time - start_time) / 1000); + std::lock_guard lock(this_->wake_word_mutex_); - this_->wake_word_opus_.push_back(""); + this_->wake_word_opus_.push_back(std::vector()); this_->wake_word_cv_.notify_all(); } - delete encoder; vTaskDelete(NULL); }, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_); } -bool WakeWordDetect::GetWakeWordOpus(std::string& opus) { +bool WakeWordDetect::GetWakeWordOpus(std::vector& opus) { std::unique_lock lock(wake_word_mutex_); wake_word_cv_.wait(lock, [this]() { return !wake_word_opus_.empty(); diff --git a/main/audio_processing/wake_word_detect.h b/main/audio_processing/wake_word_detect.h index 892ea56b..0a356b40 100644 --- a/main/audio_processing/wake_word_detect.h +++ b/main/audio_processing/wake_word_detect.h @@ -22,14 +22,14 @@ public: ~WakeWordDetect(); void Initialize(int channels, bool reference); - void Feed(std::vector& data); + void Feed(const std::vector& data); void OnWakeWordDetected(std::function callback); void OnVadStateChange(std::function callback); void StartDetection(); void StopDetection(); bool IsDetectionRunning(); void EncodeWakeWordData(); - bool GetWakeWordOpus(std::string& opus); + bool GetWakeWordOpus(std::vector& opus); const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; } private: @@ -49,7 +49,7 @@ private: StaticTask_t wake_word_encode_task_buffer_; StackType_t* wake_word_encode_task_stack_ = nullptr; std::list> wake_word_pcm_; - std::list wake_word_opus_; + std::list> wake_word_opus_; std::mutex wake_word_mutex_; std::condition_variable wake_word_cv_; diff --git a/main/background_task.cc b/main/background_task.cc index 5e0146ca..44481ffe 100644 --- a/main/background_task.cc +++ b/main/background_task.cc @@ -31,12 +31,16 @@ void BackgroundTask::Schedule(std::function callback) { ESP_LOGW(TAG, "active_tasks_ == %u", active_tasks_.load()); } active_tasks_++; - auto wrapped_callback = [this, callback]() { - callback(); - active_tasks_--; - condition_variable_.notify_all(); - }; - main_tasks_.push_back(wrapped_callback); + main_tasks_.emplace_back([this, cb = std::move(callback)]() { + cb(); + { + std::lock_guard lock(mutex_); + active_tasks_--; + if (main_tasks_.empty() && active_tasks_ == 0) { + condition_variable_.notify_all(); + } + } + }); condition_variable_.notify_all(); } diff --git a/main/idf_component.yml b/main/idf_component.yml index 03f6aff1..5d6ffcdb 100644 --- a/main/idf_component.yml +++ b/main/idf_component.yml @@ -1,7 +1,7 @@ ## IDF Component Manager Manifest File dependencies: 78/esp-wifi-connect: "~1.4.1" - 78/esp-opus-encoder: "~1.1.0" + 78/esp-opus-encoder: "~2.0.0" 78/esp-ml307: "~1.7.0" espressif/led_strip: "^2.4.1" espressif/esp_codec_dev: "^1.3.1" diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index ba6d9d06..70455e18 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -105,7 +105,7 @@ void MqttProtocol::SendText(const std::string& text) { mqtt_->Publish(publish_topic_, text); } -void MqttProtocol::SendAudio(const std::string& data) { +void MqttProtocol::SendAudio(const std::vector& data) { std::lock_guard lock(channel_mutex_); if (udp_ == nullptr) { return; @@ -202,7 +202,7 @@ bool MqttProtocol::OpenAudioChannel() { ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1); } - std::string decrypted; + std::vector decrypted; size_t decrypted_size = data.size() - aes_nonce_.size(); size_t nc_off = 0; uint8_t stream_block[16] = {0}; @@ -215,7 +215,7 @@ bool MqttProtocol::OpenAudioChannel() { return; } if (on_incoming_audio_ != nullptr) { - on_incoming_audio_(decrypted); + on_incoming_audio_(std::move(decrypted)); } remote_sequence_ = sequence; }); diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index c6da3efe..4dd84ea3 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -25,7 +25,7 @@ public: MqttProtocol(); ~MqttProtocol(); - void SendAudio(const std::string& data) override; + void SendAudio(const std::vector& data) override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override; diff --git a/main/protocols/protocol.cc b/main/protocols/protocol.cc index 5c55141e..12ecd951 100644 --- a/main/protocols/protocol.cc +++ b/main/protocols/protocol.cc @@ -8,7 +8,7 @@ void Protocol::OnIncomingJson(std::function callback) { on_incoming_json_ = callback; } -void Protocol::OnIncomingAudio(std::function callback) { +void Protocol::OnIncomingAudio(std::function&& data)> callback) { on_incoming_audio_ = callback; } diff --git a/main/protocols/protocol.h b/main/protocols/protocol.h index 5b6216ab..f51d549e 100644 --- a/main/protocols/protocol.h +++ b/main/protocols/protocol.h @@ -31,7 +31,7 @@ public: return server_sample_rate_; } - void OnIncomingAudio(std::function callback); + void OnIncomingAudio(std::function&& data)> callback); void OnIncomingJson(std::function callback); void OnAudioChannelOpened(std::function callback); void OnAudioChannelClosed(std::function callback); @@ -40,7 +40,7 @@ public: virtual bool OpenAudioChannel() = 0; virtual void CloseAudioChannel() = 0; virtual bool IsAudioChannelOpened() const = 0; - virtual void SendAudio(const std::string& data) = 0; + virtual void SendAudio(const std::vector& data) = 0; virtual void SendWakeWordDetected(const std::string& wake_word); virtual void SendStartListening(ListeningMode mode); virtual void SendStopListening(); @@ -48,7 +48,7 @@ public: protected: std::function on_incoming_json_; - std::function on_incoming_audio_; + std::function&& data)> on_incoming_audio_; std::function on_audio_channel_opened_; std::function on_audio_channel_closed_; std::function on_network_error_; diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 500615be..156276e5 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -23,7 +23,7 @@ WebsocketProtocol::~WebsocketProtocol() { vEventGroupDelete(event_group_handle_); } -void WebsocketProtocol::SendAudio(const std::string& data) { +void WebsocketProtocol::SendAudio(const std::vector& data) { if (websocket_ == nullptr) { return; } @@ -65,7 +65,7 @@ bool WebsocketProtocol::OpenAudioChannel() { websocket_->OnData([this](const char* data, size_t len, bool binary) { if (binary) { if (on_incoming_audio_ != nullptr) { - on_incoming_audio_(std::string(data, len)); + on_incoming_audio_(std::vector((uint8_t*)data, (uint8_t*)data + len)); } } else { // Parse JSON data diff --git a/main/protocols/websocket_protocol.h b/main/protocols/websocket_protocol.h index f62b04f0..5629e8af 100644 --- a/main/protocols/websocket_protocol.h +++ b/main/protocols/websocket_protocol.h @@ -15,7 +15,7 @@ public: WebsocketProtocol(); ~WebsocketProtocol(); - void SendAudio(const std::string& data) override; + void SendAudio(const std::vector& data) override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override;