1// Copyright 2020 The Gitea Authors. All rights reserved.
2// SPDX-License-Identifier: MIT
3
4package proxyprotocol
5
6import (
7 "bufio"
8 "bytes"
9 "encoding/binary"
10 "io"
11 "net"
12 "strconv"
13 "strings"
14 "sync"
15 "time"
16
17 "forgejo.org/modules/log"
18)
19
20var (
21 // v1Prefix is the string we look for at the start of a connection
22 // to check if this connection is using the proxy protocol
23 v1Prefix = []byte("PROXY ")
24 v1PrefixLen = len(v1Prefix)
25 v2Prefix = []byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A")
26 v2PrefixLen = len(v2Prefix)
27)
28
29// Conn is used to wrap and underlying connection which is speaking the
30// Proxy Protocol. RemoteAddr() will return the address of the client
31// instead of the proxy address.
32type Conn struct {
33 bufReader *bufio.Reader
34 conn net.Conn
35 localAddr net.Addr
36 remoteAddr net.Addr
37 once sync.Once
38 proxyHeaderTimeout time.Duration
39 acceptUnknown bool
40}
41
42// NewConn is used to wrap a net.Conn speaking the proxy protocol into
43// a proxyprotocol.Conn
44func NewConn(conn net.Conn, timeout time.Duration) *Conn {
45 pConn := &Conn{
46 bufReader: bufio.NewReader(conn),
47 conn: conn,
48 proxyHeaderTimeout: timeout,
49 }
50 return pConn
51}
52
53// Read reads data from the connection.
54// It will initially read the proxy protocol header.
55// If there is an error parsing the header, it is returned and the socket is closed.
56func (p *Conn) Read(b []byte) (int, error) {
57 if err := p.readProxyHeaderOnce(); err != nil {
58 return 0, err
59 }
60 return p.bufReader.Read(b)
61}
62
63// ReadFrom reads data from a provided reader and copies it to the connection.
64func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
65 if err := p.readProxyHeaderOnce(); err != nil {
66 return 0, err
67 }
68 if rf, ok := p.conn.(io.ReaderFrom); ok {
69 return rf.ReadFrom(r)
70 }
71 return io.Copy(p.conn, r)
72}
73
74// WriteTo reads data from the connection and writes it to the writer.
75// It will initially read the proxy protocol header.
76// If there is an error parsing the header, it is returned and the socket is closed.
77func (p *Conn) WriteTo(w io.Writer) (int64, error) {
78 if err := p.readProxyHeaderOnce(); err != nil {
79 return 0, err
80 }
81 return p.bufReader.WriteTo(w)
82}
83
84// Write writes data to the connection.
85// Write can be made to time out and return an error after a fixed
86// time limit; see SetDeadline and SetWriteDeadline.
87func (p *Conn) Write(b []byte) (int, error) {
88 if err := p.readProxyHeaderOnce(); err != nil {
89 return 0, err
90 }
91 return p.conn.Write(b)
92}
93
94// Close closes the connection.
95// Any blocked Read or Write operations will be unblocked and return errors.
96func (p *Conn) Close() error {
97 return p.conn.Close()
98}
99
100// LocalAddr returns the local network address.
101func (p *Conn) LocalAddr() net.Addr {
102 _ = p.readProxyHeaderOnce()
103 if p.localAddr != nil {
104 return p.localAddr
105 }
106 return p.conn.LocalAddr()
107}
108
109// RemoteAddr returns the address of the client if the proxy
110// protocol is being used, otherwise just returns the address of
111// the socket peer. If there is an error parsing the header, the
112// address of the client is not returned, and the socket is closed.
113// One implication of this is that the call could block if the
114// client is slow. Using a Deadline is recommended if this is called
115// before Read()
116func (p *Conn) RemoteAddr() net.Addr {
117 _ = p.readProxyHeaderOnce()
118 if p.remoteAddr != nil {
119 return p.remoteAddr
120 }
121 return p.conn.RemoteAddr()
122}
123
124// SetDeadline sets the read and write deadlines associated
125// with the connection. It is equivalent to calling both
126// SetReadDeadline and SetWriteDeadline.
127//
128// A deadline is an absolute time after which I/O operations
129// fail instead of blocking. The deadline applies to all future
130// and pending I/O, not just the immediately following call to
131// Read or Write. After a deadline has been exceeded, the
132// connection can be refreshed by setting a deadline in the future.
133//
134// If the deadline is exceeded a call to Read or Write or to other
135// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
136// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
137// The error's Timeout method will return true, but note that there
138// are other possible errors for which the Timeout method will
139// return true even if the deadline has not been exceeded.
140//
141// An idle timeout can be implemented by repeatedly extending
142// the deadline after successful Read or Write calls.
143//
144// A zero value for t means I/O operations will not time out.
145func (p *Conn) SetDeadline(t time.Time) error {
146 return p.conn.SetDeadline(t)
147}
148
149// SetReadDeadline sets the deadline for future Read calls
150// and any currently-blocked Read call.
151// A zero value for t means Read will not time out.
152func (p *Conn) SetReadDeadline(t time.Time) error {
153 return p.conn.SetReadDeadline(t)
154}
155
156// SetWriteDeadline sets the deadline for future Write calls
157// and any currently-blocked Write call.
158// Even if write times out, it may return n > 0, indicating that
159// some of the data was successfully written.
160// A zero value for t means Write will not time out.
161func (p *Conn) SetWriteDeadline(t time.Time) error {
162 return p.conn.SetWriteDeadline(t)
163}
164
165// readProxyHeaderOnce will ensure that the proxy header has been read
166func (p *Conn) readProxyHeaderOnce() (err error) {
167 p.once.Do(func() {
168 if err = p.readProxyHeader(); err != nil && err != io.EOF {
169 log.Error("Failed to read proxy prefix: %v", err)
170 p.Close()
171 p.bufReader = bufio.NewReader(p.conn)
172 }
173 })
174 return err
175}
176
177func (p *Conn) readProxyHeader() error {
178 if p.proxyHeaderTimeout != 0 {
179 readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
180 _ = p.conn.SetReadDeadline(readDeadLine)
181 defer func() {
182 _ = p.conn.SetReadDeadline(time.Time{})
183 }()
184 }
185
186 inp, err := p.bufReader.Peek(v1PrefixLen)
187 if err != nil {
188 return err
189 }
190
191 if bytes.Equal(inp, v1Prefix) {
192 return p.readV1ProxyHeader()
193 }
194
195 inp, err = p.bufReader.Peek(v2PrefixLen)
196 if err != nil {
197 return err
198 }
199 if bytes.Equal(inp, v2Prefix) {
200 return p.readV2ProxyHeader()
201 }
202
203 return &ErrBadHeader{inp}
204}
205
206func (p *Conn) readV2ProxyHeader() error {
207 // The binary header format starts with a constant 12 bytes block containing the
208 // protocol signature :
209 //
210 // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
211 //
212 // Note that this block contains a null byte at the 5th position, so it must not
213 // be handled as a null-terminated string.
214
215 if _, err := p.bufReader.Discard(v2PrefixLen); err != nil {
216 // This shouldn't happen as we have already asserted that there should be enough in the buffer
217 return err
218 }
219
220 // The next byte (the 13th one) is the protocol version and command.
221 version, err := p.bufReader.ReadByte()
222 if err != nil {
223 return err
224 }
225
226 // The 14th byte contains the transport protocol and address family.otocol.
227 familyByte, err := p.bufReader.ReadByte()
228 if err != nil {
229 return err
230 }
231
232 // The 15th and 16th bytes is the address length in bytes in network endian order.
233 var addressLen uint16
234 if err := binary.Read(p.bufReader, binary.BigEndian, &addressLen); err != nil {
235 return err
236 }
237
238 // Now handle the version byte: (14th byte).
239 // The highest four bits contains the version. As of this specification, it must
240 // always be sent as \x2 and the receiver must only accept this value.
241 if version>>4 != 0x2 {
242 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
243 }
244
245 // The lowest four bits represents the command :
246 switch version & 0xf {
247 case 0x0:
248 // - \x0 : LOCAL : the connection was established on purpose by the proxy
249 // without being relayed. The connection endpoints are the sender and the
250 // receiver. Such connections exist when the proxy sends health-checks to the
251 // server. The receiver must accept this connection as valid and must use the
252 // real connection endpoints and discard the protocol block including the
253 // family which is ignored.
254
255 // We therefore ignore the 14th, 15th and 16th bytes
256 p.remoteAddr = p.conn.LocalAddr()
257 p.localAddr = p.conn.RemoteAddr()
258 return nil
259 case 0x1:
260 // - \x1 : PROXY : the connection was established on behalf of another node,
261 // and reflects the original connection endpoints. The receiver must then use
262 // the information provided in the protocol block to get original the address.
263 default:
264 // - other values are unassigned and must not be emitted by senders. Receivers
265 // must drop connections presenting unexpected values here.
266 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
267 }
268
269 // Now handle the familyByte byte: (15th byte).
270 // The highest 4 bits contain the address family, the lowest 4 bits contain the protocol
271
272 // The address family maps to the original socket family without necessarily
273 // matching the values internally used by the system. It may be one of :
274 //
275 // - 0x0 : AF_UNSPEC : the connection is forwarded for an unknown, unspecified
276 // or unsupported protocol. The sender should use this family when sending
277 // LOCAL commands or when dealing with unsupported protocol families. The
278 // receiver is free to accept the connection anyway and use the real endpoint
279 // addresses or to reject it. The receiver should ignore address information.
280 //
281 // - 0x1 : AF_INET : the forwarded connection uses the AF_INET address family
282 // (IPv4). The addresses are exactly 4 bytes each in network byte order,
283 // followed by transport protocol information (typically ports).
284 //
285 // - 0x2 : AF_INET6 : the forwarded connection uses the AF_INET6 address family
286 // (IPv6). The addresses are exactly 16 bytes each in network byte order,
287 // followed by transport protocol information (typically ports).
288 //
289 // - 0x3 : AF_UNIX : the forwarded connection uses the AF_UNIX address family
290 // (UNIX). The addresses are exactly 108 bytes each.
291 //
292 // - other values are unspecified and must not be emitted in version 2 of this
293 // protocol and must be rejected as invalid by receivers.
294
295 // The transport protocol is specified in the lowest 4 bits of the 14th byte :
296 //
297 // - 0x0 : UNSPEC : the connection is forwarded for an unknown, unspecified
298 // or unsupported protocol. The sender should use this family when sending
299 // LOCAL commands or when dealing with unsupported protocol families. The
300 // receiver is free to accept the connection anyway and use the real endpoint
301 // addresses or to reject it. The receiver should ignore address information.
302 //
303 // - 0x1 : STREAM : the forwarded connection uses a SOCK_STREAM protocol (eg:
304 // TCP or UNIX_STREAM). When used with AF_INET/AF_INET6 (TCP), the addresses
305 // are followed by the source and destination ports represented on 2 bytes
306 // each in network byte order.
307 //
308 // - 0x2 : DGRAM : the forwarded connection uses a SOCK_DGRAM protocol (eg:
309 // UDP or UNIX_DGRAM). When used with AF_INET/AF_INET6 (UDP), the addresses
310 // are followed by the source and destination ports represented on 2 bytes
311 // each in network byte order.
312 //
313 // - other values are unspecified and must not be emitted in version 2 of this
314 // protocol and must be rejected as invalid by receivers.
315
316 if familyByte>>4 == 0x0 || familyByte&0xf == 0x0 {
317 // - hi 0x0 : AF_UNSPEC : the connection is forwarded for an unknown address type
318 // or
319 // - lo 0x0 : UNSPEC : the connection is forwarded for an unspecified protocol
320 if !p.acceptUnknown {
321 p.conn.Close()
322 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
323 }
324 p.remoteAddr = p.conn.LocalAddr()
325 p.localAddr = p.conn.RemoteAddr()
326 _, err = p.bufReader.Discard(int(addressLen))
327 return err
328 }
329
330 // other address or protocol
331 if (familyByte>>4) > 0x3 || (familyByte&0xf) > 0x2 {
332 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
333 }
334
335 // Handle AF_UNIX addresses
336 if familyByte>>4 == 0x3 {
337 // - \x31 : UNIX stream : the forwarded connection uses SOCK_STREAM over the
338 // AF_UNIX protocol family. Address length is 2*108 = 216 bytes.
339 // - \x32 : UNIX datagram : the forwarded connection uses SOCK_DGRAM over the
340 // AF_UNIX protocol family. Address length is 2*108 = 216 bytes.
341 if addressLen != 216 {
342 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
343 }
344 remoteName := make([]byte, 108)
345 localName := make([]byte, 108)
346 if _, err := p.bufReader.Read(remoteName); err != nil {
347 return err
348 }
349 if _, err := p.bufReader.Read(localName); err != nil {
350 return err
351 }
352 protocol := "unix"
353 if familyByte&0xf == 2 {
354 protocol = "unixgram"
355 }
356
357 p.remoteAddr = &net.UnixAddr{
358 Name: string(remoteName),
359 Net: protocol,
360 }
361 p.localAddr = &net.UnixAddr{
362 Name: string(localName),
363 Net: protocol,
364 }
365 return nil
366 }
367
368 var remoteIP []byte
369 var localIP []byte
370 var remotePort uint16
371 var localPort uint16
372
373 if familyByte>>4 == 0x1 {
374 // AF_INET
375 // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
376 // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
377 // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
378 // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
379 if addressLen != 12 {
380 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
381 }
382
383 remoteIP = make([]byte, 4)
384 localIP = make([]byte, 4)
385 } else {
386 // AF_INET6
387 // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
388 // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
389 // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
390 // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
391 if addressLen != 36 {
392 return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
393 }
394
395 remoteIP = make([]byte, 16)
396 localIP = make([]byte, 16)
397 }
398
399 if _, err := p.bufReader.Read(remoteIP); err != nil {
400 return err
401 }
402 if _, err := p.bufReader.Read(localIP); err != nil {
403 return err
404 }
405 if err := binary.Read(p.bufReader, binary.BigEndian, &remotePort); err != nil {
406 return err
407 }
408 if err := binary.Read(p.bufReader, binary.BigEndian, &localPort); err != nil {
409 return err
410 }
411
412 if familyByte&0xf == 1 {
413 p.remoteAddr = &net.TCPAddr{
414 IP: remoteIP,
415 Port: int(remotePort),
416 }
417 p.localAddr = &net.TCPAddr{
418 IP: localIP,
419 Port: int(localPort),
420 }
421 } else {
422 p.remoteAddr = &net.UDPAddr{
423 IP: remoteIP,
424 Port: int(remotePort),
425 }
426 p.localAddr = &net.UDPAddr{
427 IP: localIP,
428 Port: int(localPort),
429 }
430 }
431 return nil
432}
433
434func (p *Conn) readV1ProxyHeader() error {
435 // Read until a newline
436 header, err := p.bufReader.ReadString('\n')
437 if err != nil {
438 p.conn.Close()
439 return err
440 }
441
442 if header[len(header)-2] != '\r' {
443 return &ErrBadHeader{[]byte(header)}
444 }
445
446 // Strip the carriage return and new line
447 header = header[:len(header)-2]
448
449 // Split on spaces, should be (PROXY <type> <remote addr> <local addr> <remote port> <local port>)
450 parts := strings.Split(header, " ")
451 if len(parts) < 2 {
452 p.conn.Close()
453 return &ErrBadHeader{[]byte(header)}
454 }
455
456 // Verify the type is known
457 switch parts[1] {
458 case "UNKNOWN":
459 if !p.acceptUnknown || len(parts) != 2 {
460 p.conn.Close()
461 return &ErrBadHeader{[]byte(header)}
462 }
463 p.remoteAddr = p.conn.LocalAddr()
464 p.localAddr = p.conn.RemoteAddr()
465 return nil
466 case "TCP4":
467 case "TCP6":
468 default:
469 p.conn.Close()
470 return &ErrBadAddressType{parts[1]}
471 }
472
473 if len(parts) != 6 {
474 p.conn.Close()
475 return &ErrBadHeader{[]byte(header)}
476 }
477
478 // Parse out the remote address
479 ip := net.ParseIP(parts[2])
480 if ip == nil {
481 p.conn.Close()
482 return &ErrBadRemote{parts[2], parts[4]}
483 }
484 port, err := strconv.Atoi(parts[4])
485 if err != nil {
486 p.conn.Close()
487 return &ErrBadRemote{parts[2], parts[4]}
488 }
489 p.remoteAddr = &net.TCPAddr{IP: ip, Port: port}
490
491 // Parse out the destination address
492 ip = net.ParseIP(parts[3])
493 if ip == nil {
494 p.conn.Close()
495 return &ErrBadLocal{parts[3], parts[5]}
496 }
497 port, err = strconv.Atoi(parts[5])
498 if err != nil {
499 p.conn.Close()
500 return &ErrBadLocal{parts[3], parts[5]}
501 }
502 p.localAddr = &net.TCPAddr{IP: ip, Port: port}
503
504 return nil
505}