Live video on the AT Protocol
1package aqhttp
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "net/http"
8 "time"
9)
10
11// TrustedTransport is a basic transport that adds User-Agent headers.
12// Use this for trusted infrastructure endpoints where SSRF is not a concern.
13type TrustedTransport struct {
14 Base http.RoundTripper
15}
16
17func (t *TrustedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
18 req.Header.Add("User-Agent", UserAgent)
19 return t.Base.RoundTrip(req)
20}
21
22// NewTrustedTransport creates a transport for trusted endpoints.
23func NewTrustedTransport() *TrustedTransport {
24 return &TrustedTransport{
25 Base: &http.Transport{
26 MaxIdleConns: 100,
27 IdleConnTimeout: 90 * time.Second,
28 TLSHandshakeTimeout: 10 * time.Second,
29 },
30 }
31}
32
33// UntrustedTransport validates destination IPs using DNS-over-HTTPS before connecting.
34// Prevents SSRF attacks by blocking private, loopback, and bogon IP ranges.
35type UntrustedTransport struct {
36 Base http.RoundTripper
37 resolver *DoHResolver
38}
39
40func (t *UntrustedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
41 req.Header.Add("User-Agent", UserAgent)
42 return t.Base.RoundTrip(req)
43}
44
45// NewUntrustedTransport creates a transport that validates all destination IPs.
46func NewUntrustedTransport() *UntrustedTransport {
47 resolver := NewDoHResolver("")
48
49 dialer := &net.Dialer{
50 Timeout: 30 * time.Second,
51 KeepAlive: 30 * time.Second,
52 }
53
54 transport := &http.Transport{
55 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
56 host, port, err := net.SplitHostPort(addr)
57 if err != nil {
58 return nil, fmt.Errorf("failed to parse address: %w", err)
59 }
60
61 // Resolve IPv4 addresses using DoH
62 ipv4Addrs, _ := resolver.Resolve(host, TypeA)
63 var validIP string
64
65 // Check IPv4 addresses first
66 for _, ip := range ipv4Addrs {
67 if !resolver.IsInvalidIP(ip) {
68 validIP = ip
69 break
70 }
71 }
72
73 // Fall back to IPv6 if no valid IPv4
74 if validIP == "" {
75 ipv6Addrs, _ := resolver.Resolve(host, TypeAAAA)
76 for _, ip := range ipv6Addrs {
77 if !resolver.IsInvalidIP(ip) {
78 validIP = ip
79 break
80 }
81 }
82 }
83
84 if validIP == "" {
85 return nil, fmt.Errorf("all resolved IPs for %s are private/invalid", host)
86 }
87
88 // Dial using the validated IP
89 targetAddr := net.JoinHostPort(validIP, port)
90 return dialer.DialContext(ctx, network, targetAddr)
91 },
92 MaxIdleConns: 100,
93 IdleConnTimeout: 90 * time.Second,
94 TLSHandshakeTimeout: 10 * time.Second,
95 }
96
97 return &UntrustedTransport{
98 Base: transport,
99 resolver: resolver,
100 }
101}