diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f745a60..0d30752b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ # CMakeLists in this exact order for cmake to work correctly cmake_minimum_required(VERSION 3.16) -set(PROJECT_VER "0.8.2") +set(PROJECT_VER "0.8.3") include($ENV{IDF_PATH}/tools/cmake/project.cmake) project(xiaozhi) diff --git a/main/CMakeLists.txt b/main/CMakeLists.txt index b8311ac2..f8532e2a 100755 --- a/main/CMakeLists.txt +++ b/main/CMakeLists.txt @@ -8,6 +8,7 @@ set(SOURCES "audio_codec.cc" "board.cc" "protocol.cc" "protocols/mqtt_protocol.cc" + "protocols/websocket_protocol.cc" "system_info.cc" "application.cc" "button.cc" diff --git a/main/Kconfig.projbuild b/main/Kconfig.projbuild index 25e05e7d..ff548c98 100644 --- a/main/Kconfig.projbuild +++ b/main/Kconfig.projbuild @@ -6,13 +6,26 @@ config OTA_VERSION_URL help The application will access this URL to check for updates. +choice CONNECTION_TYPE + prompt "Connection Type" + default CONNECTION_TYPE_MQTT_UDP + help + 网络数据传输协议 + config CONNECTION_TYPE_MQTT_UDP + bool "MQTT + UDP" + config CONNECTION_TYPE_WEBSOCKET + bool "Websocket" +endchoice + config WEBSOCKET_URL + depends on CONNECTION_TYPE_WEBSOCKET string "Websocket URL" default "wss://api.tenclass.net/xiaozhi/v1/" help Communication with the server through websocket after wake up. config WEBSOCKET_ACCESS_TOKEN + depends on CONNECTION_TYPE_WEBSOCKET string "Websocket Access Token" default "test-token" help diff --git a/main/application.cc b/main/application.cc index 3232f3db..6f8674ed 100644 --- a/main/application.cc +++ b/main/application.cc @@ -3,6 +3,7 @@ #include "ml307_ssl_transport.h" #include "audio_codec.h" #include "protocols/mqtt_protocol.h" +#include "protocols/websocket_protocol.h" #include #include @@ -268,19 +269,23 @@ void Application::Start() { #endif // Initialize the protocol - display->SetText("Starting\nProtocol..."); + display->SetText("Starting protocol..."); +#ifdef CONFIG_CONNECTION_TYPE_WEBSOCKET + protocol_ = new WebsocketProtocol(); +#else protocol_ = new MqttProtocol(); +#endif protocol_->OnIncomingAudio([this](const std::string& data) { std::lock_guard lock(mutex_); audio_decode_queue_.emplace_back(std::move(data)); cv_.notify_all(); }); protocol_->OnAudioChannelOpened([this, codec, &board]() { - if (protocol_->GetServerSampleRate() != codec->output_sample_rate()) { + if (protocol_->server_sample_rate() != codec->output_sample_rate()) { ESP_LOGW(TAG, "服务器的音频采样率 %d 与设备输出的采样率 %d 不一致,重采样后可能会失真", - protocol_->GetServerSampleRate(), codec->output_sample_rate()); + protocol_->server_sample_rate(), codec->output_sample_rate()); } - SetDecodeSampleRate(protocol_->GetServerSampleRate()); + SetDecodeSampleRate(protocol_->server_sample_rate()); board.SetPowerSaveMode(false); }); protocol_->OnAudioChannelClosed([this, &board]() { diff --git a/main/application.h b/main/application.h index f347d601..fe774eca 100644 --- a/main/application.h +++ b/main/application.h @@ -22,12 +22,6 @@ #include "audio_processor.h" #endif -struct BinaryProtocol3 { - uint8_t type; - uint8_t reserved; - uint16_t payload_size; - uint8_t payload[]; -} __attribute__((packed)); enum ChatState { kChatStateUnknown, diff --git a/main/boards/common/wifi_board.cc b/main/boards/common/wifi_board.cc index b3566511..3115b053 100644 --- a/main/boards/common/wifi_board.cc +++ b/main/boards/common/wifi_board.cc @@ -66,12 +66,15 @@ Http* WifiBoard::CreateHttp() { } WebSocket* WifiBoard::CreateWebSocket() { +#ifdef CONFIG_CONNECTION_TYPE_WEBSOCKET std::string url = CONFIG_WEBSOCKET_URL; if (url.find("wss://") == 0) { - return new WebSocket(new TlsTransport()); - } else { - return new WebSocket(new TcpTransport()); - } + return new WebSocket(new TlsTransport()); + } else { + return new WebSocket(new TcpTransport()); + } +#endif + return nullptr; } Mqtt* WifiBoard::CreateMqtt() { diff --git a/main/protocol.cc b/main/protocol.cc index e69de29b..aeff91fd 100644 --- a/main/protocol.cc +++ b/main/protocol.cc @@ -0,0 +1,21 @@ +#include "protocol.h" + +#include + +#define TAG "Protocol" + +void Protocol::OnIncomingJson(std::function callback) { + on_incoming_json_ = callback; +} + +void Protocol::OnIncomingAudio(std::function callback) { + on_incoming_audio_ = callback; +} + +void Protocol::OnAudioChannelOpened(std::function callback) { + on_audio_channel_opened_ = callback; +} + +void Protocol::OnAudioChannelClosed(std::function callback) { + on_audio_channel_closed_ = callback; +} diff --git a/main/protocol.h b/main/protocol.h index 28b3b376..063405f3 100644 --- a/main/protocol.h +++ b/main/protocol.h @@ -5,22 +5,42 @@ #include #include +struct BinaryProtocol3 { + uint8_t type; + uint8_t reserved; + uint16_t payload_size; + uint8_t payload[]; +} __attribute__((packed)); + + class Protocol { public: virtual ~Protocol() = default; - virtual void OnIncomingAudio(std::function callback) = 0; - virtual void OnIncomingJson(std::function callback) = 0; + inline int server_sample_rate() const { + return server_sample_rate_; + } + + void OnIncomingAudio(std::function callback); + void OnIncomingJson(std::function callback); + void OnAudioChannelOpened(std::function callback); + void OnAudioChannelClosed(std::function callback); + virtual void SendAudio(const std::string& data) = 0; virtual void SendText(const std::string& text) = 0; virtual void SendState(const std::string& state) = 0; virtual void SendAbort() = 0; virtual bool OpenAudioChannel() = 0; virtual void CloseAudioChannel() = 0; - virtual void OnAudioChannelOpened(std::function callback) = 0; - virtual void OnAudioChannelClosed(std::function callback) = 0; virtual bool IsAudioChannelOpened() const = 0; - virtual int GetServerSampleRate() const = 0; + +protected: + std::function on_incoming_json_; + std::function on_incoming_audio_; + std::function on_audio_channel_opened_; + std::function on_audio_channel_closed_; + + int server_sample_rate_ = 16000; }; #endif // PROTOCOL_H diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index 22a92152..c17880a7 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -240,22 +240,6 @@ bool MqttProtocol::OpenAudioChannel() { return true; } -void MqttProtocol::OnIncomingJson(std::function callback) { - on_incoming_json_ = callback; -} - -void MqttProtocol::OnIncomingAudio(std::function callback) { - on_incoming_audio_ = callback; -} - -void MqttProtocol::OnAudioChannelOpened(std::function callback) { - on_audio_channel_opened_ = callback; -} - -void MqttProtocol::OnAudioChannelClosed(std::function callback) { - on_audio_channel_closed_ = callback; -} - void MqttProtocol::ParseServerHello(const cJSON* root) { auto transport = cJSON_GetObjectItem(root, "transport"); if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) { @@ -297,11 +281,6 @@ void MqttProtocol::ParseServerHello(const cJSON* root) { xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT); } -int MqttProtocol::GetServerSampleRate() const { - return server_sample_rate_; -} - - static const char hex_chars[] = "0123456789ABCDEF"; // 辅助函数,将单个十六进制字符转换为对应的数值 static inline uint8_t CharToHex(char c) { diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index a5de6251..6cc6be89 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -14,6 +14,7 @@ #include #include #include + #define MQTT_PING_INTERVAL_SECONDS 90 #define MQTT_RECONNECT_INTERVAL_MS 10000 @@ -24,27 +25,17 @@ public: MqttProtocol(); ~MqttProtocol(); - void OnIncomingAudio(std::function callback); - void OnIncomingJson(std::function callback); - void SendAudio(const std::string& data); - void SendText(const std::string& text); - void SendState(const std::string& state); - void SendAbort(); - bool OpenAudioChannel(); - void CloseAudioChannel(); - void OnAudioChannelOpened(std::function callback); - void OnAudioChannelClosed(std::function callback); - bool IsAudioChannelOpened() const; - int GetServerSampleRate() const; + void SendAudio(const std::string& data) override; + void SendText(const std::string& text) override; + void SendState(const std::string& state) override; + void SendAbort() override; + bool OpenAudioChannel() override; + void CloseAudioChannel() override; + bool IsAudioChannelOpened() const override; private: EventGroupHandle_t event_group_handle_; - std::function on_incoming_json_; - std::function on_incoming_audio_; - std::function on_audio_channel_opened_; - std::function on_audio_channel_closed_; - std::string endpoint_; std::string client_id_; std::string username_; @@ -62,7 +53,6 @@ private: uint32_t local_sequence_; uint32_t remote_sequence_; std::string session_id_; - int server_sample_rate_ = 16000; bool StartMqttClient(); void ParseServerHello(const cJSON* root); diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc new file mode 100644 index 00000000..70e433d6 --- /dev/null +++ b/main/protocols/websocket_protocol.cc @@ -0,0 +1,159 @@ +#include "websocket_protocol.h" +#include "board.h" +#include "system_info.h" +#include "application.h" + +#include +#include +#include +#include + +#define TAG "WS" + +#ifdef CONFIG_CONNECTION_TYPE_WEBSOCKET + +WebsocketProtocol::WebsocketProtocol() { + event_group_handle_ = xEventGroupCreate(); +} + +WebsocketProtocol::~WebsocketProtocol() { + if (websocket_ != nullptr) { + delete websocket_; + } + vEventGroupDelete(event_group_handle_); +} + +void WebsocketProtocol::SendAudio(const std::string& data) { + if (websocket_ == nullptr) { + return; + } + + websocket_->Send(data.data(), data.size(), true); +} + +void WebsocketProtocol::SendText(const std::string& text) { + if (websocket_ == nullptr) { + return; + } + + websocket_->Send(text); +} + +void WebsocketProtocol::SendState(const std::string& state) { + std::string message = "{"; + message += "\"type\":\"state\","; + message += "\"state\":\"" + state + "\""; + message += "}"; + SendText(message); +} + +void WebsocketProtocol::SendAbort() { + std::string message = "{"; + message += "\"type\":\"abort\""; + message += "}"; + SendText(message); +} + +bool WebsocketProtocol::IsAudioChannelOpened() const { + return websocket_ != nullptr; +} + +void WebsocketProtocol::CloseAudioChannel() { + if (websocket_ != nullptr) { + delete websocket_; + websocket_ = nullptr; + } +} + +bool WebsocketProtocol::OpenAudioChannel() { + if (websocket_ != nullptr) { + delete websocket_; + } + + std::string url = CONFIG_WEBSOCKET_URL; + std::string token = "Bearer " + std::string(CONFIG_WEBSOCKET_ACCESS_TOKEN); + websocket_ = Board::GetInstance().CreateWebSocket(); + websocket_->SetHeader("Authorization", token.c_str()); + websocket_->SetHeader("Protocol-Version", "1"); + websocket_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str()); + + websocket_->OnData([this](const char* data, size_t len, bool binary) { + if (binary) { + if (on_incoming_audio_ != nullptr) { + on_incoming_audio_(std::string(data, len)); + } + } else { + // Parse JSON data + auto root = cJSON_Parse(data); + auto type = cJSON_GetObjectItem(root, "type"); + if (type != NULL) { + if (strcmp(type->valuestring, "hello") == 0) { + ParseServerHello(root); + } else { + if (on_incoming_json_ != nullptr) { + on_incoming_json_(root); + } + } + } else { + ESP_LOGE(TAG, "Missing message type, data: %s", data); + } + cJSON_Delete(root); + } + }); + + websocket_->OnDisconnected([this]() { + ESP_LOGI(TAG, "Websocket disconnected"); + if (on_audio_channel_closed_ != nullptr) { + on_audio_channel_closed_(); + } + }); + + if (!websocket_->Connect(url.c_str())) { + ESP_LOGE(TAG, "Failed to connect to websocket server"); + return false; + } + + // Send hello message to describe the client + // keys: message type, version, audio_params (format, sample_rate, channels) + std::string message = "{"; + message += "\"type\":\"hello\","; + message += "\"version\": 1,"; + message += "\"transport\":\"websocket\","; + message += "\"audio_params\":{"; + message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS); + message += "}}"; + websocket_->Send(message); + + // Wait for server hello + EventBits_t bits = xEventGroupWaitBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000)); + if (!(bits & WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT)) { + ESP_LOGE(TAG, "Failed to receive server hello"); + return false; + } + + if (on_audio_channel_opened_ != nullptr) { + on_audio_channel_opened_(); + } + + return true; +} + +void WebsocketProtocol::ParseServerHello(const cJSON* root) { + auto transport = cJSON_GetObjectItem(root, "transport"); + if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) { + ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring); + return; + } + + auto audio_params = cJSON_GetObjectItem(root, "audio_params"); + if (audio_params != NULL) { + auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate"); + if (sample_rate != NULL) { + server_sample_rate_ = sample_rate->valueint; + } + } + + xEventGroupSetBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT); +} + +#endif diff --git a/main/protocols/websocket_protocol.h b/main/protocols/websocket_protocol.h new file mode 100644 index 00000000..b4bd7670 --- /dev/null +++ b/main/protocols/websocket_protocol.h @@ -0,0 +1,33 @@ +#ifndef _WEBSOCKET_PROTOCOL_H_ +#define _WEBSOCKET_PROTOCOL_H_ + + +#include "protocol.h" + +#include +#include +#include + +#define WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT (1 << 0) + +class WebsocketProtocol : public Protocol { +public: + WebsocketProtocol(); + ~WebsocketProtocol(); + + void SendAudio(const std::string& data) override; + void SendText(const std::string& text) override; + void SendState(const std::string& state) override; + void SendAbort() override; + bool OpenAudioChannel() override; + void CloseAudioChannel() override; + bool IsAudioChannelOpened() const override; + +private: + EventGroupHandle_t event_group_handle_; + WebSocket* websocket_ = nullptr; + + void ParseServerHello(const cJSON* root); +}; + +#endif