source dump of claude code
at main 141 lines 4.4 kB view raw
1import { randomUUID } from 'crypto' 2import type { QuerySource } from '../../constants/querySource.js' 3import { queryModelWithoutStreaming } from '../../services/api/claude.js' 4import type { Message } from '../../types/message.js' 5import { createAbortController } from '../../utils/abortController.js' 6import { logError } from '../../utils/log.js' 7import { toError } from '../errors.js' 8import { extractTextContent } from '../messages.js' 9import { asSystemPrompt } from '../systemPromptType.js' 10import type { REPLHookContext } from './postSamplingHooks.js' 11 12export type ApiQueryHookContext = REPLHookContext & { 13 queryMessageCount?: number 14} 15 16export type ApiQueryHookConfig<TResult> = { 17 name: QuerySource 18 shouldRun: (context: ApiQueryHookContext) => Promise<boolean> 19 20 // Build the complete message list to send to the API 21 buildMessages: (context: ApiQueryHookContext) => Message[] 22 23 // Optional: override system prompt (defaults to context.systemPrompt) 24 systemPrompt?: string 25 26 // Optional: whether to use tools from context (defaults to true) 27 // Set to false to pass empty tools array 28 useTools?: boolean 29 30 parseResponse: (content: string, context: ApiQueryHookContext) => TResult 31 logResult: ( 32 result: ApiQueryResult<TResult>, 33 context: ApiQueryHookContext, 34 ) => void 35 // Must be a function to ensure lazy loading (config is accessed before allowed) 36 // Receives context so callers can inherit the main loop model if desired. 37 getModel: (context: ApiQueryHookContext) => string 38} 39 40export type ApiQueryResult<TResult> = 41 | { 42 type: 'success' 43 queryName: string 44 result: TResult 45 messageId: string 46 model: string 47 uuid: string 48 } 49 | { 50 type: 'error' 51 queryName: string 52 error: Error 53 uuid: string 54 } 55 56export function createApiQueryHook<TResult>( 57 config: ApiQueryHookConfig<TResult>, 58) { 59 return async (context: ApiQueryHookContext): Promise<void> => { 60 try { 61 const shouldRun = await config.shouldRun(context) 62 if (!shouldRun) { 63 return 64 } 65 66 const uuid = randomUUID() 67 68 // Build messages using the config's buildMessages function 69 const messages = config.buildMessages(context) 70 context.queryMessageCount = messages.length 71 72 // Use config's system prompt if provided, otherwise use context's 73 const systemPrompt = config.systemPrompt 74 ? asSystemPrompt([config.systemPrompt]) 75 : context.systemPrompt 76 77 // Use config's tools preference (defaults to true = use context tools) 78 const useTools = config.useTools ?? true 79 const tools = useTools ? context.toolUseContext.options.tools : [] 80 81 // Get model (lazy loaded) 82 const model = config.getModel(context) 83 84 // Make API call 85 const response = await queryModelWithoutStreaming({ 86 messages, 87 systemPrompt, 88 thinkingConfig: { type: 'disabled' as const }, 89 tools, 90 signal: createAbortController().signal, 91 options: { 92 getToolPermissionContext: async () => { 93 const appState = context.toolUseContext.getAppState() 94 return appState.toolPermissionContext 95 }, 96 model, 97 toolChoice: undefined, 98 isNonInteractiveSession: 99 context.toolUseContext.options.isNonInteractiveSession, 100 hasAppendSystemPrompt: 101 !!context.toolUseContext.options.appendSystemPrompt, 102 temperatureOverride: 0, 103 agents: context.toolUseContext.options.agentDefinitions.activeAgents, 104 querySource: config.name, 105 mcpTools: [], 106 agentId: context.toolUseContext.agentId, 107 }, 108 }) 109 110 // Parse response 111 const content = extractTextContent(response.message.content).trim() 112 113 try { 114 const result = config.parseResponse(content, context) 115 config.logResult( 116 { 117 type: 'success', 118 queryName: config.name, 119 result, 120 messageId: response.message.id, 121 model, 122 uuid, 123 }, 124 context, 125 ) 126 } catch (error) { 127 config.logResult( 128 { 129 type: 'error', 130 queryName: config.name, 131 error: error as Error, 132 uuid, 133 }, 134 context, 135 ) 136 } 137 } catch (error) { 138 logError(toError(error)) 139 } 140 } 141}