net/http: reduce memory usage when hijacking

Previously, Hijack allocated a new write buffer and the existing
connection write buffer used an extra 4KiB of memory until the handler
finished and the "conn" was garbage collected. Now, hijack re-uses the
existing write buffer and re-attaches it to the raw connection to avoid
referencing the net/http "conn" after returning.

After a handler that hijacked exited, the "conn" reference in
"connReader" will now be unset. This allows all of the "conn",
"response" and "Request" to get garbage collected.
Overall, this is reducing the memory usage by 43% or 6.7KiB per hijacked
connection (see BenchmarkServerHijackMemoryUsage in an earlier revision
of the CL).

CloseNotify will continue to work _before_ the handler has exited
(i.e. while the "conn" is still referenced in "connReader"). This aligns
with the documentation of CloseNotifier:
> After the Handler has returned, there is no guarantee that the channel
> receives a value.

goos: linux
goarch: amd64
pkg: net/http
cpu: Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz
               │   before    │             after              │
               │   sec/op    │    sec/op     vs base          │
ServerHijack-8   42.59µ ± 8%   39.47µ ± 16%  ~ (p=0.481 n=10)

               │    before    │                after                 │
               │     B/op     │     B/op      vs base                │
ServerHijack-8   16.12Ki ± 0%   12.06Ki ± 0%  -25.16% (p=0.000 n=10)

               │   before   │               after               │
               │ allocs/op  │ allocs/op   vs base               │
ServerHijack-8   51.00 ± 0%   49.00 ± 0%  -3.92% (p=0.000 n=10)

Change-Id: I20a37ee314ed0d47463a4657d712154e78e48138
GitHub-Last-Rev: 80f09dfa273035f53cdd72845e5c5fb129c3e230
GitHub-Pull-Request: golang/go#70756
Reviewed-on: https://go-review.googlesource.com/c/go/+/634855
Reviewed-by: Sean Liao <sean@liao.dev>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Sean Liao <sean@liao.dev>
This commit is contained in:
Jakob Ackermann 2025-03-03 22:16:38 +00:00 committed by Gopher Robot
parent cac276f81a
commit cfb78d63bd
2 changed files with 79 additions and 20 deletions

View File

@ -6139,6 +6139,50 @@ func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
<-done <-done
} }
// Test that the bufio.Reader returned by Hijack yields the entire body.
func TestServerHijackGetsFullBody(t *testing.T) {
run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
}
func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
if runtime.GOOS == "plan9" {
t.Skip("skipping test; see https://golang.org/issue/18657")
}
done := make(chan struct{})
needle := strings.Repeat("x", 100*1024) // assume: larger than net/http bufio size
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(done)
conn, buf, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer conn.Close()
got := make([]byte, len(needle))
n, err := io.ReadFull(buf.Reader, got)
if n != len(needle) || string(got) != needle || err != nil {
t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
}
})).ts
cn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer cn.Close()
buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
buf = append(buf, []byte(needle)...)
if _, err := cn.Write(buf); err != nil {
t.Fatal(err)
}
if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
t.Fatal(err)
}
<-done
}
// Like TestServerHijackGetsBackgroundByte above but sending a // Like TestServerHijackGetsBackgroundByte above but sending a
// immediate 1MB of data to the server to fill up the server's 4KB // immediate 1MB of data to the server to fill up the server's 4KB
// buffer. // buffer.

View File

