2011-03-25 00:46:17 +01:00
|
|
|
// Copyright 2011 The Go Authors. All rights reserved.
|
|
|
|
// Use of this source code is governed by a BSD-style
|
|
|
|
// license that can be found in the LICENSE file.
|
|
|
|
|
|
|
|
package httptest
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2016-07-22 20:15:38 +02:00
|
|
|
"io/ioutil"
|
2011-12-07 02:11:29 +01:00
|
|
|
"net/http"
|
2017-01-14 01:05:42 +01:00
|
|
|
"strconv"
|
|
|
|
"strings"
|
2011-03-25 00:46:17 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
// ResponseRecorder is an implementation of http.ResponseWriter that
|
|
|
|
// records its mutations for later inspection in tests.
|
|
|
|
type ResponseRecorder struct {
|
2017-01-14 01:05:42 +01:00
|
|
|
// Code is the HTTP response code set by WriteHeader.
|
|
|
|
//
|
|
|
|
// Note that if a Handler never calls WriteHeader or Write,
|
|
|
|
// this might end up being 0, rather than the implicit
|
|
|
|
// http.StatusOK. To get the implicit value, use the Result
|
|
|
|
// method.
|
|
|
|
Code int
|
|
|
|
|
|
|
|
// HeaderMap contains the headers explicitly set by the Handler.
|
|
|
|
//
|
|
|
|
// To get the implicit headers set by the server (such as
|
|
|
|
// automatic Content-Type), use the Result method.
|
|
|
|
HeaderMap http.Header
|
|
|
|
|
|
|
|
// Body is the buffer to which the Handler's Write calls are sent.
|
|
|
|
// If nil, the Writes are silently discarded.
|
|
|
|
Body *bytes.Buffer
|
|
|
|
|
|
|
|
// Flushed is whether the Handler called Flush.
|
|
|
|
Flushed bool
|
2012-10-23 06:31:11 +02:00
|
|
|
|
2016-07-22 20:15:38 +02:00
|
|
|
result *http.Response // cache of Result's return value
|
|
|
|
snapHeader http.Header // snapshot of HeaderMap at first Write
|
2012-10-23 06:31:11 +02:00
|
|
|
wroteHeader bool
|
2011-03-25 00:46:17 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// NewRecorder returns an initialized ResponseRecorder.
|
|
|
|
func NewRecorder() *ResponseRecorder {
|
|
|
|
return &ResponseRecorder{
|
|
|
|
HeaderMap: make(http.Header),
|
|
|
|
Body: new(bytes.Buffer),
|
2012-10-23 06:31:11 +02:00
|
|
|
Code: 200,
|
2011-03-25 00:46:17 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
|
|
|
|
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
|
|
|
|
const DefaultRemoteAddr = "1.2.3.4"
|
|
|
|
|
|
|
|
// Header returns the response headers.
|
|
|
|
func (rw *ResponseRecorder) Header() http.Header {
|
2012-10-23 06:31:11 +02:00
|
|
|
m := rw.HeaderMap
|
|
|
|
if m == nil {
|
|
|
|
m = make(http.Header)
|
|
|
|
rw.HeaderMap = m
|
|
|
|
}
|
|
|
|
return m
|
2011-03-25 00:46:17 +01:00
|
|
|
}
|
|
|
|
|
2016-02-03 22:58:02 +01:00
|
|
|
// writeHeader writes a header if it was not written yet and
|
|
|
|
// detects Content-Type if needed.
|
|
|
|
//
|
|
|
|
// bytes or str are the beginning of the response body.
|
|
|
|
// We pass both to avoid unnecessarily generate garbage
|
|
|
|
// in rw.WriteString which was created for performance reasons.
|
|
|
|
// Non-nil bytes win.
|
|
|
|
func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
|
|
|
|
if rw.wroteHeader {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if len(str) > 512 {
|
|
|
|
str = str[:512]
|
|
|
|
}
|
|
|
|
|
2016-07-22 20:15:38 +02:00
|
|
|
m := rw.Header()
|
|
|
|
|
|
|
|
_, hasType := m["Content-Type"]
|
|
|
|
hasTE := m.Get("Transfer-Encoding") != ""
|
2016-02-03 22:58:02 +01:00
|
|
|
if !hasType && !hasTE {
|
|
|
|
if b == nil {
|
|
|
|
b = []byte(str)
|
|
|
|
}
|
2016-07-22 20:15:38 +02:00
|
|
|
m.Set("Content-Type", http.DetectContentType(b))
|
2016-02-03 22:58:02 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
rw.WriteHeader(200)
|
|
|
|
}
|
|
|
|
|
2011-03-25 00:46:17 +01:00
|
|
|
// Write always succeeds and writes to rw.Body, if not nil.
|
2011-12-03 03:17:34 +01:00
|
|
|
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
|
2016-02-03 22:58:02 +01:00
|
|
|
rw.writeHeader(buf, "")
|
2011-03-25 00:46:17 +01:00
|
|
|
if rw.Body != nil {
|
|
|
|
rw.Body.Write(buf)
|
|
|
|
}
|
|
|
|
return len(buf), nil
|
|
|
|
}
|
|
|
|
|
2016-02-03 22:58:02 +01:00
|
|
|
// WriteString always succeeds and writes to rw.Body, if not nil.
|
|
|
|
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
|
|
|
|
rw.writeHeader(nil, str)
|
|
|
|
if rw.Body != nil {
|
|
|
|
rw.Body.WriteString(str)
|
|
|
|
}
|
|
|
|
return len(str), nil
|
|
|
|
}
|
|
|
|
|
2016-07-22 20:15:38 +02:00
|
|
|
// WriteHeader sets rw.Code. After it is called, changing rw.Header
|
|
|
|
// will not affect rw.HeaderMap.
|
2011-03-25 00:46:17 +01:00
|
|
|
func (rw *ResponseRecorder) WriteHeader(code int) {
|
2016-07-22 20:15:38 +02:00
|
|
|
if rw.wroteHeader {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
rw.Code = code
|
|
|
|
rw.wroteHeader = true
|
|
|
|
if rw.HeaderMap == nil {
|
|
|
|
rw.HeaderMap = make(http.Header)
|
2012-10-23 06:31:11 +02:00
|
|
|
}
|
2016-07-22 20:15:38 +02:00
|
|
|
rw.snapHeader = cloneHeader(rw.HeaderMap)
|
|
|
|
}
|
|
|
|
|
|
|
|
func cloneHeader(h http.Header) http.Header {
|
|
|
|
h2 := make(http.Header, len(h))
|
|
|
|
for k, vv := range h {
|
|
|
|
vv2 := make([]string, len(vv))
|
|
|
|
copy(vv2, vv)
|
|
|
|
h2[k] = vv2
|
|
|
|
}
|
|
|
|
return h2
|
2011-03-25 00:46:17 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Flush sets rw.Flushed to true.
|
|
|
|
func (rw *ResponseRecorder) Flush() {
|
2012-10-23 06:31:11 +02:00
|
|
|
if !rw.wroteHeader {
|
|
|
|
rw.WriteHeader(200)
|
|
|
|
}
|
2011-03-25 00:46:17 +01:00
|
|
|
rw.Flushed = true
|
|
|
|
}
|
2016-07-22 20:15:38 +02:00
|
|
|
|
|
|
|
// Result returns the response generated by the handler.
|
|
|
|
//
|
|
|
|
// The returned Response will have at least its StatusCode,
|
|
|
|
// Header, Body, and optionally Trailer populated.
|
|
|
|
// More fields may be populated in the future, so callers should
|
|
|
|
// not DeepEqual the result in tests.
|
|
|
|
//
|
|
|
|
// The Response.Header is a snapshot of the headers at the time of the
|
|
|
|
// first write call, or at the time of this call, if the handler never
|
|
|
|
// did a write.
|
|
|
|
//
|
2017-01-14 01:05:42 +01:00
|
|
|
// The Response.Body is guaranteed to be non-nil and Body.Read call is
|
|
|
|
// guaranteed to not return any error other than io.EOF.
|
|
|
|
//
|
2016-07-22 20:15:38 +02:00
|
|
|
// Result must only be called after the handler has finished running.
|
|
|
|
func (rw *ResponseRecorder) Result() *http.Response {
|
|
|
|
if rw.result != nil {
|
|
|
|
return rw.result
|
|
|
|
}
|
|
|
|
if rw.snapHeader == nil {
|
|
|
|
rw.snapHeader = cloneHeader(rw.HeaderMap)
|
|
|
|
}
|
|
|
|
res := &http.Response{
|
|
|
|
Proto: "HTTP/1.1",
|
|
|
|
ProtoMajor: 1,
|
|
|
|
ProtoMinor: 1,
|
|
|
|
StatusCode: rw.Code,
|
|
|
|
Header: rw.snapHeader,
|
|
|
|
}
|
|
|
|
rw.result = res
|
|
|
|
if res.StatusCode == 0 {
|
|
|
|
res.StatusCode = 200
|
|
|
|
}
|
|
|
|
res.Status = http.StatusText(res.StatusCode)
|
|
|
|
if rw.Body != nil {
|
|
|
|
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
|
|
|
|
}
|
2017-01-14 01:05:42 +01:00
|
|
|
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
|
2016-07-22 20:15:38 +02:00
|
|
|
|
|
|
|
if trailers, ok := rw.snapHeader["Trailer"]; ok {
|
|
|
|
res.Trailer = make(http.Header, len(trailers))
|
|
|
|
for _, k := range trailers {
|
|
|
|
// TODO: use http2.ValidTrailerHeader, but we can't
|
|
|
|
// get at it easily because it's bundled into net/http
|
|
|
|
// unexported. This is good enough for now:
|
|
|
|
switch k {
|
|
|
|
case "Transfer-Encoding", "Content-Length", "Trailer":
|
|
|
|
// Ignore since forbidden by RFC 2616 14.40.
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
k = http.CanonicalHeaderKey(k)
|
|
|
|
vv, ok := rw.HeaderMap[k]
|
|
|
|
if !ok {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
vv2 := make([]string, len(vv))
|
|
|
|
copy(vv2, vv)
|
|
|
|
res.Trailer[k] = vv2
|
|
|
|
}
|
|
|
|
}
|
2017-01-14 01:05:42 +01:00
|
|
|
for k, vv := range rw.HeaderMap {
|
|
|
|
if !strings.HasPrefix(k, http.TrailerPrefix) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if res.Trailer == nil {
|
|
|
|
res.Trailer = make(http.Header)
|
|
|
|
}
|
|
|
|
for _, v := range vv {
|
|
|
|
res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
|
|
|
|
}
|
|
|
|
}
|
2016-07-22 20:15:38 +02:00
|
|
|
return res
|
|
|
|
}
|
2017-01-14 01:05:42 +01:00
|
|
|
|
|
|
|
// parseContentLength trims whitespace from s and returns -1 if no value
|
|
|
|
// is set, or the value if it's >= 0.
|
|
|
|
//
|
|
|
|
// This a modified version of same function found in net/http/transfer.go. This
|
|
|
|
// one just ignores an invalid header.
|
|
|
|
func parseContentLength(cl string) int64 {
|
|
|
|
cl = strings.TrimSpace(cl)
|
|
|
|
if cl == "" {
|
|
|
|
return -1
|
|
|
|
}
|
|
|
|
n, err := strconv.ParseInt(cl, 10, 64)
|
|
|
|
if err != nil {
|
|
|
|
return -1
|
|
|
|
}
|
|
|
|
return n
|
|
|
|
}
|