From b804343d30d6f2605684959436087e0d5886b375 Mon Sep 17 00:00:00 2001 From: Xiaoxia Date: Mon, 28 Apr 2025 23:10:24 +0800 Subject: [PATCH] Audio stream packet with timestamp --- main/application.cc | 41 +++++++++++++++---------- main/application.h | 3 +- main/protocols/mqtt_protocol.cc | 18 +++++------ main/protocols/mqtt_protocol.h | 2 +- main/protocols/protocol.cc | 2 +- main/protocols/protocol.h | 8 ++--- main/protocols/websocket_protocol.cc | 45 +++++++++++++++++----------- main/protocols/websocket_protocol.h | 2 +- 8 files changed, 70 insertions(+), 51 deletions(-) diff --git a/main/application.cc b/main/application.cc index 8d91ae11..102564a0 100644 --- a/main/application.cc +++ b/main/application.cc @@ -245,13 +245,13 @@ void Application::PlaySound(const std::string_view& sound) { p += sizeof(BinaryProtocol3); auto payload_size = ntohs(p3->payload_size); - std::vector opus; - opus.resize(payload_size); - memcpy(opus.data(), p3->payload, payload_size); + AudioStreamPacket packet; + packet.payload.resize(payload_size); + memcpy(packet.payload.data(), p3->payload, payload_size); p += payload_size; std::lock_guard lock(mutex_); - audio_decode_queue_.emplace_back(std::move(opus)); + audio_decode_queue_.emplace_back(std::move(packet)); } } @@ -391,11 +391,11 @@ void Application::Start() { SetDeviceState(kDeviceStateIdle); Alert(Lang::Strings::ERROR, message.c_str(), "sad", Lang::Sounds::P3_EXCLAMATION); }); - protocol_->OnIncomingAudio([this](std::vector&& data) { + protocol_->OnIncomingAudio([this](AudioStreamPacket&& packet) { const int max_packets_in_queue = 600 / OPUS_FRAME_DURATION_MS; std::lock_guard lock(mutex_); if (audio_decode_queue_.size() < max_packets_in_queue) { - audio_decode_queue_.emplace_back(std::move(data)); + audio_decode_queue_.emplace_back(std::move(packet)); } }); protocol_->OnAudioChannelOpened([this, codec, &board]() { @@ -510,8 +510,12 @@ void Application::Start() { return; } opus_encoder_->Encode(std::move(data), [this](std::vector&& opus) { - Schedule([this, opus = std::move(opus)]() { - protocol_->SendAudio(opus); + AudioStreamPacket packet; + packet.payload = std::move(opus); + packet.timestamp = last_output_timestamp_; + last_output_timestamp_ = 0; + Schedule([this, packet = std::move(packet)]() { + protocol_->SendAudio(packet); }); }); }); @@ -544,10 +548,10 @@ void Application::Start() { return; } - std::vector opus; + AudioStreamPacket packet; // Encode and send the wake word data to the server - while (wake_word_detect_.GetWakeWordOpus(opus)) { - protocol_->SendAudio(opus); + while (wake_word_detect_.GetWakeWordOpus(packet.payload)) { + protocol_->SendAudio(packet); } // Set the chat state to wake word detected protocol_->SendWakeWordDetected(wake_word); @@ -671,20 +675,20 @@ void Application::OnAudioOutput() { return; } - auto opus = std::move(audio_decode_queue_.front()); + auto packet = std::move(audio_decode_queue_.front()); audio_decode_queue_.pop_front(); lock.unlock(); audio_decode_cv_.notify_all(); busy_decoding_audio_ = true; - background_task_->Schedule([this, codec, opus = std::move(opus)]() mutable { + background_task_->Schedule([this, codec, packet = std::move(packet)]() mutable { busy_decoding_audio_ = false; if (aborted_) { return; } std::vector pcm; - if (!opus_decoder_->Decode(std::move(opus), pcm)) { + if (!opus_decoder_->Decode(std::move(packet.payload), pcm)) { return; } // Resample if the sample rate is different @@ -695,6 +699,7 @@ void Application::OnAudioOutput() { pcm = std::move(resampled); } codec->OutputData(pcm); + last_output_timestamp_ = packet.timestamp; last_output_time_ = std::chrono::steady_clock::now(); }); } @@ -730,8 +735,12 @@ void Application::OnAudioInput() { return; } opus_encoder_->Encode(std::move(data), [this](std::vector&& opus) { - Schedule([this, opus = std::move(opus)]() { - protocol_->SendAudio(opus); + AudioStreamPacket packet; + packet.payload = std::move(opus); + packet.timestamp = last_output_timestamp_; + last_output_timestamp_ = 0; + Schedule([this, packet = std::move(packet)]() { + protocol_->SendAudio(packet); }); }); }); diff --git a/main/application.h b/main/application.h index 0fe6d04c..a806c346 100644 --- a/main/application.h +++ b/main/application.h @@ -107,7 +107,8 @@ private: TaskHandle_t audio_loop_task_handle_ = nullptr; BackgroundTask* background_task_ = nullptr; std::chrono::steady_clock::time_point last_output_time_; - std::list> audio_decode_queue_; + std::atomic last_output_timestamp_ = 0; + std::list audio_decode_queue_; std::condition_variable audio_decode_cv_; std::unique_ptr opus_encoder_; diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index 19f94b83..f995304e 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -121,24 +121,24 @@ bool MqttProtocol::SendText(const std::string& text) { return true; } -void MqttProtocol::SendAudio(const std::vector& data) { +void MqttProtocol::SendAudio(const AudioStreamPacket& packet) { std::lock_guard lock(channel_mutex_); if (udp_ == nullptr) { return; } std::string nonce(aes_nonce_); - *(uint16_t*)&nonce[2] = htons(data.size()); + *(uint16_t*)&nonce[2] = htons(packet.payload.size()); *(uint32_t*)&nonce[12] = htonl(++local_sequence_); std::string encrypted; - encrypted.resize(aes_nonce_.size() + data.size()); + encrypted.resize(aes_nonce_.size() + packet.payload.size()); memcpy(encrypted.data(), nonce.data(), nonce.size()); size_t nc_off = 0; uint8_t stream_block[16] = {0}; - if (mbedtls_aes_crypt_ctr(&aes_ctx_, data.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block, - (uint8_t*)data.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) { + if (mbedtls_aes_crypt_ctr(&aes_ctx_, packet.payload.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block, + (uint8_t*)packet.payload.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) { ESP_LOGE(TAG, "Failed to encrypt audio data"); return; } @@ -229,20 +229,20 @@ bool MqttProtocol::OpenAudioChannel() { ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1); } - std::vector decrypted; size_t decrypted_size = data.size() - aes_nonce_.size(); size_t nc_off = 0; uint8_t stream_block[16] = {0}; - decrypted.resize(decrypted_size); auto nonce = (uint8_t*)data.data(); auto encrypted = (uint8_t*)data.data() + aes_nonce_.size(); - int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)decrypted.data()); + AudioStreamPacket packet; + packet.payload.resize(decrypted_size); + int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)packet.payload.data()); if (ret != 0) { ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret); return; } if (on_incoming_audio_ != nullptr) { - on_incoming_audio_(std::move(decrypted)); + on_incoming_audio_(std::move(packet)); } remote_sequence_ = sequence; last_incoming_time_ = std::chrono::steady_clock::now(); diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index e531b00a..f8e9fc00 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -26,7 +26,7 @@ public: ~MqttProtocol(); bool Start() override; - void SendAudio(const std::vector& data) override; + void SendAudio(const AudioStreamPacket& packet) override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override; diff --git a/main/protocols/protocol.cc b/main/protocols/protocol.cc index b89bed6f..a9515494 100644 --- a/main/protocols/protocol.cc +++ b/main/protocols/protocol.cc @@ -8,7 +8,7 @@ void Protocol::OnIncomingJson(std::function callback) { on_incoming_json_ = callback; } -void Protocol::OnIncomingAudio(std::function&& data)> callback) { +void Protocol::OnIncomingAudio(std::function callback) { on_incoming_audio_ = callback; } diff --git a/main/protocols/protocol.h b/main/protocols/protocol.h index 3a377f56..7f9f541b 100644 --- a/main/protocols/protocol.h +++ b/main/protocols/protocol.h @@ -8,7 +8,7 @@ #include struct AudioStreamPacket { - uint32_t timestamp; + uint32_t timestamp = 0; std::vector payload; }; @@ -53,7 +53,7 @@ public: return session_id_; } - void OnIncomingAudio(std::function&& data)> callback); + void OnIncomingAudio(std::function callback); void OnIncomingJson(std::function callback); void OnAudioChannelOpened(std::function callback); void OnAudioChannelClosed(std::function callback); @@ -64,7 +64,7 @@ public: virtual void CloseAudioChannel() = 0; virtual bool IsAudioChannelOpened() const = 0; virtual bool IsAudioChannelBusy() const; - virtual void SendAudio(const std::vector& data) = 0; + virtual void SendAudio(const AudioStreamPacket& packet) = 0; virtual void SendWakeWordDetected(const std::string& wake_word); virtual void SendStartListening(ListeningMode mode); virtual void SendStopListening(); @@ -74,7 +74,7 @@ public: protected: std::function on_incoming_json_; - std::function&& data)> on_incoming_audio_; + std::function on_incoming_audio_; std::function on_audio_channel_opened_; std::function on_audio_channel_closed_; std::function on_network_error_; diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 4e48d311..41bfbaa9 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -28,40 +28,40 @@ bool WebsocketProtocol::Start() { return true; } -void WebsocketProtocol::SendAudio(const std::vector& data) { +void WebsocketProtocol::SendAudio(const AudioStreamPacket& packet) { if (websocket_ == nullptr) { return; } if (version_ == 2) { - std::string packet; - packet.resize(sizeof(BinaryProtocol2) + data.size()); - auto bp2 = (BinaryProtocol2*)packet.data(); + std::string serialized; + serialized.resize(sizeof(BinaryProtocol2) + packet.payload.size()); + auto bp2 = (BinaryProtocol2*)serialized.data(); bp2->version = htons(version_); bp2->type = 0; bp2->reserved = 0; - bp2->timestamp = htonl(0); - bp2->payload_size = htonl(data.size()); - memcpy(bp2->payload, data.data(), data.size()); + bp2->timestamp = htonl(packet.timestamp); + bp2->payload_size = htonl(packet.payload.size()); + memcpy(bp2->payload, packet.payload.data(), packet.payload.size()); busy_sending_audio_ = true; - websocket_->Send(packet.data(), packet.size(), true); + websocket_->Send(serialized.data(), serialized.size(), true); busy_sending_audio_ = false; } else if (version_ == 3) { - std::string packet; - packet.resize(sizeof(BinaryProtocol3) + data.size()); - auto bp3 = (BinaryProtocol3*)packet.data(); + std::string serialized; + serialized.resize(sizeof(BinaryProtocol3) + packet.payload.size()); + auto bp3 = (BinaryProtocol3*)serialized.data(); bp3->type = 0; bp3->reserved = 0; - bp3->payload_size = htons(data.size()); - memcpy(bp3->payload, data.data(), data.size()); + bp3->payload_size = htons(packet.payload.size()); + memcpy(bp3->payload, packet.payload.data(), packet.payload.size()); busy_sending_audio_ = true; - websocket_->Send(packet.data(), packet.size(), true); + websocket_->Send(serialized.data(), serialized.size(), true); busy_sending_audio_ = false; } else { busy_sending_audio_ = true; - websocket_->Send(data.data(), data.size(), true); + websocket_->Send(packet.payload.data(), packet.payload.size(), true); busy_sending_audio_ = false; } } @@ -130,15 +130,24 @@ bool WebsocketProtocol::OpenAudioChannel() { bp2->timestamp = ntohl(bp2->timestamp); bp2->payload_size = ntohl(bp2->payload_size); auto payload = (uint8_t*)bp2->payload; - on_incoming_audio_(std::vector(payload, payload + bp2->payload_size)); + on_incoming_audio_(AudioStreamPacket{ + .timestamp = bp2->timestamp, + .payload = std::vector(payload, payload + bp2->payload_size) + }); } else if (version_ == 3) { BinaryProtocol3* bp3 = (BinaryProtocol3*)data; bp3->type = bp3->type; bp3->payload_size = ntohs(bp3->payload_size); auto payload = (uint8_t*)bp3->payload; - on_incoming_audio_(std::vector(payload, payload + bp3->payload_size)); + on_incoming_audio_(AudioStreamPacket{ + .timestamp = 0, + .payload = std::vector(payload, payload + bp3->payload_size) + }); } else { - on_incoming_audio_(std::vector((uint8_t*)data, (uint8_t*)data + len)); + on_incoming_audio_(AudioStreamPacket{ + .timestamp = 0, + .payload = std::vector((uint8_t*)data, (uint8_t*)data + len) + }); } } } else { diff --git a/main/protocols/websocket_protocol.h b/main/protocols/websocket_protocol.h index db998cc1..73ae693e 100644 --- a/main/protocols/websocket_protocol.h +++ b/main/protocols/websocket_protocol.h @@ -16,7 +16,7 @@ public: ~WebsocketProtocol(); bool Start() override; - void SendAudio(const std::vector& data) override; + void SendAudio(const AudioStreamPacket& packet) override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override;