+72
-10
src/handlers/messages.ts
+72
-10
src/handlers/messages.ts
···
1
1
import modelPrompt from "../model/prompt.txt";
2
-
import { ChatMessage, Conversation } from "@skyware/bot";
2
+
import { ChatMessage, Conversation, RichText } from "@skyware/bot";
3
3
import * as c from "../core";
4
4
import * as tools from "../tools";
5
5
import consola from "consola";
···
37
37
parts: [
38
38
{
39
39
text: modelPrompt
40
-
.replace("{{ handle }}", env.HANDLE),
40
+
.replace("$handle", env.HANDLE),
41
41
},
42
42
],
43
43
},
···
102
102
return inference;
103
103
}
104
104
105
+
function addCitations(
106
+
inference: Awaited<ReturnType<typeof c.ai.models.generateContent>>,
107
+
) {
108
+
let originalText = inference.text ?? "";
109
+
if (!inference.candidates) {
110
+
return originalText;
111
+
}
112
+
const supports = inference.candidates[0]?.groundingMetadata
113
+
?.groundingSupports;
114
+
const chunks = inference.candidates[0]?.groundingMetadata?.groundingChunks;
115
+
116
+
const richText = new RichText();
117
+
118
+
if (!supports || !chunks || originalText === "") {
119
+
return richText.addText(originalText);
120
+
}
121
+
122
+
const sortedSupports = [...supports].sort(
123
+
(a, b) => (b.segment?.endIndex ?? 0) - (a.segment?.endIndex ?? 0),
124
+
);
125
+
126
+
let currentText = originalText;
127
+
128
+
for (const support of sortedSupports) {
129
+
const endIndex = support.segment?.endIndex;
130
+
if (endIndex === undefined || !support.groundingChunkIndices?.length) {
131
+
continue;
132
+
}
133
+
134
+
const citationLinks = support.groundingChunkIndices
135
+
.map((i) => {
136
+
const uri = chunks[i]?.web?.uri;
137
+
if (uri) {
138
+
return { index: i + 1, uri };
139
+
}
140
+
return null;
141
+
})
142
+
.filter(Boolean);
143
+
144
+
if (citationLinks.length > 0) {
145
+
richText.addText(currentText.slice(endIndex));
146
+
147
+
citationLinks.forEach((citation, idx) => {
148
+
if (citation) {
149
+
richText.addLink(`[${citation.index}]`, citation.uri);
150
+
if (idx < citationLinks.length - 1) {
151
+
richText.addText(", ");
152
+
}
153
+
}
154
+
});
155
+
156
+
currentText = currentText.slice(0, endIndex);
157
+
}
158
+
}
159
+
160
+
richText.addText(currentText);
161
+
162
+
return richText;
163
+
}
164
+
105
165
export async function handler(message: ChatMessage): Promise<void> {
106
166
const conversation = await message.getConversation();
107
167
// ? Conversation should always be able to be found, but just in case:
···
115
175
: env.AUTHORIZED_USERS.includes(message.senderDid as any);
116
176
117
177
if (!authorized) {
118
-
conversation.sendMessage({
178
+
await conversation.sendMessage({
119
179
text: c.UNAUTHORIZED_MESSAGE,
120
180
});
121
181
···
157
217
parsedConversation.messages,
158
218
);
159
219
if (!inference) {
160
-
throw new Error("Failed to generate text. Returned undefined.");
220
+
logger.error("Failed to generate text. Returned undefined.");
221
+
return;
161
222
}
162
223
163
224
const responseText = inference.text;
225
+
const responseWithCitations = addCitations(inference);
164
226
165
-
if (responseText) {
166
-
logger.success("Generated text:", inference.text);
167
-
saveMessage(conversation, env.DID, inference.text!);
227
+
if (responseWithCitations) {
228
+
logger.success("Generated text:", responseText);
229
+
saveMessage(conversation, env.DID, responseText!);
168
230
169
-
if (exceedsGraphemes(responseText)) {
170
-
multipartResponse(conversation, responseText);
231
+
if (exceedsGraphemes(responseWithCitations)) {
232
+
multipartResponse(conversation, responseWithCitations);
171
233
} else {
172
234
conversation.sendMessage({
173
-
text: responseText,
235
+
text: responseWithCitations,
174
236
});
175
237
}
176
238
}
+1
-1
src/model/prompt.txt
+1
-1
src/model/prompt.txt
+42
src/utils/cache.ts
+42
src/utils/cache.ts
···
1
+
interface CacheEntry<T> {
2
+
value: T;
3
+
expiry: number;
4
+
}
5
+
6
+
class TimedCache<T> {
7
+
private cache = new Map<string, CacheEntry<T>>();
8
+
private ttl: number; // Time to live in milliseconds
9
+
10
+
constructor(ttl: number) {
11
+
this.ttl = ttl;
12
+
}
13
+
14
+
get(key: string): T | undefined {
15
+
const entry = this.cache.get(key);
16
+
if (!entry) {
17
+
return undefined;
18
+
}
19
+
20
+
if (Date.now() > entry.expiry) {
21
+
this.cache.delete(key); // Entry expired
22
+
return undefined;
23
+
}
24
+
25
+
return entry.value;
26
+
}
27
+
28
+
set(key: string, value: T): void {
29
+
const expiry = Date.now() + this.ttl;
30
+
this.cache.set(key, { value, expiry });
31
+
}
32
+
33
+
delete(key: string): void {
34
+
this.cache.delete(key);
35
+
}
36
+
37
+
clear(): void {
38
+
this.cache.clear();
39
+
}
40
+
}
41
+
42
+
export const postCache = new TimedCache<any>(2 * 60 * 1000); // 2 minutes cache
+44
-9
src/utils/conversation.ts
+44
-9
src/utils/conversation.ts
···
2
2
type ChatMessage,
3
3
type Conversation,
4
4
graphemeLength,
5
+
RichText,
5
6
} from "@skyware/bot";
6
7
import * as yaml from "js-yaml";
7
8
import db from "../db";
···
10
11
import { env } from "../env";
11
12
import { bot, ERROR_MESSAGE, MAX_GRAPHEMES } from "../core";
12
13
import { parsePost, parsePostImages, traverseThread } from "./post";
14
+
import { postCache } from "../utils/cache";
13
15
14
16
/*
15
17
Utilities
···
31
33
32
34
const postUri = await parseMessagePostUri(initialMessage);
33
35
if (!postUri) {
34
-
convo.sendMessage({
36
+
await convo.sendMessage({
35
37
text:
36
38
"Please send a post for me to make sense of the noise for you.",
37
39
});
40
+
38
41
throw new Error("No post reference in initial message.");
39
42
}
40
43
···
60
63
did: user.did,
61
64
postUri,
62
65
revision: _convo.revision,
63
-
text: initialMessage.text,
66
+
text:
67
+
!initialMessage.text ||
68
+
initialMessage.text.trim().length == 0
69
+
? "Explain this post."
70
+
: initialMessage.text,
64
71
});
65
72
66
73
return _convo!;
···
109
116
did: getUserDid(convo).did,
110
117
postUri: row.postUri,
111
118
revision: row.revision,
112
-
text: latestMessage!.text,
119
+
text: postUri &&
120
+
(!latestMessage.text ||
121
+
latestMessage.text.trim().length == 0)
122
+
? "Explain this post."
123
+
: latestMessage.text,
113
124
});
114
125
}
115
126
116
-
const post = await bot.getPost(row.postUri);
127
+
let post = postCache.get(row.postUri);
128
+
if (!post) {
129
+
post = await bot.getPost(row.postUri);
130
+
postCache.set(row.postUri, post);
131
+
}
117
132
const convoMessages = await getRelevantMessages(row!);
118
133
119
134
let parseResult = null;
120
135
try {
136
+
const parsedPost = await parsePost(post, true, new Set());
121
137
parseResult = {
122
138
context: yaml.dump({
123
-
post: await parsePost(post, true),
139
+
post: parsedPost || null,
124
140
}),
125
141
messages: convoMessages.map((message) => {
126
142
const role = message.did == env.DID ? "model" : "user";
···
136
152
}),
137
153
};
138
154
} catch (e) {
139
-
convo.sendMessage({
155
+
await convo.sendMessage({
140
156
text: ERROR_MESSAGE,
141
157
});
142
158
···
195
211
/*
196
212
Reponse Utilities
197
213
*/
198
-
export function exceedsGraphemes(content: string) {
214
+
export function exceedsGraphemes(content: string | RichText) {
215
+
if (content instanceof RichText) {
216
+
return graphemeLength(content.text) > MAX_GRAPHEMES;
217
+
}
199
218
return graphemeLength(content) > MAX_GRAPHEMES;
200
219
}
201
220
···
223
242
return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`);
224
243
}
225
244
226
-
export async function multipartResponse(convo: Conversation, content: string) {
227
-
const parts = splitResponse(content).filter((p) => p.trim().length > 0);
245
+
export async function multipartResponse(
246
+
convo: Conversation,
247
+
content: string | RichText,
248
+
) {
249
+
let parts: (string | RichText)[];
250
+
251
+
if (content instanceof RichText) {
252
+
if (exceedsGraphemes(content)) {
253
+
// If RichText exceeds grapheme limit, convert to plain text for splitting
254
+
parts = splitResponse(content.text);
255
+
} else {
256
+
// Otherwise, send the RichText directly as a single part
257
+
parts = [content];
258
+
}
259
+
} else {
260
+
// If content is a string, behave as before
261
+
parts = splitResponse(content);
262
+
}
228
263
229
264
for (const segment of parts) {
230
265
await convo.sendMessage({
+23
-8
src/utils/post.ts
+23
-8
src/utils/post.ts
···
8
8
import * as c from "../core";
9
9
import * as yaml from "js-yaml";
10
10
import type { ParsedPost } from "../types";
11
+
import { postCache } from "../utils/cache";
11
12
12
13
export async function parsePost(
13
14
post: Post,
14
15
includeThread: boolean,
15
-
): Promise<ParsedPost> {
16
+
seenUris: Set<string> = new Set(),
17
+
): Promise<ParsedPost | undefined> {
18
+
if (seenUris.has(post.uri)) {
19
+
return undefined;
20
+
}
21
+
seenUris.add(post.uri);
22
+
16
23
const [images, quotePost, ancestorPosts] = await Promise.all([
17
24
parsePostImages(post),
18
-
parseQuote(post),
25
+
parseQuote(post, seenUris),
19
26
includeThread ? traverseThread(post) : Promise.resolve(null),
20
27
]);
21
28
···
28
35
...(quotePost && { quotePost }),
29
36
...(ancestorPosts && {
30
37
thread: {
31
-
ancestors: await Promise.all(
32
-
ancestorPosts.map((ancestor) => parsePost(ancestor, false)),
33
-
),
38
+
ancestors: (await Promise.all(
39
+
ancestorPosts.map((ancestor) => parsePost(ancestor, false, seenUris)),
40
+
)).filter((post): post is ParsedPost => post !== undefined),
34
41
},
35
42
}),
36
43
};
37
44
}
38
45
39
-
async function parseQuote(post: Post) {
46
+
async function parseQuote(post: Post, seenUris: Set<string>) {
40
47
if (
41
48
!post.embed || (!post.embed.isRecord() && !post.embed.isRecordWithMedia())
42
49
) return undefined;
43
50
44
51
const record = (post.embed as RecordEmbed || RecordWithMediaEmbed).record;
45
-
const embedPost = await c.bot.getPost(record.uri);
52
+
if (seenUris.has(record.uri)) {
53
+
return undefined;
54
+
}
55
+
56
+
let embedPost = postCache.get(record.uri);
57
+
if (!embedPost) {
58
+
embedPost = await c.bot.getPost(record.uri);
59
+
postCache.set(record.uri, embedPost);
60
+
}
46
61
47
-
return await parsePost(embedPost, false);
62
+
return await parsePost(embedPost, false, seenUris);
48
63
}
49
64
50
65
export function parsePostImages(post: Post) {