From 0bf515c8c4e4ce7ffdbcdb4c0701a3de1892af6c Mon Sep 17 00:00:00 2001 From: Nodir Turakulov Date: Mon, 19 Oct 2015 14:36:25 -0700 Subject: [PATCH] net/http/httptest: detect Content-Type in ResponseRecorder * detect Content-Type on ReponseRecorder.Write[String] call if header wasn't written yet, Content-Type header is not set and Transfer-Encoding is not set. * fix typos in serve_test.go Updates #12986 Change-Id: Id2ed8b1994e64657370fed71eb3882d611f76b31 Reviewed-on: https://go-review.googlesource.com/16096 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/http/httptest/recorder.go | 37 ++++++++++++++++++++----- src/net/http/httptest/recorder_test.go | 38 +++++++++++++++++++++++++- src/net/http/serve_test.go | 2 +- 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 30c5140dae..c813cf5021 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -44,11 +44,36 @@ func (rw *ResponseRecorder) Header() http.Header { return m } +// writeHeader writes a header if it was not written yet and +// detects Content-Type if needed. +// +// bytes or str are the beginning of the response body. +// We pass both to avoid unnecessarily generate garbage +// in rw.WriteString which was created for performance reasons. +// Non-nil bytes win. +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { + if rw.wroteHeader { + return + } + if len(str) > 512 { + str = str[:512] + } + + _, hasType := rw.HeaderMap["Content-Type"] + hasTE := rw.HeaderMap.Get("Transfer-Encoding") != "" + if !hasType && !hasTE { + if b == nil { + b = []byte(str) + } + rw.HeaderMap.Set("Content-Type", http.DetectContentType(b)) + } + + rw.WriteHeader(200) +} + // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { - if !rw.wroteHeader { - rw.WriteHeader(200) - } + rw.writeHeader(buf, "") if rw.Body != nil { rw.Body.Write(buf) } @@ -57,9 +82,7 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) { // WriteString always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { - if !rw.wroteHeader { - rw.WriteHeader(200) - } + rw.writeHeader(nil, str) if rw.Body != nil { rw.Body.WriteString(str) } @@ -70,8 +93,8 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) { func (rw *ResponseRecorder) WriteHeader(code int) { if !rw.wroteHeader { rw.Code = code + rw.wroteHeader = true } - rw.wroteHeader = true } // Flush sets rw.Flushed to true. diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index bc486e6b63..a5a1725fa9 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -39,6 +39,14 @@ func TestRecorder(t *testing.T) { return nil } } + hasHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.HeaderMap.Get(key); got != want { + return fmt.Errorf("header %s = %q; want %q", key, got, want) + } + return nil + } + } tests := []struct { name string @@ -73,7 +81,12 @@ func TestRecorder(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "hi first") }, - check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + check( + hasStatus(200), + hasContents("hi first"), + hasFlush(false), + hasHeader("Content-Type", "text/plain; charset=utf-8"), + ), }, { "flush", @@ -83,6 +96,29 @@ func TestRecorder(t *testing.T) { }, check(hasStatus(200), hasFlush(true)), }, + { + "Content-Type detection", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, + { + "no Content-Type detection with Transfer-Encoding", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "some encoding") + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "")), // no header + }, + { + "no Content-Type detection if set explicitly", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "some/type") + io.WriteString(w, "") + }, + check(hasHeader("Content-Type", "some/type")), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests { diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 7a008274e7..f9c2accc98 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2487,7 +2487,7 @@ func TestHeaderToWire(t *testing.T) { if !strings.Contains(got, "404") { return errors.New("wrong status") } - if strings.Contains(got, "Some-Header") { + if strings.Contains(got, "Too-Late") { return errors.New("shouldn't have seen Too-Late") } return nil