this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

listests: use x/tools/go/ast/inspector

+207 -156
+207 -156
listests/main.go
··· 19 19 "text/template" 20 20 "unicode" 21 21 22 + "golang.org/x/tools/go/ast/inspector" 22 23 "golang.org/x/tools/go/packages" 23 24 ) 24 25 ··· 53 54 flagVerbose bool 54 55 flagVimgrep bool 55 56 flagFormat string 57 + flagDir string 56 58 ) 57 59 58 60 fs.StringVar(&flagTags, "tags", "", "comma-separated list of build tags to apply") 59 61 fs.BoolVar(&flagVerbose, "v", false, "verbose mode") 60 62 fs.BoolVar(&flagVimgrep, "vimgrep", false, "output in ripgrep's vimgrep format") 61 63 fs.StringVar(&flagFormat, "format", "", "output format") 64 + fs.StringVar(&flagDir, "dir", ".", "directory to run in") 62 65 63 66 fs.Usage = func() { 64 67 fmt.Fprintf(fs.Output(), "Usage: %s [options] [packages...]\n", fs.Name()) ··· 90 93 logger("Discovering tests...\n") 91 94 tests, err := findTestsInPackages( 92 95 ctx, 93 - ".", 96 + flagDir, 94 97 patterns, 95 98 buildTags, 96 99 logger, ··· 106 109 return fmt.Errorf("cannot use -vimgrep and -format together") 107 110 } 108 111 109 - flagFormat = "{{.RelativeFileName}}:{{.Range.Start.Line}}:{{.Range.Start.Column}}:{{.PackageName}}:{{.FullName}}" 112 + flagFormat = "{{.RelativeFileName}}:{{.Range.Start.Line}}:{{.Range.Start.Column}}:{{.Package}}:{{.FullName}}" 110 113 } 111 114 112 115 if flagFormat == "" { ··· 234 237 continue 235 238 } 236 239 237 - // TODO: 238 - // fmt.Println("Errors:", pkg.Errors) 239 - // fmt.Println("TypeErrors:", pkg.TypeErrors) 240 - 240 + var testFiles []*ast.File 241 241 for _, file := range pkg.Syntax { 242 242 filename := pkg.Fset.Position(file.Pos()).Filename 243 - if !strings.HasSuffix(filename, "_test.go") { 244 - continue 243 + if strings.HasSuffix(filename, "_test.go") { 244 + testFiles = append(testFiles, file) 245 245 } 246 + } 246 247 247 - moduleName := pkg.Module.Path 248 - pkgPath := pkg.PkgPath 249 - packageName := strings.TrimPrefix(pkgPath, moduleName+"/") 250 - directory := pkg.Dir 248 + if len(testFiles) == 0 { 249 + continue 250 + } 251 + 252 + moduleName := pkg.Module.Path 253 + pkgPath := pkg.PkgPath 254 + packageName := strings.TrimPrefix(pkgPath, moduleName+"/") 255 + directory := pkg.Dir 256 + 257 + inspect := inspector.New(testFiles) 251 258 252 - logger("Processing %s in package %s...\n", filename, packageName) 253 - tests := findTestsInFile(file, pkg.Fset, filename, packageName, directory) 254 - allTests = append(allTests, tests...) 255 - } 259 + tests := func() []*TestInfo { 260 + finder := newTestFinder(pkg.Fset, packageName, directory, logger) 261 + return finder.find(inspect) 262 + }() 263 + allTests = append(allTests, tests...) 256 264 } 257 265 258 266 return allTests, nil 259 267 } 260 268 261 - func findTestsInFile(file *ast.File, fset *token.FileSet, filename, pkgName string, dir string) []*TestInfo { 262 - var tests []*TestInfo 269 + type testFinder struct { 270 + fset *token.FileSet 271 + pkgName string 272 + directory string 273 + logger func(string, ...any) 274 + 275 + allTests []*TestInfo 276 + testMap map[ast.Node]*TestInfo 277 + } 278 + 279 + func newTestFinder(fset *token.FileSet, pkgName, dir string, logger func(string, ...any)) *testFinder { 280 + return &testFinder{ 281 + fset: fset, 282 + pkgName: pkgName, 283 + directory: dir, 284 + logger: logger, 285 + allTests: []*TestInfo{}, 286 + testMap: make(map[ast.Node]*TestInfo), 287 + } 288 + } 289 + 290 + func (tf *testFinder) find(inspect *inspector.Inspector) []*TestInfo { 291 + nodeFilter := []ast.Node{ 292 + (*ast.FuncDecl)(nil), 293 + (*ast.CallExpr)(nil), 294 + } 263 295 264 - for _, decl := range file.Decls { 265 - funcDecl, ok := decl.(*ast.FuncDecl) 266 - if !ok || funcDecl.Name == nil { 267 - continue 296 + inspect.WithStack(nodeFilter, func(node ast.Node, push bool, stack []ast.Node) bool { 297 + if !push { 298 + return true 268 299 } 269 300 270 - if strings.HasPrefix(funcDecl.Name.Name, "Test") { 271 - if isTestFunction(funcDecl) { 272 - testName := funcDecl.Name.Name 301 + switch n := node.(type) { 302 + case *ast.FuncDecl: 303 + tf.handleFuncDecl(n) 304 + case *ast.CallExpr: 305 + tf.handleCallExpr(n, stack) 306 + } 273 307 274 - start := fset.Position(funcDecl.Name.Pos()) 275 - end := fset.Position(funcDecl.End()) 308 + return true 309 + }) 276 310 277 - test := &TestInfo{ 278 - Name: testName, 279 - FullName: testName, 280 - Package: pkgName, 281 - Directory: dir, 282 - File: filename, 283 - // Start: start.Line, 284 - // Column: start.Column, 285 - Range: SourceRange{ 286 - Start: SourcePosition{ 287 - Line: start.Line, 288 - Column: start.Column, 289 - }, 290 - End: SourcePosition{ 291 - Line: end.Line, 292 - Column: end.Column, 293 - }, 294 - }, 295 - HasGeneratedName: false, 296 - IsSubtest: false, 297 - SubTests: nil, 298 - } 311 + return tf.allTests 312 + } 313 + 314 + func (tf *testFinder) handleFuncDecl(n *ast.FuncDecl) { 315 + if n.Name == nil || !strings.HasPrefix(n.Name.Name, "Test") || !isTestFunction(n) { 316 + return 317 + } 318 + 319 + filename := tf.fset.Position(n.Pos()).Filename 320 + tf.logger("Processing %s in package %s...\n", filename, tf.pkgName) 321 + 322 + start := tf.fset.Position(n.Name.Pos()) 323 + end := tf.fset.Position(n.End()) 324 + 325 + test := &TestInfo{ 326 + Name: n.Name.Name, 327 + FullName: n.Name.Name, 328 + Package: tf.pkgName, 329 + Directory: tf.directory, 330 + File: filename, 331 + Range: SourceRange{ 332 + Start: SourcePosition{ 333 + Line: start.Line, 334 + Column: start.Column, 335 + }, 336 + End: SourcePosition{ 337 + Line: end.Line, 338 + Column: end.Column, 339 + }, 340 + }, 341 + HasGeneratedName: false, 342 + IsSubtest: false, 343 + SubTests: nil, 344 + } 345 + 346 + tf.testMap[n] = test 347 + tf.allTests = append(tf.allTests, test) 348 + } 349 + 350 + func (tf *testFinder) handleCallExpr(n *ast.CallExpr, stack []ast.Node) { 351 + if !tf.isRunCall(n) { 352 + return 353 + } 354 + 355 + parentTest := tf.findParentTest(stack) 356 + if parentTest == nil { 357 + return 358 + } 359 + 360 + subTest := tf.createSubTest(n, parentTest) 361 + if subTest == nil { 362 + return 363 + } 364 + 365 + parentTest.SubTests = append(parentTest.SubTests, subTest) 366 + 367 + // Map the function literal body to this subtest so nested t.Run calls can find it 368 + if funcLit, ok := n.Args[1].(*ast.FuncLit); ok && funcLit.Body != nil { 369 + tf.testMap[funcLit.Body] = subTest 370 + } 371 + } 372 + 373 + func (tf *testFinder) isRunCall(n *ast.CallExpr) bool { 374 + selExpr, ok := n.Fun.(*ast.SelectorExpr) 375 + return ok && selExpr.Sel.Name == "Run" && len(n.Args) >= 2 376 + } 377 + 378 + func (tf *testFinder) findParentTest(stack []ast.Node) *TestInfo { 379 + for i := len(stack) - 1; i >= 0; i-- { 380 + if test, exists := tf.testMap[stack[i]]; exists { 381 + return test 382 + } 383 + } 384 + return nil 385 + } 299 386 300 - if funcDecl.Body != nil { 301 - findSubtests(funcDecl.Body, test, fset, filename, pkgName) 302 - } 387 + func (tf *testFinder) createSubTest(n *ast.CallExpr, parentTest *TestInfo) *TestInfo { 388 + filename := tf.fset.Position(n.Pos()).Filename 389 + start := tf.fset.Position(n.Pos()) 390 + end := tf.fset.Position(n.End()) 303 391 304 - tests = append(tests, test) 305 - } 392 + switch arg := n.Args[0].(type) { 393 + case *ast.BasicLit: 394 + if arg.Kind == token.STRING { 395 + return tf.createLiteralSubTest(arg, parentTest, filename, start, end) 306 396 } 397 + default: 398 + return tf.createGeneratedSubTest(arg, parentTest, filename, start, end) 307 399 } 400 + return nil 401 + } 308 402 309 - return tests 403 + func (tf *testFinder) createLiteralSubTest(arg *ast.BasicLit, parentTest *TestInfo, filename string, start, end token.Position) *TestInfo { 404 + subtestName := strings.Trim(arg.Value, "\"'`") 405 + sanitizedSubtestName := rewriteSubTestName(subtestName) 406 + 407 + fullName := parentTest.FullName + "/" + subtestName 408 + sanitizedFullName := parentTest.FullName + "/" + sanitizedSubtestName 409 + 410 + return &TestInfo{ 411 + Name: subtestName, 412 + DisplayName: sanitizedSubtestName, 413 + FullName: fullName, 414 + FullDisplayName: sanitizedFullName, 415 + Package: tf.pkgName, 416 + Directory: tf.directory, 417 + File: filename, 418 + Range: SourceRange{ 419 + Start: SourcePosition{ 420 + Line: start.Line, 421 + Column: start.Column, 422 + }, 423 + End: SourcePosition{ 424 + Line: end.Line, 425 + Column: end.Column, 426 + }, 427 + }, 428 + HasGeneratedName: false, 429 + IsSubtest: true, 430 + } 431 + } 432 + 433 + func (tf *testFinder) createGeneratedSubTest(arg ast.Expr, parentTest *TestInfo, filename string, start, end token.Position) *TestInfo { 434 + var buf bytes.Buffer 435 + err := printer.Fprint(&buf, tf.fset, arg) 436 + if err != nil { 437 + fmt.Fprintf(os.Stderr, "Error printing argument: %v\n", err) 438 + return nil 439 + } 440 + subtestName := fmt.Sprintf("<%s>", strings.TrimSpace(buf.String())) 441 + fullName := fmt.Sprintf("%s/%s", parentTest.FullName, subtestName) 442 + 443 + return &TestInfo{ 444 + Name: subtestName, 445 + FullName: fullName, 446 + Package: tf.pkgName, 447 + Directory: tf.directory, 448 + File: filename, 449 + Range: SourceRange{ 450 + Start: SourcePosition{ 451 + Line: start.Line, 452 + Column: start.Column, 453 + }, 454 + End: SourcePosition{ 455 + Line: end.Line, 456 + Column: end.Column, 457 + }, 458 + }, 459 + HasGeneratedName: true, 460 + IsSubtest: true, 461 + } 310 462 } 311 463 312 464 // Can do better by checking the parameter type but this is faster and good ··· 333 485 } 334 486 335 487 return false 336 - } 337 - 338 - func findSubtests(block *ast.BlockStmt, parentTest *TestInfo, fset *token.FileSet, filename, pkgName string) { 339 - ast.Inspect(block, func(n ast.Node) bool { 340 - callExpr, ok := n.(*ast.CallExpr) 341 - if !ok { 342 - return true 343 - } 344 - 345 - selExpr, ok := callExpr.Fun.(*ast.SelectorExpr) 346 - if !ok { 347 - return true 348 - } 349 - 350 - if selExpr.Sel.Name != "Run" { 351 - return true 352 - } 353 - 354 - // NOTE: no type check here. 355 - if len(callExpr.Args) < 2 { 356 - return true 357 - } 358 - 359 - start := fset.Position(callExpr.Pos()) 360 - end := fset.Position(callExpr.End()) 361 - 362 - var subTest *TestInfo 363 - switch arg := callExpr.Args[0].(type) { 364 - case *ast.BasicLit: 365 - if arg.Kind == token.STRING { 366 - subtestName := strings.Trim(arg.Value, "\"'`") 367 - sanitizedSubtestName := rewriteSubTestName(subtestName) 368 - 369 - fullName := parentTest.FullName + "/" + subtestName 370 - sanitizedFullName := parentTest.FullName + "/" + sanitizedSubtestName 371 - 372 - subTest = &TestInfo{ 373 - Name: subtestName, 374 - DisplayName: sanitizedSubtestName, 375 - FullName: fullName, 376 - FullDisplayName: sanitizedFullName, 377 - Package: pkgName, 378 - Directory: parentTest.Directory, 379 - File: filename, 380 - Range: SourceRange{ 381 - Start: SourcePosition{ 382 - Line: start.Line, 383 - Column: start.Column, 384 - }, 385 - End: SourcePosition{ 386 - Line: end.Line, 387 - Column: end.Column, 388 - }, 389 - }, 390 - HasGeneratedName: false, 391 - IsSubtest: true, 392 - } 393 - } 394 - default: 395 - // TODO: how to report runtime generated names? 396 - var buf bytes.Buffer 397 - err := printer.Fprint(&buf, fset, arg) 398 - if err != nil { 399 - fmt.Fprintf(os.Stderr, "Error printing argument: %v\n", err) 400 - return true 401 - } 402 - subtestName := fmt.Sprintf("<%s>", strings.TrimSpace(buf.String())) 403 - fullName := fmt.Sprintf("%s/%s", parentTest.FullName, subtestName) 404 - subTest = &TestInfo{ 405 - Name: subtestName, 406 - FullName: fullName, 407 - Package: pkgName, 408 - Directory: parentTest.Directory, 409 - File: filename, 410 - Range: SourceRange{ 411 - Start: SourcePosition{ 412 - Line: start.Line, 413 - Column: start.Column, 414 - }, 415 - End: SourcePosition{ 416 - Line: end.Line, 417 - Column: end.Column, 418 - }, 419 - }, 420 - HasGeneratedName: true, 421 - IsSubtest: true, 422 - } 423 - } 424 - 425 - if subTest != nil { 426 - parentTest.SubTests = append(parentTest.SubTests, subTest) 427 - 428 - if funcLit, ok := callExpr.Args[1].(*ast.FuncLit); ok && funcLit.Body != nil { 429 - findSubtests(funcLit.Body, subTest, fset, filename, pkgName) 430 - } 431 - 432 - return false 433 - } 434 - 435 - return true 436 - }) 437 488 } 438 489 439 490 func iterTests(tests []*TestInfo) iter.Seq[*TestInfo] {