Automated labeller that gives classifies news content using a fine-tuned BERT model
at main 126 lines 3.6 kB view raw
1import "dotenv/config"; 2import { Jetstream } from "@skyware/jetstream"; 3import { LABELS } from "./const"; 4import { pipeline, TextClassificationOutput } from "@huggingface/transformers"; 5import pThrottle from "p-throttle"; 6import { Agent, CredentialSession } from "@atproto/api"; 7import slugify from "slugify"; 8 9const session = new CredentialSession(new URL("https://bsky.social")); 10await session.login({ 11 identifier: "news.aendra.dev", 12 password: process.env.BSKY_PASS!, 13}); 14const agent = new Agent(session); 15 16const throttle = pThrottle({ 17 limit: 1, 18 interval: 5000, 19}); 20 21const infer = await pipeline("text-classification", "./models/bert-news-class"); 22 23const MODEL_NAME = "cssupport/bert-news-class_v1"; 24 25const createLabel = async ( 26 post: { uri: string; cid: string }, 27 label: string, 28 score: number | string, 29 model: string = MODEL_NAME, 30) => { 31 try { 32 const { uri, cid } = post; 33 34 return agent.tools.ozone.moderation.emitEvent( 35 { 36 // specify the label event 37 event: { 38 $type: "tools.ozone.moderation.defs#modEventLabel", 39 createLabelVals: [`class-${slugify(label)}`.toLowerCase()], 40 negateLabelVals: [], 41 comment: `Inferred by model ${model} (${score})`, 42 }, 43 // specify the labeled post by strongRef 44 subject: { 45 $type: "com.atproto.repo.strongRef", 46 uri, 47 cid, 48 }, 49 createdBy: session.did!, 50 subjectBlobCids: [], 51 }, 52 { 53 encoding: "application/json", 54 headers: { 55 "atproto-proxy": `${session.did}#atproto_labeler`, 56 }, 57 }, 58 ); 59 } catch (e) { 60 console.error(e); 61 } 62}; 63 64const FEED_URI = "https://bsky.app/profile/aendra.com/feed/verified-news"; 65const CONTRAILS_ENDPOINT = `wss://api.graze.social/app/contrail?feed=${FEED_URI}`; 66 67const jetstream = new Jetstream({ 68 wantedCollections: ["app.bsky.feed.post"], // omit to receive all collections 69 endpoint: CONTRAILS_ENDPOINT, // Uncomment to get just the News feed 70}); 71 72jetstream.start(); 73 74jetstream.onCreate("app.bsky.feed.post", async (event) => { 75 const fetchDescription = throttle(async (uri: string) => 76 fetch(`https://cardyb.bsky.app/v1/extract?url=${uri}`), 77 ); 78 if ( 79 event.commit.record.embed?.$type === "app.bsky.embed.external" && 80 event.commit.record.langs?.includes("en") 81 ) { 82 try { 83 const { description, title }: { description: string; title: string } = 84 await ( 85 await fetchDescription(event.commit.record.embed.external.uri) 86 ).json(); 87 88 if (description) { 89 const [result] = ( 90 (await infer(`${title}: ${description}`, { 91 top_k: 5, 92 })) as TextClassificationOutput 93 ) 94 ?.filter( 95 (d) => 96 !Object.values(LABELS) 97 .filter((d) => d.includes("UNUSED")) 98 .includes(d.label as LABELS), 99 ) 100 .sort((a, b) => b.score - a.score) 101 .map(({ label, score }: { score: number; label: string }) => ({ 102 label: LABELS[label as keyof typeof LABELS], 103 score, 104 })); 105 106 if (result.score > 0.5) { 107 await createLabel( 108 { 109 uri: `at://${event.did}/${event.commit.collection}/${event.commit.rkey}`, 110 cid: event.commit.cid, 111 }, 112 result.label, 113 result.score, 114 ); 115 console.log( 116 `https://bsky.app/profile/${event.did}/post/${event.commit.rkey}`, 117 result.label, 118 result.score, 119 ); 120 } 121 } 122 } catch (e) { 123 console.error(e); 124 } 125 } 126});