This commit is contained in:
Terrence
2024-09-25 03:44:28 +08:00
parent 53b08843d4
commit 0396b4a91c
5 changed files with 75 additions and 51 deletions

View File

@@ -49,8 +49,8 @@ Application::~Application() {
for (auto& pcm : wake_word_pcm_) {
free(pcm.iov_base);
}
for (auto& opus : wake_word_opus_) {
free(opus.iov_base);
for (auto& packet : wake_word_opus_) {
heap_caps_free(packet);
}
if (opus_decoder_ != nullptr) {
@@ -133,46 +133,50 @@ void Application::Start() {
}
void Application::SetChatState(ChatState state) {
auto& builtin_led = BuiltinLed::GetInstance();
const char* state_str[] = {
"idle",
"connecting",
"listening",
"speaking",
"wake_word_detected",
"testing",
"upgrading",
"unknown"
};
chat_state_ = state;
ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]);
auto& builtin_led = BuiltinLed::GetInstance();
switch (chat_state_) {
case kChatStateIdle:
ESP_LOGI(TAG, "Chat state: idle");
builtin_led.TurnOff();
break;
case kChatStateConnecting:
ESP_LOGI(TAG, "Chat state: connecting");
builtin_led.SetBlue();
builtin_led.TurnOn();
break;
case kChatStateListening:
ESP_LOGI(TAG, "Chat state: listening");
builtin_led.SetRed();
builtin_led.TurnOn();
break;
case kChatStateSpeaking:
ESP_LOGI(TAG, "Chat state: speaking");
builtin_led.SetGreen();
builtin_led.TurnOn();
break;
case kChatStateWakeWordDetected:
ESP_LOGI(TAG, "Chat state: wake word detected");
builtin_led.SetBlue();
builtin_led.TurnOn();
break;
case kChatStateTesting:
ESP_LOGI(TAG, "Chat state: testing");
builtin_led.SetRed();
builtin_led.TurnOn();
break;
case kChatStateUpgrading:
ESP_LOGI(TAG, "Chat state: upgrading");
builtin_led.SetGreen();
builtin_led.StartContinuousBlink(100);
break;
}
const char* state_str[] = { "idle", "connecting", "listening", "speaking", "wake_word_detected", "testing", "unknown" };
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (ws_client_ && ws_client_->IsConnected()) {
cJSON* root = cJSON_CreateObject();
@@ -291,7 +295,7 @@ void Application::AudioFeedTask() {
}
void Application::StoreWakeWordData(uint8_t* data, size_t size) {
// store audio data to detect_packets_
// store audio data to wake_word_pcm_
auto iov = (iovec){
.iov_base = heap_caps_malloc(size, MALLOC_CAP_SPIRAM),
.iov_len = size
@@ -320,12 +324,8 @@ void Application::EncodeWakeWordData() {
for (auto& pcm: app->wake_word_pcm_) {
encoder->Encode(pcm, [app](const iovec opus) {
iovec iov = {
.iov_base = heap_caps_malloc(opus.iov_len, MALLOC_CAP_SPIRAM),
.iov_len = opus.iov_len
};
memcpy(iov.iov_base, opus.iov_base, opus.iov_len);
app->wake_word_opus_.push_back(iov);
auto protocol = app->AllocateBinaryProtocol(opus.iov_base, opus.iov_len);
app->wake_word_opus_.push_back(protocol);
});
heap_caps_free(pcm.iov_base);
}
@@ -333,20 +333,33 @@ void Application::EncodeWakeWordData() {
auto end_time = esp_timer_get_time();
ESP_LOGI(TAG, "Encode wake word data opus packets: %d in %lld ms", app->wake_word_opus_.size(), (end_time - start_time) / 1000);
xEventGroupSetBits(app->event_group_, DETECT_PACKETS_ENCODED);
xEventGroupSetBits(app->event_group_, WAKE_WORD_ENCODED);
delete encoder;
vTaskDelete(NULL);
}, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_);
}
void Application::SendWakeWordData() {
for (auto& opus: wake_word_opus_) {
ws_client_->Send(opus.iov_base, opus.iov_len, true);
heap_caps_free(opus.iov_base);
for (auto& protocol: wake_word_opus_) {
ws_client_->Send(protocol, sizeof(BinaryProtocol) + ntohl(protocol->payload_size), true);
heap_caps_free(protocol);
}
wake_word_opus_.clear();
}
BinaryProtocol* Application::AllocateBinaryProtocol(void* payload, size_t payload_size) {
auto last_timestamp = audio_device_.playing() ? audio_device_.last_timestamp() : 0;
auto protocol = (BinaryProtocol*)heap_caps_malloc(sizeof(BinaryProtocol) + payload_size, MALLOC_CAP_SPIRAM);
protocol->version = htons(PROTOCOL_VERSION);
protocol->type = htons(0);
protocol->reserved = 0;
protocol->timestamp = htonl(last_timestamp);
protocol->payload_size = htonl(payload_size);
assert(sizeof(BinaryProtocol) == 16);
memcpy(protocol->payload, payload, payload_size);
return protocol;
}
void Application::CheckTestButton() {
if (gpio_get_level(GPIO_NUM_1) == 0) {
if (chat_state_ == kChatStateIdle) {
@@ -437,7 +450,7 @@ void Application::AudioDetectionTask() {
StartWebSocketClient();
// Here the websocket is done, and we also wait for the wake word data to be encoded
xEventGroupWaitBits(event_group_, DETECT_PACKETS_ENCODED, pdTRUE, pdTRUE, portMAX_DELAY);
xEventGroupWaitBits(event_group_, WAKE_WORD_ENCODED, pdTRUE, pdTRUE, portMAX_DELAY);
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (ws_client_ && ws_client_->IsConnected()) {
@@ -448,8 +461,6 @@ void Application::AudioDetectionTask() {
opus_encoder_.ResetState();
// If connected, the hello message is already sent, so we can start communication
xEventGroupSetBits(event_group_, COMMUNICATION_RUNNING);
ESP_LOGI(TAG, "Start communication after wake word detected");
} else {
SetChatState(kChatStateIdle);
xEventGroupSetBits(event_group_, DETECTION_RUNNING);
@@ -521,10 +532,12 @@ void Application::AudioEncodeTask() {
// Encode audio data
opus_encoder_.Encode(pcm, [this](const iovec opus) {
auto protocol = AllocateBinaryProtocol(opus.iov_base, opus.iov_len);
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (ws_client_ && ws_client_->IsConnected()) {
ws_client_->Send(opus.iov_base, opus.iov_len, true);
ws_client_->Send(protocol, sizeof(BinaryProtocol) + opus.iov_len, true);
}
heap_caps_free(protocol);
});
free(pcm.iov_base);
@@ -548,9 +561,9 @@ void Application::AudioDecodeTask() {
}
if (opus_decode_sample_rate_ != CONFIG_AUDIO_OUTPUT_SAMPLE_RATE) {
int target_size = test_resampler_.GetOutputSamples(frame_size);
int target_size = opus_resampler_.GetOutputSamples(frame_size);
std::vector<int16_t> resampled(target_size);
test_resampler_.Process(packet->pcm.data(), frame_size, resampled.data());
opus_resampler_.Process(packet->pcm.data(), frame_size, resampled.data());
packet->pcm = std::move(resampled);
}
}
@@ -581,6 +594,7 @@ void Application::StartWebSocketClient() {
ws_client_ = new WebSocketClient();
ws_client_->SetHeader("Authorization", token.c_str());
ws_client_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
ws_client_->SetHeader("Protocol-Version", std::to_string(PROTOCOL_VERSION).c_str());
ws_client_->OnConnected([this]() {
ESP_LOGI(TAG, "Websocket connected");
@@ -588,7 +602,7 @@ void Application::StartWebSocketClient() {
// Send hello message to describe the client
// keys: message type, version, wakeup_model, audio_params (format, sample_rate, channels)
std::string message = "{";
message += "\"type\":\"hello\", \"version\":\"1.0\",";
message += "\"type\":\"hello\",";
message += "\"wakeup_model\":\"" + std::string(wakenet_model_) + "\",";
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":" + std::to_string(CONFIG_AUDIO_INPUT_SAMPLE_RATE) + ", \"channels\":1";
@@ -597,21 +611,23 @@ void Application::StartWebSocketClient() {
});
ws_client_->OnData([this](const char* data, size_t len, bool binary) {
auto packet = new AudioPacket();
if (binary) {
auto header = (AudioDataHeader*)data;
packet->type = kAudioPacketTypeData;
packet->timestamp = ntohl(header->timestamp);
auto protocol = (BinaryProtocol*)data;
auto payload_size = ntohl(header->payload_size);
auto packet = new AudioPacket();
packet->type = kAudioPacketTypeData;
packet->timestamp = ntohl(protocol->timestamp);
auto payload_size = ntohl(protocol->payload_size);
packet->opus.resize(payload_size);
memcpy(packet->opus.data(), data + sizeof(AudioDataHeader), payload_size);
memcpy(packet->opus.data(), protocol->payload, payload_size);
xQueueSend(audio_decode_queue_, &packet, portMAX_DELAY);
} else {
// Parse JSON data
auto root = cJSON_Parse(data);
auto type = cJSON_GetObjectItem(root, "type");
if (type != NULL) {
if (strcmp(type->valuestring, "tts") == 0) {
auto packet = new AudioPacket();
auto state = cJSON_GetObjectItem(root, "state");
if (strcmp(state->valuestring, "start") == 0) {
packet->type = kAudioPacketTypeStart;
@@ -627,21 +643,22 @@ void Application::StartWebSocketClient() {
packet->type = kAudioPacketTypeSentenceStart;
packet->text = cJSON_GetObjectItem(root, "text")->valuestring;
}
xQueueSend(audio_decode_queue_, &packet, portMAX_DELAY);
} else if (strcmp(type->valuestring, "stt") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
if (text != NULL) {
ESP_LOGI(TAG, ">> %s", text->valuestring);
}
}
}
cJSON_Delete(root);
}
xQueueSend(audio_decode_queue_, &packet, portMAX_DELAY);
});
ws_client_->OnError([this](int error) {
ESP_LOGE(TAG, "Websocket error: %d", error);
});
ws_client_->OnClosed([this]() {
ESP_LOGI(TAG, "Websocket closed");
});
if (!ws_client_->Connect(CONFIG_WEBSOCKET_URL)) {
ESP_LOGE(TAG, "Failed to connect to websocket server");
return;