go/src/net/http/httptest/recorder.go
David Url 7662e6588c net/http: use RFC 723x as normative reference in docs
Replace references to the obsoleted RFC 2616 with references to RFC
7230 through 7235, to avoid unnecessary confusion.
Obvious inconsistencies are marked with todo comments.

Updates #21974

Change-Id: I8fb4fcdd1333fc5193b93a2f09598f18c45e7a00
Reviewed-on: https://go-review.googlesource.com/94095
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
2018-02-20 01:41:06 +00:00

238 lines
6.1 KiB
Go

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httptest
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strconv"
"strings"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
// records its mutations for later inspection in tests.
type ResponseRecorder struct {
// Code is the HTTP response code set by WriteHeader.
//
// Note that if a Handler never calls WriteHeader or Write,
// this might end up being 0, rather than the implicit
// http.StatusOK. To get the implicit value, use the Result
// method.
Code int
// HeaderMap contains the headers explicitly set by the Handler.
//
// To get the implicit headers set by the server (such as
// automatic Content-Type), use the Result method.
HeaderMap http.Header
// Body is the buffer to which the Handler's Write calls are sent.
// If nil, the Writes are silently discarded.
Body *bytes.Buffer
// Flushed is whether the Handler called Flush.
Flushed bool
result *http.Response // cache of Result's return value
snapHeader http.Header // snapshot of HeaderMap at first Write
wroteHeader bool
}
// NewRecorder returns an initialized ResponseRecorder.
func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
Code: 200,
}
}
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
const DefaultRemoteAddr = "1.2.3.4"
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
m := rw.HeaderMap
if m == nil {
m = make(http.Header)
rw.HeaderMap = m
}
return m
}
// writeHeader writes a header if it was not written yet and
// detects Content-Type if needed.
//
// bytes or str are the beginning of the response body.
// We pass both to avoid unnecessarily generate garbage
// in rw.WriteString which was created for performance reasons.
// Non-nil bytes win.
func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
if rw.wroteHeader {
return
}
if len(str) > 512 {
str = str[:512]
}
m := rw.Header()
_, hasType := m["Content-Type"]
hasTE := m.Get("Transfer-Encoding") != ""
if !hasType && !hasTE {
if b == nil {
b = []byte(str)
}
m.Set("Content-Type", http.DetectContentType(b))
}
rw.WriteHeader(200)
}
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
rw.writeHeader(buf, "")
if rw.Body != nil {
rw.Body.Write(buf)
}
return len(buf), nil
}
// WriteString always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
rw.writeHeader(nil, str)
if rw.Body != nil {
rw.Body.WriteString(str)
}
return len(str), nil
}
// WriteHeader sets rw.Code. After it is called, changing rw.Header
// will not affect rw.HeaderMap.
func (rw *ResponseRecorder) WriteHeader(code int) {
if rw.wroteHeader {
return
}
rw.Code = code
rw.wroteHeader = true
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
rw.snapHeader = cloneHeader(rw.HeaderMap)
}
func cloneHeader(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
// Flush sets rw.Flushed to true.
func (rw *ResponseRecorder) Flush() {
if !rw.wroteHeader {
rw.WriteHeader(200)
}
rw.Flushed = true
}
// Result returns the response generated by the handler.
//
// The returned Response will have at least its StatusCode,
// Header, Body, and optionally Trailer populated.
// More fields may be populated in the future, so callers should
// not DeepEqual the result in tests.
//
// The Response.Header is a snapshot of the headers at the time of the
// first write call, or at the time of this call, if the handler never
// did a write.
//
// The Response.Body is guaranteed to be non-nil and Body.Read call is
// guaranteed to not return any error other than io.EOF.
//
// Result must only be called after the handler has finished running.
func (rw *ResponseRecorder) Result() *http.Response {
if rw.result != nil {
return rw.result
}
if rw.snapHeader == nil {
rw.snapHeader = cloneHeader(rw.HeaderMap)
}
res := &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: rw.Code,
Header: rw.snapHeader,
}
rw.result = res
if res.StatusCode == 0 {
res.StatusCode = 200
}
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
}
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
for _, k := range trailers {
// TODO: use http2.ValidTrailerHeader, but we can't
// get at it easily because it's bundled into net/http
// unexported. This is good enough for now:
switch k {
case "Transfer-Encoding", "Content-Length", "Trailer":
// Ignore since forbidden by RFC 2616 14.40.
// TODO: inconsistent with RFC 7230, section 4.1.2.
continue
}
k = http.CanonicalHeaderKey(k)
vv, ok := rw.HeaderMap[k]
if !ok {
continue
}
vv2 := make([]string, len(vv))
copy(vv2, vv)
res.Trailer[k] = vv2
}
}
for k, vv := range rw.HeaderMap {
if !strings.HasPrefix(k, http.TrailerPrefix) {
continue
}
if res.Trailer == nil {
res.Trailer = make(http.Header)
}
for _, v := range vv {
res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
}
}
return res
}
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
//
// This a modified version of same function found in net/http/transfer.go. This
// one just ignores an invalid header.
func parseContentLength(cl string) int64 {
cl = strings.TrimSpace(cl)
if cl == "" {
return -1
}
n, err := strconv.ParseInt(cl, 10, 64)
if err != nil {
return -1
}
return n
}