Live video on the AT Protocol
at natb/command-errors 204 lines 5.0 kB view raw
1package devenv 2 3import ( 4 "bufio" 5 "context" 6 "encoding/json" 7 "fmt" 8 "net" 9 "net/http" 10 "os/exec" 11 "path/filepath" 12 "runtime" 13 "strings" 14 "testing" 15 "time" 16 17 comatproto "github.com/bluesky-social/indigo/api/atproto" 18 "github.com/bluesky-social/indigo/atproto/identity" 19 "github.com/bluesky-social/indigo/xrpc" 20 "github.com/cenkalti/backoff" 21 "github.com/google/uuid" 22 "github.com/stretchr/testify/require" 23 "stream.place/streamplace/pkg/aqhttp" 24 "stream.place/streamplace/pkg/log" 25) 26 27type DevEnv struct { 28 PDSURL string `json:"pds-url"` 29 PLCURL string `json:"plc-url"` 30 Accounts []*DevEnvAccount `json:"accounts"` 31} 32 33func WithDevEnv(t *testing.T) *DevEnv { 34 _, filename, _, _ := runtime.Caller(0) 35 cmd := exec.Command("node", "../../js/dev-env/run.mjs") 36 cmd.Dir = filepath.Dir(filename) 37 38 // Start the command and get pipes for streaming output 39 stdout, err := cmd.StdoutPipe() 40 if err != nil { 41 t.Logf("Error getting stdout pipe: %v", err) 42 t.FailNow() 43 } 44 45 if err := cmd.Start(); err != nil { 46 t.Logf("Error starting dev env: %v", err) 47 t.FailNow() 48 } 49 50 var env DevEnv 51 52 scanner := bufio.NewScanner(stdout) 53 scanner.Scan() 54 err = json.Unmarshal(scanner.Bytes(), &env) 55 if err != nil { 56 t.Logf("Error unmarshalling dev-env stdout: %v", err) 57 t.FailNow() 58 } 59 env.Accounts = []*DevEnvAccount{} 60 61 go func() { 62 scanner := bufio.NewScanner(stdout) 63 for scanner.Scan() { 64 t.Logf("dev-env stdout: %s", scanner.Text()) 65 if scanner.Err() != nil { 66 return 67 } 68 } 69 }() 70 71 // Ensure cleanup happens when test finishes 72 t.Cleanup(func() { 73 t.Logf("killing dev env") 74 if cmd.Process != nil { 75 _ = cmd.Process.Kill() 76 _ = cmd.Wait() 77 } 78 }) 79 80 return &env 81} 82 83type DevEnvAccount struct { 84 Handle string 85 Email string 86 Password string 87 DID string 88 XRPC *xrpc.Client 89} 90 91func (d *DevEnv) CreateAccount(t *testing.T) *DevEnvAccount { 92 93 xrpcc := &xrpc.Client{ 94 Host: d.PDSURL, 95 Client: &aqhttp.Client, 96 } 97 98 uu, err := uuid.NewRandom() 99 require.NoError(t, err) 100 101 handle := fmt.Sprintf("sp-%s.test", uu.String()[:8]) 102 email := fmt.Sprintf("%s@example.com", handle) 103 password := "test" 104 105 out, err := comatproto.ServerCreateAccount(context.Background(), xrpcc, &comatproto.ServerCreateAccount_Input{ 106 Handle: handle, 107 Email: &email, 108 Password: &password, 109 }) 110 require.NoError(t, err) 111 log.Log(context.Background(), "created account", "did", out.Did, "handle", out.Handle) 112 113 session, err := comatproto.ServerCreateSession(context.Background(), xrpcc, &comatproto.ServerCreateSession_Input{ 114 Identifier: out.Handle, 115 Password: password, 116 }) 117 require.NoError(t, err) 118 119 xrpcc = &xrpc.Client{ 120 Host: d.PDSURL, 121 Client: &aqhttp.Client, 122 Auth: &xrpc.AuthInfo{ 123 Did: out.Did, 124 AccessJwt: session.AccessJwt, 125 RefreshJwt: session.RefreshJwt, 126 Handle: out.Handle, 127 }, 128 } 129 acct := &DevEnvAccount{ 130 Handle: out.Handle, 131 Email: email, 132 Password: password, 133 DID: out.Did, 134 XRPC: xrpcc, 135 } 136 d.Accounts = append(d.Accounts, acct) 137 return acct 138} 139 140// Custom RoundTripper for intercepting .test domain requests 141type TestRoundTripper struct { 142 DevEnv *DevEnv 143} 144 145func (d *DevEnv) TestHTTPClient() *http.Client { 146 return &http.Client{ 147 Transport: d.TestRoundTripper(), 148 } 149} 150 151func (d *DevEnv) TestRoundTripper() *TestRoundTripper { 152 return &TestRoundTripper{DevEnv: d} 153} 154 155func (rt *TestRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 156 if strings.HasSuffix(req.URL.Hostname(), ".test") { 157 log.Log(context.Background(), "intercepting .test domain request", "url", req.URL.String()) 158 upstreamURL := fmt.Sprintf("%s%s", rt.DevEnv.PDSURL, req.URL.Path) 159 upstreamReq, err := http.NewRequest(req.Method, upstreamURL, req.Body) 160 if err != nil { 161 return nil, err 162 } 163 upstreamReq.Header = req.Header 164 upstreamReq.Host = req.URL.Hostname() 165 upstreamResp, err := http.DefaultTransport.RoundTrip(upstreamReq) 166 if err != nil { 167 return nil, err 168 } 169 return upstreamResp, nil 170 } 171 // For non-.test domains, use the default transport 172 return http.DefaultTransport.RoundTrip(req) 173} 174 175func (d *DevEnv) TestDirectory() identity.Directory { 176 // We need to create a new directory with our custom client 177 base := identity.BaseDirectory{ 178 PLCURL: d.PLCURL, 179 HTTPClient: *d.TestHTTPClient(), 180 Resolver: net.Resolver{ 181 Dial: func(ctx context.Context, network, address string) (net.Conn, error) { 182 d := net.Dialer{Timeout: time.Second * 3} 183 return d.DialContext(ctx, network, address) 184 }, 185 }, 186 TryAuthoritativeDNS: true, 187 SkipDNSDomainSuffixes: []string{".bsky.social"}, 188 } 189 return &base 190} 191 192// More aggressive backoff for tests 193func NewExponentialBackOff() *backoff.ExponentialBackOff { 194 b := &backoff.ExponentialBackOff{ 195 InitialInterval: 100 * time.Millisecond, 196 RandomizationFactor: backoff.DefaultRandomizationFactor, 197 Multiplier: backoff.DefaultMultiplier, 198 MaxInterval: 2 * time.Second, 199 MaxElapsedTime: 10 * time.Second, 200 Clock: backoff.SystemClock, 201 } 202 b.Reset() 203 return b 204}