mirror of
https://github.com/golang/go.git
synced 2025-05-29 11:25:43 +00:00
net/http/httputil: make ReverseProxy close response body if ModifyResponse returns an error
Fixes #22658 Change-Id: I00e2b007d77b6f54798f7755d0b08e4fea824392 Reviewed-on: https://go-review.googlesource.com/77170 Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com> Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
f01b928aad
commit
d96ebf8a6d
@ -207,6 +207,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
if err := p.ModifyResponse(res); err != nil {
|
if err := p.ModifyResponse(res); err != nil {
|
||||||
p.logf("http: proxy error: %v", err)
|
p.logf("http: proxy error: %v", err)
|
||||||
rw.WriteHeader(http.StatusBadGateway)
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
res.Body.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -769,3 +769,47 @@ type roundTripperFunc func(req *http.Request) (*http.Response, error)
|
|||||||
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return fn(req)
|
return fn(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModifyResponseClosesBody(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:56789"
|
||||||
|
closeCheck := new(checkCloser)
|
||||||
|
logBuf := new(bytes.Buffer)
|
||||||
|
outErr := errors.New("ModifyResponse error")
|
||||||
|
rp := &ReverseProxy{
|
||||||
|
Director: func(req *http.Request) {},
|
||||||
|
Transport: &staticTransport{&http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: closeCheck,
|
||||||
|
}},
|
||||||
|
ErrorLog: log.New(logBuf, "", 0),
|
||||||
|
ModifyResponse: func(*http.Response) error {
|
||||||
|
return outErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rp.ServeHTTP(rec, req)
|
||||||
|
res := rec.Result()
|
||||||
|
if g, e := res.StatusCode, http.StatusBadGateway; g != e {
|
||||||
|
t.Errorf("got res.StatusCode %d; expected %d", g, e)
|
||||||
|
}
|
||||||
|
if !closeCheck.closed {
|
||||||
|
t.Errorf("body should have been closed")
|
||||||
|
}
|
||||||
|
if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
|
||||||
|
t.Errorf("ErrorLog %q does not contain %q", g, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type checkCloser struct {
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *checkCloser) Close() error {
|
||||||
|
cc.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *checkCloser) Read(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user