+1
src-tauri/Cargo.lock
+1
src-tauri/Cargo.lock
+1
src-tauri/Cargo.toml
+1
src-tauri/Cargo.toml
+9
-12
src-tauri/src/cartesia/tts.rs
+9
-12
src-tauri/src/cartesia/tts.rs
···
76
76
},
77
77
}
78
78
79
+
#[derive(Clone)]
79
80
pub struct TtsContext {
80
81
pub id: String,
81
-
reader: Mutex<UnboundedReceiver<TtsMessage>>,
82
+
pub reader: Arc<Mutex<UnboundedReceiver<TtsMessage>>>,
82
83
writer: Arc<Mutex<UnboundedSender<TtsInputMessage>>>,
83
84
}
84
85
impl TtsContext {
85
-
async fn send(&self, content: String, is_final: bool) {
86
+
pub async fn send(&self, content: String, is_final: bool) {
86
87
let id = self.id.clone();
87
88
let tx = self.writer.lock().await;
88
89
···
93
94
})
94
95
.expect(format!("failed to send content to context {id}").as_str())
95
96
}
96
-
97
-
async fn recv(&self) -> Option<TtsMessage> {
98
-
self.reader.lock().await.recv().await
99
-
}
100
97
}
101
98
102
99
#[derive(Serialize, Deserialize, TS)]
103
100
#[ts(export)]
104
101
pub struct TtsInputMessage {
105
-
context_id: String,
106
-
content: String,
107
-
done: bool,
102
+
pub context_id: String,
103
+
pub content: String,
104
+
pub done: bool,
108
105
}
109
106
110
107
pub struct TtsManager {
···
121
118
}
122
119
}
123
120
124
-
async fn connect(&mut self) {
121
+
pub async fn connect(&mut self) {
125
122
match &self.input {
126
123
Some(_) => (),
127
124
None => {
···
137
134
}
138
135
}
139
136
140
-
async fn new_context(&self, id: String) -> Result<TtsContext, ()> {
137
+
pub async fn new_context(&self, id: String) -> Result<TtsContext, ()> {
141
138
match &self.input {
142
139
Some(input) => {
143
140
let (tx_out, rx_out) = unbounded_channel::<TtsMessage>();
···
149
146
150
147
Ok(TtsContext {
151
148
id,
152
-
reader: Mutex::new(rx_out),
149
+
reader: Arc::new(Mutex::new(rx_out)),
153
150
writer: input.clone(),
154
151
})
155
152
}
+178
src-tauri/src/conversation/mod.rs
+178
src-tauri/src/conversation/mod.rs
···
1
+
use std::{io::Cursor, sync::Arc};
2
+
3
+
use base64::prelude::{Engine as _, BASE64_STANDARD};
4
+
use tauri::async_runtime::spawn;
5
+
use tokio::sync::{
6
+
mpsc::{unbounded_channel, UnboundedReceiver},
7
+
Mutex, RwLock,
8
+
};
9
+
10
+
use crate::{
11
+
cartesia::tts::{TtsContext, TtsInputMessage, TtsManager, TtsMessage},
12
+
conversation::types::{Turn, TurnMessage},
13
+
letta::{
14
+
types::{LettaCompletionMessage, LettaMessageContent},
15
+
LettaManager,
16
+
},
17
+
};
18
+
19
+
mod types;
20
+
21
+
pub struct ConversationManager {
22
+
turn: Option<Arc<RwLock<Turn>>>,
23
+
current_msg_index: RwLock<usize>,
24
+
letta_manager: Arc<LettaManager>,
25
+
tts_manager: Arc<TtsManager>,
26
+
}
27
+
impl ConversationManager {
28
+
pub fn new(letta_manager: Arc<LettaManager>, tts_manager: Arc<TtsManager>) -> Self {
29
+
Self {
30
+
turn: None,
31
+
current_msg_index: RwLock::new(0),
32
+
letta_manager,
33
+
tts_manager,
34
+
}
35
+
}
36
+
37
+
pub fn is_idle(&self) -> bool {
38
+
self.turn.is_none()
39
+
}
40
+
41
+
/// Start a new conversation turn
42
+
pub async fn start_turn(&mut self, prompt: String) {
43
+
if !self.is_idle() {
44
+
return;
45
+
};
46
+
47
+
let turn = Arc::new(RwLock::new(Turn::new()));
48
+
49
+
{
50
+
let mut idx = self.current_msg_index.write().await;
51
+
52
+
*idx = 0;
53
+
self.turn = Some(turn.clone());
54
+
}
55
+
56
+
spawn(handle_letta_messages(
57
+
prompt,
58
+
self.letta_manager.clone(),
59
+
self.tts_manager.clone(),
60
+
turn.clone(),
61
+
));
62
+
}
63
+
}
64
+
65
+
async fn handle_letta_messages(
66
+
prompt: String,
67
+
letta: Arc<LettaManager>,
68
+
tts: Arc<TtsManager>,
69
+
turn: Arc<RwLock<Turn>>,
70
+
) {
71
+
match letta.start_completion(prompt).await {
72
+
Ok(mut iter) => {
73
+
while let Some(msg) = iter.recv().await {
74
+
let turn = turn.read().await;
75
+
76
+
match turn.latest() {
77
+
Some(latest) => match latest {
78
+
TurnMessage::TextMessage {
79
+
id,
80
+
reader: _,
81
+
writer,
82
+
} if id == msg.id => {
83
+
// Add current message to chunks
84
+
}
85
+
TurnMessage::AudioMessage {
86
+
id,
87
+
reader: _,
88
+
writer,
89
+
context,
90
+
cursor: _,
91
+
timestamps: _,
92
+
} => {
93
+
// Add current message to chunks
94
+
}
95
+
},
96
+
None => {
97
+
// Create a new message
98
+
// Append message to chunks
99
+
}
100
+
}
101
+
}
102
+
}
103
+
Err(err) => eprintln!("failed to start completion: {}", err),
104
+
}
105
+
}
106
+
107
+
async fn create_turn_message_from_letta(
108
+
src: LettaCompletionMessage,
109
+
tts: Arc<TtsManager>,
110
+
) -> Option<TurnMessage> {
111
+
let (writer, reader) = unbounded_channel::<LettaCompletionMessage>();
112
+
113
+
writer.send(src.clone());
114
+
115
+
match src {
116
+
LettaCompletionMessage::ApprovalRequestMessage { id, .. }
117
+
| LettaCompletionMessage::ApprovalResponseMessage { id, .. }
118
+
| LettaCompletionMessage::HiddenReasoningMessage { id, .. }
119
+
| LettaCompletionMessage::SystemMessage { id, .. }
120
+
| LettaCompletionMessage::ToolCallMessage { id, .. }
121
+
| LettaCompletionMessage::ToolReturnMessage { id, .. } => {
122
+
Some(TurnMessage::TextMessage { id, reader, writer })
123
+
}
124
+
LettaCompletionMessage::AssistantMessage {
125
+
id,
126
+
content: blocks,
127
+
..
128
+
} => {
129
+
let cursor = Cursor::new(Vec::new());
130
+
let mut content = "".to_owned();
131
+
132
+
for b in blocks {
133
+
match b {
134
+
LettaMessageContent::Text { text } => content.push_str(&text),
135
+
_ => (),
136
+
}
137
+
}
138
+
139
+
let context = tts
140
+
.new_context(id.clone())
141
+
.await
142
+
.expect("failed to create new TTS context");
143
+
144
+
context.send(content, false).await;
145
+
146
+
// Spawn task for handling audio generation
147
+
spawn((async |reader: Arc<
148
+
Mutex<UnboundedReceiver<TtsMessage>>,
149
+
>| {
150
+
// Listen to context and append to cursor
151
+
while let Some(msg) = reader.lock().await.recv().await {
152
+
match msg {
153
+
TtsMessage::Chunk { data, .. } => {
154
+
// Decode chunk and write to cursor
155
+
let out = BASE64_STANDARD.decode_vec(data, cursor.get_mut());
156
+
}
157
+
TtsMessage::Timestamps {
158
+
word_timestamps, ..
159
+
} => {
160
+
// Append timestamps
161
+
}
162
+
_ => (),
163
+
}
164
+
}
165
+
})(context.reader.clone()));
166
+
167
+
Some(TurnMessage::AudioMessage {
168
+
id: id.clone(),
169
+
reader,
170
+
writer,
171
+
context,
172
+
cursor,
173
+
timestamps: Vec::new(),
174
+
})
175
+
}
176
+
_ => None,
177
+
}
178
+
}
+52
src-tauri/src/conversation/types.rs
+52
src-tauri/src/conversation/types.rs
···
1
+
use std::io::Cursor;
2
+
3
+
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
4
+
5
+
use crate::{
6
+
cartesia::tts::{TtsContext, TtsTimestamp},
7
+
letta::types::LettaCompletionMessage,
8
+
};
9
+
10
+
pub enum TurnMessage {
11
+
/// Represents a message to be displayed without audio
12
+
TextMessage {
13
+
id: String,
14
+
reader: UnboundedReceiver<LettaCompletionMessage>,
15
+
writer: UnboundedSender<LettaCompletionMessage>,
16
+
},
17
+
18
+
/// Represents a message with associated audio data
19
+
AudioMessage {
20
+
id: String,
21
+
reader: UnboundedReceiver<LettaCompletionMessage>,
22
+
writer: UnboundedSender<LettaCompletionMessage>,
23
+
context: TtsContext,
24
+
cursor: Cursor<Vec<u8>>,
25
+
timestamps: Vec<TtsTimestamp>,
26
+
},
27
+
}
28
+
29
+
pub struct Turn {
30
+
messages: Vec<TurnMessage>,
31
+
done: bool,
32
+
}
33
+
impl Turn {
34
+
pub fn new() -> Self {
35
+
Self {
36
+
messages: Vec::new(),
37
+
done: false,
38
+
}
39
+
}
40
+
41
+
pub fn complete(&mut self) {
42
+
self.done = true;
43
+
}
44
+
45
+
pub fn latest(&self) -> Option<&TurnMessage> {
46
+
self.messages.last()
47
+
}
48
+
49
+
pub fn add_message(&mut self, msg: TurnMessage) {
50
+
self.messages.push(msg);
51
+
}
52
+
}
+6
-6
src-tauri/src/letta/types.rs
+6
-6
src-tauri/src/letta/types.rs
···
56
56
name: String,
57
57
}
58
58
59
-
#[derive(Serialize, Deserialize, TS, Debug)]
59
+
#[derive(Serialize, Deserialize, TS, Debug, Clone)]
60
60
#[serde(tag = "type", rename_all = "lowercase")]
61
61
pub enum LettaMessageContent {
62
62
Text { text: String },
63
63
Image { source: String },
64
64
}
65
65
66
-
#[derive(TS, Debug, SerializeLabeledStringEnum, DeserializeLabeledStringEnum)]
66
+
#[derive(TS, Debug, SerializeLabeledStringEnum, DeserializeLabeledStringEnum, Clone, Copy)]
67
67
pub enum LettaReasoningSource {
68
68
#[string = "reasoner_model"]
69
69
ReasonerModel,
···
72
72
NonReasonerModel,
73
73
}
74
74
75
-
#[derive(SerializeLabeledStringEnum, DeserializeLabeledStringEnum, TS, Debug)]
75
+
#[derive(SerializeLabeledStringEnum, DeserializeLabeledStringEnum, TS, Debug, Clone, Copy)]
76
76
pub enum LettaHiddenReasoningState {
77
77
#[string = "redacted"]
78
78
Redacted,
···
81
81
Omitted,
82
82
}
83
83
84
-
#[derive(Serialize, Deserialize, TS, Debug)]
84
+
#[derive(Serialize, Deserialize, TS, Debug, Clone)]
85
85
#[serde(untagged)]
86
86
pub enum LettaToolCall {
87
87
Call {
···
96
96
},
97
97
}
98
98
99
-
#[derive(SerializeLabeledStringEnum, DeserializeLabeledStringEnum, TS, Debug)]
99
+
#[derive(SerializeLabeledStringEnum, DeserializeLabeledStringEnum, TS, Debug, Clone, Copy)]
100
100
pub enum LettaToolReturnStatus {
101
101
#[string = "success"]
102
102
Success,
···
105
105
Error,
106
106
}
107
107
108
-
#[derive(Serialize, Deserialize, TS, Debug)]
108
+
#[derive(Serialize, Deserialize, TS, Debug, Clone)]
109
109
#[serde(tag = "message_type", rename_all = "snake_case")]
110
110
#[ts(export)]
111
111
pub enum LettaCompletionMessage {