Automated labeller that gives classifies news content using a fine-tuned BERT model
1import "dotenv/config";
2import { Jetstream } from "@skyware/jetstream";
3import { LABELS } from "./const";
4import { pipeline, TextClassificationOutput } from "@huggingface/transformers";
5import pThrottle from "p-throttle";
6import { Agent, CredentialSession } from "@atproto/api";
7import slugify from "slugify";
8
9const session = new CredentialSession(new URL("https://bsky.social"));
10await session.login({
11 identifier: "news.aendra.dev",
12 password: process.env.BSKY_PASS!,
13});
14const agent = new Agent(session);
15
16const throttle = pThrottle({
17 limit: 1,
18 interval: 5000,
19});
20
21const infer = await pipeline("text-classification", "./models/bert-news-class");
22
23const MODEL_NAME = "cssupport/bert-news-class_v1";
24
25const createLabel = async (
26 post: { uri: string; cid: string },
27 label: string,
28 score: number | string,
29 model: string = MODEL_NAME,
30) => {
31 try {
32 const { uri, cid } = post;
33
34 return agent.tools.ozone.moderation.emitEvent(
35 {
36 // specify the label event
37 event: {
38 $type: "tools.ozone.moderation.defs#modEventLabel",
39 createLabelVals: [`class-${slugify(label)}`.toLowerCase()],
40 negateLabelVals: [],
41 comment: `Inferred by model ${model} (${score})`,
42 },
43 // specify the labeled post by strongRef
44 subject: {
45 $type: "com.atproto.repo.strongRef",
46 uri,
47 cid,
48 },
49 createdBy: session.did!,
50 subjectBlobCids: [],
51 },
52 {
53 encoding: "application/json",
54 headers: {
55 "atproto-proxy": `${session.did}#atproto_labeler`,
56 },
57 },
58 );
59 } catch (e) {
60 console.error(e);
61 }
62};
63
64const FEED_URI = "https://bsky.app/profile/aendra.com/feed/verified-news";
65const CONTRAILS_ENDPOINT = `wss://api.graze.social/app/contrail?feed=${FEED_URI}`;
66
67const jetstream = new Jetstream({
68 wantedCollections: ["app.bsky.feed.post"], // omit to receive all collections
69 endpoint: CONTRAILS_ENDPOINT, // Uncomment to get just the News feed
70});
71
72jetstream.start();
73
74jetstream.onCreate("app.bsky.feed.post", async (event) => {
75 const fetchDescription = throttle(async (uri: string) =>
76 fetch(`https://cardyb.bsky.app/v1/extract?url=${uri}`),
77 );
78 if (
79 event.commit.record.embed?.$type === "app.bsky.embed.external" &&
80 event.commit.record.langs?.includes("en")
81 ) {
82 try {
83 const { description, title }: { description: string; title: string } =
84 await (
85 await fetchDescription(event.commit.record.embed.external.uri)
86 ).json();
87
88 if (description) {
89 const [result] = (
90 (await infer(`${title}: ${description}`, {
91 top_k: 5,
92 })) as TextClassificationOutput
93 )
94 ?.filter(
95 (d) =>
96 !Object.values(LABELS)
97 .filter((d) => d.includes("UNUSED"))
98 .includes(d.label as LABELS),
99 )
100 .sort((a, b) => b.score - a.score)
101 .map(({ label, score }: { score: number; label: string }) => ({
102 label: LABELS[label as keyof typeof LABELS],
103 score,
104 }));
105
106 if (result.score > 0.5) {
107 await createLabel(
108 {
109 uri: `at://${event.did}/${event.commit.collection}/${event.commit.rkey}`,
110 cid: event.commit.cid,
111 },
112 result.label,
113 result.score,
114 );
115 console.log(
116 `https://bsky.app/profile/${event.did}/post/${event.commit.rkey}`,
117 result.label,
118 result.score,
119 );
120 }
121 }
122 } catch (e) {
123 console.error(e);
124 }
125 }
126});