porting all github actions from bluesky-social/indigo to tangled CI
at main 11 kB view raw
1// Package lex generates Go code for lexicons. 2// 3// (It is not a lexer.) 4package lex 5 6import ( 7 "bytes" 8 "encoding/json" 9 "fmt" 10 "io" 11 "os" 12 "path/filepath" 13 "sort" 14 "strings" 15 16 "golang.org/x/tools/imports" 17) 18 19const ( 20 EncodingCBOR = "application/cbor" 21 EncodingJSON = "application/json" 22 EncodingJSONL = "application/jsonl" 23 EncodingCAR = "application/vnd.ipld.car" 24 EncodingMP4 = "video/mp4" 25 EncodingANY = "*/*" 26) 27 28type outputType struct { 29 Name string 30 Type *TypeSchema 31 NeedsCbor bool 32 NeedsType bool 33} 34 35// Build total map of all types defined inside schemas. 36// Return map from fully qualified type name to its *TypeSchema 37func BuildExtDefMap(ss []*Schema, packages []Package) map[string]*ExtDef { 38 out := make(map[string]*ExtDef) 39 for _, s := range ss { 40 for k, d := range s.Defs { 41 d.defMap = out 42 d.id = s.ID 43 d.defName = k 44 45 var pref string 46 for _, pkg := range packages { 47 if strings.HasPrefix(s.ID, pkg.Prefix) { 48 pref = pkg.Prefix 49 break 50 } 51 } 52 d.prefix = pref 53 54 n := s.ID 55 if k != "main" { 56 n = s.ID + "#" + k 57 } 58 out[n] = &ExtDef{ 59 Type: d, 60 } 61 } 62 } 63 return out 64} 65 66type ExtDef struct { 67 Type *TypeSchema 68} 69 70// TODO: this method is necessary because in lexicon there is no way to know if 71// a type needs to be marshaled with a "$type" field up front, you can only 72// know for sure by seeing where the type is used. 73func FixRecordReferences(schemas []*Schema, defmap map[string]*ExtDef, prefix string) { 74 for _, s := range schemas { 75 if !strings.HasPrefix(s.ID, prefix) { 76 continue 77 } 78 79 tps := s.AllTypes(prefix, defmap) 80 for _, t := range tps { 81 if t.Type.Type == "record" { 82 t.NeedsType = true 83 t.Type.needsType = true 84 } 85 86 if t.Type.Type == "union" { 87 for _, r := range t.Type.Refs { 88 if r[0] == '#' { 89 r = s.ID + r 90 } 91 92 if _, known := defmap[r]; known != true { 93 panic(fmt.Sprintf("reference to unknown record type: %s", r)) 94 } 95 96 if t.NeedsCbor { 97 defmap[r].Type.needsCbor = true 98 } 99 } 100 } 101 } 102 } 103} 104 105func printerf(w io.Writer) func(format string, args ...any) { 106 return func(format string, args ...any) { 107 fmt.Fprintf(w, format, args...) 108 } 109} 110 111func GenCodeForSchema(pkg Package, reqcode bool, s *Schema, packages []Package, defmap map[string]*ExtDef) error { 112 err := os.MkdirAll(pkg.Outdir, 0755) 113 if err != nil { 114 return fmt.Errorf("%s: could not mkdir, %w", pkg.Outdir, err) 115 } 116 fname := filepath.Join(pkg.Outdir, s.Name()+".go") 117 buf := new(bytes.Buffer) 118 pf := printerf(buf) 119 120 s.prefix = pkg.Prefix 121 for _, d := range s.Defs { 122 d.prefix = pkg.Prefix 123 } 124 125 // Add the standard Go generated code header as recognized by GitHub, VS Code, etc. 126 // See https://golang.org/s/generatedcode. 127 pf("// Code generated by cmd/lexgen (see Makefile's lexgen); DO NOT EDIT.\n\n") 128 129 pf("package %s\n\n", pkg.GoPackage) 130 131 pf("// schema: %s\n\n", s.ID) 132 133 pf("import (\n") 134 pf("\t\"context\"\n") 135 pf("\t\"fmt\"\n") 136 pf("\t\"encoding/json\"\n") 137 pf("\tcbg \"github.com/whyrusleeping/cbor-gen\"\n") 138 pf("\t\"github.com/bluesky-social/indigo/lex/util\"\n") 139 for _, xpkg := range packages { 140 if xpkg.Prefix != pkg.Prefix { 141 pf("\t%s %q\n", importNameForPrefix(xpkg.Prefix), xpkg.Import) 142 } 143 } 144 pf(")\n\n") 145 146 tps := s.AllTypes(pkg.Prefix, defmap) 147 148 if err := writeDecoderRegister(buf, tps); err != nil { 149 return err 150 } 151 152 sort.Slice(tps, func(i, j int) bool { 153 return tps[i].Name < tps[j].Name 154 }) 155 for _, ot := range tps { 156 fmt.Println("TYPE: ", ot.Name, ot.NeedsCbor, ot.NeedsType) 157 if err := ot.Type.WriteType(ot.Name, buf); err != nil { 158 return err 159 } 160 } 161 162 // reqcode is always True 163 if reqcode { 164 name := nameFromID(s.ID, pkg.Prefix) 165 main, ok := s.Defs["main"] 166 if ok { 167 if err := writeMethods(name, main, buf); err != nil { 168 return err 169 } 170 } 171 } 172 173 if err := writeCodeFile(buf.Bytes(), fname); err != nil { 174 return err 175 } 176 177 return nil 178} 179 180func writeDecoderRegister(w io.Writer, tps []outputType) error { 181 var buf bytes.Buffer 182 outf := printerf(&buf) 183 184 for _, t := range tps { 185 if t.Type.needsType && !strings.Contains(t.Name, "_") { 186 id := t.Type.id 187 if t.Type.defName != "" { 188 id = id + "#" + t.Type.defName 189 } 190 if buf.Len() == 0 { 191 outf("func init() {\n") 192 } 193 outf("util.RegisterType(%q, &%s{})\n", id, t.Name) 194 } 195 } 196 if buf.Len() == 0 { 197 return nil 198 } 199 outf("}") 200 _, err := w.Write(buf.Bytes()) 201 return err 202} 203 204func writeCodeFile(b []byte, fname string) error { 205 fixed, err := imports.Process(fname, b, nil) 206 if err != nil { 207 werr := os.WriteFile("temp", b, 0664) 208 if werr != nil { 209 return werr 210 } 211 return fmt.Errorf("failed to format output of %q with goimports: %w (wrote failed file to ./temp)", fname, err) 212 } 213 214 if err := os.WriteFile(fname, fixed, 0664); err != nil { 215 return err 216 } 217 218 return nil 219} 220 221func writeMethods(typename string, ts *TypeSchema, w io.Writer) error { 222 switch ts.Type { 223 case "token": 224 n := ts.id 225 if ts.defName != "main" { 226 n += "#" + ts.defName 227 } 228 229 fmt.Fprintf(w, "const %s = %q\n", typename, n) 230 return nil 231 case "record": 232 return nil 233 case "query": 234 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename)) 235 case "procedure": 236 if ts.Input == nil || ts.Input.Schema == nil || ts.Input.Schema.Type == "object" { 237 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename)) 238 } else if ts.Input.Schema.Type == "ref" { 239 inputname, _ := ts.namesFromRef(ts.Input.Schema.Ref) 240 return ts.WriteRPC(w, typename, inputname) 241 } else { 242 return fmt.Errorf("unhandled input type: %s", ts.Input.Schema.Type) 243 } 244 case "object", "string": 245 return nil 246 case "subscription": 247 // TODO: should probably have some methods generated for this 248 return nil 249 default: 250 return fmt.Errorf("unrecognized lexicon type %q", ts.Type) 251 } 252} 253 254func nameFromID(id, prefix string) string { 255 parts := strings.Split(strings.TrimPrefix(id, prefix), ".") 256 var tname string 257 for _, s := range parts { 258 tname += strings.Title(s) 259 } 260 261 return tname 262 263} 264 265func orderedMapIter[T any](m map[string]T, cb func(string, T) error) error { 266 var keys []string 267 for k := range m { 268 keys = append(keys, k) 269 } 270 271 sort.Strings(keys) 272 273 for _, k := range keys { 274 if err := cb(k, m[k]); err != nil { 275 return err 276 } 277 } 278 return nil 279} 280 281func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error { 282 buf := new(bytes.Buffer) 283 284 if err := WriteXrpcServer(buf, schemas, pkg, impmap); err != nil { 285 return err 286 } 287 288 fname := filepath.Join(dir, "stubs.go") 289 if err := writeCodeFile(buf.Bytes(), fname); err != nil { 290 return err 291 } 292 293 if handlers { 294 buf := new(bytes.Buffer) 295 296 if err := WriteServerHandlers(buf, schemas, pkg, impmap); err != nil { 297 return err 298 } 299 300 fname := filepath.Join(dir, "handlers.go") 301 if err := writeCodeFile(buf.Bytes(), fname); err != nil { 302 return err 303 } 304 305 } 306 307 return nil 308} 309 310func importNameForPrefix(prefix string) string { 311 return strings.Join(strings.Split(prefix, "."), "") + "types" 312} 313 314func WriteServerHandlers(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error { 315 pf := printerf(w) 316 pf("package %s\n\n", pkg) 317 pf("import (\n") 318 pf("\t\"context\"\n") 319 pf("\t\"fmt\"\n") 320 pf("\t\"encoding/json\"\n") 321 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n") 322 for k, v := range impmap { 323 pf("\t%s\"%s\"\n", importNameForPrefix(k), v) 324 } 325 pf(")\n\n") 326 327 for _, s := range schemas { 328 329 var prefix string 330 for k := range impmap { 331 if strings.HasPrefix(s.ID, k) { 332 prefix = k 333 break 334 } 335 } 336 337 main, ok := s.Defs["main"] 338 if !ok { 339 fmt.Printf("WARNING: schema %q doesn't have a main def\n", s.ID) 340 continue 341 } 342 343 if main.Type == "procedure" || main.Type == "query" { 344 fname := idToTitle(s.ID) 345 tname := nameFromID(s.ID, prefix) 346 impname := importNameForPrefix(prefix) 347 if err := main.WriteHandlerStub(w, fname, tname, impname); err != nil { 348 return err 349 } 350 } 351 } 352 353 return nil 354} 355 356func WriteXrpcServer(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error { 357 pf := printerf(w) 358 pf("package %s\n\n", pkg) 359 pf("import (\n") 360 pf("\t\"context\"\n") 361 pf("\t\"fmt\"\n") 362 pf("\t\"encoding/json\"\n") 363 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n") 364 pf("\t\"github.com/labstack/echo/v4\"\n") 365 366 var prefixes []string 367 orderedMapIter[string](impmap, func(k, v string) error { 368 prefixes = append(prefixes, k) 369 pf("\t%s\"%s\"\n", importNameForPrefix(k), v) 370 return nil 371 }) 372 pf(")\n\n") 373 374 ssets := make(map[string][]*Schema) 375 for _, s := range schemas { 376 var pref string 377 for _, p := range prefixes { 378 if strings.HasPrefix(s.ID, p) { 379 pref = p 380 break 381 } 382 } 383 if pref == "" { 384 return fmt.Errorf("no matching prefix for schema %q (tried %s)", s.ID, prefixes) 385 } 386 387 ssets[pref] = append(ssets[pref], s) 388 } 389 390 for _, p := range prefixes { 391 ss := ssets[p] 392 393 pf("func (s *Server) RegisterHandlers%s(e *echo.Echo) error {\n", idToTitle(p)) 394 for _, s := range ss { 395 396 main, ok := s.Defs["main"] 397 if !ok { 398 continue 399 } 400 401 var verb string 402 switch main.Type { 403 case "query": 404 verb = "GET" 405 case "procedure": 406 verb = "POST" 407 default: 408 continue 409 } 410 411 pf("e.%s(\"/xrpc/%s\", s.Handle%s)\n", verb, s.ID, idToTitle(s.ID)) 412 } 413 414 pf("return nil\n}\n\n") 415 416 for _, s := range ss { 417 418 var prefix string 419 for k := range impmap { 420 if strings.HasPrefix(s.ID, k) { 421 prefix = k 422 break 423 } 424 } 425 426 main, ok := s.Defs["main"] 427 if !ok { 428 continue 429 } 430 431 if main.Type == "procedure" || main.Type == "query" { 432 fname := idToTitle(s.ID) 433 tname := nameFromID(s.ID, prefix) 434 impname := importNameForPrefix(prefix) 435 if err := main.WriteRPCHandler(w, fname, tname, impname); err != nil { 436 return fmt.Errorf("writing handler for %s: %w", s.ID, err) 437 } 438 } 439 } 440 } 441 442 return nil 443} 444 445func idToTitle(id string) string { 446 var fname string 447 for _, p := range strings.Split(id, ".") { 448 fname += strings.Title(p) 449 } 450 return fname 451} 452 453type Package struct { 454 GoPackage string `json:"package"` 455 Prefix string `json:"prefix"` 456 Outdir string `json:"outdir"` 457 Import string `json:"import"` 458} 459 460// ParsePackages reads a json blob which should be an array of Package{} objects. 461func ParsePackages(jsonBytes []byte) ([]Package, error) { 462 var packages []Package 463 err := json.Unmarshal(jsonBytes, &packages) 464 if err != nil { 465 return nil, err 466 } 467 return packages, nil 468} 469 470func Run(schemas []*Schema, externalSchemas []*Schema, packages []Package) error { 471 defmap := BuildExtDefMap(append(schemas, externalSchemas...), packages) 472 473 for _, pkg := range packages { 474 prefix := pkg.Prefix 475 FixRecordReferences(schemas, defmap, prefix) 476 } 477 478 for _, pkg := range packages { 479 for _, s := range schemas { 480 if !strings.HasPrefix(s.ID, pkg.Prefix) { 481 continue 482 } 483 484 if err := GenCodeForSchema(pkg, true, s, packages, defmap); err != nil { 485 return fmt.Errorf("failed to process schema %q: %w", s.path, err) 486 } 487 } 488 } 489 return nil 490}