Files
xiaozhi-esp32/main/protocols/mqtt_protocol.cc

306 lines
9.7 KiB
C++
Raw Normal View History

2024-11-14 23:15:43 +08:00
#include "mqtt_protocol.h"
#include "board.h"
2024-11-15 04:44:53 +08:00
#include "application.h"
#include "settings.h"
2024-11-14 23:15:43 +08:00
#include <esp_log.h>
#include <ml307_mqtt.h>
#include <ml307_udp.h>
#include <cstring>
#include <arpa/inet.h>
#define TAG "MQTT"
2024-11-15 04:44:53 +08:00
MqttProtocol::MqttProtocol() {
2024-11-14 23:15:43 +08:00
event_group_handle_ = xEventGroupCreate();
StartMqttClient();
}
MqttProtocol::~MqttProtocol() {
ESP_LOGI(TAG, "MqttProtocol deinit");
if (udp_ != nullptr) {
delete udp_;
}
if (mqtt_ != nullptr) {
delete mqtt_;
}
vEventGroupDelete(event_group_handle_);
}
bool MqttProtocol::StartMqttClient() {
if (mqtt_ != nullptr) {
ESP_LOGW(TAG, "Mqtt client already started");
delete mqtt_;
}
2024-11-15 04:44:53 +08:00
Settings settings("mqtt", false);
endpoint_ = settings.GetString("endpoint");
client_id_ = settings.GetString("client_id");
username_ = settings.GetString("username");
password_ = settings.GetString("password");
subscribe_topic_ = settings.GetString("subscribe_topic");
publish_topic_ = settings.GetString("publish_topic");
if (endpoint_.empty()) {
ESP_LOGE(TAG, "MQTT endpoint is not specified");
return false;
}
2024-11-14 23:15:43 +08:00
mqtt_ = Board::GetInstance().CreateMqtt();
mqtt_->SetKeepAlive(90);
mqtt_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Disconnected from endpoint");
});
mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
cJSON* root = cJSON_Parse(payload.c_str());
if (root == nullptr) {
ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
return;
}
cJSON* type = cJSON_GetObjectItem(root, "type");
if (type == nullptr) {
ESP_LOGE(TAG, "Message type is not specified");
cJSON_Delete(root);
return;
}
2024-11-15 04:44:53 +08:00
2024-11-14 23:15:43 +08:00
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else if (strcmp(type->valuestring, "goodbye") == 0) {
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (session_id == nullptr || session_id_ == session_id->valuestring) {
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
}
2024-11-15 04:44:53 +08:00
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
2024-11-14 23:15:43 +08:00
}
cJSON_Delete(root);
});
ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint_.c_str());
if (!mqtt_->Connect(endpoint_, 8883, client_id_, username_, password_)) {
ESP_LOGE(TAG, "Failed to connect to endpoint");
return false;
}
ESP_LOGI(TAG, "Connected to endpoint");
if (!subscribe_topic_.empty()) {
mqtt_->Subscribe(subscribe_topic_, 2);
}
return true;
}
void MqttProtocol::SendText(const std::string& text) {
if (publish_topic_.empty()) {
return;
}
mqtt_->Publish(publish_topic_, text);
}
void MqttProtocol::SendAudio(const std::string& data) {
2024-11-15 04:44:53 +08:00
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return;
}
2024-11-14 23:15:43 +08:00
std::string nonce(aes_nonce_);
*(uint16_t*)&nonce[2] = htons(data.size());
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
std::string encrypted;
encrypted.resize(aes_nonce_.size() + data.size());
memcpy(encrypted.data(), nonce.data(), nonce.size());
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
if (mbedtls_aes_crypt_ctr(&aes_ctx_, data.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
(uint8_t*)data.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
ESP_LOGE(TAG, "Failed to encrypt audio data");
return;
}
udp_->Send(encrypted);
}
2024-11-15 04:44:53 +08:00
void MqttProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}
void MqttProtocol::SendAbort() {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}
2024-11-14 23:15:43 +08:00
void MqttProtocol::CloseAudioChannel() {
{
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ != nullptr) {
delete udp_;
udp_ = nullptr;
}
}
std::string message = "{";
2024-11-15 04:44:53 +08:00
message += "\"session_id\":\"" + session_id_ + "\",";
2024-11-14 23:15:43 +08:00
message += "\"type\":\"goodbye\"";
message += "}";
SendText(message);
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
}
bool MqttProtocol::OpenAudioChannel() {
2024-11-15 04:44:53 +08:00
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
2024-11-14 23:15:43 +08:00
if (!StartMqttClient()) {
ESP_LOGE(TAG, "Failed to connect to MQTT");
return false;
}
}
session_id_ = "";
// 发送 hello 消息申请 UDP 通道
std::string message = "{";
message += "\"type\":\"hello\",";
message += "\"version\": 3,";
message += "\"transport\":\"udp\",";
message += "\"audio_params\":{";
2024-11-15 04:44:53 +08:00
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
2024-11-14 23:15:43 +08:00
message += "}}";
SendText(message);
// 等待服务器响应
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
return false;
}
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ != nullptr) {
delete udp_;
}
udp_ = Board::GetInstance().CreateUdp();
udp_->OnMessage([this](const std::string& data) {
if (data.size() < sizeof(aes_nonce_)) {
ESP_LOGE(TAG, "Invalid audio packet size: %zu", data.size());
return;
}
if (data[0] != 0x01) {
ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
return;
}
uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
if (sequence < remote_sequence_) {
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
return;
}
2024-11-15 04:44:53 +08:00
if (sequence != remote_sequence_ + 1) {
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
}
2024-11-14 23:15:43 +08:00
std::string decrypted;
size_t decrypted_size = data.size() - aes_nonce_.size();
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
decrypted.resize(decrypted_size);
auto nonce = (uint8_t*)data.data();
auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)decrypted.data());
if (ret != 0) {
ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
return;
}
if (on_incoming_audio_ != nullptr) {
on_incoming_audio_(decrypted);
}
remote_sequence_ = sequence;
});
udp_->Connect(udp_server_, udp_port_);
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
void MqttProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (session_id != nullptr) {
session_id_ = session_id->valuestring;
}
2024-11-15 04:44:53 +08:00
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
server_sample_rate_ = sample_rate->valueint;
}
}
2024-11-14 23:15:43 +08:00
auto udp = cJSON_GetObjectItem(root, "udp");
if (udp == nullptr) {
ESP_LOGE(TAG, "UDP is not specified");
return;
}
udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
// auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
// ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
aes_nonce_ = DecodeHexString(nonce);
mbedtls_aes_init(&aes_ctx_);
mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
local_sequence_ = 0;
remote_sequence_ = 0;
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
}
static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值
static inline uint8_t CharToHex(char c) {
if (c >= '0' && c <= '9') return c - '0';
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
return 0; // 对于无效输入返回0
}
std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
std::string decoded;
decoded.reserve(hex_string.size() / 2);
for (size_t i = 0; i < hex_string.size(); i += 2) {
char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
decoded.push_back(byte);
}
return decoded;
}
2024-11-15 04:44:53 +08:00
bool MqttProtocol::IsAudioChannelOpened() const {
return udp_ != nullptr;
}