1/*
2 * Written in 2019 by Andrew Ayer.
3 * Patched 2025, Bluesky Social PBC.
4 *
5 * Original: https://www.agwa.name/blog/post/preventing_server_side_request_forgery_in_golang
6 *
7 * To the extent possible under law, the author(s) have dedicated all
8 * copyright and related and neighboring rights to this software to the
9 * public domain worldwide. This software is distributed without any
10 * warranty.
11 *
12 * You should have received a copy of the CC0 Public
13 * Domain Dedication along with this software. If not, see
14 * <https://creativecommons.org/publicdomain/zero/1.0/>.
15 */
16package ssrf
17
18import (
19 "fmt"
20 "net"
21 "net/http"
22 "syscall"
23 "time"
24)
25
26func ipv4Net(a, b, c, d byte, subnetPrefixLen int) net.IPNet {
27 return net.IPNet{
28 IP: net.IPv4(a, b, c, d),
29 Mask: net.CIDRMask(96+subnetPrefixLen, 128),
30 }
31}
32
33var reservedIPv4Nets = []net.IPNet{
34 ipv4Net(0, 0, 0, 0, 8), // Current network
35 ipv4Net(10, 0, 0, 0, 8), // Private
36 ipv4Net(100, 64, 0, 0, 10), // RFC6598
37 ipv4Net(127, 0, 0, 0, 8), // Loopback
38 ipv4Net(169, 254, 0, 0, 16), // Link-local
39 ipv4Net(172, 16, 0, 0, 12), // Private
40 ipv4Net(192, 0, 0, 0, 24), // RFC6890
41 ipv4Net(192, 0, 2, 0, 24), // Test, doc, examples
42 ipv4Net(192, 88, 99, 0, 24), // IPv6 to IPv4 relay
43 ipv4Net(192, 168, 0, 0, 16), // Private
44 ipv4Net(198, 18, 0, 0, 15), // Benchmarking tests
45 ipv4Net(198, 51, 100, 0, 24), // Test, doc, examples
46 ipv4Net(203, 0, 113, 0, 24), // Test, doc, examples
47 ipv4Net(224, 0, 0, 0, 4), // Multicast
48 ipv4Net(240, 0, 0, 0, 4), // Reserved (includes broadcast / 255.255.255.255)
49}
50
51var globalUnicastIPv6Net = net.IPNet{
52 IP: net.IP{0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
53 Mask: net.CIDRMask(3, 128),
54}
55
56func isIPv6GlobalUnicast(address net.IP) bool {
57 return globalUnicastIPv6Net.Contains(address)
58}
59
60func isIPv4Reserved(address net.IP) bool {
61 for _, reservedNet := range reservedIPv4Nets {
62 if reservedNet.Contains(address) {
63 return true
64 }
65 }
66 return false
67}
68
69func IsPublicIPAddress(address net.IP) bool {
70 if address.To4() != nil {
71 return !isIPv4Reserved(address)
72 } else {
73 return isIPv6GlobalUnicast(address)
74 }
75}
76
77// Implementation of [net.Dialer] `Control` field (a function) which avoids some SSRF attacks by rejecting local IPv4 and IPv6 address ranges, and only allowing ports 80 or 443.
78func PublicOnlyControl(network string, address string, conn syscall.RawConn) error {
79 if !(network == "tcp4" || network == "tcp6") {
80 return fmt.Errorf("%s is not a safe network type", network)
81 }
82
83 host, port, err := net.SplitHostPort(address)
84 if err != nil {
85 return fmt.Errorf("%s is not a valid host/port pair: %s", address, err)
86 }
87
88 ipaddress := net.ParseIP(host)
89 if ipaddress == nil {
90 return fmt.Errorf("%s is not a valid IP address", host)
91 }
92
93 if !IsPublicIPAddress(ipaddress) {
94 return fmt.Errorf("%s is not a public IP address", ipaddress)
95 }
96
97 if !(port == "80" || port == "443") {
98 return fmt.Errorf("%s is not a safe port number", port)
99 }
100
101 return nil
102}
103
104// [net.Dialer] with [PublicOnlyControl] for `Control` function (for SSRF protection). Other fields are same default values as standard library.
105func PublicOnlyDialer() *net.Dialer {
106 return &net.Dialer{
107 Timeout: 30 * time.Second,
108 KeepAlive: 30 * time.Second,
109 DualStack: true,
110 Control: PublicOnlyControl,
111 }
112}
113
114// [http.Transport] with [PublicOnlyDialer] for `DialContext` field (for SSRF protection). Other fields are same default values as standard library.
115//
116// Use this in an [http.Client] like: `c := http.Client{ Transport: PublicOnlyTransport() }`
117func PublicOnlyTransport() *http.Transport {
118 dialer := PublicOnlyDialer()
119 return &http.Transport{
120 Proxy: http.ProxyFromEnvironment,
121 DialContext: dialer.DialContext,
122 ForceAttemptHTTP2: true,
123 MaxIdleConns: 100,
124 IdleConnTimeout: 90 * time.Second,
125 TLSHandshakeTimeout: 10 * time.Second,
126 ExpectContinueTimeout: 1 * time.Second,
127 }
128}