+26
src-tauri/src/cartesia/client.rs
+26
src-tauri/src/cartesia/client.rs
···
49
49
50
50
stream
51
51
}
52
+
53
+
pub async fn open_tts_connection(&self) -> WebSocketStream<MaybeTlsStream<TcpStream>> {
54
+
let mut request = Url::parse("wss://api.cartesia.ai/tts/websocket")
55
+
.expect("failed to parse TTS connection URL")
56
+
.as_str()
57
+
.into_client_request()
58
+
.expect("failed to instantiate TTS WebSocket request");
59
+
60
+
let headers = request.headers_mut();
61
+
let api_key = self
62
+
.secrets_manager
63
+
.get_secret(SecretName::CartesiaApiKey)
64
+
.expect("failed to retrieve API key");
65
+
66
+
headers.insert(
67
+
"X-API-Key",
68
+
HeaderValue::from_str(api_key.as_str()).expect("could not convert key to header value"),
69
+
);
70
+
headers.insert("Cartesia-Version", "2025-04-16".parse().unwrap());
71
+
72
+
let (stream, _) = connect_async(request)
73
+
.await
74
+
.expect("failed to open TTS websocket connection");
75
+
76
+
stream
77
+
}
52
78
}
+1
src-tauri/src/cartesia/mod.rs
+1
src-tauri/src/cartesia/mod.rs
+232
src-tauri/src/cartesia/tts.rs
+232
src-tauri/src/cartesia/tts.rs
···
1
+
use std::{collections::HashMap, sync::Arc};
2
+
3
+
use futures_util::{
4
+
stream::{SplitSink, SplitStream},
5
+
SinkExt, StreamExt,
6
+
};
7
+
use serde::{Deserialize, Serialize};
8
+
use serde_json::{from_str, json};
9
+
use tokio::{
10
+
net::TcpStream,
11
+
sync::{
12
+
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
13
+
Mutex,
14
+
},
15
+
};
16
+
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
17
+
use ts_rs::TS;
18
+
use tungstenite::Message;
19
+
20
+
use crate::cartesia::client::CartesiaClient;
21
+
22
+
#[derive(Serialize, Deserialize, TS)]
23
+
pub struct TtsTimestamp {
24
+
words: Vec<String>,
25
+
start: Vec<f32>,
26
+
end: Vec<f32>,
27
+
}
28
+
29
+
#[derive(Serialize, Deserialize, TS)]
30
+
pub struct TtsPhonemeTimestamp {
31
+
phonemes: Vec<String>,
32
+
start: Vec<f32>,
33
+
end: Vec<f32>,
34
+
}
35
+
36
+
#[derive(Serialize, Deserialize, TS)]
37
+
#[serde(tag = "type", rename_all = "snake_case")]
38
+
#[ts(export)]
39
+
pub enum TtsMessage {
40
+
Chunk {
41
+
data: String,
42
+
done: bool,
43
+
status_code: u16,
44
+
step_time: f32,
45
+
context_id: Option<String>,
46
+
},
47
+
FlushDone {
48
+
done: bool,
49
+
flush_done: bool,
50
+
flush_id: u16,
51
+
status_code: u16,
52
+
context_id: Option<String>,
53
+
},
54
+
Done {
55
+
done: bool,
56
+
status_code: u16,
57
+
context_id: Option<String>,
58
+
},
59
+
Timestamps {
60
+
done: bool,
61
+
status_code: u16,
62
+
context_id: Option<String>,
63
+
word_timestamps: Option<Vec<TtsTimestamp>>,
64
+
},
65
+
Error {
66
+
done: bool,
67
+
error: String,
68
+
status_code: u16,
69
+
context_id: Option<String>,
70
+
},
71
+
PhonemeTimestamps {
72
+
done: bool,
73
+
status_code: u16,
74
+
context_id: Option<String>,
75
+
phoneme_timestamps: Option<Vec<TtsPhonemeTimestamp>>,
76
+
},
77
+
}
78
+
79
+
pub struct TtsContext {
80
+
pub id: String,
81
+
reader: Mutex<UnboundedReceiver<TtsMessage>>,
82
+
writer: Arc<Mutex<UnboundedSender<TtsInputMessage>>>,
83
+
}
84
+
impl TtsContext {
85
+
async fn send(&self, content: String, is_final: bool) {
86
+
let id = self.id.clone();
87
+
let tx = self.writer.lock().await;
88
+
89
+
tx.send(TtsInputMessage {
90
+
context_id: id.clone(),
91
+
content,
92
+
done: is_final,
93
+
})
94
+
.expect(format!("failed to send content to context {id}").as_str())
95
+
}
96
+
97
+
async fn recv(&self) -> Option<TtsMessage> {
98
+
self.reader.lock().await.recv().await
99
+
}
100
+
}
101
+
102
+
#[derive(Serialize, Deserialize, TS)]
103
+
#[ts(export)]
104
+
pub struct TtsInputMessage {
105
+
context_id: String,
106
+
content: String,
107
+
done: bool,
108
+
}
109
+
110
+
pub struct TtsManager {
111
+
client: Arc<CartesiaClient>,
112
+
contexts: Arc<Mutex<HashMap<String, Arc<Mutex<UnboundedSender<TtsMessage>>>>>>,
113
+
input: Option<Arc<Mutex<UnboundedSender<TtsInputMessage>>>>,
114
+
}
115
+
impl TtsManager {
116
+
pub fn new(client: Arc<CartesiaClient>) -> Self {
117
+
Self {
118
+
client,
119
+
contexts: Arc::new(Mutex::new(HashMap::new().into())),
120
+
input: None,
121
+
}
122
+
}
123
+
124
+
async fn connect(&mut self) {
125
+
match &self.input {
126
+
Some(_) => (),
127
+
None => {
128
+
let stream = self.client.open_tts_connection().await;
129
+
let (tx_in, rx_in) = unbounded_channel::<TtsInputMessage>();
130
+
let (tx_ws, rx_ws) = stream.split();
131
+
132
+
tokio::spawn(handle_incoming(self.contexts.clone(), rx_ws));
133
+
tokio::spawn(handle_outgoing(tx_ws, rx_in));
134
+
135
+
self.input = Some(Arc::new(Mutex::new(tx_in)))
136
+
}
137
+
}
138
+
}
139
+
140
+
async fn new_context(&self, id: String) -> Result<TtsContext, ()> {
141
+
match &self.input {
142
+
Some(input) => {
143
+
let (tx_out, rx_out) = unbounded_channel::<TtsMessage>();
144
+
145
+
{
146
+
let mut map = self.contexts.lock().await;
147
+
map.insert(id.clone(), Arc::new(Mutex::new(tx_out)));
148
+
}
149
+
150
+
Ok(TtsContext {
151
+
id,
152
+
reader: Mutex::new(rx_out),
153
+
writer: input.clone(),
154
+
})
155
+
}
156
+
None => {
157
+
todo!()
158
+
}
159
+
}
160
+
}
161
+
}
162
+
163
+
async fn handle_incoming(
164
+
contexts: Arc<Mutex<HashMap<String, Arc<Mutex<UnboundedSender<TtsMessage>>>>>>,
165
+
mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
166
+
) {
167
+
while let Some(event) = reader.next().await {
168
+
match event {
169
+
Ok(msg) => {
170
+
let text = msg.into_text().expect("failed to decode websocket message");
171
+
172
+
println!("Got message: {}", text);
173
+
174
+
match from_str::<TtsMessage>(&text).expect("failed to decode TTS message") {
175
+
TtsMessage::Chunk {
176
+
data,
177
+
done,
178
+
status_code,
179
+
step_time,
180
+
context_id: Some(id),
181
+
} => match contexts.lock().await.get(&id) {
182
+
Some(ctx) => {
183
+
ctx.lock()
184
+
.await
185
+
.send(TtsMessage::Chunk {
186
+
data,
187
+
done,
188
+
status_code,
189
+
step_time,
190
+
context_id: Some(id),
191
+
})
192
+
.expect("failed to forward TTS chunk to contextual writer");
193
+
}
194
+
None => {
195
+
eprintln!("no matching TTS sender context")
196
+
}
197
+
},
198
+
_ => println!("got not chunk"),
199
+
}
200
+
}
201
+
Err(err) => {
202
+
eprintln!("got error from websocket: {}", err)
203
+
}
204
+
}
205
+
}
206
+
}
207
+
208
+
async fn handle_outgoing(
209
+
mut writer: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
210
+
mut reader: UnboundedReceiver<TtsInputMessage>,
211
+
) {
212
+
while let Some(msg) = reader.recv().await {
213
+
let payload = json!({
214
+
"model_id": "sonic-2",
215
+
"transcript": msg.content,
216
+
"voice": "TODO",
217
+
"output_format": {
218
+
"container": "raw",
219
+
"encoding": "pcm_f32le",
220
+
"sample_rate": 44100
221
+
},
222
+
"continue": msg.done,
223
+
"max_buffer_delay": 250,
224
+
"context_id": msg.context_id,
225
+
});
226
+
227
+
writer
228
+
.send(Message::Text(payload.to_string().into()))
229
+
.await
230
+
.expect("failed to send TTS generation request");
231
+
}
232
+
}
+4
-1
src-tauri/src/state.rs
+4
-1
src-tauri/src/state.rs
···
5
5
use tauri_plugin_store::Store;
6
6
7
7
use crate::{
8
-
cartesia::{client::CartesiaClient, stt::SttManager},
8
+
cartesia::{client::CartesiaClient, stt::SttManager, tts::TtsManager},
9
9
devices::{input::InputDeviceManager, output::OutputDeviceManager, types::AudioDeviceError},
10
10
letta::LettaManager,
11
11
secrets::SecretsManager,
···
14
14
pub struct AppState {
15
15
pub cartesia_client: Arc<CartesiaClient>,
16
16
pub stt_manager: Arc<SttManager>,
17
+
pub tts_manager: Arc<TtsManager>,
17
18
pub letta_manager: Arc<LettaManager>,
18
19
pub secrets_manager: Arc<SecretsManager>,
19
20
pub input_device_manager: Arc<InputDeviceManager>,
···
38
39
cartesia_client.clone(),
39
40
input_device_manager.clone(),
40
41
));
42
+
let tts_manager = Arc::new(TtsManager::new(cartesia_client.clone()));
41
43
42
44
Ok(AppState {
43
45
input_device_manager,
···
46
48
letta_manager,
47
49
cartesia_client,
48
50
stt_manager,
51
+
tts_manager,
49
52
})
50
53
}
51
54
}