diff --git a/main/application.cc b/main/application.cc index b11deb81..291170c0 100644 --- a/main/application.cc +++ b/main/application.cc @@ -302,12 +302,12 @@ void Application::StartListening() { } void Application::StopListening() { - if (device_state_ == kDeviceStateListening) { - Schedule([this]() { + Schedule([this]() { + if (device_state_ == kDeviceStateListening) { protocol_->SendStopListening(); SetDeviceState(kDeviceStateIdle); - }); - } + } + }); } void Application::Start() { diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index bd6ac00a..3a8d07e3 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -28,10 +28,10 @@ MqttProtocol::~MqttProtocol() { } void MqttProtocol::Start() { - StartMqttClient(); + StartMqttClient(false); } -bool MqttProtocol::StartMqttClient() { +bool MqttProtocol::StartMqttClient(bool report_error) { if (mqtt_ != nullptr) { ESP_LOGW(TAG, "Mqtt client already started"); delete mqtt_; @@ -45,9 +45,9 @@ bool MqttProtocol::StartMqttClient() { publish_topic_ = settings.GetString("publish_topic"); if (endpoint_.empty()) { - ESP_LOGE(TAG, "MQTT endpoint is not specified"); - if (on_network_error_ != nullptr) { - on_network_error_(Lang::Strings::SERVER_NOT_FOUND); + ESP_LOGW(TAG, "MQTT endpoint is not specified"); + if (report_error) { + SetError(Lang::Strings::SERVER_NOT_FOUND); } return false; } @@ -76,6 +76,7 @@ bool MqttProtocol::StartMqttClient() { ParseServerHello(root); } else if (strcmp(type->valuestring, "goodbye") == 0) { auto session_id = cJSON_GetObjectItem(root, "session_id"); + ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null"); if (session_id == nullptr || session_id_ == session_id->valuestring) { Application::GetInstance().Schedule([this]() { CloseAudioChannel(); @@ -85,14 +86,13 @@ bool MqttProtocol::StartMqttClient() { on_incoming_json_(root); } cJSON_Delete(root); + last_incoming_time_ = std::chrono::steady_clock::now(); }); ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint_.c_str()); if (!mqtt_->Connect(endpoint_, 8883, client_id_, username_, password_)) { ESP_LOGE(TAG, "Failed to connect to endpoint"); - if (on_network_error_ != nullptr) { - on_network_error_(Lang::Strings::SERVER_NOT_CONNECTED); - } + SetError(Lang::Strings::SERVER_NOT_CONNECTED); return false; } @@ -105,10 +105,8 @@ void MqttProtocol::SendText(const std::string& text) { return; } if (!mqtt_->Publish(publish_topic_, text)) { - ESP_LOGE(TAG, "Failed to publish message"); - if (on_network_error_ != nullptr) { - on_network_error_(Lang::Strings::SERVER_ERROR); - } + ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str()); + SetError(Lang::Strings::SERVER_ERROR); } } @@ -159,11 +157,12 @@ void MqttProtocol::CloseAudioChannel() { bool MqttProtocol::OpenAudioChannel() { if (mqtt_ == nullptr || !mqtt_->IsConnected()) { ESP_LOGI(TAG, "MQTT is not connected, try to connect now"); - if (!StartMqttClient()) { + if (!StartMqttClient(true)) { return false; } } + error_occurred_ = false; session_id_ = ""; xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT); @@ -181,9 +180,7 @@ bool MqttProtocol::OpenAudioChannel() { EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000)); if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) { ESP_LOGE(TAG, "Failed to receive server hello"); - if (on_network_error_ != nullptr) { - on_network_error_(Lang::Strings::SERVER_TIMEOUT); - } + SetError(Lang::Strings::SERVER_TIMEOUT); return false; } @@ -226,6 +223,7 @@ bool MqttProtocol::OpenAudioChannel() { on_incoming_audio_(std::move(decrypted)); } remote_sequence_ = sequence; + last_incoming_time_ = std::chrono::steady_clock::now(); }); udp_->Connect(udp_server_, udp_port_); @@ -298,5 +296,5 @@ std::string MqttProtocol::DecodeHexString(const std::string& hex_string) { } bool MqttProtocol::IsAudioChannelOpened() const { - return udp_ != nullptr; + return udp_ != nullptr && !error_occurred_ && !IsTimeout(); } diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index 5f3938e9..d7253fe6 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -50,7 +50,7 @@ private: uint32_t local_sequence_; uint32_t remote_sequence_; - bool StartMqttClient(); + bool StartMqttClient(bool report_error=false); void ParseServerHello(const cJSON* root); std::string DecodeHexString(const std::string& hex_string); diff --git a/main/protocols/protocol.cc b/main/protocols/protocol.cc index dae93237..999a29ea 100644 --- a/main/protocols/protocol.cc +++ b/main/protocols/protocol.cc @@ -24,6 +24,13 @@ void Protocol::OnNetworkError(std::function ca on_network_error_ = callback; } +void Protocol::SetError(const std::string& message) { + error_occurred_ = true; + if (on_network_error_ != nullptr) { + on_network_error_(message); + } +} + void Protocol::SendAbortSpeaking(AbortReason reason) { std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\""; if (reason == kAbortReasonWakeWordDetected) { @@ -68,3 +75,14 @@ void Protocol::SendIotStates(const std::string& states) { SendText(message); } +bool Protocol::IsTimeout() const { + const int kTimeoutSeconds = 120; + auto now = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast(now - last_incoming_time_); + bool timeout = duration.count() > kTimeoutSeconds; + if (timeout) { + ESP_LOGE(TAG, "Channel timeout %lld seconds", duration.count()); + } + return timeout; +} + diff --git a/main/protocols/protocol.h b/main/protocols/protocol.h index 5328368d..73f7b989 100644 --- a/main/protocols/protocol.h +++ b/main/protocols/protocol.h @@ -4,6 +4,7 @@ #include #include #include +#include struct BinaryProtocol3 { uint8_t type; @@ -60,9 +61,13 @@ protected: std::function on_network_error_; int server_sample_rate_ = 16000; + bool error_occurred_ = false; std::string session_id_; + std::chrono::time_point last_incoming_time_; virtual void SendText(const std::string& text) = 0; + virtual void SetError(const std::string& message); + virtual bool IsTimeout() const; }; #endif // PROTOCOL_H diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 43d7e0b5..9a6b997b 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -38,11 +38,14 @@ void WebsocketProtocol::SendText(const std::string& text) { return; } - websocket_->Send(text); + if (!websocket_->Send(text)) { + ESP_LOGE(TAG, "Failed to send text: %s", text.c_str()); + SetError(Lang::Strings::SERVER_ERROR); + } } bool WebsocketProtocol::IsAudioChannelOpened() const { - return websocket_ != nullptr && websocket_->IsConnected(); + return websocket_ != nullptr && websocket_->IsConnected() && !error_occurred_ && !IsTimeout(); } void WebsocketProtocol::CloseAudioChannel() { @@ -57,6 +60,7 @@ bool WebsocketProtocol::OpenAudioChannel() { delete websocket_; } + error_occurred_ = false; std::string url = CONFIG_WEBSOCKET_URL; std::string token = "Bearer " + std::string(CONFIG_WEBSOCKET_ACCESS_TOKEN); websocket_ = Board::GetInstance().CreateWebSocket(); @@ -87,6 +91,7 @@ bool WebsocketProtocol::OpenAudioChannel() { } cJSON_Delete(root); } + last_incoming_time_ = std::chrono::steady_clock::now(); }); websocket_->OnDisconnected([this]() { @@ -98,9 +103,7 @@ bool WebsocketProtocol::OpenAudioChannel() { if (!websocket_->Connect(url.c_str())) { ESP_LOGE(TAG, "Failed to connect to websocket server"); - if (on_network_error_ != nullptr) { - on_network_error_(Lang::Strings::SERVER_NOT_FOUND); - } + SetError(Lang::Strings::SERVER_NOT_FOUND); return false; } @@ -119,9 +122,7 @@ bool WebsocketProtocol::OpenAudioChannel() { EventBits_t bits = xEventGroupWaitBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000)); if (!(bits & WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT)) { ESP_LOGE(TAG, "Failed to receive server hello"); - if (on_network_error_ != nullptr) { - on_network_error_(Lang::Strings::SERVER_TIMEOUT); - } + SetError(Lang::Strings::SERVER_TIMEOUT); return false; }