From 5da66773d5ba6a5cf41dcc0c722af80c0f2f7eb7 Mon Sep 17 00:00:00 2001 From: Terrence Date: Thu, 22 May 2025 19:19:36 +0800 Subject: [PATCH] Add MCP server --- main/CMakeLists.txt | 1 + main/Kconfig.projbuild | 23 +- main/application.cc | 44 ++- main/application.h | 1 + .../bread-compact-wifi/compact_wifi_board.cc | 1 + main/boards/common/board.h | 1 + main/boards/common/dual_network_board.cc | 6 +- main/boards/common/dual_network_board.h | 2 +- main/boards/common/ml307_board.cc | 83 ++++++ main/boards/common/ml307_board.h | 1 + main/boards/common/wifi_board.cc | 80 +++++ main/boards/common/wifi_board.h | 1 + main/display/lcd_display.cc | 1 + main/idf_component.yml | 2 +- main/mcp_server.cc | 269 +++++++++++++++++ main/mcp_server.h | 278 ++++++++++++++++++ main/ota.cc | 4 +- main/protocols/mqtt_protocol.cc | 39 ++- main/protocols/mqtt_protocol.h | 1 + main/protocols/protocol.cc | 5 + main/protocols/protocol.h | 1 + main/protocols/websocket_protocol.cc | 41 ++- main/protocols/websocket_protocol.h | 1 + 23 files changed, 845 insertions(+), 41 deletions(-) create mode 100644 main/mcp_server.cc create mode 100644 main/mcp_server.h diff --git a/main/CMakeLists.txt b/main/CMakeLists.txt index a6086278..c3f5e431 100644 --- a/main/CMakeLists.txt +++ b/main/CMakeLists.txt @@ -15,6 +15,7 @@ set(SOURCES "audio_codecs/audio_codec.cc" "protocols/websocket_protocol.cc" "iot/thing.cc" "iot/thing_manager.cc" + "mcp_server.cc" "system_info.cc" "application.cc" "ota.cc" diff --git a/main/Kconfig.projbuild b/main/Kconfig.projbuild index 9b2ecae6..686bcf3b 100644 --- a/main/Kconfig.projbuild +++ b/main/Kconfig.projbuild @@ -8,7 +8,7 @@ config OTA_URL choice - prompt "语言选择" + prompt "Default Language" default LANGUAGE_ZH_CN help Select device display language @@ -249,37 +249,48 @@ choice DISPLAY_ESP32S3_KORVO2_V3 endchoice config USE_WECHAT_MESSAGE_STYLE - bool "使用微信聊天界面风格" + bool "Enable WeChat Message Style" default n help 使用微信聊天界面风格 config USE_WAKE_WORD_DETECT - bool "启用唤醒词检测" + bool "Enable Wake Word Detection" default y depends on IDF_TARGET_ESP32S3 || IDF_TARGET_ESP32P4 && SPIRAM help 需要 ESP32 S3 与 AFE 支持 config USE_AUDIO_PROCESSOR - bool "启用音频降噪、增益处理" + bool "Enable Audio Noise Reduction" default y depends on IDF_TARGET_ESP32S3 || IDF_TARGET_ESP32P4 && SPIRAM help 需要 ESP32 S3 与 AFE 支持 config USE_DEVICE_AEC - bool "在通话过程中启用设备端 AEC" + bool "Enable Device-Side AEC" default n depends on USE_AUDIO_PROCESSOR && (BOARD_TYPE_ESP_BOX_3 || BOARD_TYPE_ESP_BOX || BOARD_TYPE_ESP_BOX_LITE || BOARD_TYPE_LICHUANG_DEV || BOARD_TYPE_ESP32S3_KORVO2_V3 || BOARD_TYPE_ESP32S3_Touch_AMOLED_1_75) help 因为性能不够,不建议和微信聊天界面风格同时开启 config USE_SERVER_AEC - bool "在通话过程中启用服务器端 AEC" + bool "Enable Server-Side AEC" default n depends on USE_AUDIO_PROCESSOR help 启用服务器端 AEC,需要服务器支持 +choice IOT_PROTOCOL + prompt "IoT Protocol" + default IOT_PROTOCOL_XIAOZHI + help + IoT 协议,用于获取设备状态与发送控制指令 + config IOT_PROTOCOL_MCP + bool "MCP协议 2024-11-05" + config IOT_PROTOCOL_XIAOZHI + bool "小智IoT协议 1.0" +endchoice + endmenu diff --git a/main/application.cc b/main/application.cc index 54841a45..9a3d139c 100644 --- a/main/application.cc +++ b/main/application.cc @@ -9,6 +9,7 @@ #include "font_awesome_symbols.h" #include "iot/thing_manager.h" #include "assets/lang_config.h" +#include "mcp_server.h" #if CONFIG_USE_AUDIO_PROCESSOR #include "afe_audio_processor.h" @@ -41,7 +42,7 @@ static const char* const STATE_STRINGS[] = { Application::Application() { event_group_ = xEventGroupCreate(); - background_task_ = new BackgroundTask(4096 * 8); + background_task_ = new BackgroundTask(4096 * 7); #if CONFIG_USE_AUDIO_PROCESSOR audio_processor_ = std::make_unique(); @@ -425,12 +426,15 @@ void Application::Start() { protocol_->server_sample_rate(), codec->output_sample_rate()); } SetDecodeSampleRate(protocol_->server_sample_rate(), protocol_->server_frame_duration()); + +#if CONFIG_IOT_PROTOCOL_XIAOZHI auto& thing_manager = iot::ThingManager::GetInstance(); protocol_->SendIotDescriptors(thing_manager.GetDescriptorsJson()); std::string states; if (thing_manager.GetStatesJson(states, false)) { protocol_->SendIotStates(states); } +#endif }); protocol_->OnAudioChannelClosed([this, &board]() { board.SetPowerSaveMode(true); @@ -465,7 +469,7 @@ void Application::Start() { }); } else if (strcmp(state->valuestring, "sentence_start") == 0) { auto text = cJSON_GetObjectItem(root, "text"); - if (text != NULL) { + if (cJSON_IsString(text)) { ESP_LOGI(TAG, "<< %s", text->valuestring); Schedule([this, display, message = std::string(text->valuestring)]() { display->SetChatMessage("assistant", message.c_str()); @@ -474,7 +478,7 @@ void Application::Start() { } } else if (strcmp(type->valuestring, "stt") == 0) { auto text = cJSON_GetObjectItem(root, "text"); - if (text != NULL) { + if (cJSON_IsString(text)) { ESP_LOGI(TAG, ">> %s", text->valuestring); Schedule([this, display, message = std::string(text->valuestring)]() { display->SetChatMessage("user", message.c_str()); @@ -482,23 +486,32 @@ void Application::Start() { } } else if (strcmp(type->valuestring, "llm") == 0) { auto emotion = cJSON_GetObjectItem(root, "emotion"); - if (emotion != NULL) { + if (cJSON_IsString(emotion)) { Schedule([this, display, emotion_str = std::string(emotion->valuestring)]() { display->SetEmotion(emotion_str.c_str()); }); } +#if CONFIG_IOT_PROTOCOL_MCP + } else if (strcmp(type->valuestring, "mcp") == 0) { + auto payload = cJSON_GetObjectItem(root, "payload"); + if (cJSON_IsObject(payload)) { + McpServer::GetInstance().ParseMessage(payload); + } +#endif +#if CONFIG_IOT_PROTOCOL_XIAOZHI } else if (strcmp(type->valuestring, "iot") == 0) { auto commands = cJSON_GetObjectItem(root, "commands"); - if (commands != NULL) { + if (cJSON_IsArray(commands)) { auto& thing_manager = iot::ThingManager::GetInstance(); for (int i = 0; i < cJSON_GetArraySize(commands); ++i) { auto command = cJSON_GetArrayItem(commands, i); thing_manager.Invoke(command); } } +#endif } else if (strcmp(type->valuestring, "system") == 0) { auto command = cJSON_GetObjectItem(root, "command"); - if (command != NULL) { + if (cJSON_IsString(command)) { ESP_LOGI(TAG, "System command: %s", command->valuestring); if (strcmp(command->valuestring, "reboot") == 0) { // Do a reboot if user requests a OTA update @@ -513,7 +526,7 @@ void Application::Start() { auto status = cJSON_GetObjectItem(root, "status"); auto message = cJSON_GetObjectItem(root, "message"); auto emotion = cJSON_GetObjectItem(root, "emotion"); - if (status != NULL && message != NULL && emotion != NULL) { + if (cJSON_IsString(status) && cJSON_IsString(message) && cJSON_IsString(emotion)) { Alert(status->valuestring, message->valuestring, emotion->valuestring, Lang::Sounds::P3_VIBRATION); } else { ESP_LOGW(TAG, "Alert command requires status, message and emotion"); @@ -620,7 +633,10 @@ void Application::OnClockTimer() { clock_ticks_++; // Print the debug info every 10 seconds - if (clock_ticks_ % 10 == 0) { + if (clock_ticks_ % 3 == 0) { + // char buffer[500]; + // vTaskList(buffer); + // ESP_LOGI(TAG, "Task list: \n%s", buffer); // SystemInfo::PrintRealTimeStats(pdMS_TO_TICKS(1000)); int free_sram = heap_caps_get_free_size(MALLOC_CAP_INTERNAL); @@ -850,7 +866,9 @@ void Application::SetDeviceState(DeviceState state) { display->SetStatus(Lang::Strings::LISTENING); display->SetEmotion("neutral"); // Update the IoT states before sending the start listening command +#if CONFIG_IOT_PROTOCOL_XIAOZHI UpdateIotStates(); +#endif // Make sure the audio processor is running if (!audio_processor_->IsRunning()) { @@ -910,11 +928,13 @@ void Application::SetDecodeSampleRate(int sample_rate, int frame_duration) { } void Application::UpdateIotStates() { +#if CONFIG_IOT_PROTOCOL_XIAOZHI auto& thing_manager = iot::ThingManager::GetInstance(); std::string states; if (thing_manager.GetStatesJson(states, true)) { protocol_->SendIotStates(states); } +#endif } void Application::Reboot() { @@ -955,3 +975,11 @@ bool Application::CanEnterSleepMode() { // Now it is safe to enter sleep mode return true; } + +void Application::SendMcpMessage(const std::string& payload) { + Schedule([this, payload]() { + if (protocol_) { + protocol_->SendMcpMessage(payload); + } + }); +} diff --git a/main/application.h b/main/application.h index 7dd32409..6914ecf0 100644 --- a/main/application.h +++ b/main/application.h @@ -72,6 +72,7 @@ public: void WakeWordInvoke(const std::string& wake_word); void PlaySound(const std::string_view& sound); bool CanEnterSleepMode(); + void SendMcpMessage(const std::string& payload); private: Application(); diff --git a/main/boards/bread-compact-wifi/compact_wifi_board.cc b/main/boards/bread-compact-wifi/compact_wifi_board.cc index 04dc0455..cfdc1afa 100644 --- a/main/boards/bread-compact-wifi/compact_wifi_board.cc +++ b/main/boards/bread-compact-wifi/compact_wifi_board.cc @@ -94,6 +94,7 @@ private: display_ = new NoDisplay(); return; } + ESP_ERROR_CHECK(esp_lcd_panel_invert_color(panel_, false)); // Set the display to on ESP_LOGI(TAG, "Turning display on"); diff --git a/main/boards/common/board.h b/main/boards/common/board.h index b869fbad..2daabec1 100644 --- a/main/boards/common/board.h +++ b/main/boards/common/board.h @@ -49,6 +49,7 @@ public: virtual std::string GetJson(); virtual void SetPowerSaveMode(bool enabled) = 0; virtual std::string GetBoardJson() = 0; + virtual std::string GetDeviceStatusJson() = 0; }; #define DECLARE_BOARD(BOARD_CLASS_NAME) \ diff --git a/main/boards/common/dual_network_board.cc b/main/boards/common/dual_network_board.cc index 79ac00ce..41dbef55 100644 --- a/main/boards/common/dual_network_board.cc +++ b/main/boards/common/dual_network_board.cc @@ -98,4 +98,8 @@ void DualNetworkBoard::SetPowerSaveMode(bool enabled) { std::string DualNetworkBoard::GetBoardJson() { return current_board_->GetBoardJson(); -} \ No newline at end of file +} + +std::string DualNetworkBoard::GetDeviceStatusJson() { + return current_board_->GetDeviceStatusJson(); +} diff --git a/main/boards/common/dual_network_board.h b/main/boards/common/dual_network_board.h index 5fd1d39b..780c877a 100644 --- a/main/boards/common/dual_network_board.h +++ b/main/boards/common/dual_network_board.h @@ -56,7 +56,7 @@ public: virtual const char* GetNetworkStateIcon() override; virtual void SetPowerSaveMode(bool enabled) override; virtual std::string GetBoardJson() override; - + virtual std::string GetDeviceStatusJson() override; }; #endif // DUAL_NETWORK_BOARD_H \ No newline at end of file diff --git a/main/boards/common/ml307_board.cc b/main/boards/common/ml307_board.cc index e150b750..8a8346a3 100644 --- a/main/boards/common/ml307_board.cc +++ b/main/boards/common/ml307_board.cc @@ -120,3 +120,86 @@ std::string Ml307Board::GetBoardJson() { void Ml307Board::SetPowerSaveMode(bool enabled) { // TODO: Implement power save mode for ML307 } + +std::string Ml307Board::GetDeviceStatusJson() { + /* + * 返回设备状态JSON + * + * 返回的JSON结构如下: + * { + * "audio_speaker": { + * "volume": 70 + * }, + * "screen": { + * "brightness": 100, + * "theme": "light" + * }, + * "battery": { + * "level": 50, + * "charging": true + * }, + * "network": { + * "type": "cellular", + * "carrier": "CHINA MOBILE", + * "csq": 10 + * } + * } + */ + auto& board = Board::GetInstance(); + auto root = cJSON_CreateObject(); + + // Audio speaker + auto audio_speaker = cJSON_CreateObject(); + auto audio_codec = board.GetAudioCodec(); + if (audio_codec) { + cJSON_AddNumberToObject(audio_speaker, "volume", audio_codec->output_volume()); + } + cJSON_AddItemToObject(root, "audio_speaker", audio_speaker); + + // Screen brightness + auto backlight = board.GetBacklight(); + auto screen = cJSON_CreateObject(); + if (backlight) { + cJSON_AddNumberToObject(screen, "brightness", backlight->brightness()); + } + auto display = board.GetDisplay(); + if (display && display->height() > 64) { // For LCD display only + cJSON_AddStringToObject(screen, "theme", display->GetTheme().c_str()); + } + cJSON_AddItemToObject(root, "screen", screen); + + // Battery + int battery_level = 0; + bool charging = false; + bool discharging = false; + if (board.GetBatteryLevel(battery_level, charging, discharging)) { + cJSON* battery = cJSON_CreateObject(); + cJSON_AddNumberToObject(battery, "level", battery_level); + cJSON_AddBoolToObject(battery, "charging", charging); + cJSON_AddItemToObject(root, "battery", battery); + } + + // Network + auto network = cJSON_CreateObject(); + cJSON_AddStringToObject(network, "type", "cellular"); + cJSON_AddStringToObject(network, "carrier", modem_.GetCarrierName().c_str()); + int csq = modem_.GetCsq(); + if (csq == -1) { + cJSON_AddStringToObject(network, "signal", "unknown"); + } else if (csq >= 0 && csq <= 14) { + cJSON_AddStringToObject(network, "signal", "very weak"); + } else if (csq >= 15 && csq <= 19) { + cJSON_AddStringToObject(network, "signal", "weak"); + } else if (csq >= 20 && csq <= 24) { + cJSON_AddStringToObject(network, "signal", "medium"); + } else if (csq >= 25 && csq <= 31) { + cJSON_AddStringToObject(network, "signal", "strong"); + } + cJSON_AddItemToObject(root, "network", network); + + auto json_str = cJSON_PrintUnformatted(root); + std::string json(json_str); + cJSON_free(json_str); + cJSON_Delete(root); + return json; +} diff --git a/main/boards/common/ml307_board.h b/main/boards/common/ml307_board.h index 4dd6cb09..05321084 100644 --- a/main/boards/common/ml307_board.h +++ b/main/boards/common/ml307_board.h @@ -21,6 +21,7 @@ public: virtual const char* GetNetworkStateIcon() override; virtual void SetPowerSaveMode(bool enabled) override; virtual AudioCodec* GetAudioCodec() override { return nullptr; } + virtual std::string GetDeviceStatusJson() override; }; #endif // ML307_BOARD_H diff --git a/main/boards/common/wifi_board.cc b/main/boards/common/wifi_board.cc index 569d6b4a..494e2c46 100644 --- a/main/boards/common/wifi_board.cc +++ b/main/boards/common/wifi_board.cc @@ -181,3 +181,83 @@ void WifiBoard::ResetWifiConfiguration() { // Reboot the device esp_restart(); } + +std::string WifiBoard::GetDeviceStatusJson() { + /* + * 返回设备状态JSON + * + * 返回的JSON结构如下: + * { + * "audio_speaker": { + * "volume": 70 + * }, + * "screen": { + * "brightness": 100, + * "theme": "light" + * }, + * "battery": { + * "level": 50, + * "charging": true + * }, + * "network": { + * "type": "wifi", + * "ssid": "Xiaozhi", + * "rssi": -60 + * } + * } + */ + auto& board = Board::GetInstance(); + auto root = cJSON_CreateObject(); + + // Audio speaker + auto audio_speaker = cJSON_CreateObject(); + auto audio_codec = board.GetAudioCodec(); + if (audio_codec) { + cJSON_AddNumberToObject(audio_speaker, "volume", audio_codec->output_volume()); + } + cJSON_AddItemToObject(root, "audio_speaker", audio_speaker); + + // Screen brightness + auto backlight = board.GetBacklight(); + auto screen = cJSON_CreateObject(); + if (backlight) { + cJSON_AddNumberToObject(screen, "brightness", backlight->brightness()); + } + auto display = board.GetDisplay(); + if (display && display->height() > 64) { // For LCD display only + cJSON_AddStringToObject(screen, "theme", display->GetTheme().c_str()); + } + cJSON_AddItemToObject(root, "screen", screen); + + // Battery + int battery_level = 0; + bool charging = false; + bool discharging = false; + if (board.GetBatteryLevel(battery_level, charging, discharging)) { + cJSON* battery = cJSON_CreateObject(); + cJSON_AddNumberToObject(battery, "level", battery_level); + cJSON_AddBoolToObject(battery, "charging", charging); + cJSON_AddItemToObject(root, "battery", battery); + } + + // Network + auto network = cJSON_CreateObject(); + auto& wifi_station = WifiStation::GetInstance(); + cJSON_AddStringToObject(network, "type", "wifi"); + cJSON_AddStringToObject(network, "ssid", wifi_station.GetSsid().c_str()); + int rssi = wifi_station.GetRssi(); + if (rssi >= -60) { + cJSON_AddStringToObject(network, "signal", "strong"); + } else if (rssi >= -70) { + cJSON_AddStringToObject(network, "signal", "medium"); + } else { + cJSON_AddStringToObject(network, "signal", "weak"); + } + cJSON_AddItemToObject(root, "network", network); + + auto json_str = cJSON_PrintUnformatted(root); + std::string json(json_str); + cJSON_free(json_str); + cJSON_Delete(root); + return json; +} diff --git a/main/boards/common/wifi_board.h b/main/boards/common/wifi_board.h index a701b2b7..6827b58e 100644 --- a/main/boards/common/wifi_board.h +++ b/main/boards/common/wifi_board.h @@ -21,6 +21,7 @@ public: virtual void SetPowerSaveMode(bool enabled) override; virtual void ResetWifiConfiguration(); virtual AudioCodec* GetAudioCodec() override { return nullptr; } + virtual std::string GetDeviceStatusJson() override; }; #endif // WIFI_BOARD_H diff --git a/main/display/lcd_display.cc b/main/display/lcd_display.cc index 3dd400ee..5d38adc3 100644 --- a/main/display/lcd_display.cc +++ b/main/display/lcd_display.cc @@ -1,6 +1,7 @@ #include "lcd_display.h" #include +#include #include #include #include diff --git a/main/idf_component.yml b/main/idf_component.yml index c967007b..47820c16 100644 --- a/main/idf_component.yml +++ b/main/idf_component.yml @@ -11,7 +11,7 @@ dependencies: 78/esp_lcd_nv3023: ~1.0.0 78/esp-wifi-connect: ~2.4.2 78/esp-opus-encoder: ~2.3.2 - 78/esp-ml307: ~2.0.1 + 78/esp-ml307: ~2.0.2 78/xiaozhi-fonts: ~1.3.2 espressif/led_strip: ^2.5.5 espressif/esp_codec_dev: ~1.3.2 diff --git a/main/mcp_server.cc b/main/mcp_server.cc new file mode 100644 index 00000000..feece557 --- /dev/null +++ b/main/mcp_server.cc @@ -0,0 +1,269 @@ +/* + * MCP Server Implementation + * Reference: https://modelcontextprotocol.io/specification/2024-11-05 + */ + +#include "mcp_server.h" +#include +#include +#include +#include + +#include "application.h" +#include "display.h" +#include "board.h" + +#define TAG "MCP" + +McpServer::McpServer() { + AddCommonTools(); +} + +void McpServer::AddCommonTools() { + AddTool("self.get_device_status", + "Provides the real-time information of the device, including the current status of the audio speaker, screen, battery, network, etc.\n" + "Use this tool for: \n" + "1. Answering questions about current condition (e.g. what is the current volume of the audio speaker?)\n" + "2. As the first step to control the device (e.g. turn up / down the volume of the audio speaker, etc.)", + PropertyList(), + [](const PropertyList& properties) -> ReturnValue { + return Board::GetInstance().GetDeviceStatusJson(); + }); + + AddTool("self.speaker.set_volume", + "Set the volume of the audio speaker. If the current volume is unknown, you must call `self.get_device_status` tool first and then call this tool.", + PropertyList({ + Property("volume", kPropertyTypeInteger, 0, 100) + }), + [](const PropertyList& properties) -> ReturnValue { + auto codec = Board::GetInstance().GetAudioCodec(); + codec->SetOutputVolume(properties["volume"].value()); + return true; + }); + + AddTool("self.screen.set_brightness", + "Set the brightness of the screen.", + PropertyList({ + Property("brightness", kPropertyTypeInteger, 0, 100) + }), + [](const PropertyList& properties) -> ReturnValue { + uint8_t brightness = static_cast(properties["brightness"].value()); + auto backlight = Board::GetInstance().GetBacklight(); + if (backlight) { + backlight->SetBrightness(brightness, true); + } + return true; + }); + + AddTool("self.screen.set_theme", + "Set the theme of the screen. The theme can be 'light' or 'dark'.", + PropertyList({ + Property("theme", kPropertyTypeString) + }), + [](const PropertyList& properties) -> ReturnValue { + auto display = Board::GetInstance().GetDisplay(); + if (display) { + display->SetTheme(properties["theme"].value().c_str()); + } + return true; + }); + +} + +void McpServer::AddTool(McpTool* tool) { + tools_.push_back(tool); +} + +void McpServer::AddTool(const std::string& name, const std::string& description, const PropertyList& properties, std::function callback) { + tools_.push_back(new McpTool(name, description, properties, callback)); +} + +void McpServer::ParseMessage(const std::string& message) { + cJSON* json = cJSON_Parse(message.c_str()); + if (json == nullptr) { + ESP_LOGE(TAG, "Failed to parse MCP message: %s", message.c_str()); + return; + } + ParseMessage(json); + cJSON_Delete(json); +} + +void McpServer::ParseMessage(const cJSON* json) { + // Check JSONRPC version + auto version = cJSON_GetObjectItem(json, "jsonrpc"); + if (version == nullptr || !cJSON_IsString(version) || strcmp(version->valuestring, "2.0") != 0) { + ESP_LOGE(TAG, "Invalid JSONRPC version: %s", version ? version->valuestring : "null"); + return; + } + + // Check method + auto method = cJSON_GetObjectItem(json, "method"); + if (method == nullptr || !cJSON_IsString(method)) { + ESP_LOGE(TAG, "Missing method"); + return; + } + + auto method_str = std::string(method->valuestring); + if (method_str.find("notifications") == 0) { + return; + } + + // Check params + auto params = cJSON_GetObjectItem(json, "params"); + if (params != nullptr && !cJSON_IsObject(params)) { + ESP_LOGE(TAG, "Invalid params for method: %s", method_str.c_str()); + return; + } + + auto id = cJSON_GetObjectItem(json, "id"); + if (id == nullptr || !cJSON_IsNumber(id)) { + ESP_LOGE(TAG, "Invalid id for method: %s", method_str.c_str()); + return; + } + auto id_int = id->valueint; + + if (method_str == "initialize") { + auto app_desc = esp_app_get_description(); + ReplyResult(id_int, "{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{\"tools\":{}}," + "\"serverInfo\":{\"name\":\"" BOARD_NAME "\",\"version\":\"" + std::string(app_desc->version) + "\"}}"); + } else if (method_str == "tools/list") { + std::string cursor_str = ""; + if (params != nullptr) { + auto cursor = cJSON_GetObjectItem(params, "cursor"); + if (cJSON_IsString(cursor)) { + cursor_str = std::string(cursor->valuestring); + } + } + GetToolsList(id_int, cursor_str); + } else if (method_str == "tools/call") { + if (!cJSON_IsObject(params)) { + ESP_LOGE(TAG, "tools/call: Missing params"); + ReplyError(id_int, "Missing params"); + return; + } + auto tool_name = cJSON_GetObjectItem(params, "name"); + if (!cJSON_IsString(tool_name)) { + ESP_LOGE(TAG, "tools/call: Missing name"); + ReplyError(id_int, "Missing name"); + return; + } + auto tool_arguments = cJSON_GetObjectItem(params, "arguments"); + if (tool_arguments != nullptr && !cJSON_IsObject(tool_arguments)) { + ESP_LOGE(TAG, "tools/call: Invalid arguments"); + ReplyError(id_int, "Invalid arguments"); + return; + } + DoToolCall(id_int, std::string(tool_name->valuestring), tool_arguments); + } else { + ESP_LOGE(TAG, "Method not implemented: %s", method_str.c_str()); + ReplyError(id_int, "Method not implemented: " + method_str); + } +} + +void McpServer::ReplyResult(int id, const std::string& result) { + std::string payload = "{\"jsonrpc\":\"2.0\",\"id\":" + std::to_string(id) + ",\"result\":" + result + "}"; + Application::GetInstance().SendMcpMessage(payload); +} + +void McpServer::ReplyError(int id, const std::string& message) { + std::string payload = "{\"jsonrpc\":\"2.0\",\"id\":"; + payload += std::to_string(id) + ",\"error\":{\"message\":\"" + message + "\"}}"; + Application::GetInstance().SendMcpMessage(payload); +} + +void McpServer::GetToolsList(int id, const std::string& cursor) { + const int max_payload_size = 1400; // ML307 MQTT publish size limit + std::string json = "{\"tools\":["; + + bool found_cursor = cursor.empty(); + auto it = tools_.begin(); + std::string next_cursor = ""; + + while (it != tools_.end()) { + // 如果我们还没有找到起始位置,继续搜索 + if (!found_cursor) { + if ((*it)->name() == cursor) { + found_cursor = true; + } else { + ++it; + continue; + } + } + + // 添加tool前检查大小 + std::string tool_json = (*it)->to_json() + ","; + if (json.length() + tool_json.length() + 30 > max_payload_size) { + // 如果添加这个tool会超出大小限制,设置next_cursor并退出循环 + next_cursor = (*it)->name(); + break; + } + + json += tool_json; + ++it; + } + + if (json.back() == ',') { + json.pop_back(); + } + + if (json.back() == '[' && !tools_.empty()) { + // 如果没有添加任何tool,返回错误 + ESP_LOGE(TAG, "tools/list: Failed to add tool %s because of payload size limit", next_cursor.c_str()); + ReplyError(id, "Failed to add tool " + next_cursor + " because of payload size limit"); + return; + } + + if (next_cursor.empty()) { + json += "]}"; + } else { + json += "],\"nextCursor\":\"" + next_cursor + "\"}"; + } + + ReplyResult(id, json); +} + +void McpServer::DoToolCall(int id, const std::string& tool_name, const cJSON* tool_arguments) { + auto tool_iter = std::find_if(tools_.begin(), tools_.end(), + [&tool_name](const McpTool* tool) { + return tool->name() == tool_name; + }); + + if (tool_iter == tools_.end()) { + ESP_LOGE(TAG, "tools/call: Unknown tool: %s", tool_name.c_str()); + ReplyError(id, "Unknown tool: " + tool_name); + return; + } + + PropertyList arguments = (*tool_iter)->properties(); + for (auto& argument : arguments) { + bool found = false; + if (cJSON_IsObject(tool_arguments)) { + auto value = cJSON_GetObjectItem(tool_arguments, argument.name().c_str()); + if (argument.type() == kPropertyTypeBoolean && cJSON_IsBool(value)) { + argument.set_value(value->valueint == 1); + found = true; + } else if (argument.type() == kPropertyTypeInteger && cJSON_IsNumber(value)) { + argument.set_value(value->valueint); + found = true; + } else if (argument.type() == kPropertyTypeString && cJSON_IsString(value)) { + argument.set_value(value->valuestring); + found = true; + } + } + + if (!argument.has_default_value() && !found) { + ESP_LOGE(TAG, "tools/call: Missing valid argument: %s", argument.name().c_str()); + ReplyError(id, "Missing valid argument: " + argument.name()); + return; + } + } + + Application::GetInstance().Schedule([this, id, tool_iter, arguments = std::move(arguments)]() { + try { + ReplyResult(id, (*tool_iter)->Call(arguments)); + } catch (const std::runtime_error& e) { + ESP_LOGE(TAG, "tools/call: %s", e.what()); + ReplyError(id, e.what()); + } + }); +} \ No newline at end of file diff --git a/main/mcp_server.h b/main/mcp_server.h new file mode 100644 index 00000000..b5c45f80 --- /dev/null +++ b/main/mcp_server.h @@ -0,0 +1,278 @@ +#ifndef MCP_SERVER_H +#define MCP_SERVER_H + +#include +#include +#include +#include +#include +#include +#include + +#include + +// 添加类型别名 +using ReturnValue = std::variant; + +enum PropertyType { + kPropertyTypeBoolean, + kPropertyTypeInteger, + kPropertyTypeString +}; + +class Property { +private: + std::string name_; + PropertyType type_; + std::variant value_; + bool has_default_value_; + std::optional min_value_; // 新增:整数最小值 + std::optional max_value_; // 新增:整数最大值 + +public: + // Required field constructor + Property(const std::string& name, PropertyType type) + : name_(name), type_(type), has_default_value_(false) {} + + // Optional field constructor with default value + template + Property(const std::string& name, PropertyType type, const T& default_value) + : name_(name), type_(type), has_default_value_(true) { + value_ = default_value; + } + + Property(const std::string& name, PropertyType type, int min_value, int max_value) + : name_(name), type_(type), has_default_value_(false), min_value_(min_value), max_value_(max_value) { + if (type != kPropertyTypeInteger) { + throw std::invalid_argument("Range limits only apply to integer properties"); + } + } + + Property(const std::string& name, PropertyType type, int default_value, int min_value, int max_value) + : name_(name), type_(type), has_default_value_(true), min_value_(min_value), max_value_(max_value) { + if (type != kPropertyTypeInteger) { + throw std::invalid_argument("Range limits only apply to integer properties"); + } + if (default_value < min_value || default_value > max_value) { + throw std::invalid_argument("Default value must be within the specified range"); + } + value_ = default_value; + } + + inline const std::string& name() const { return name_; } + inline PropertyType type() const { return type_; } + inline bool has_default_value() const { return has_default_value_; } + inline bool has_range() const { return min_value_.has_value() && max_value_.has_value(); } + inline int min_value() const { return min_value_.value_or(0); } + inline int max_value() const { return max_value_.value_or(0); } + + template + inline T value() const { + return std::get(value_); + } + + template + inline void set_value(const T& value) { + // 添加对设置的整数值进行范围检查 + if constexpr (std::is_same_v) { + if (min_value_.has_value() && value < min_value_.value()) { + throw std::invalid_argument("Value is below minimum allowed: " + std::to_string(min_value_.value())); + } + if (max_value_.has_value() && value > max_value_.value()) { + throw std::invalid_argument("Value exceeds maximum allowed: " + std::to_string(max_value_.value())); + } + } + value_ = value; + } + + std::string to_json() const { + cJSON *json = cJSON_CreateObject(); + + if (type_ == kPropertyTypeBoolean) { + cJSON_AddStringToObject(json, "type", "boolean"); + if (has_default_value_) { + cJSON_AddBoolToObject(json, "default", value()); + } + } else if (type_ == kPropertyTypeInteger) { + cJSON_AddStringToObject(json, "type", "integer"); + if (has_default_value_) { + cJSON_AddNumberToObject(json, "default", value()); + } + if (min_value_.has_value()) { + cJSON_AddNumberToObject(json, "minimum", min_value_.value()); + } + if (max_value_.has_value()) { + cJSON_AddNumberToObject(json, "maximum", max_value_.value()); + } + } else if (type_ == kPropertyTypeString) { + cJSON_AddStringToObject(json, "type", "string"); + if (has_default_value_) { + cJSON_AddStringToObject(json, "default", value().c_str()); + } + } + + char *json_str = cJSON_PrintUnformatted(json); + std::string result(json_str); + cJSON_free(json_str); + cJSON_Delete(json); + + return result; + } +}; + +class PropertyList { +private: + std::vector properties_; + +public: + PropertyList() = default; + PropertyList(const std::vector& properties) : properties_(properties) {} + void AddProperty(const Property& property) { + properties_.push_back(property); + } + + const Property& operator[](const std::string& name) const { + for (const auto& property : properties_) { + if (property.name() == name) { + return property; + } + } + throw std::runtime_error("Property not found: " + name); + } + + auto begin() { return properties_.begin(); } + auto end() { return properties_.end(); } + + std::vector GetRequired() const { + std::vector required; + for (auto& property : properties_) { + if (!property.has_default_value()) { + required.push_back(property.name()); + } + } + return required; + } + + std::string to_json() const { + cJSON *json = cJSON_CreateObject(); + + for (const auto& property : properties_) { + cJSON *prop_json = cJSON_Parse(property.to_json().c_str()); + cJSON_AddItemToObject(json, property.name().c_str(), prop_json); + } + + char *json_str = cJSON_PrintUnformatted(json); + std::string result(json_str); + cJSON_free(json_str); + cJSON_Delete(json); + + return result; + } +}; + +class McpTool { +private: + std::string name_; + std::string description_; + PropertyList properties_; + std::function callback_; + +public: + McpTool(const std::string& name, + const std::string& description, + const PropertyList& properties, + std::function callback) + : name_(name), + description_(description), + properties_(properties), + callback_(callback) {} + + inline const std::string& name() const { return name_; } + inline const std::string& description() const { return description_; } + inline const PropertyList& properties() const { return properties_; } + + std::string to_json() const { + std::vector required = properties_.GetRequired(); + + cJSON *json = cJSON_CreateObject(); + cJSON_AddStringToObject(json, "name", name_.c_str()); + cJSON_AddStringToObject(json, "description", description_.c_str()); + + cJSON *input_schema = cJSON_CreateObject(); + cJSON_AddStringToObject(input_schema, "type", "object"); + + cJSON *properties = cJSON_Parse(properties_.to_json().c_str()); + cJSON_AddItemToObject(input_schema, "properties", properties); + + if (!required.empty()) { + cJSON *required_array = cJSON_CreateArray(); + for (const auto& property : required) { + cJSON_AddItemToArray(required_array, cJSON_CreateString(property.c_str())); + } + cJSON_AddItemToObject(input_schema, "required", required_array); + } + + cJSON_AddItemToObject(json, "inputSchema", input_schema); + + char *json_str = cJSON_PrintUnformatted(json); + std::string result(json_str); + cJSON_free(json_str); + cJSON_Delete(json); + + return result; + } + + std::string Call(const PropertyList& properties) { + ReturnValue return_value = callback_(properties); + // 返回结果 + cJSON* result = cJSON_CreateObject(); + cJSON* content = cJSON_CreateArray(); + cJSON* text = cJSON_CreateObject(); + cJSON_AddStringToObject(text, "type", "text"); + if (std::holds_alternative(return_value)) { + cJSON_AddStringToObject(text, "text", std::get(return_value).c_str()); + } else if (std::holds_alternative(return_value)) { + cJSON_AddStringToObject(text, "text", std::get(return_value) ? "true" : "false"); + } else if (std::holds_alternative(return_value)) { + cJSON_AddStringToObject(text, "text", std::to_string(std::get(return_value)).c_str()); + } + cJSON_AddItemToArray(content, text); + cJSON_AddItemToObject(result, "content", content); + cJSON_AddBoolToObject(result, "isError", false); + + auto json_str = cJSON_PrintUnformatted(result); + std::string result_str(json_str); + cJSON_free(json_str); + cJSON_Delete(result); + return result_str; + } +}; + +class McpServer { +public: + static McpServer& GetInstance() { + static McpServer instance; + return instance; + } + + void AddTool(McpTool* tool); + void AddTool(const std::string& name, const std::string& description, const PropertyList& properties, std::function callback); + void ParseMessage(const cJSON* json); + void ParseMessage(const std::string& message); + +private: + McpServer(); + ~McpServer() = default; + + void AddCommonTools(); + + void ReplyResult(int id, const std::string& result); + void ReplyError(int id, const std::string& message); + + void GetToolsList(int id, const std::string& cursor); + void DoToolCall(int id, const std::string& tool_name, const cJSON* tool_arguments); + + std::vector tools_; +}; + +#endif // MCP_SERVER_H diff --git a/main/ota.cc b/main/ota.cc index f5005d7f..6a60d216 100644 --- a/main/ota.cc +++ b/main/ota.cc @@ -415,7 +415,9 @@ std::string Ota::GetActivationPayload() { cJSON_AddStringToObject(payload, "serial_number", serial_number_.c_str()); cJSON_AddStringToObject(payload, "challenge", activation_challenge_.c_str()); cJSON_AddStringToObject(payload, "hmac", hmac_hex.c_str()); - std::string json = cJSON_Print(payload); + auto json_str = cJSON_PrintUnformatted(payload); + std::string json(json_str); + cJSON_free(json_str); cJSON_Delete(payload); ESP_LOGI(TAG, "Activation payload: %s", json.c_str()); diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index 67a40045..89284dc7 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -182,17 +182,7 @@ bool MqttProtocol::OpenAudioChannel() { session_id_ = ""; xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT); - // 发送 hello 消息申请 UDP 通道 - std::string message = "{"; - message += "\"type\":\"hello\","; - message += "\"version\": 3,"; - message += "\"transport\":\"udp\","; -#if CONFIG_USE_SERVER_AEC - message += "\"features\":{\"aec\":true},"; -#endif - message += "\"audio_params\":{"; - message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS); - message += "}}"; + auto message = GetHelloMessage(); if (!SendText(message)) { return false; } @@ -262,6 +252,33 @@ bool MqttProtocol::OpenAudioChannel() { return true; } +std::string MqttProtocol::GetHelloMessage() { + // 发送 hello 消息申请 UDP 通道 + cJSON* root = cJSON_CreateObject(); + cJSON_AddStringToObject(root, "type", "hello"); + cJSON_AddNumberToObject(root, "version", 3); + cJSON_AddStringToObject(root, "transport", "udp"); + cJSON* features = cJSON_CreateObject(); +#if CONFIG_USE_SERVER_AEC + cJSON_AddBoolToObject(features, "aec", true); +#endif +#if CONFIG_IOT_PROTOCOL_MCP + cJSON_AddBoolToObject(features, "mcp", true); +#endif + cJSON_AddItemToObject(root, "features", features); + cJSON* audio_params = cJSON_CreateObject(); + cJSON_AddStringToObject(audio_params, "format", "opus"); + cJSON_AddNumberToObject(audio_params, "sample_rate", 16000); + cJSON_AddNumberToObject(audio_params, "channels", 1); + cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS); + cJSON_AddItemToObject(root, "audio_params", audio_params); + auto json_str = cJSON_PrintUnformatted(root); + std::string message(json_str); + cJSON_free(json_str); + cJSON_Delete(root); + return message; +} + void MqttProtocol::ParseServerHello(const cJSON* root) { auto transport = cJSON_GetObjectItem(root, "transport"); if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) { diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index df3da3aa..d7d712de 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -51,6 +51,7 @@ private: std::string DecodeHexString(const std::string& hex_string); bool SendText(const std::string& text) override; + std::string GetHelloMessage(); }; diff --git a/main/protocols/protocol.cc b/main/protocols/protocol.cc index a9515494..cc35b698 100644 --- a/main/protocols/protocol.cc +++ b/main/protocols/protocol.cc @@ -115,6 +115,11 @@ void Protocol::SendIotStates(const std::string& states) { SendText(message); } +void Protocol::SendMcpMessage(const std::string& payload) { + std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"mcp\",\"payload\":" + payload + "}"; + SendText(message); +} + bool Protocol::IsTimeout() const { const int kTimeoutSeconds = 120; auto now = std::chrono::steady_clock::now(); diff --git a/main/protocols/protocol.h b/main/protocols/protocol.h index 7f9f541b..c08802e6 100644 --- a/main/protocols/protocol.h +++ b/main/protocols/protocol.h @@ -71,6 +71,7 @@ public: virtual void SendAbortSpeaking(AbortReason reason); virtual void SendIotDescriptors(const std::string& descriptors); virtual void SendIotStates(const std::string& states); + virtual void SendMcpMessage(const std::string& message); protected: std::function on_incoming_json_; diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 83527296..6cf76504 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -180,22 +180,12 @@ bool WebsocketProtocol::OpenAudioChannel() { ESP_LOGI(TAG, "Connecting to websocket server: %s with version: %d", url.c_str(), version_); if (!websocket_->Connect(url.c_str())) { ESP_LOGE(TAG, "Failed to connect to websocket server"); - SetError(Lang::Strings::SERVER_NOT_FOUND); + SetError(Lang::Strings::SERVER_NOT_CONNECTED); return false; } // Send hello message to describe the client - // keys: message type, version, audio_params (format, sample_rate, channels) - std::string message = "{"; - message += "\"type\":\"hello\","; - message += "\"version\": " + std::to_string(version_) + ","; -#if CONFIG_USE_SERVER_AEC - message += "\"features\":{\"aec\":true},"; -#endif - message += "\"transport\":\"websocket\","; - message += "\"audio_params\":{"; - message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS); - message += "}}"; + auto message = GetHelloMessage(); if (!SendText(message)) { return false; } @@ -215,6 +205,33 @@ bool WebsocketProtocol::OpenAudioChannel() { return true; } +std::string WebsocketProtocol::GetHelloMessage() { + // keys: message type, version, audio_params (format, sample_rate, channels) + cJSON* root = cJSON_CreateObject(); + cJSON_AddStringToObject(root, "type", "hello"); + cJSON_AddNumberToObject(root, "version", version_); + cJSON* features = cJSON_CreateObject(); +#if CONFIG_USE_SERVER_AEC + cJSON_AddBoolToObject(features, "aec", true); +#endif +#if CONFIG_IOT_PROTOCOL_MCP + cJSON_AddBoolToObject(features, "mcp", true); +#endif + cJSON_AddItemToObject(root, "features", features); + cJSON_AddStringToObject(root, "transport", "websocket"); + cJSON* audio_params = cJSON_CreateObject(); + cJSON_AddStringToObject(audio_params, "format", "opus"); + cJSON_AddNumberToObject(audio_params, "sample_rate", 16000); + cJSON_AddNumberToObject(audio_params, "channels", 1); + cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS); + cJSON_AddItemToObject(root, "audio_params", audio_params); + auto json_str = cJSON_PrintUnformatted(root); + std::string message(json_str); + cJSON_free(json_str); + cJSON_Delete(root); + return message; +} + void WebsocketProtocol::ParseServerHello(const cJSON* root) { auto transport = cJSON_GetObjectItem(root, "transport"); if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) { diff --git a/main/protocols/websocket_protocol.h b/main/protocols/websocket_protocol.h index 73ae693e..9d16fe18 100644 --- a/main/protocols/websocket_protocol.h +++ b/main/protocols/websocket_protocol.h @@ -28,6 +28,7 @@ private: void ParseServerHello(const cJSON* root); bool SendText(const std::string& text) override; + std::string GetHelloMessage(); }; #endif