add websocket protocol

This commit is contained in:
Terrence
2024-11-16 05:49:35 +08:00
parent cabb29a1bb
commit 794e6f4bef
12 changed files with 277 additions and 59 deletions

View File

@@ -240,22 +240,6 @@ bool MqttProtocol::OpenAudioChannel() {
return true;
}
void MqttProtocol::OnIncomingJson(std::function<void(const cJSON* root)> callback) {
on_incoming_json_ = callback;
}
void MqttProtocol::OnIncomingAudio(std::function<void(const std::string& data)> callback) {
on_incoming_audio_ = callback;
}
void MqttProtocol::OnAudioChannelOpened(std::function<void()> callback) {
on_audio_channel_opened_ = callback;
}
void MqttProtocol::OnAudioChannelClosed(std::function<void()> callback) {
on_audio_channel_closed_ = callback;
}
void MqttProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
@@ -297,11 +281,6 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
}
int MqttProtocol::GetServerSampleRate() const {
return server_sample_rate_;
}
static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值
static inline uint8_t CharToHex(char c) {

View File

@@ -14,6 +14,7 @@
#include <string>
#include <map>
#include <mutex>
#define MQTT_PING_INTERVAL_SECONDS 90
#define MQTT_RECONNECT_INTERVAL_MS 10000
@@ -24,27 +25,17 @@ public:
MqttProtocol();
~MqttProtocol();
void OnIncomingAudio(std::function<void(const std::string& data)> callback);
void OnIncomingJson(std::function<void(const cJSON* root)> callback);
void SendAudio(const std::string& data);
void SendText(const std::string& text);
void SendState(const std::string& state);
void SendAbort();
bool OpenAudioChannel();
void CloseAudioChannel();
void OnAudioChannelOpened(std::function<void()> callback);
void OnAudioChannelClosed(std::function<void()> callback);
bool IsAudioChannelOpened() const;
int GetServerSampleRate() const;
void SendAudio(const std::string& data) override;
void SendText(const std::string& text) override;
void SendState(const std::string& state) override;
void SendAbort() override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
private:
EventGroupHandle_t event_group_handle_;
std::function<void(const cJSON* root)> on_incoming_json_;
std::function<void(const std::string& data)> on_incoming_audio_;
std::function<void()> on_audio_channel_opened_;
std::function<void()> on_audio_channel_closed_;
std::string endpoint_;
std::string client_id_;
std::string username_;
@@ -62,7 +53,6 @@ private:
uint32_t local_sequence_;
uint32_t remote_sequence_;
std::string session_id_;
int server_sample_rate_ = 16000;
bool StartMqttClient();
void ParseServerHello(const cJSON* root);

View File

@@ -0,0 +1,159 @@
#include "websocket_protocol.h"
#include "board.h"
#include "system_info.h"
#include "application.h"
#include <cstring>
#include <cJSON.h>
#include <esp_log.h>
#include <arpa/inet.h>
#define TAG "WS"
#ifdef CONFIG_CONNECTION_TYPE_WEBSOCKET
WebsocketProtocol::WebsocketProtocol() {
event_group_handle_ = xEventGroupCreate();
}
WebsocketProtocol::~WebsocketProtocol() {
if (websocket_ != nullptr) {
delete websocket_;
}
vEventGroupDelete(event_group_handle_);
}
void WebsocketProtocol::SendAudio(const std::string& data) {
if (websocket_ == nullptr) {
return;
}
websocket_->Send(data.data(), data.size(), true);
}
void WebsocketProtocol::SendText(const std::string& text) {
if (websocket_ == nullptr) {
return;
}
websocket_->Send(text);
}
void WebsocketProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}
void WebsocketProtocol::SendAbort() {
std::string message = "{";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}
bool WebsocketProtocol::IsAudioChannelOpened() const {
return websocket_ != nullptr;
}
void WebsocketProtocol::CloseAudioChannel() {
if (websocket_ != nullptr) {
delete websocket_;
websocket_ = nullptr;
}
}
bool WebsocketProtocol::OpenAudioChannel() {
if (websocket_ != nullptr) {
delete websocket_;
}
std::string url = CONFIG_WEBSOCKET_URL;
std::string token = "Bearer " + std::string(CONFIG_WEBSOCKET_ACCESS_TOKEN);
websocket_ = Board::GetInstance().CreateWebSocket();
websocket_->SetHeader("Authorization", token.c_str());
websocket_->SetHeader("Protocol-Version", "1");
websocket_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
websocket_->OnData([this](const char* data, size_t len, bool binary) {
if (binary) {
if (on_incoming_audio_ != nullptr) {
on_incoming_audio_(std::string(data, len));
}
} else {
// Parse JSON data
auto root = cJSON_Parse(data);
auto type = cJSON_GetObjectItem(root, "type");
if (type != NULL) {
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else {
if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
}
} else {
ESP_LOGE(TAG, "Missing message type, data: %s", data);
}
cJSON_Delete(root);
}
});
websocket_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Websocket disconnected");
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
});
if (!websocket_->Connect(url.c_str())) {
ESP_LOGE(TAG, "Failed to connect to websocket server");
return false;
}
// Send hello message to describe the client
// keys: message type, version, audio_params (format, sample_rate, channels)
std::string message = "{";
message += "\"type\":\"hello\",";
message += "\"version\": 1,";
message += "\"transport\":\"websocket\",";
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
message += "}}";
websocket_->Send(message);
// Wait for server hello
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
return false;
}
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
void WebsocketProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
server_sample_rate_ = sample_rate->valueint;
}
}
xEventGroupSetBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT);
}
#endif

View File

@@ -0,0 +1,33 @@
#ifndef _WEBSOCKET_PROTOCOL_H_
#define _WEBSOCKET_PROTOCOL_H_
#include "protocol.h"
#include <web_socket.h>
#include <freertos/FreeRTOS.h>
#include <freertos/event_groups.h>
#define WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT (1 << 0)
class WebsocketProtocol : public Protocol {
public:
WebsocketProtocol();
~WebsocketProtocol();
void SendAudio(const std::string& data) override;
void SendText(const std::string& text) override;
void SendState(const std::string& state) override;
void SendAbort() override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
private:
EventGroupHandle_t event_group_handle_;
WebSocket* websocket_ = nullptr;
void ParseServerHello(const cJSON* root);
};
#endif