diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index b1d9f0b3e3..c603c201d5 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6139,6 +6139,50 @@ func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { <-done } +// Test that the bufio.Reader returned by Hijack yields the entire body. +func TestServerHijackGetsFullBody(t *testing.T) { + run(t, testServerHijackGetsFullBody, []testMode{http1Mode}) +} +func testServerHijackGetsFullBody(t *testing.T, mode testMode) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see https://golang.org/issue/18657") + } + done := make(chan struct{}) + needle := strings.Repeat("x", 100*1024) // assume: larger than net/http bufio size + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(done) + + conn, buf, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + got := make([]byte, len(needle)) + n, err := io.ReadFull(buf.Reader, got) + if n != len(needle) || string(got) != needle || err != nil { + t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err) + } + })).ts + + cn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cn.Close() + buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") + buf = append(buf, []byte(needle)...) + if _, err := cn.Write(buf); err != nil { + t.Fatal(err) + } + + if err := cn.(*net.TCPConn).CloseWrite(); err != nil { + t.Fatal(err) + } + <-done +} + // Like TestServerHijackGetsBackgroundByte above but sending a // immediate 1MB of data to the server to fill up the server's 4KB // buffer. diff --git a/src/net/http/server.go b/src/net/http/server.go index be25e9a450..49a9d30207 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -324,12 +324,14 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { rwc = c.rwc rwc.SetDeadline(time.Time{}) - buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) if c.r.hasByte { if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil { return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) } } + c.bufw.Reset(rwc) + buf = bufio.NewReadWriter(c.bufr, c.bufw) + c.setState(rwc, StateHijacked, runHooks) return } @@ -652,10 +654,13 @@ type readResult struct { // read sizes) with support for selectively keeping an io.Reader.Read // call blocked in a background goroutine to wait for activity and // trigger a CloseNotifier channel. +// After a Handler has hijacked the conn and exited, connReader behaves like a +// proxy for the net.Conn and the aforementioned behavior is bypassed. type connReader struct { - conn *conn + rwc net.Conn // rwc is the underlying network connection. mu sync.Mutex // guards following + conn *conn // conn is nil after handler exit. hasByte bool byteBuf [1]byte cond *sync.Cond @@ -673,6 +678,12 @@ func (cr *connReader) lock() { func (cr *connReader) unlock() { cr.mu.Unlock() } +func (cr *connReader) releaseConn() { + cr.lock() + defer cr.unlock() + cr.conn = nil +} + func (cr *connReader) startBackgroundRead() { cr.lock() defer cr.unlock() @@ -683,12 +694,12 @@ func (cr *connReader) startBackgroundRead() { return } cr.inRead = true - cr.conn.rwc.SetReadDeadline(time.Time{}) + cr.rwc.SetReadDeadline(time.Time{}) go cr.backgroundRead() } func (cr *connReader) backgroundRead() { - n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + n, err := cr.rwc.Read(cr.byteBuf[:]) cr.lock() if n == 1 { cr.hasByte = true @@ -719,7 +730,7 @@ func (cr *connReader) backgroundRead() { // Ignore this error. It's the expected error from // another goroutine calling abortPendingRead. } else if err != nil { - cr.handleReadError(err) + cr.handleReadErrorLocked(err) } cr.aborted = false cr.inRead = false @@ -734,18 +745,18 @@ func (cr *connReader) abortPendingRead() { return } cr.aborted = true - cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + cr.rwc.SetReadDeadline(aLongTimeAgo) for cr.inRead { cr.cond.Wait() } - cr.conn.rwc.SetReadDeadline(time.Time{}) + cr.rwc.SetReadDeadline(time.Time{}) } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } -// handleReadError is called whenever a Read from the client returns a +// handleReadErrorLocked is called whenever a Read from the client returns a // non-nil error. // // The provided non-nil err is almost always io.EOF or a "use of @@ -754,14 +765,12 @@ func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } // development. Any error means the connection is dead and we should // down its context. // -// It may be called from multiple goroutines. -func (cr *connReader) handleReadError(_ error) { +// The caller must hold connReader.mu. +func (cr *connReader) handleReadErrorLocked(_ error) { + if cr.conn == nil { + return + } cr.conn.cancelCtx() - cr.closeNotify() -} - -// may be called from multiple goroutines. -func (cr *connReader) closeNotify() { if res := cr.conn.curReq.Load(); res != nil { res.closeNotify() } @@ -769,9 +778,14 @@ func (cr *connReader) closeNotify() { func (cr *connReader) Read(p []byte) (n int, err error) { cr.lock() - if cr.inRead { + if cr.conn == nil { cr.unlock() - if cr.conn.hijacked() { + return cr.rwc.Read(p) + } + if cr.inRead { + hijacked := cr.conn.hijacked() + cr.unlock() + if hijacked { panic("invalid Body.Read call. After hijacked, the original Request must not be used") } panic("invalid concurrent Body.Read call") @@ -795,12 +809,12 @@ func (cr *connReader) Read(p []byte) (n int, err error) { } cr.inRead = true cr.unlock() - n, err = cr.conn.rwc.Read(p) + n, err = cr.rwc.Read(p) cr.lock() cr.inRead = false if err != nil { - cr.handleReadError(err) + cr.handleReadErrorLocked(err) } cr.remain -= int64(n) cr.unlock() @@ -1986,7 +2000,7 @@ func (c *conn) serve(ctx context.Context) { c.cancelCtx = cancelCtx defer cancelCtx() - c.r = &connReader{conn: c} + c.r = &connReader{conn: c, rwc: c.rwc} c.bufr = newBufioReader(c.r) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) @@ -2083,6 +2097,7 @@ func (c *conn) serve(ctx context.Context) { inFlightResponse = nil w.cancelCtx() if c.hijacked() { + c.r.releaseConn() return } w.finishRequest()