From cfb78d63bd1e9a13e58920f7670da7b807b29a0d Mon Sep 17 00:00:00 2001 From: Jakob Ackermann Date: Mon, 3 Mar 2025 22:16:38 +0000 Subject: [PATCH] net/http: reduce memory usage when hijacking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, Hijack allocated a new write buffer and the existing connection write buffer used an extra 4KiB of memory until the handler finished and the "conn" was garbage collected. Now, hijack re-uses the existing write buffer and re-attaches it to the raw connection to avoid referencing the net/http "conn" after returning. After a handler that hijacked exited, the "conn" reference in "connReader" will now be unset. This allows all of the "conn", "response" and "Request" to get garbage collected. Overall, this is reducing the memory usage by 43% or 6.7KiB per hijacked connection (see BenchmarkServerHijackMemoryUsage in an earlier revision of the CL). CloseNotify will continue to work _before_ the handler has exited (i.e. while the "conn" is still referenced in "connReader"). This aligns with the documentation of CloseNotifier: > After the Handler has returned, there is no guarantee that the channel > receives a value. goos: linux goarch: amd64 pkg: net/http cpu: Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz │ before │ after │ │ sec/op │ sec/op vs base │ ServerHijack-8 42.59µ ± 8% 39.47µ ± 16% ~ (p=0.481 n=10) │ before │ after │ │ B/op │ B/op vs base │ ServerHijack-8 16.12Ki ± 0% 12.06Ki ± 0% -25.16% (p=0.000 n=10) │ before │ after │ │ allocs/op │ allocs/op vs base │ ServerHijack-8 51.00 ± 0% 49.00 ± 0% -3.92% (p=0.000 n=10) Change-Id: I20a37ee314ed0d47463a4657d712154e78e48138 GitHub-Last-Rev: 80f09dfa273035f53cdd72845e5c5fb129c3e230 GitHub-Pull-Request: golang/go#70756 Reviewed-on: https://go-review.googlesource.com/c/go/+/634855 Reviewed-by: Sean Liao Reviewed-by: Dmitri Shuralyov LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Auto-Submit: Sean Liao --- src/net/http/serve_test.go | 44 ++++++++++++++++++++++++++++++ src/net/http/server.go | 55 ++++++++++++++++++++++++-------------- 2 files changed, 79 insertions(+), 20 deletions(-) diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index b1d9f0b3e3..c603c201d5 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6139,6 +6139,50 @@ func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { <-done } +// Test that the bufio.Reader returned by Hijack yields the entire body. +func TestServerHijackGetsFullBody(t *testing.T) { + run(t, testServerHijackGetsFullBody, []testMode{http1Mode}) +} +func testServerHijackGetsFullBody(t *testing.T, mode testMode) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see https://golang.org/issue/18657") + } + done := make(chan struct{}) + needle := strings.Repeat("x", 100*1024) // assume: larger than net/http bufio size + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(done) + + conn, buf, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + got := make([]byte, len(needle)) + n, err := io.ReadFull(buf.Reader, got) + if n != len(needle) || string(got) != needle || err != nil { + t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err) + } + })).ts + + cn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cn.Close() + buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") + buf = append(buf, []byte(needle)...) + if _, err := cn.Write(buf); err != nil { + t.Fatal(err) + } + + if err := cn.(*net.TCPConn).CloseWrite(); err != nil { + t.Fatal(err) + } + <-done +} + // Like TestServerHijackGetsBackgroundByte above but sending a // immediate 1MB of data to the server to fill up the server's 4KB // buffer. diff --git a/src/net/http/server.go b/src/net/http/server.go index be25e9a450..49a9d30207 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -324,12 +324,14 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { rwc = c.rwc rwc.SetDeadline(time.Time{}) - buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) if c.r.hasByte { if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil { return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) } } + c.bufw.Reset(rwc) + buf = bufio.NewReadWriter(c.bufr, c.bufw) + c.setState(rwc, StateHijacked, runHooks) return } @@ -652,10 +654,13 @@ type readResult struct { // read sizes) with support for selectively keeping an io.Reader.Read // call blocked in a background goroutine to wait for activity and // trigger a CloseNotifier channel. +// After a Handler has hijacked the conn and exited, connReader behaves like a +// proxy for the net.Conn and the aforementioned behavior is bypassed. type connReader struct { - conn *conn + rwc net.Conn // rwc is the underlying network connection. mu sync.Mutex // guards following + conn *conn // conn is nil after handler exit. hasByte bool byteBuf [1]byte cond *sync.Cond @@ -673,6 +678,12 @@ func (cr *connReader) lock() { func (cr *connReader) unlock() { cr.mu.Unlock() } +func (cr *connReader) releaseConn() { + cr.lock() + defer cr.unlock() + cr.conn = nil +} + func (cr *connReader) startBackgroundRead() { cr.lock() defer cr.unlock() @@ -683,12 +694,12 @@ func (cr *connReader) startBackgroundRead() { return } cr.inRead = true - cr.conn.rwc.SetReadDeadline(time.Time{}) + cr.rwc.SetReadDeadline(time.Time{}) go cr.backgroundRead() } func (cr *connReader) backgroundRead() { - n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + n, err := cr.rwc.Read(cr.byteBuf[:]) cr.lock() if n == 1 { cr.hasByte = true @@ -719,7 +730,7 @@ func (cr *connReader) backgroundRead() { // Ignore this error. It's the expected error from // another goroutine calling abortPendingRead. } else if err != nil { - cr.handleReadError(err) + cr.handleReadErrorLocked(err) } cr.aborted = false cr.inRead = false @@ -734,18 +745,18 @@ func (cr *connReader) abortPendingRead() { return } cr.aborted = true - cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + cr.rwc.SetReadDeadline(aLongTimeAgo) for cr.inRead { cr.cond.Wait() } - cr.conn.rwc.SetReadDeadline(time.Time{}) + cr.rwc.SetReadDeadline(time.Time{}) } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } -// handleReadError is called whenever a Read from the client returns a +// handleReadErrorLocked is called whenever a Read from the client returns a // non-nil error. // // The provided non-nil err is almost always io.EOF or a "use of @@ -754,14 +765,12 @@ func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } // development. Any error means the connection is dead and we should // down its context. // -// It may be called from multiple goroutines. -func (cr *connReader) handleReadError(_ error) { +// The caller must hold connReader.mu. +func (cr *connReader) handleReadErrorLocked(_ error) { + if cr.conn == nil { + return + } cr.conn.cancelCtx() - cr.closeNotify() -} - -// may be called from multiple goroutines. -func (cr *connReader) closeNotify() { if res := cr.conn.curReq.Load(); res != nil { res.closeNotify() } @@ -769,9 +778,14 @@ func (cr *connReader) closeNotify() { func (cr *connReader) Read(p []byte) (n int, err error) { cr.lock() - if cr.inRead { + if cr.conn == nil { cr.unlock() - if cr.conn.hijacked() { + return cr.rwc.Read(p) + } + if cr.inRead { + hijacked := cr.conn.hijacked() + cr.unlock() + if hijacked { panic("invalid Body.Read call. After hijacked, the original Request must not be used") } panic("invalid concurrent Body.Read call") @@ -795,12 +809,12 @@ func (cr *connReader) Read(p []byte) (n int, err error) { } cr.inRead = true cr.unlock() - n, err = cr.conn.rwc.Read(p) + n, err = cr.rwc.Read(p) cr.lock() cr.inRead = false if err != nil { - cr.handleReadError(err) + cr.handleReadErrorLocked(err) } cr.remain -= int64(n) cr.unlock() @@ -1986,7 +2000,7 @@ func (c *conn) serve(ctx context.Context) { c.cancelCtx = cancelCtx defer cancelCtx() - c.r = &connReader{conn: c} + c.r = &connReader{conn: c, rwc: c.rwc} c.bufr = newBufioReader(c.r) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) @@ -2083,6 +2097,7 @@ func (c *conn) serve(ctx context.Context) { inFlightResponse = nil w.cancelCtx() if c.hijacked() { + c.r.releaseConn() return } w.finishRequest()