🔧 Where my dotfiles lives in harmony and peace, most of the time

✨ Add pi output schema extension

+399
+399
agents/pi/extensions/output-schema.ts
··· 1 + import { readFileSync } from "node:fs"; 2 + import { resolve } from "node:path"; 3 + import type { ExtensionAPI } from "@mariozechner/pi-coding-agent"; 4 + import { Type } from "@sinclair/typebox"; 5 + 6 + const FLAG_NAME = "output-schema"; 7 + const TOOL_NAME = "submit_output"; 8 + const MAX_REPAIR_ATTEMPTS = 2; 9 + 10 + type Json = null | boolean | number | string | Json[] | { [key: string]: Json }; 11 + 12 + type JsonSchema = { 13 + type?: string; 14 + properties?: Record<string, JsonSchema>; 15 + required?: string[]; 16 + additionalProperties?: boolean; 17 + items?: JsonSchema; 18 + enum?: Json[]; 19 + description?: string; 20 + [key: string]: unknown; 21 + }; 22 + 23 + type SubmitOutputParams = Record<string, Json>; 24 + 25 + type AgentEndEvent = { 26 + messages?: Array<{ 27 + role?: string; 28 + content?: Array<{ type?: string; text?: string }>; 29 + }>; 30 + }; 31 + 32 + function isRecord(value: unknown): value is Record<string, unknown> { 33 + return typeof value === "object" && value !== null && !Array.isArray(value); 34 + } 35 + 36 + function asJsonSchema(value: unknown): JsonSchema { 37 + if (!isRecord(value)) { 38 + throw new Error("output schema must be a JSON object"); 39 + } 40 + return value as JsonSchema; 41 + } 42 + 43 + function assertSupportedSchema( 44 + schema: JsonSchema, 45 + path = "$", 46 + isRoot = true, 47 + ): void { 48 + for (const key of [ 49 + "$ref", 50 + "oneOf", 51 + "anyOf", 52 + "allOf", 53 + "not", 54 + "if", 55 + "then", 56 + "else", 57 + "patternProperties", 58 + ]) { 59 + if (key in schema) { 60 + throw new Error( 61 + `${FLAG_NAME}: unsupported schema feature at ${path}: ${key}`, 62 + ); 63 + } 64 + } 65 + 66 + if (schema.enum !== undefined) { 67 + if (!Array.isArray(schema.enum) || schema.enum.length === 0) { 68 + throw new Error( 69 + `${FLAG_NAME}: enum at ${path} must be a non-empty array`, 70 + ); 71 + } 72 + return; 73 + } 74 + 75 + const type = schema.type; 76 + if (typeof type !== "string") { 77 + throw new Error( 78 + `${FLAG_NAME}: schema at ${path} must declare a string type`, 79 + ); 80 + } 81 + 82 + if (isRoot && type !== "object") { 83 + throw new Error(`${FLAG_NAME}: root schema must be an object`); 84 + } 85 + 86 + if ( 87 + ![ 88 + "object", 89 + "array", 90 + "string", 91 + "number", 92 + "integer", 93 + "boolean", 94 + "null", 95 + ].includes(type) 96 + ) { 97 + throw new Error(`${FLAG_NAME}: unsupported type at ${path}: ${type}`); 98 + } 99 + 100 + if (type === "object") { 101 + const properties = schema.properties; 102 + if (!isRecord(properties)) { 103 + throw new Error( 104 + `${FLAG_NAME}: object schema at ${path} must define properties`, 105 + ); 106 + } 107 + 108 + if (schema.required !== undefined && !Array.isArray(schema.required)) { 109 + throw new Error( 110 + `${FLAG_NAME}: required at ${path} must be an array of strings`, 111 + ); 112 + } 113 + 114 + for (const [key, child] of Object.entries(properties)) { 115 + assertSupportedSchema(asJsonSchema(child), `${path}.${key}`, false); 116 + } 117 + return; 118 + } 119 + 120 + if (type === "array") { 121 + if (schema.items === undefined) { 122 + throw new Error( 123 + `${FLAG_NAME}: array schema at ${path} must define items`, 124 + ); 125 + } 126 + assertSupportedSchema(asJsonSchema(schema.items), `${path}[]`, false); 127 + } 128 + } 129 + 130 + function isJson(value: unknown): value is Json { 131 + if (value === null) return true; 132 + if (typeof value === "string" || typeof value === "boolean") return true; 133 + if (typeof value === "number") return Number.isFinite(value); 134 + if (Array.isArray(value)) return value.every(isJson); 135 + if (!isRecord(value)) return false; 136 + return Object.values(value).every(isJson); 137 + } 138 + 139 + function describeValue(value: unknown): string { 140 + if (value === null) return "null"; 141 + if (Array.isArray(value)) return "array"; 142 + return typeof value; 143 + } 144 + 145 + function jsonEquals(left: Json, right: Json): boolean { 146 + return JSON.stringify(left) === JSON.stringify(right); 147 + } 148 + 149 + function validateValue( 150 + schema: JsonSchema, 151 + value: unknown, 152 + path = "$", 153 + errors: string[] = [], 154 + ): string[] { 155 + if (schema.enum !== undefined) { 156 + if ( 157 + !isJson(value) || 158 + !schema.enum.some((item) => jsonEquals(item, value)) 159 + ) { 160 + errors.push(`${path} must be one of the enum values`); 161 + } 162 + return errors; 163 + } 164 + 165 + switch (schema.type) { 166 + case "object": { 167 + if (!isRecord(value) || Array.isArray(value)) { 168 + errors.push(`${path} must be an object`); 169 + return errors; 170 + } 171 + 172 + const properties = isRecord(schema.properties) ? schema.properties : {}; 173 + const required = Array.isArray(schema.required) ? schema.required : []; 174 + 175 + for (const key of required) { 176 + if (!(key in value)) { 177 + errors.push(`${path}.${key} is required`); 178 + } 179 + } 180 + 181 + if (schema.additionalProperties === false) { 182 + for (const key of Object.keys(value)) { 183 + if (!(key in properties)) { 184 + errors.push(`${path}.${key} is not allowed`); 185 + } 186 + } 187 + } 188 + 189 + for (const [key, childSchema] of Object.entries(properties)) { 190 + if (!(key in value)) continue; 191 + validateValue( 192 + asJsonSchema(childSchema), 193 + value[key], 194 + `${path}.${key}`, 195 + errors, 196 + ); 197 + } 198 + return errors; 199 + } 200 + case "array": { 201 + if (!Array.isArray(value)) { 202 + errors.push(`${path} must be an array`); 203 + return errors; 204 + } 205 + 206 + const itemSchema = asJsonSchema(schema.items); 207 + for (let index = 0; index < value.length; index += 1) { 208 + validateValue(itemSchema, value[index], `${path}[${index}]`, errors); 209 + } 210 + return errors; 211 + } 212 + case "string": { 213 + if (typeof value !== "string") errors.push(`${path} must be a string`); 214 + return errors; 215 + } 216 + case "number": { 217 + if (typeof value !== "number" || !Number.isFinite(value)) { 218 + errors.push(`${path} must be a finite number`); 219 + } 220 + return errors; 221 + } 222 + case "integer": { 223 + if (typeof value !== "number" || !Number.isInteger(value)) { 224 + errors.push(`${path} must be an integer`); 225 + } 226 + return errors; 227 + } 228 + case "boolean": { 229 + if (typeof value !== "boolean") errors.push(`${path} must be a boolean`); 230 + return errors; 231 + } 232 + case "null": { 233 + if (value !== null) errors.push(`${path} must be null`); 234 + return errors; 235 + } 236 + default: { 237 + errors.push( 238 + `${path} has unsupported schema type ${describeValue(schema.type)}`, 239 + ); 240 + return errors; 241 + } 242 + } 243 + } 244 + 245 + function getFinalAssistantText(event: AgentEndEvent): string | undefined { 246 + const messages = event.messages ?? []; 247 + for (let index = messages.length - 1; index >= 0; index -= 1) { 248 + const message = messages[index]; 249 + if (message.role !== "assistant") continue; 250 + const text = (message.content ?? []) 251 + .filter((part) => part.type === "text" && typeof part.text === "string") 252 + .map((part) => part.text ?? "") 253 + .join("\n") 254 + .trim(); 255 + return text; 256 + } 257 + return undefined; 258 + } 259 + 260 + function buildSystemPrompt(basePrompt: string, schemaPath: string): string { 261 + return `${basePrompt}\n\n[Structured output contract]\nThe user started pi with --${FLAG_NAME} ${schemaPath}.\nYou must finish by calling ${TOOL_NAME} exactly once with arguments that match the provided JSON schema.\n- ${TOOL_NAME} must be your final tool call.\n- Do not end with prose, markdown, or explanations.\n- After the tool call, your final assistant text must be exactly the same JSON object and nothing else.\n- If the tool rejects your arguments, fix them and call ${TOOL_NAME} again.`; 262 + } 263 + 264 + export default function outputSchemaExtension(pi: ExtensionAPI): void { 265 + let enabled = false; 266 + let schemaPath: string | undefined; 267 + let schema: JsonSchema | undefined; 268 + let acceptedJson: string | undefined; 269 + let repairAttempts = 0; 270 + let repairing = false; 271 + let toolRegistered = false; 272 + 273 + function isActive(): boolean { 274 + return enabled && schema !== undefined && schemaPath !== undefined; 275 + } 276 + 277 + function ensureToolRegistered(): void { 278 + if (toolRegistered || !schema) return; 279 + 280 + pi.registerTool({ 281 + name: TOOL_NAME, 282 + label: "Submit Output", 283 + description: 284 + "Submit the final JSON output. Use this exactly once as the final tool call.", 285 + promptSnippet: 286 + "Submit the final response as JSON matching the requested schema.", 287 + promptGuidelines: [ 288 + `Call ${TOOL_NAME} exactly once when the task is complete.`, 289 + "Pass arguments that match the active output schema exactly.", 290 + "After the tool call, emit exactly the same JSON and nothing else.", 291 + ], 292 + parameters: Type.Unsafe<Record<string, Json>>(schema as never), 293 + async execute(_toolCallId: string, params: SubmitOutputParams) { 294 + if (!isActive() || !schema) { 295 + throw new Error( 296 + `${TOOL_NAME} is only available when --${FLAG_NAME} is set`, 297 + ); 298 + } 299 + 300 + const errors = validateValue(schema, params); 301 + if (errors.length > 0) { 302 + throw new Error( 303 + `output does not match schema:\n- ${errors.join("\n- ")}`, 304 + ); 305 + } 306 + 307 + acceptedJson = JSON.stringify(params); 308 + repairing = false; 309 + return { 310 + content: [{ type: "text", text: acceptedJson }], 311 + details: { output: params }, 312 + }; 313 + }, 314 + }); 315 + 316 + toolRegistered = true; 317 + } 318 + 319 + function ensureToolActive(): void { 320 + if (!toolRegistered) return; 321 + const activeTools = new Set(pi.getActiveTools()); 322 + if (activeTools.has(TOOL_NAME)) return; 323 + activeTools.add(TOOL_NAME); 324 + pi.setActiveTools([...activeTools]); 325 + } 326 + 327 + function resetState(): void { 328 + enabled = false; 329 + schemaPath = undefined; 330 + schema = undefined; 331 + acceptedJson = undefined; 332 + repairAttempts = 0; 333 + repairing = false; 334 + } 335 + 336 + pi.registerFlag(FLAG_NAME, { 337 + description: 338 + "Path to a JSON Schema file that the final response must match", 339 + type: "string", 340 + }); 341 + 342 + pi.on("session_start", async (_event, ctx) => { 343 + resetState(); 344 + 345 + const flagValue = pi.getFlag(FLAG_NAME); 346 + if (typeof flagValue !== "string" || flagValue.trim().length === 0) { 347 + return; 348 + } 349 + 350 + schemaPath = resolve(ctx.cwd, flagValue); 351 + const raw = readFileSync(schemaPath, "utf8"); 352 + schema = asJsonSchema(JSON.parse(raw)); 353 + assertSupportedSchema(schema); 354 + enabled = true; 355 + ensureToolRegistered(); 356 + ensureToolActive(); 357 + }); 358 + 359 + pi.on("before_agent_start", async (event) => { 360 + if (!isActive() || !schemaPath) return; 361 + ensureToolActive(); 362 + 363 + if (!repairing) { 364 + acceptedJson = undefined; 365 + repairAttempts = 0; 366 + } 367 + 368 + return { 369 + systemPrompt: buildSystemPrompt(event.systemPrompt, schemaPath), 370 + }; 371 + }); 372 + 373 + pi.on("agent_end", async (event: AgentEndEvent, ctx) => { 374 + if (!isActive() || !schemaPath) return; 375 + if (repairAttempts >= MAX_REPAIR_ATTEMPTS) return; 376 + 377 + if (!acceptedJson) { 378 + repairAttempts += 1; 379 + repairing = true; 380 + const followUp = `You stopped without calling ${TOOL_NAME}. Call ${TOOL_NAME} now with JSON that matches ${schemaPath}. After the tool call, output exactly the same JSON and nothing else.`; 381 + if (ctx.isIdle()) pi.sendUserMessage(followUp); 382 + else pi.sendUserMessage(followUp, { deliverAs: "followUp" }); 383 + return; 384 + } 385 + 386 + const finalAssistantText = getFinalAssistantText(event); 387 + if (finalAssistantText === acceptedJson) { 388 + repairAttempts = 0; 389 + repairing = false; 390 + return; 391 + } 392 + 393 + repairAttempts += 1; 394 + repairing = true; 395 + const correction = `Your final response must be exactly this JSON and nothing else:\n${acceptedJson}`; 396 + if (ctx.isIdle()) pi.sendUserMessage(correction); 397 + else pi.sendUserMessage(correction, { deliverAs: "followUp" }); 398 + }); 399 + }