source dump of claude code
at main 166 lines 5.2 kB view raw
1import { 2 getModelStrings as getModelStringsState, 3 setModelStrings as setModelStringsState, 4} from 'src/bootstrap/state.js' 5import { logError } from '../log.js' 6import { sequential } from '../sequential.js' 7import { getInitialSettings } from '../settings/settings.js' 8import { findFirstMatch, getBedrockInferenceProfiles } from './bedrock.js' 9import { 10 ALL_MODEL_CONFIGS, 11 CANONICAL_ID_TO_KEY, 12 type CanonicalModelId, 13 type ModelKey, 14} from './configs.js' 15import { type APIProvider, getAPIProvider } from './providers.js' 16 17/** 18 * Maps each model version to its provider-specific model ID string. 19 * Derived from ALL_MODEL_CONFIGS — adding a model there extends this type. 20 */ 21export type ModelStrings = Record<ModelKey, string> 22 23const MODEL_KEYS = Object.keys(ALL_MODEL_CONFIGS) as ModelKey[] 24 25function getBuiltinModelStrings(provider: APIProvider): ModelStrings { 26 const out = {} as ModelStrings 27 for (const key of MODEL_KEYS) { 28 out[key] = ALL_MODEL_CONFIGS[key][provider] 29 } 30 return out 31} 32 33async function getBedrockModelStrings(): Promise<ModelStrings> { 34 const fallback = getBuiltinModelStrings('bedrock') 35 let profiles: string[] | undefined 36 try { 37 profiles = await getBedrockInferenceProfiles() 38 } catch (error) { 39 logError(error as Error) 40 return fallback 41 } 42 if (!profiles?.length) { 43 return fallback 44 } 45 // Each config's firstParty ID is the canonical substring we search for in the 46 // user's inference profile list (e.g. "claude-opus-4-6" matches 47 // "eu.anthropic.claude-opus-4-6-v1"). Fall back to the hardcoded bedrock ID 48 // when no matching profile is found. 49 const out = {} as ModelStrings 50 for (const key of MODEL_KEYS) { 51 const needle = ALL_MODEL_CONFIGS[key].firstParty 52 out[key] = findFirstMatch(profiles, needle) || fallback[key] 53 } 54 return out 55} 56 57/** 58 * Layer user-configured modelOverrides (from settings.json) on top of the 59 * provider-derived model strings. Overrides are keyed by canonical first-party 60 * model ID (e.g. "claude-opus-4-6") and map to arbitrary provider-specific 61 * strings — typically Bedrock inference profile ARNs. 62 */ 63function applyModelOverrides(ms: ModelStrings): ModelStrings { 64 const overrides = getInitialSettings().modelOverrides 65 if (!overrides) { 66 return ms 67 } 68 const out = { ...ms } 69 for (const [canonicalId, override] of Object.entries(overrides)) { 70 const key = CANONICAL_ID_TO_KEY[canonicalId as CanonicalModelId] 71 if (key && override) { 72 out[key] = override 73 } 74 } 75 return out 76} 77 78/** 79 * Resolve an overridden model ID (e.g. a Bedrock ARN) back to its canonical 80 * first-party model ID. If the input doesn't match any current override value, 81 * it is returned unchanged. Safe to call during module init (no-ops if settings 82 * aren't loaded yet). 83 */ 84export function resolveOverriddenModel(modelId: string): string { 85 let overrides: Record<string, string> | undefined 86 try { 87 overrides = getInitialSettings().modelOverrides 88 } catch { 89 return modelId 90 } 91 if (!overrides) { 92 return modelId 93 } 94 for (const [canonicalId, override] of Object.entries(overrides)) { 95 if (override === modelId) { 96 return canonicalId 97 } 98 } 99 return modelId 100} 101 102const updateBedrockModelStrings = sequential(async () => { 103 if (getModelStringsState() !== null) { 104 // Already initialized. Doing the check here, combined with 105 // `sequential`, allows the test suite to reset the state 106 // between tests while still preventing multiple API calls 107 // in production. 108 return 109 } 110 try { 111 const ms = await getBedrockModelStrings() 112 setModelStringsState(ms) 113 } catch (error) { 114 logError(error as Error) 115 } 116}) 117 118function initModelStrings(): void { 119 const ms = getModelStringsState() 120 if (ms !== null) { 121 // Already initialized 122 return 123 } 124 // Initial with default values for non-Bedrock providers 125 if (getAPIProvider() !== 'bedrock') { 126 setModelStringsState(getBuiltinModelStrings(getAPIProvider())) 127 return 128 } 129 // On Bedrock, update model strings in the background without blocking. 130 // Don't set the state in this case so that we can use `sequential` on 131 // `updateBedrockModelStrings` and check for existing state on multiple 132 // calls. 133 void updateBedrockModelStrings() 134} 135 136export function getModelStrings(): ModelStrings { 137 const ms = getModelStringsState() 138 if (ms === null) { 139 initModelStrings() 140 // Bedrock path falls through here while the profile fetch runs in the 141 // background — still honor overrides on the interim defaults. 142 return applyModelOverrides(getBuiltinModelStrings(getAPIProvider())) 143 } 144 return applyModelOverrides(ms) 145} 146 147/** 148 * Ensure model strings are fully initialized. 149 * For Bedrock users, this waits for the profile fetch to complete. 150 * Call this before generating model options to ensure correct region strings. 151 */ 152export async function ensureModelStringsInitialized(): Promise<void> { 153 const ms = getModelStringsState() 154 if (ms !== null) { 155 return 156 } 157 158 // For non-Bedrock, initialize synchronously 159 if (getAPIProvider() !== 'bedrock') { 160 setModelStringsState(getBuiltinModelStrings(getAPIProvider())) 161 return 162 } 163 164 // For Bedrock, wait for the profile fetch 165 await updateBedrockModelStrings() 166}