Add MCP server

This commit is contained in:
Terrence
2025-05-22 19:19:36 +08:00
parent f142c5469c
commit 5da66773d5
23 changed files with 845 additions and 41 deletions

View File

@@ -15,6 +15,7 @@ set(SOURCES "audio_codecs/audio_codec.cc"
"protocols/websocket_protocol.cc"
"iot/thing.cc"
"iot/thing_manager.cc"
"mcp_server.cc"
"system_info.cc"
"application.cc"
"ota.cc"

View File

@@ -8,7 +8,7 @@ config OTA_URL
choice
prompt "语言选择"
prompt "Default Language"
default LANGUAGE_ZH_CN
help
Select device display language
@@ -249,37 +249,48 @@ choice DISPLAY_ESP32S3_KORVO2_V3
endchoice
config USE_WECHAT_MESSAGE_STYLE
bool "使用微信聊天界面风格"
bool "Enable WeChat Message Style"
default n
help
使用微信聊天界面风格
config USE_WAKE_WORD_DETECT
bool "启用唤醒词检测"
bool "Enable Wake Word Detection"
default y
depends on IDF_TARGET_ESP32S3 || IDF_TARGET_ESP32P4 && SPIRAM
help
需要 ESP32 S3 与 AFE 支持
config USE_AUDIO_PROCESSOR
bool "启用音频降噪、增益处理"
bool "Enable Audio Noise Reduction"
default y
depends on IDF_TARGET_ESP32S3 || IDF_TARGET_ESP32P4 && SPIRAM
help
需要 ESP32 S3 与 AFE 支持
config USE_DEVICE_AEC
bool "在通话过程中启用设备端 AEC"
bool "Enable Device-Side AEC"
default n
depends on USE_AUDIO_PROCESSOR && (BOARD_TYPE_ESP_BOX_3 || BOARD_TYPE_ESP_BOX || BOARD_TYPE_ESP_BOX_LITE || BOARD_TYPE_LICHUANG_DEV || BOARD_TYPE_ESP32S3_KORVO2_V3 || BOARD_TYPE_ESP32S3_Touch_AMOLED_1_75)
help
因为性能不够,不建议和微信聊天界面风格同时开启
config USE_SERVER_AEC
bool "在通话过程中启用服务器端 AEC"
bool "Enable Server-Side AEC"
default n
depends on USE_AUDIO_PROCESSOR
help
启用服务器端 AEC需要服务器支持
choice IOT_PROTOCOL
prompt "IoT Protocol"
default IOT_PROTOCOL_XIAOZHI
help
IoT 协议,用于获取设备状态与发送控制指令
config IOT_PROTOCOL_MCP
bool "MCP协议 2024-11-05"
config IOT_PROTOCOL_XIAOZHI
bool "小智IoT协议 1.0"
endchoice
endmenu

View File

@@ -9,6 +9,7 @@
#include "font_awesome_symbols.h"
#include "iot/thing_manager.h"
#include "assets/lang_config.h"
#include "mcp_server.h"
#if CONFIG_USE_AUDIO_PROCESSOR
#include "afe_audio_processor.h"
@@ -41,7 +42,7 @@ static const char* const STATE_STRINGS[] = {
Application::Application() {
event_group_ = xEventGroupCreate();
background_task_ = new BackgroundTask(4096 * 8);
background_task_ = new BackgroundTask(4096 * 7);
#if CONFIG_USE_AUDIO_PROCESSOR
audio_processor_ = std::make_unique<AfeAudioProcessor>();
@@ -425,12 +426,15 @@ void Application::Start() {
protocol_->server_sample_rate(), codec->output_sample_rate());
}
SetDecodeSampleRate(protocol_->server_sample_rate(), protocol_->server_frame_duration());
#if CONFIG_IOT_PROTOCOL_XIAOZHI
auto& thing_manager = iot::ThingManager::GetInstance();
protocol_->SendIotDescriptors(thing_manager.GetDescriptorsJson());
std::string states;
if (thing_manager.GetStatesJson(states, false)) {
protocol_->SendIotStates(states);
}
#endif
});
protocol_->OnAudioChannelClosed([this, &board]() {
board.SetPowerSaveMode(true);
@@ -465,7 +469,7 @@ void Application::Start() {
});
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
if (text != NULL) {
if (cJSON_IsString(text)) {
ESP_LOGI(TAG, "<< %s", text->valuestring);
Schedule([this, display, message = std::string(text->valuestring)]() {
display->SetChatMessage("assistant", message.c_str());
@@ -474,7 +478,7 @@ void Application::Start() {
}
} else if (strcmp(type->valuestring, "stt") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
if (text != NULL) {
if (cJSON_IsString(text)) {
ESP_LOGI(TAG, ">> %s", text->valuestring);
Schedule([this, display, message = std::string(text->valuestring)]() {
display->SetChatMessage("user", message.c_str());
@@ -482,23 +486,32 @@ void Application::Start() {
}
} else if (strcmp(type->valuestring, "llm") == 0) {
auto emotion = cJSON_GetObjectItem(root, "emotion");
if (emotion != NULL) {
if (cJSON_IsString(emotion)) {
Schedule([this, display, emotion_str = std::string(emotion->valuestring)]() {
display->SetEmotion(emotion_str.c_str());
});
}
#if CONFIG_IOT_PROTOCOL_MCP
} else if (strcmp(type->valuestring, "mcp") == 0) {
auto payload = cJSON_GetObjectItem(root, "payload");
if (cJSON_IsObject(payload)) {
McpServer::GetInstance().ParseMessage(payload);
}
#endif
#if CONFIG_IOT_PROTOCOL_XIAOZHI
} else if (strcmp(type->valuestring, "iot") == 0) {
auto commands = cJSON_GetObjectItem(root, "commands");
if (commands != NULL) {
if (cJSON_IsArray(commands)) {
auto& thing_manager = iot::ThingManager::GetInstance();
for (int i = 0; i < cJSON_GetArraySize(commands); ++i) {
auto command = cJSON_GetArrayItem(commands, i);
thing_manager.Invoke(command);
}
}
#endif
} else if (strcmp(type->valuestring, "system") == 0) {
auto command = cJSON_GetObjectItem(root, "command");
if (command != NULL) {
if (cJSON_IsString(command)) {
ESP_LOGI(TAG, "System command: %s", command->valuestring);
if (strcmp(command->valuestring, "reboot") == 0) {
// Do a reboot if user requests a OTA update
@@ -513,7 +526,7 @@ void Application::Start() {
auto status = cJSON_GetObjectItem(root, "status");
auto message = cJSON_GetObjectItem(root, "message");
auto emotion = cJSON_GetObjectItem(root, "emotion");
if (status != NULL && message != NULL && emotion != NULL) {
if (cJSON_IsString(status) && cJSON_IsString(message) && cJSON_IsString(emotion)) {
Alert(status->valuestring, message->valuestring, emotion->valuestring, Lang::Sounds::P3_VIBRATION);
} else {
ESP_LOGW(TAG, "Alert command requires status, message and emotion");
@@ -620,7 +633,10 @@ void Application::OnClockTimer() {
clock_ticks_++;
// Print the debug info every 10 seconds
if (clock_ticks_ % 10 == 0) {
if (clock_ticks_ % 3 == 0) {
// char buffer[500];
// vTaskList(buffer);
// ESP_LOGI(TAG, "Task list: \n%s", buffer);
// SystemInfo::PrintRealTimeStats(pdMS_TO_TICKS(1000));
int free_sram = heap_caps_get_free_size(MALLOC_CAP_INTERNAL);
@@ -850,7 +866,9 @@ void Application::SetDeviceState(DeviceState state) {
display->SetStatus(Lang::Strings::LISTENING);
display->SetEmotion("neutral");
// Update the IoT states before sending the start listening command
#if CONFIG_IOT_PROTOCOL_XIAOZHI
UpdateIotStates();
#endif
// Make sure the audio processor is running
if (!audio_processor_->IsRunning()) {
@@ -910,11 +928,13 @@ void Application::SetDecodeSampleRate(int sample_rate, int frame_duration) {
}
void Application::UpdateIotStates() {
#if CONFIG_IOT_PROTOCOL_XIAOZHI
auto& thing_manager = iot::ThingManager::GetInstance();
std::string states;
if (thing_manager.GetStatesJson(states, true)) {
protocol_->SendIotStates(states);
}
#endif
}
void Application::Reboot() {
@@ -955,3 +975,11 @@ bool Application::CanEnterSleepMode() {
// Now it is safe to enter sleep mode
return true;
}
void Application::SendMcpMessage(const std::string& payload) {
Schedule([this, payload]() {
if (protocol_) {
protocol_->SendMcpMessage(payload);
}
});
}

View File

@@ -72,6 +72,7 @@ public:
void WakeWordInvoke(const std::string& wake_word);
void PlaySound(const std::string_view& sound);
bool CanEnterSleepMode();
void SendMcpMessage(const std::string& payload);
private:
Application();

View File

@@ -94,6 +94,7 @@ private:
display_ = new NoDisplay();
return;
}
ESP_ERROR_CHECK(esp_lcd_panel_invert_color(panel_, false));
// Set the display to on
ESP_LOGI(TAG, "Turning display on");

View File

@@ -49,6 +49,7 @@ public:
virtual std::string GetJson();
virtual void SetPowerSaveMode(bool enabled) = 0;
virtual std::string GetBoardJson() = 0;
virtual std::string GetDeviceStatusJson() = 0;
};
#define DECLARE_BOARD(BOARD_CLASS_NAME) \

View File

@@ -99,3 +99,7 @@ void DualNetworkBoard::SetPowerSaveMode(bool enabled) {
std::string DualNetworkBoard::GetBoardJson() {
return current_board_->GetBoardJson();
}
std::string DualNetworkBoard::GetDeviceStatusJson() {
return current_board_->GetDeviceStatusJson();
}

View File

@@ -56,7 +56,7 @@ public:
virtual const char* GetNetworkStateIcon() override;
virtual void SetPowerSaveMode(bool enabled) override;
virtual std::string GetBoardJson() override;
virtual std::string GetDeviceStatusJson() override;
};
#endif // DUAL_NETWORK_BOARD_H

View File

@@ -120,3 +120,86 @@ std::string Ml307Board::GetBoardJson() {
void Ml307Board::SetPowerSaveMode(bool enabled) {
// TODO: Implement power save mode for ML307
}
std::string Ml307Board::GetDeviceStatusJson() {
/*
* 返回设备状态JSON
*
* 返回的JSON结构如下
* {
* "audio_speaker": {
* "volume": 70
* },
* "screen": {
* "brightness": 100,
* "theme": "light"
* },
* "battery": {
* "level": 50,
* "charging": true
* },
* "network": {
* "type": "cellular",
* "carrier": "CHINA MOBILE",
* "csq": 10
* }
* }
*/
auto& board = Board::GetInstance();
auto root = cJSON_CreateObject();
// Audio speaker
auto audio_speaker = cJSON_CreateObject();
auto audio_codec = board.GetAudioCodec();
if (audio_codec) {
cJSON_AddNumberToObject(audio_speaker, "volume", audio_codec->output_volume());
}
cJSON_AddItemToObject(root, "audio_speaker", audio_speaker);
// Screen brightness
auto backlight = board.GetBacklight();
auto screen = cJSON_CreateObject();
if (backlight) {
cJSON_AddNumberToObject(screen, "brightness", backlight->brightness());
}
auto display = board.GetDisplay();
if (display && display->height() > 64) { // For LCD display only
cJSON_AddStringToObject(screen, "theme", display->GetTheme().c_str());
}
cJSON_AddItemToObject(root, "screen", screen);
// Battery
int battery_level = 0;
bool charging = false;
bool discharging = false;
if (board.GetBatteryLevel(battery_level, charging, discharging)) {
cJSON* battery = cJSON_CreateObject();
cJSON_AddNumberToObject(battery, "level", battery_level);
cJSON_AddBoolToObject(battery, "charging", charging);
cJSON_AddItemToObject(root, "battery", battery);
}
// Network
auto network = cJSON_CreateObject();
cJSON_AddStringToObject(network, "type", "cellular");
cJSON_AddStringToObject(network, "carrier", modem_.GetCarrierName().c_str());
int csq = modem_.GetCsq();
if (csq == -1) {
cJSON_AddStringToObject(network, "signal", "unknown");
} else if (csq >= 0 && csq <= 14) {
cJSON_AddStringToObject(network, "signal", "very weak");
} else if (csq >= 15 && csq <= 19) {
cJSON_AddStringToObject(network, "signal", "weak");
} else if (csq >= 20 && csq <= 24) {
cJSON_AddStringToObject(network, "signal", "medium");
} else if (csq >= 25 && csq <= 31) {
cJSON_AddStringToObject(network, "signal", "strong");
}
cJSON_AddItemToObject(root, "network", network);
auto json_str = cJSON_PrintUnformatted(root);
std::string json(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return json;
}

View File

@@ -21,6 +21,7 @@ public:
virtual const char* GetNetworkStateIcon() override;
virtual void SetPowerSaveMode(bool enabled) override;
virtual AudioCodec* GetAudioCodec() override { return nullptr; }
virtual std::string GetDeviceStatusJson() override;
};
#endif // ML307_BOARD_H

View File

@@ -181,3 +181,83 @@ void WifiBoard::ResetWifiConfiguration() {
// Reboot the device
esp_restart();
}
std::string WifiBoard::GetDeviceStatusJson() {
/*
* 返回设备状态JSON
*
* 返回的JSON结构如下
* {
* "audio_speaker": {
* "volume": 70
* },
* "screen": {
* "brightness": 100,
* "theme": "light"
* },
* "battery": {
* "level": 50,
* "charging": true
* },
* "network": {
* "type": "wifi",
* "ssid": "Xiaozhi",
* "rssi": -60
* }
* }
*/
auto& board = Board::GetInstance();
auto root = cJSON_CreateObject();
// Audio speaker
auto audio_speaker = cJSON_CreateObject();
auto audio_codec = board.GetAudioCodec();
if (audio_codec) {
cJSON_AddNumberToObject(audio_speaker, "volume", audio_codec->output_volume());
}
cJSON_AddItemToObject(root, "audio_speaker", audio_speaker);
// Screen brightness
auto backlight = board.GetBacklight();
auto screen = cJSON_CreateObject();
if (backlight) {
cJSON_AddNumberToObject(screen, "brightness", backlight->brightness());
}
auto display = board.GetDisplay();
if (display && display->height() > 64) { // For LCD display only
cJSON_AddStringToObject(screen, "theme", display->GetTheme().c_str());
}
cJSON_AddItemToObject(root, "screen", screen);
// Battery
int battery_level = 0;
bool charging = false;
bool discharging = false;
if (board.GetBatteryLevel(battery_level, charging, discharging)) {
cJSON* battery = cJSON_CreateObject();
cJSON_AddNumberToObject(battery, "level", battery_level);
cJSON_AddBoolToObject(battery, "charging", charging);
cJSON_AddItemToObject(root, "battery", battery);
}
// Network
auto network = cJSON_CreateObject();
auto& wifi_station = WifiStation::GetInstance();
cJSON_AddStringToObject(network, "type", "wifi");
cJSON_AddStringToObject(network, "ssid", wifi_station.GetSsid().c_str());
int rssi = wifi_station.GetRssi();
if (rssi >= -60) {
cJSON_AddStringToObject(network, "signal", "strong");
} else if (rssi >= -70) {
cJSON_AddStringToObject(network, "signal", "medium");
} else {
cJSON_AddStringToObject(network, "signal", "weak");
}
cJSON_AddItemToObject(root, "network", network);
auto json_str = cJSON_PrintUnformatted(root);
std::string json(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return json;
}

View File

@@ -21,6 +21,7 @@ public:
virtual void SetPowerSaveMode(bool enabled) override;
virtual void ResetWifiConfiguration();
virtual AudioCodec* GetAudioCodec() override { return nullptr; }
virtual std::string GetDeviceStatusJson() override;
};
#endif // WIFI_BOARD_H

View File

@@ -1,6 +1,7 @@
#include "lcd_display.h"
#include <vector>
#include <algorithm>
#include <font_awesome_symbols.h>
#include <esp_log.h>
#include <esp_err.h>

View File

@@ -11,7 +11,7 @@ dependencies:
78/esp_lcd_nv3023: ~1.0.0
78/esp-wifi-connect: ~2.4.2
78/esp-opus-encoder: ~2.3.2
78/esp-ml307: ~2.0.1
78/esp-ml307: ~2.0.2
78/xiaozhi-fonts: ~1.3.2
espressif/led_strip: ^2.5.5
espressif/esp_codec_dev: ~1.3.2

269
main/mcp_server.cc Normal file
View File

@@ -0,0 +1,269 @@
/*
* MCP Server Implementation
* Reference: https://modelcontextprotocol.io/specification/2024-11-05
*/
#include "mcp_server.h"
#include <esp_log.h>
#include <esp_app_desc.h>
#include <algorithm>
#include <cstring>
#include "application.h"
#include "display.h"
#include "board.h"
#define TAG "MCP"
McpServer::McpServer() {
AddCommonTools();
}
void McpServer::AddCommonTools() {
AddTool("self.get_device_status",
"Provides the real-time information of the device, including the current status of the audio speaker, screen, battery, network, etc.\n"
"Use this tool for: \n"
"1. Answering questions about current condition (e.g. what is the current volume of the audio speaker?)\n"
"2. As the first step to control the device (e.g. turn up / down the volume of the audio speaker, etc.)",
PropertyList(),
[](const PropertyList& properties) -> ReturnValue {
return Board::GetInstance().GetDeviceStatusJson();
});
AddTool("self.speaker.set_volume",
"Set the volume of the audio speaker. If the current volume is unknown, you must call `self.get_device_status` tool first and then call this tool.",
PropertyList({
Property("volume", kPropertyTypeInteger, 0, 100)
}),
[](const PropertyList& properties) -> ReturnValue {
auto codec = Board::GetInstance().GetAudioCodec();
codec->SetOutputVolume(properties["volume"].value<int>());
return true;
});
AddTool("self.screen.set_brightness",
"Set the brightness of the screen.",
PropertyList({
Property("brightness", kPropertyTypeInteger, 0, 100)
}),
[](const PropertyList& properties) -> ReturnValue {
uint8_t brightness = static_cast<uint8_t>(properties["brightness"].value<int>());
auto backlight = Board::GetInstance().GetBacklight();
if (backlight) {
backlight->SetBrightness(brightness, true);
}
return true;
});
AddTool("self.screen.set_theme",
"Set the theme of the screen. The theme can be 'light' or 'dark'.",
PropertyList({
Property("theme", kPropertyTypeString)
}),
[](const PropertyList& properties) -> ReturnValue {
auto display = Board::GetInstance().GetDisplay();
if (display) {
display->SetTheme(properties["theme"].value<std::string>().c_str());
}
return true;
});
}
void McpServer::AddTool(McpTool* tool) {
tools_.push_back(tool);
}
void McpServer::AddTool(const std::string& name, const std::string& description, const PropertyList& properties, std::function<ReturnValue(const PropertyList&)> callback) {
tools_.push_back(new McpTool(name, description, properties, callback));
}
void McpServer::ParseMessage(const std::string& message) {
cJSON* json = cJSON_Parse(message.c_str());
if (json == nullptr) {
ESP_LOGE(TAG, "Failed to parse MCP message: %s", message.c_str());
return;
}
ParseMessage(json);
cJSON_Delete(json);
}
void McpServer::ParseMessage(const cJSON* json) {
// Check JSONRPC version
auto version = cJSON_GetObjectItem(json, "jsonrpc");
if (version == nullptr || !cJSON_IsString(version) || strcmp(version->valuestring, "2.0") != 0) {
ESP_LOGE(TAG, "Invalid JSONRPC version: %s", version ? version->valuestring : "null");
return;
}
// Check method
auto method = cJSON_GetObjectItem(json, "method");
if (method == nullptr || !cJSON_IsString(method)) {
ESP_LOGE(TAG, "Missing method");
return;
}
auto method_str = std::string(method->valuestring);
if (method_str.find("notifications") == 0) {
return;
}
// Check params
auto params = cJSON_GetObjectItem(json, "params");
if (params != nullptr && !cJSON_IsObject(params)) {
ESP_LOGE(TAG, "Invalid params for method: %s", method_str.c_str());
return;
}
auto id = cJSON_GetObjectItem(json, "id");
if (id == nullptr || !cJSON_IsNumber(id)) {
ESP_LOGE(TAG, "Invalid id for method: %s", method_str.c_str());
return;
}
auto id_int = id->valueint;
if (method_str == "initialize") {
auto app_desc = esp_app_get_description();
ReplyResult(id_int, "{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{\"tools\":{}},"
"\"serverInfo\":{\"name\":\"" BOARD_NAME "\",\"version\":\"" + std::string(app_desc->version) + "\"}}");
} else if (method_str == "tools/list") {
std::string cursor_str = "";
if (params != nullptr) {
auto cursor = cJSON_GetObjectItem(params, "cursor");
if (cJSON_IsString(cursor)) {
cursor_str = std::string(cursor->valuestring);
}
}
GetToolsList(id_int, cursor_str);
} else if (method_str == "tools/call") {
if (!cJSON_IsObject(params)) {
ESP_LOGE(TAG, "tools/call: Missing params");
ReplyError(id_int, "Missing params");
return;
}
auto tool_name = cJSON_GetObjectItem(params, "name");
if (!cJSON_IsString(tool_name)) {
ESP_LOGE(TAG, "tools/call: Missing name");
ReplyError(id_int, "Missing name");
return;
}
auto tool_arguments = cJSON_GetObjectItem(params, "arguments");
if (tool_arguments != nullptr && !cJSON_IsObject(tool_arguments)) {
ESP_LOGE(TAG, "tools/call: Invalid arguments");
ReplyError(id_int, "Invalid arguments");
return;
}
DoToolCall(id_int, std::string(tool_name->valuestring), tool_arguments);
} else {
ESP_LOGE(TAG, "Method not implemented: %s", method_str.c_str());
ReplyError(id_int, "Method not implemented: " + method_str);
}
}
void McpServer::ReplyResult(int id, const std::string& result) {
std::string payload = "{\"jsonrpc\":\"2.0\",\"id\":" + std::to_string(id) + ",\"result\":" + result + "}";
Application::GetInstance().SendMcpMessage(payload);
}
void McpServer::ReplyError(int id, const std::string& message) {
std::string payload = "{\"jsonrpc\":\"2.0\",\"id\":";
payload += std::to_string(id) + ",\"error\":{\"message\":\"" + message + "\"}}";
Application::GetInstance().SendMcpMessage(payload);
}
void McpServer::GetToolsList(int id, const std::string& cursor) {
const int max_payload_size = 1400; // ML307 MQTT publish size limit
std::string json = "{\"tools\":[";
bool found_cursor = cursor.empty();
auto it = tools_.begin();
std::string next_cursor = "";
while (it != tools_.end()) {
// 如果我们还没有找到起始位置,继续搜索
if (!found_cursor) {
if ((*it)->name() == cursor) {
found_cursor = true;
} else {
++it;
continue;
}
}
// 添加tool前检查大小
std::string tool_json = (*it)->to_json() + ",";
if (json.length() + tool_json.length() + 30 > max_payload_size) {
// 如果添加这个tool会超出大小限制设置next_cursor并退出循环
next_cursor = (*it)->name();
break;
}
json += tool_json;
++it;
}
if (json.back() == ',') {
json.pop_back();
}
if (json.back() == '[' && !tools_.empty()) {
// 如果没有添加任何tool返回错误
ESP_LOGE(TAG, "tools/list: Failed to add tool %s because of payload size limit", next_cursor.c_str());
ReplyError(id, "Failed to add tool " + next_cursor + " because of payload size limit");
return;
}
if (next_cursor.empty()) {
json += "]}";
} else {
json += "],\"nextCursor\":\"" + next_cursor + "\"}";
}
ReplyResult(id, json);
}
void McpServer::DoToolCall(int id, const std::string& tool_name, const cJSON* tool_arguments) {
auto tool_iter = std::find_if(tools_.begin(), tools_.end(),
[&tool_name](const McpTool* tool) {
return tool->name() == tool_name;
});
if (tool_iter == tools_.end()) {
ESP_LOGE(TAG, "tools/call: Unknown tool: %s", tool_name.c_str());
ReplyError(id, "Unknown tool: " + tool_name);
return;
}
PropertyList arguments = (*tool_iter)->properties();
for (auto& argument : arguments) {
bool found = false;
if (cJSON_IsObject(tool_arguments)) {
auto value = cJSON_GetObjectItem(tool_arguments, argument.name().c_str());
if (argument.type() == kPropertyTypeBoolean && cJSON_IsBool(value)) {
argument.set_value<bool>(value->valueint == 1);
found = true;
} else if (argument.type() == kPropertyTypeInteger && cJSON_IsNumber(value)) {
argument.set_value<int>(value->valueint);
found = true;
} else if (argument.type() == kPropertyTypeString && cJSON_IsString(value)) {
argument.set_value<std::string>(value->valuestring);
found = true;
}
}
if (!argument.has_default_value() && !found) {
ESP_LOGE(TAG, "tools/call: Missing valid argument: %s", argument.name().c_str());
ReplyError(id, "Missing valid argument: " + argument.name());
return;
}
}
Application::GetInstance().Schedule([this, id, tool_iter, arguments = std::move(arguments)]() {
try {
ReplyResult(id, (*tool_iter)->Call(arguments));
} catch (const std::runtime_error& e) {
ESP_LOGE(TAG, "tools/call: %s", e.what());
ReplyError(id, e.what());
}
});
}

278
main/mcp_server.h Normal file
View File

@@ -0,0 +1,278 @@
#ifndef MCP_SERVER_H
#define MCP_SERVER_H
#include <string>
#include <vector>
#include <map>
#include <functional>
#include <variant>
#include <optional>
#include <stdexcept>
#include <cJSON.h>
// 添加类型别名
using ReturnValue = std::variant<bool, int, std::string>;
enum PropertyType {
kPropertyTypeBoolean,
kPropertyTypeInteger,
kPropertyTypeString
};
class Property {
private:
std::string name_;
PropertyType type_;
std::variant<bool, int, std::string> value_;
bool has_default_value_;
std::optional<int> min_value_; // 新增:整数最小值
std::optional<int> max_value_; // 新增:整数最大值
public:
// Required field constructor
Property(const std::string& name, PropertyType type)
: name_(name), type_(type), has_default_value_(false) {}
// Optional field constructor with default value
template<typename T>
Property(const std::string& name, PropertyType type, const T& default_value)
: name_(name), type_(type), has_default_value_(true) {
value_ = default_value;
}
Property(const std::string& name, PropertyType type, int min_value, int max_value)
: name_(name), type_(type), has_default_value_(false), min_value_(min_value), max_value_(max_value) {
if (type != kPropertyTypeInteger) {
throw std::invalid_argument("Range limits only apply to integer properties");
}
}
Property(const std::string& name, PropertyType type, int default_value, int min_value, int max_value)
: name_(name), type_(type), has_default_value_(true), min_value_(min_value), max_value_(max_value) {
if (type != kPropertyTypeInteger) {
throw std::invalid_argument("Range limits only apply to integer properties");
}
if (default_value < min_value || default_value > max_value) {
throw std::invalid_argument("Default value must be within the specified range");
}
value_ = default_value;
}
inline const std::string& name() const { return name_; }
inline PropertyType type() const { return type_; }
inline bool has_default_value() const { return has_default_value_; }
inline bool has_range() const { return min_value_.has_value() && max_value_.has_value(); }
inline int min_value() const { return min_value_.value_or(0); }
inline int max_value() const { return max_value_.value_or(0); }
template<typename T>
inline T value() const {
return std::get<T>(value_);
}
template<typename T>
inline void set_value(const T& value) {
// 添加对设置的整数值进行范围检查
if constexpr (std::is_same_v<T, int>) {
if (min_value_.has_value() && value < min_value_.value()) {
throw std::invalid_argument("Value is below minimum allowed: " + std::to_string(min_value_.value()));
}
if (max_value_.has_value() && value > max_value_.value()) {
throw std::invalid_argument("Value exceeds maximum allowed: " + std::to_string(max_value_.value()));
}
}
value_ = value;
}
std::string to_json() const {
cJSON *json = cJSON_CreateObject();
if (type_ == kPropertyTypeBoolean) {
cJSON_AddStringToObject(json, "type", "boolean");
if (has_default_value_) {
cJSON_AddBoolToObject(json, "default", value<bool>());
}
} else if (type_ == kPropertyTypeInteger) {
cJSON_AddStringToObject(json, "type", "integer");
if (has_default_value_) {
cJSON_AddNumberToObject(json, "default", value<int>());
}
if (min_value_.has_value()) {
cJSON_AddNumberToObject(json, "minimum", min_value_.value());
}
if (max_value_.has_value()) {
cJSON_AddNumberToObject(json, "maximum", max_value_.value());
}
} else if (type_ == kPropertyTypeString) {
cJSON_AddStringToObject(json, "type", "string");
if (has_default_value_) {
cJSON_AddStringToObject(json, "default", value<std::string>().c_str());
}
}
char *json_str = cJSON_PrintUnformatted(json);
std::string result(json_str);
cJSON_free(json_str);
cJSON_Delete(json);
return result;
}
};
class PropertyList {
private:
std::vector<Property> properties_;
public:
PropertyList() = default;
PropertyList(const std::vector<Property>& properties) : properties_(properties) {}
void AddProperty(const Property& property) {
properties_.push_back(property);
}
const Property& operator[](const std::string& name) const {
for (const auto& property : properties_) {
if (property.name() == name) {
return property;
}
}
throw std::runtime_error("Property not found: " + name);
}
auto begin() { return properties_.begin(); }
auto end() { return properties_.end(); }
std::vector<std::string> GetRequired() const {
std::vector<std::string> required;
for (auto& property : properties_) {
if (!property.has_default_value()) {
required.push_back(property.name());
}
}
return required;
}
std::string to_json() const {
cJSON *json = cJSON_CreateObject();
for (const auto& property : properties_) {
cJSON *prop_json = cJSON_Parse(property.to_json().c_str());
cJSON_AddItemToObject(json, property.name().c_str(), prop_json);
}
char *json_str = cJSON_PrintUnformatted(json);
std::string result(json_str);
cJSON_free(json_str);
cJSON_Delete(json);
return result;
}
};
class McpTool {
private:
std::string name_;
std::string description_;
PropertyList properties_;
std::function<ReturnValue(const PropertyList&)> callback_;
public:
McpTool(const std::string& name,
const std::string& description,
const PropertyList& properties,
std::function<ReturnValue(const PropertyList&)> callback)
: name_(name),
description_(description),
properties_(properties),
callback_(callback) {}
inline const std::string& name() const { return name_; }
inline const std::string& description() const { return description_; }
inline const PropertyList& properties() const { return properties_; }
std::string to_json() const {
std::vector<std::string> required = properties_.GetRequired();
cJSON *json = cJSON_CreateObject();
cJSON_AddStringToObject(json, "name", name_.c_str());
cJSON_AddStringToObject(json, "description", description_.c_str());
cJSON *input_schema = cJSON_CreateObject();
cJSON_AddStringToObject(input_schema, "type", "object");
cJSON *properties = cJSON_Parse(properties_.to_json().c_str());
cJSON_AddItemToObject(input_schema, "properties", properties);
if (!required.empty()) {
cJSON *required_array = cJSON_CreateArray();
for (const auto& property : required) {
cJSON_AddItemToArray(required_array, cJSON_CreateString(property.c_str()));
}
cJSON_AddItemToObject(input_schema, "required", required_array);
}
cJSON_AddItemToObject(json, "inputSchema", input_schema);
char *json_str = cJSON_PrintUnformatted(json);
std::string result(json_str);
cJSON_free(json_str);
cJSON_Delete(json);
return result;
}
std::string Call(const PropertyList& properties) {
ReturnValue return_value = callback_(properties);
// 返回结果
cJSON* result = cJSON_CreateObject();
cJSON* content = cJSON_CreateArray();
cJSON* text = cJSON_CreateObject();
cJSON_AddStringToObject(text, "type", "text");
if (std::holds_alternative<std::string>(return_value)) {
cJSON_AddStringToObject(text, "text", std::get<std::string>(return_value).c_str());
} else if (std::holds_alternative<bool>(return_value)) {
cJSON_AddStringToObject(text, "text", std::get<bool>(return_value) ? "true" : "false");
} else if (std::holds_alternative<int>(return_value)) {
cJSON_AddStringToObject(text, "text", std::to_string(std::get<int>(return_value)).c_str());
}
cJSON_AddItemToArray(content, text);
cJSON_AddItemToObject(result, "content", content);
cJSON_AddBoolToObject(result, "isError", false);
auto json_str = cJSON_PrintUnformatted(result);
std::string result_str(json_str);
cJSON_free(json_str);
cJSON_Delete(result);
return result_str;
}
};
class McpServer {
public:
static McpServer& GetInstance() {
static McpServer instance;
return instance;
}
void AddTool(McpTool* tool);
void AddTool(const std::string& name, const std::string& description, const PropertyList& properties, std::function<ReturnValue(const PropertyList&)> callback);
void ParseMessage(const cJSON* json);
void ParseMessage(const std::string& message);
private:
McpServer();
~McpServer() = default;
void AddCommonTools();
void ReplyResult(int id, const std::string& result);
void ReplyError(int id, const std::string& message);
void GetToolsList(int id, const std::string& cursor);
void DoToolCall(int id, const std::string& tool_name, const cJSON* tool_arguments);
std::vector<McpTool*> tools_;
};
#endif // MCP_SERVER_H

View File

@@ -415,7 +415,9 @@ std::string Ota::GetActivationPayload() {
cJSON_AddStringToObject(payload, "serial_number", serial_number_.c_str());
cJSON_AddStringToObject(payload, "challenge", activation_challenge_.c_str());
cJSON_AddStringToObject(payload, "hmac", hmac_hex.c_str());
std::string json = cJSON_Print(payload);
auto json_str = cJSON_PrintUnformatted(payload);
std::string json(json_str);
cJSON_free(json_str);
cJSON_Delete(payload);
ESP_LOGI(TAG, "Activation payload: %s", json.c_str());

View File

@@ -182,17 +182,7 @@ bool MqttProtocol::OpenAudioChannel() {
session_id_ = "";
xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
// 发送 hello 消息申请 UDP 通道
std::string message = "{";
message += "\"type\":\"hello\",";
message += "\"version\": 3,";
message += "\"transport\":\"udp\",";
#if CONFIG_USE_SERVER_AEC
message += "\"features\":{\"aec\":true},";
#endif
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
message += "}}";
auto message = GetHelloMessage();
if (!SendText(message)) {
return false;
}
@@ -262,6 +252,33 @@ bool MqttProtocol::OpenAudioChannel() {
return true;
}
std::string MqttProtocol::GetHelloMessage() {
// 发送 hello 消息申请 UDP 通道
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "hello");
cJSON_AddNumberToObject(root, "version", 3);
cJSON_AddStringToObject(root, "transport", "udp");
cJSON* features = cJSON_CreateObject();
#if CONFIG_USE_SERVER_AEC
cJSON_AddBoolToObject(features, "aec", true);
#endif
#if CONFIG_IOT_PROTOCOL_MCP
cJSON_AddBoolToObject(features, "mcp", true);
#endif
cJSON_AddItemToObject(root, "features", features);
cJSON* audio_params = cJSON_CreateObject();
cJSON_AddStringToObject(audio_params, "format", "opus");
cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
cJSON_AddNumberToObject(audio_params, "channels", 1);
cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
cJSON_AddItemToObject(root, "audio_params", audio_params);
auto json_str = cJSON_PrintUnformatted(root);
std::string message(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return message;
}
void MqttProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {

View File

@@ -51,6 +51,7 @@ private:
std::string DecodeHexString(const std::string& hex_string);
bool SendText(const std::string& text) override;
std::string GetHelloMessage();
};

View File

@@ -115,6 +115,11 @@ void Protocol::SendIotStates(const std::string& states) {
SendText(message);
}
void Protocol::SendMcpMessage(const std::string& payload) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"mcp\",\"payload\":" + payload + "}";
SendText(message);
}
bool Protocol::IsTimeout() const {
const int kTimeoutSeconds = 120;
auto now = std::chrono::steady_clock::now();

View File

@@ -71,6 +71,7 @@ public:
virtual void SendAbortSpeaking(AbortReason reason);
virtual void SendIotDescriptors(const std::string& descriptors);
virtual void SendIotStates(const std::string& states);
virtual void SendMcpMessage(const std::string& message);
protected:
std::function<void(const cJSON* root)> on_incoming_json_;

View File

@@ -180,22 +180,12 @@ bool WebsocketProtocol::OpenAudioChannel() {
ESP_LOGI(TAG, "Connecting to websocket server: %s with version: %d", url.c_str(), version_);
if (!websocket_->Connect(url.c_str())) {
ESP_LOGE(TAG, "Failed to connect to websocket server");
SetError(Lang::Strings::SERVER_NOT_FOUND);
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
return false;
}
// Send hello message to describe the client
// keys: message type, version, audio_params (format, sample_rate, channels)
std::string message = "{";
message += "\"type\":\"hello\",";
message += "\"version\": " + std::to_string(version_) + ",";
#if CONFIG_USE_SERVER_AEC
message += "\"features\":{\"aec\":true},";
#endif
message += "\"transport\":\"websocket\",";
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
message += "}}";
auto message = GetHelloMessage();
if (!SendText(message)) {
return false;
}
@@ -215,6 +205,33 @@ bool WebsocketProtocol::OpenAudioChannel() {
return true;
}
std::string WebsocketProtocol::GetHelloMessage() {
// keys: message type, version, audio_params (format, sample_rate, channels)
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "hello");
cJSON_AddNumberToObject(root, "version", version_);
cJSON* features = cJSON_CreateObject();
#if CONFIG_USE_SERVER_AEC
cJSON_AddBoolToObject(features, "aec", true);
#endif
#if CONFIG_IOT_PROTOCOL_MCP
cJSON_AddBoolToObject(features, "mcp", true);
#endif
cJSON_AddItemToObject(root, "features", features);
cJSON_AddStringToObject(root, "transport", "websocket");
cJSON* audio_params = cJSON_CreateObject();
cJSON_AddStringToObject(audio_params, "format", "opus");
cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
cJSON_AddNumberToObject(audio_params, "channels", 1);
cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
cJSON_AddItemToObject(root, "audio_params", audio_params);
auto json_str = cJSON_PrintUnformatted(root);
std::string message(json_str);
cJSON_free(json_str);
cJSON_Delete(root);
return message;
}
void WebsocketProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) {

View File

@@ -28,6 +28,7 @@ private:
void ParseServerHello(const cJSON* root);
bool SendText(const std::string& text) override;
std::string GetHelloMessage();
};
#endif