@ -324,12 +324,14 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
rwc = c.rwc rwc = c.rwc
rwc.SetDeadline(time.Time{}) rwc.SetDeadline(time.Time{})
buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
if c.r.hasByte { if c.r.hasByte {
if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil { if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil {
return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err)
} }
} }
c.bufw.Reset(rwc)
buf = bufio.NewReadWriter(c.bufr, c.bufw)
c.setState(rwc, StateHijacked, runHooks) c.setState(rwc, StateHijacked, runHooks)
return return
} }
@ -652,10 +654,13 @@ type readResult struct {
// read sizes) with support for selectively keeping an io.Reader.Read // read sizes) with support for selectively keeping an io.Reader.Read
// call blocked in a background goroutine to wait for activity and // call blocked in a background goroutine to wait for activity and
// trigger a CloseNotifier channel. // trigger a CloseNotifier channel.
// After a Handler has hijacked the conn and exited, connReader behaves like a
// proxy for the net.Conn and the aforementioned behavior is bypassed.
type connReader struct { type connReader struct {
conn *conn rwc net.Conn // rwc is the underlying network connection.
mu sync.Mutex // guards following mu sync.Mutex // guards following
conn *conn // conn is nil after handler exit.
hasByte bool hasByte bool
byteBuf [1]byte byteBuf [1]byte
cond *sync.Cond cond *sync.Cond
@ -673,6 +678,12 @@ func (cr *connReader) lock() {
func (cr *connReader) unlock() { cr.mu.Unlock() } func (cr *connReader) unlock() { cr.mu.Unlock() }
func (cr *connReader) releaseConn() {
cr.lock()
defer cr.unlock()
cr.conn = nil
}
func (cr *connReader) startBackgroundRead() { func (cr *connReader) startBackgroundRead() {
cr.lock() cr.lock()
defer cr.unlock() defer cr.unlock()
@ -683,12 +694,12 @@ func (cr *connReader) startBackgroundRead() {
return return
} }
cr.inRead = true cr.inRead = true
cr.conn.rwc.SetReadDeadline(time.Time{}) cr.rwc.SetReadDeadline(time.Time{})
go cr.backgroundRead() go cr.backgroundRead()
} }
func (cr *connReader) backgroundRead() { func (cr *connReader) backgroundRead() {
n, err := cr.conn.rwc.Read(cr.byteBuf[:]) n, err := cr.rwc.Read(cr.byteBuf[:])
cr.lock() cr.lock()
if n == 1 { if n == 1 {
cr.hasByte = true cr.hasByte = true
@ -719,7 +730,7 @@ func (cr *connReader) backgroundRead() {
// Ignore this error. It's the expected error from // Ignore this error. It's the expected error from
// another goroutine calling abortPendingRead. // another goroutine calling abortPendingRead.
} else if err != nil { } else if err != nil {
cr.handleReadError(err) cr.handleReadErrorLocked(err)
} }
cr.aborted = false cr.aborted = false
cr.inRead = false cr.inRead = false
@ -734,18 +745,18 @@ func (cr *connReader) abortPendingRead() {
return return
} }
cr.aborted = true cr.aborted = true
cr.conn.rwc.SetReadDeadline(aLongTimeAgo) cr.rwc.SetReadDeadline(aLongTimeAgo)
for cr.inRead { for cr.inRead {
cr.cond.Wait() cr.cond.Wait()
} }
cr.conn.rwc.SetReadDeadline(time.Time{}) cr.rwc.SetReadDeadline(time.Time{})
} }
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
// handleReadError is called whenever a Read from the client returns a // handleReadErrorLocked is called whenever a Read from the client returns a
// non-nil error. // non-nil error.
// //
// The provided non-nil err is almost always io.EOF or a "use of // The provided non-nil err is almost always io.EOF or a "use of
@ -754,14 +765,12 @@ func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
// development. Any error means the connection is dead and we should // development. Any error means the connection is dead and we should
// down its context. // down its context.
// //
// It may be called from multiple goroutines. // The caller must hold connReader.mu.
func (cr *connReader) handleReadError(_ error) { func (cr *connReader) handleReadErrorLocked(_ error) {
if cr.conn == nil {
return
}
cr.conn.cancelCtx() cr.conn.cancelCtx()
cr.closeNotify()
}
// may be called from multiple goroutines.
func (cr *connReader) closeNotify() {
if res := cr.conn.curReq.Load(); res != nil { if res := cr.conn.curReq.Load(); res != nil {
res.closeNotify() res.closeNotify()
} }
@ -769,9 +778,14 @@ func (cr *connReader) closeNotify() {
func (cr *connReader) Read(p []byte) (n int, err error) { func (cr *connReader) Read(p []byte) (n int, err error) {
cr.lock() cr.lock()
if cr.inRead { if cr.conn == nil {
cr.unlock() cr.unlock()
if cr.conn.hijacked() { return cr.rwc.Read(p)
}
if cr.inRead {
hijacked := cr.conn.hijacked()
cr.unlock()
if hijacked {
panic("invalid Body.Read call. After hijacked, the original Request must not be used") panic("invalid Body.Read call. After hijacked, the original Request must not be used")
} }
panic("invalid concurrent Body.Read call") panic("invalid concurrent Body.Read call")
@ -795,12 +809,12 @@ func (cr *connReader) Read(p []byte) (n int, err error) {
} }
cr.inRead = true cr.inRead = true
cr.unlock() cr.unlock()
n, err = cr.conn.rwc.Read(p) n, err = cr.rwc.Read(p)
cr.lock() cr.lock()
cr.inRead = false cr.inRead = false
if err != nil { if err != nil {
cr.handleReadError(err) cr.handleReadErrorLocked(err)
} }
cr.remain -= int64(n) cr.remain -= int64(n)
cr.unlock() cr.unlock()
@ -1986,7 +2000,7 @@ func (c *conn) serve(ctx context.Context) {
c.cancelCtx = cancelCtx c.cancelCtx = cancelCtx
defer cancelCtx() defer cancelCtx()
c.r = &connReader{conn: c} c.r = &connReader{conn: c, rwc: c.rwc}
c.bufr = newBufioReader(c.r) c.bufr = newBufioReader(c.r)
c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
@ -2083,6 +2097,7 @@ func (c *conn) serve(ctx context.Context) {
inFlightResponse = nil inFlightResponse = nil
w.cancelCtx() w.cancelCtx()
if c.hijacked() { if c.hijacked() {
c.r.releaseConn()
return return
} }
w.finishRequest() w.finishRequest()