diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 5d27880735..8d3e20c302 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -802,9 +802,6 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R if err == nil { err = <-errc } - if err != nil && err != errCopyDone { - p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err)) - } } var errCopyDone = errors.New("hijacked connection copy complete") diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 1acbc296c3..62c93fb261 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -2104,10 +2104,56 @@ func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, } } +// Issue #72954: We should not call WriteHeader on a ResponseWriter after hijacking +// the connection. +func TestReverseProxyHijackCopyError(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Upgrade", "someproto") + w.WriteHeader(http.StatusSwitchingProtocols) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + ModifyResponse: func(resp *http.Response) error { + resp.Body = &testReadWriteCloser{ + read: func([]byte) (int, error) { + return 0, errors.New("read error") + }, + } + return nil + }, + } + + hijacked := false + rw := &testResponseWriter{ + writeHeader: func(statusCode int) { + if hijacked { + t.Errorf("WriteHeader(%v) called after Hijack", statusCode) + } + }, + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + hijacked = true + cli, srv := net.Pipe() + go io.Copy(io.Discard, cli) + return srv, bufio.NewReadWriter(bufio.NewReader(srv), bufio.NewWriter(srv)), nil + }, + } + req, _ := http.NewRequest("GET", "http://example.tld/", nil) + req.Header.Set("Upgrade", "someproto") + proxyHandler.ServeHTTP(rw, req) +} + type testResponseWriter struct { h http.Header writeHeader func(int) write func([]byte) (int, error) + hijack func() (net.Conn, *bufio.ReadWriter, error) } func (rw *testResponseWriter) Header() http.Header { @@ -2129,3 +2175,37 @@ func (rw *testResponseWriter) Write(p []byte) (int, error) { } return len(p), nil } + +func (rw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if rw.hijack != nil { + return rw.hijack() + } + return nil, nil, errors.ErrUnsupported +} + +type testReadWriteCloser struct { + read func([]byte) (int, error) + write func([]byte) (int, error) + close func() error +} + +func (rc *testReadWriteCloser) Read(p []byte) (int, error) { + if rc.read != nil { + return rc.read(p) + } + return 0, io.EOF +} + +func (rc *testReadWriteCloser) Write(p []byte) (int, error) { + if rc.write != nil { + return rc.write(p) + } + return len(p), nil +} + +func (rc *testReadWriteCloser) Close() error { + if rc.close != nil { + return rc.close() + } + return nil +}