gcc/libgo/go/net/http/transport_test.go

6486 lines
167 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests for transport.go.
//
// More tests are in clientserver_test.go (for things testing both client & server for both
// HTTP/1 and HTTP/2). This
package http_test
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"go/token"
"internal/nettrace"
"io"
"log"
mrand "math/rand"
"net"
. "net/http"
"net/http/httptest"
"net/http/httptrace"
"net/http/httputil"
"net/http/internal"
"net/textproto"
"net/url"
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"testing/iotest"
"time"
"golang.org/x/net/http/httpguts"
)
// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
// and then verify that the final 2 responses get errors back.
// hostPortHandler writes back the client's "host:port".
var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
if r.FormValue("close") == "true" {
w.Header().Set("Connection", "close")
}
w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
w.Write([]byte(r.RemoteAddr))
})
// testCloseConn is a net.Conn tracked by a testConnSet.
type testCloseConn struct {
net.Conn
set *testConnSet
}
func (c *testCloseConn) Close() error {
c.set.remove(c)
return c.Conn.Close()
}
// testConnSet tracks a set of TCP connections and whether they've
// been closed.
type testConnSet struct {
t *testing.T
mu sync.Mutex // guards closed and list
closed map[net.Conn]bool
list []net.Conn // in order created
}
func (tcs *testConnSet) insert(c net.Conn) {
tcs.mu.Lock()
defer tcs.mu.Unlock()
tcs.closed[c] = false
tcs.list = append(tcs.list, c)
}
func (tcs *testConnSet) remove(c net.Conn) {
tcs.mu.Lock()
defer tcs.mu.Unlock()
tcs.closed[c] = true
}
// some tests use this to manage raw tcp connections for later inspection
func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
connSet := &testConnSet{
t: t,
closed: make(map[net.Conn]bool),
}
dial := func(n, addr string) (net.Conn, error) {
c, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
tc := &testCloseConn{c, connSet}
connSet.insert(tc)
return tc, nil
}
return connSet, dial
}
func (tcs *testConnSet) check(t *testing.T) {
tcs.mu.Lock()
defer tcs.mu.Unlock()
for i := 4; i >= 0; i-- {
for i, c := range tcs.list {
if tcs.closed[c] {
continue
}
if i != 0 {
tcs.mu.Unlock()
time.Sleep(50 * time.Millisecond)
tcs.mu.Lock()
continue
}
t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
}
}
}
func TestReuseRequest(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("{}"))
}))
defer ts.Close()
c := ts.Client()
req, _ := NewRequest("GET", ts.URL, nil)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
err = res.Body.Close()
if err != nil {
t.Fatal(err)
}
res, err = c.Do(req)
if err != nil {
t.Fatal(err)
}
err = res.Body.Close()
if err != nil {
t.Fatal(err)
}
}
// Two subsequent requests and verify their response is the same.
// The response from the server is our own IP:port
func TestTransportKeepAlives(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
c := ts.Client()
for _, disableKeepAlive := range []bool{false, true} {
c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
fetch := func(n int) string {
res, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
}
return string(body)
}
body1 := fetch(1)
body2 := fetch(2)
bodiesDiffer := body1 != body2
if bodiesDiffer != disableKeepAlive {
t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
disableKeepAlive, bodiesDiffer, body1, body2)
}
}
}
func TestTransportConnectionCloseOnResponse(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
connSet, testDial := makeTestDial(t)
c := ts.Client()
tr := c.Transport.(*Transport)
tr.Dial = testDial
for _, connectionClose := range []bool{false, true} {
fetch := func(n int) string {
req := new(Request)
var err error
req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
if err != nil {
t.Fatalf("URL parse error: %v", err)
}
req.Method = "GET"
req.Proto = "HTTP/1.1"
req.ProtoMajor = 1
req.ProtoMinor = 1
res, err := c.Do(req)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
return string(body)
}
body1 := fetch(1)
body2 := fetch(2)
bodiesDiffer := body1 != body2
if bodiesDiffer != connectionClose {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
tr.CloseIdleConnections()
}
connSet.check(t)
}
func TestTransportConnectionCloseOnRequest(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
connSet, testDial := makeTestDial(t)
c := ts.Client()
tr := c.Transport.(*Transport)
tr.Dial = testDial
for _, connectionClose := range []bool{false, true} {
fetch := func(n int) string {
req := new(Request)
var err error
req.URL, err = url.Parse(ts.URL)
if err != nil {
t.Fatalf("URL parse error: %v", err)
}
req.Method = "GET"
req.Proto = "HTTP/1.1"
req.ProtoMajor = 1
req.ProtoMinor = 1
req.Close = connectionClose
res, err := c.Do(req)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
}
if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want {
t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v",
connectionClose, got, !connectionClose)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
return string(body)
}
body1 := fetch(1)
body2 := fetch(2)
bodiesDiffer := body1 != body2
if bodiesDiffer != connectionClose {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
tr.CloseIdleConnections()
}
connSet.check(t)
}
// if the Transport's DisableKeepAlives is set, all requests should
// send Connection: close.
// HTTP/1-only (Connection: close doesn't exist in h2)
func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DisableKeepAlives = true
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if res.Header.Get("X-Saw-Close") != "true" {
t.Errorf("handler didn't see Connection: close ")
}
}
// Test that Transport only sends one "Connection: close", regardless of
// how "close" was indicated.
func TestTransportRespectRequestWantsClose(t *testing.T) {
tests := []struct {
disableKeepAlives bool
close bool
}{
{disableKeepAlives: false, close: false},
{disableKeepAlives: false, close: true},
{disableKeepAlives: true, close: false},
{disableKeepAlives: true, close: true},
}
for _, tc := range tests {
t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
func(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
count := 0
trace := &httptrace.ClientTrace{
WroteHeaderField: func(key string, field []string) {
if key != "Connection" {
return
}
if httpguts.HeaderValuesContainsToken(field, "close") {
count += 1
}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
req.Close = tc.close
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
}
})
}
}
func TestTransportIdleCacheKeys(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
}
resp, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
}
io.ReadAll(resp.Body)
keys := tr.IdleConnKeysForTesting()
if e, g := 1, len(keys); e != g {
t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
}
if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
}
tr.CloseIdleConnections()
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
}
}
// Tests that the HTTP transport re-uses connections when a client
// reads to the end of a response Body without closing it.
func TestTransportReadToEndReusesConn(t *testing.T) {
defer afterTest(t)
const msg = "foobar"
var addrSeen map[string]int
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
addrSeen[r.RemoteAddr]++
if r.URL.Path == "/chunked/" {
w.WriteHeader(200)
w.(Flusher).Flush()
} else {
w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
w.WriteHeader(200)
}
w.Write([]byte(msg))
}))
defer ts.Close()
buf := make([]byte, len(msg))
for pi, path := range []string{"/content-length/", "/chunked/"} {
wantLen := []int{len(msg), -1}[pi]
addrSeen = make(map[string]int)
for i := 0; i < 3; i++ {
res, err := Get(ts.URL + path)
if err != nil {
t.Errorf("Get %s: %v", path, err)
continue
}
// We want to close this body eventually (before the
// defer afterTest at top runs), but not before the
// len(addrSeen) check at the bottom of this test,
// since Closing this early in the loop would risk
// making connections be re-used for the wrong reason.
defer res.Body.Close()
if res.ContentLength != int64(wantLen) {
t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
}
n, err := res.Body.Read(buf)
if n != len(msg) || err != io.EOF {
t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg))
}
}
if len(addrSeen) != 1 {
t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
}
}
}
func TestTransportMaxPerHostIdleConns(t *testing.T) {
defer afterTest(t)
stop := make(chan struct{}) // stop marks the exit of main Test goroutine
defer close(stop)
resch := make(chan string)
gotReq := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
gotReq <- true
var msg string
select {
case <-stop:
return
case msg = <-resch:
}
_, err := w.Write([]byte(msg))
if err != nil {
t.Errorf("Write: %v", err)
return
}
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
maxIdleConnsPerHost := 2
tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
// Start 3 outstanding requests and wait for the server to get them.
// Their responses will hang until we write to resch, though.
donech := make(chan bool)
doReq := func() {
defer func() {
select {
case <-stop:
return
case donech <- t.Failed():
}
}()
resp, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
return
}
if _, err := io.ReadAll(resp.Body); err != nil {
t.Errorf("ReadAll: %v", err)
return
}
}
go doReq()
<-gotReq
go doReq()
<-gotReq
go doReq()
<-gotReq
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
}
resch <- "res1"
<-donech
keys := tr.IdleConnKeysForTesting()
if e, g := 1, len(keys); e != g {
t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
}
addr := ts.Listener.Addr().String()
cacheKey := "|http|" + addr
if keys[0] != cacheKey {
t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
}
if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
t.Errorf("after first response, expected %d idle conns; got %d", e, g)
}
resch <- "res2"
<-donech
if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
t.Errorf("after second response, idle conns = %d; want %d", g, w)
}
resch <- "res3"
<-donech
if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
t.Errorf("after third response, idle conns = %d; want %d", g, w)
}
}
func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("foo"))
if err != nil {
t.Fatalf("Write: %v", err)
}
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
dialStarted := make(chan struct{})
stallDial := make(chan struct{})
tr.Dial = func(network, addr string) (net.Conn, error) {
dialStarted <- struct{}{}
<-stallDial
return net.Dial(network, addr)
}
tr.DisableKeepAlives = true
tr.MaxConnsPerHost = 1
preDial := make(chan struct{})
reqComplete := make(chan struct{})
doReq := func(reqId string) {
req, _ := NewRequest("GET", ts.URL, nil)
trace := &httptrace.ClientTrace{
GetConn: func(hostPort string) {
preDial <- struct{}{}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
resp, err := tr.RoundTrip(req)
if err != nil {
t.Errorf("unexpected error for request %s: %v", reqId, err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
t.Errorf("unexpected error for request %s: %v", reqId, err)
}
reqComplete <- struct{}{}
}
// get req1 to dial-in-progress
go doReq("req1")
<-preDial
<-dialStarted
// get req2 to waiting on conns per host to go down below max
go doReq("req2")
<-preDial
select {
case <-dialStarted:
t.Error("req2 dial started while req1 dial in progress")
return
default:
}
// let req1 complete
stallDial <- struct{}{}
<-reqComplete
// let req2 complete
<-dialStarted
stallDial <- struct{}{}
<-reqComplete
}
func TestTransportMaxConnsPerHost(t *testing.T) {
defer afterTest(t)
CondSkipHTTP2(t)
h := HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("foo"))
if err != nil {
t.Fatalf("Write: %v", err)
}
})
testMaxConns := func(scheme string, ts *httptest.Server) {
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
tr.MaxConnsPerHost = 1
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
}
connCh := make(chan net.Conn, 1)
var dialCnt, gotConnCnt, tlsHandshakeCnt int32
tr.Dial = func(network, addr string) (net.Conn, error) {
atomic.AddInt32(&dialCnt, 1)
c, err := net.Dial(network, addr)
connCh <- c
return c, err
}
doReq := func() {
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
if !connInfo.Reused {
atomic.AddInt32(&gotConnCnt, 1)
}
},
TLSHandshakeStart: func() {
atomic.AddInt32(&tlsHandshakeCnt, 1)
},
}
req, _ := NewRequest("GET", ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
resp, err := c.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body failed: %v", err)
}
}
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
doReq()
}()
}
wg.Wait()
expected := int32(tr.MaxConnsPerHost)
if dialCnt != expected {
t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected)
}
if gotConnCnt != expected {
t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected)
}
if ts.TLS != nil && tlsHandshakeCnt != expected {
t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected)
}
if t.Failed() {
t.FailNow()
}
(<-connCh).Close()
tr.CloseIdleConnections()
doReq()
expected++
if dialCnt != expected {
t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt)
}
if gotConnCnt != expected {
t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected)
}
if ts.TLS != nil && tlsHandshakeCnt != expected {
t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected)
}
}
testMaxConns("http", httptest.NewServer(h))
testMaxConns("https", httptest.NewTLSServer(h))
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
ts.StartTLS()
testMaxConns("http2", ts)
}
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, r.RemoteAddr)
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
doReq := func(name string) string {
// Do a POST instead of a GET to prevent the Transport's
// idempotent request retry logic from kicking in...
res, err := c.Post(ts.URL, "", nil)
if err != nil {
t.Fatalf("%s: %v", name, err)
}
if res.StatusCode != 200 {
t.Fatalf("%s: %v", name, res.Status)
}
defer res.Body.Close()
slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("%s: %v", name, err)
}
return string(slurp)
}
first := doReq("first")
keys1 := tr.IdleConnKeysForTesting()
ts.CloseClientConnections()
var keys2 []string
if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool {
keys2 = tr.IdleConnKeysForTesting()
return len(keys2) == 0
}) {
t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2)
}
second := doReq("second")
if first == second {
t.Errorf("expected a different connection between requests. got %q both times", first)
}
}
// Test that the Transport notices when a server hangs up on its
// unexpectedly (a keep-alive connection is closed).
func TestTransportServerClosingUnexpectedly(t *testing.T) {
setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
c := ts.Client()
fetch := func(n, retries int) string {
condFatalf := func(format string, arg ...interface{}) {
if retries <= 0 {
t.Fatalf(format, arg...)
}
t.Logf("retrying shortly after expected error: "+format, arg...)
time.Sleep(time.Second / time.Duration(retries))
}
for retries >= 0 {
retries--
res, err := c.Get(ts.URL)
if err != nil {
condFatalf("error in req #%d, GET: %v", n, err)
continue
}
body, err := io.ReadAll(res.Body)
if err != nil {
condFatalf("error in req #%d, ReadAll: %v", n, err)
continue
}
res.Body.Close()
return string(body)
}
panic("unreachable")
}
body1 := fetch(1, 0)
body2 := fetch(2, 0)
// Close all the idle connections in a way that's similar to
// the server hanging up on us. We don't use
// httptest.Server.CloseClientConnections because it's
// best-effort and stops blocking after 5 seconds. On a loaded
// machine running many tests concurrently it's possible for
// that method to be async and cause the body3 fetch below to
// run on an old connection. This function is synchronous.
ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
body3 := fetch(3, 5)
if body1 != body2 {
t.Errorf("expected body1 and body2 to be equal")
}
if body2 == body3 {
t.Errorf("expected body2 and body3 to be different")
}
}
// Test for https://golang.org/issue/2616 (appropriate issue number)
// This fails pretty reliably with GOMAXPROCS=100 or something high.
func TestStressSurpriseServerCloses(t *testing.T) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in short mode")
}
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "5")
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte("Hello"))
w.(Flusher).Flush()
conn, buf, _ := w.(Hijacker).Hijack()
buf.Flush()
conn.Close()
}))
defer ts.Close()
c := ts.Client()
// Do a bunch of traffic from different goroutines. Send to activityc
// after each request completes, regardless of whether it failed.
// If these are too high, OS X exhausts its ephemeral ports
// and hangs waiting for them to transition TCP states. That's
// not what we want to test. TODO(bradfitz): use an io.Pipe
// dialer for this test instead?
const (
numClients = 20
reqsPerClient = 25
)
activityc := make(chan bool)
for i := 0; i < numClients; i++ {
go func() {
for i := 0; i < reqsPerClient; i++ {
res, err := c.Get(ts.URL)
if err == nil {
// We expect errors since the server is
// hanging up on us after telling us to
// send more requests, so we don't
// actually care what the error is.
// But we want to close the body in cases
// where we won the race.
res.Body.Close()
}
if !<-activityc { // Receives false when close(activityc) is executed
return
}
}
}()
}
// Make sure all the request come back, one way or another.
for i := 0; i < numClients*reqsPerClient; i++ {
select {
case activityc <- true:
case <-time.After(5 * time.Second):
close(activityc)
t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile")
}
}
}
// TestTransportHeadResponses verifies that we deal with Content-Lengths
// with no bodies properly
func TestTransportHeadResponses(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
}
w.Header().Set("Content-Length", "123")
w.WriteHeader(200)
}))
defer ts.Close()
c := ts.Client()
for i := 0; i < 2; i++ {
res, err := c.Head(ts.URL)
if err != nil {
t.Errorf("error on loop %d: %v", i, err)
continue
}
if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
}
if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
}
if all, err := io.ReadAll(res.Body); err != nil {
t.Errorf("loop %d: Body ReadAll: %v", i, err)
} else if len(all) != 0 {
t.Errorf("Bogus body %q", all)
}
}
}
// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
// on responses to HEAD requests.
func TestTransportHeadChunkedResponse(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
}
w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
w.Header().Set("x-client-ipport", r.RemoteAddr)
w.WriteHeader(200)
}))
defer ts.Close()
c := ts.Client()
// Ensure that we wait for the readLoop to complete before
// calling Head again
didRead := make(chan bool)
SetReadLoopBeforeNextReadHook(func() { didRead <- true })
defer SetReadLoopBeforeNextReadHook(nil)
res1, err := c.Head(ts.URL)
<-didRead
if err != nil {
t.Fatalf("request 1 error: %v", err)
}
res2, err := c.Head(ts.URL)
<-didRead
if err != nil {
t.Fatalf("request 2 error: %v", err)
}
if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
}
}
var roundTripTests = []struct {
accept string
expectAccept string
compressed bool
}{
// Requests with no accept-encoding header use transparent compression
{"", "gzip", false},
// Requests with other accept-encoding should pass through unmodified
{"foo", "foo", false},
// Requests with accept-encoding == gzip should be passed through
{"gzip", "gzip", true},
}
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
setParallel(t)
defer afterTest(t)
const responseBody = "test response body"
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding")
if expect := req.FormValue("expect_accept"); accept != expect {
t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
req.FormValue("testnum"), accept, expect)
}
if accept == "gzip" {
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte(responseBody))
gz.Close()
} else {
rw.Header().Set("Content-Encoding", accept)
rw.Write([]byte(responseBody))
}
}))
defer ts.Close()
tr := ts.Client().Transport.(*Transport)
for i, test := range roundTripTests {
// Test basic request (no accept-encoding)
req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
if test.accept != "" {
req.Header.Set("Accept-Encoding", test.accept)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Errorf("%d. RoundTrip: %v", i, err)
continue
}
var body []byte
if test.compressed {
var r *gzip.Reader
r, err = gzip.NewReader(res.Body)
if err != nil {
t.Errorf("%d. gzip NewReader: %v", i, err)
continue
}
body, err = io.ReadAll(r)
res.Body.Close()
} else {
body, err = io.ReadAll(res.Body)
}
if err != nil {
t.Errorf("%d. Error: %q", i, err)
continue
}
if g, e := string(body), responseBody; g != e {
t.Errorf("%d. body = %q; want %q", i, g, e)
}
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
}
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
}
}
}
func TestTransportGzip(t *testing.T) {
setParallel(t)
defer afterTest(t)
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
const nRandBytes = 1024 * 1024
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
if req.Method == "HEAD" {
if g := req.Header.Get("Accept-Encoding"); g != "" {
t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
}
return
}
if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
t.Errorf("Accept-Encoding = %q, want %q", g, e)
}
rw.Header().Set("Content-Encoding", "gzip")
var w io.Writer = rw
var buf bytes.Buffer
if req.FormValue("chunked") == "0" {
w = &buf
defer io.Copy(rw, &buf)
defer func() {
rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
}()
}
gz := gzip.NewWriter(w)
gz.Write([]byte(testString))
if req.FormValue("body") == "large" {
io.CopyN(gz, rand.Reader, nRandBytes)
}
gz.Close()
}))
defer ts.Close()
c := ts.Client()
for _, chunked := range []string{"1", "0"} {
// First fetch something large, but only read some of it.
res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
if err != nil {
t.Fatalf("large get: %v", err)
}
buf := make([]byte, len(testString))
n, err := io.ReadFull(res.Body, buf)
if err != nil {
t.Fatalf("partial read of large response: size=%d, %v", n, err)
}
if e, g := testString, string(buf); e != g {
t.Errorf("partial read got %q, expected %q", g, e)
}
res.Body.Close()
// Read on the body, even though it's closed
n, err = res.Body.Read(buf)
if n != 0 || err == nil {
t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
}
// Then something small.
res, err = c.Get(ts.URL + "/?chunked=" + chunked)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if g, e := string(body), testString; g != e {
t.Fatalf("body = %q; want %q", g, e)
}
if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
t.Fatalf("Content-Encoding = %q; want %q", g, e)
}
// Read on the body after it's been fully read:
n, err = res.Body.Read(buf)
if n != 0 || err == nil {
t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
}
res.Body.Close()
n, err = res.Body.Read(buf)
if n != 0 || err == nil {
t.Errorf("expected Read error after Close; got %d, %v", n, err)
}
}
// And a HEAD request too, because they're always weird.
res, err := c.Head(ts.URL)
if err != nil {
t.Fatalf("Head: %v", err)
}
if res.StatusCode != 200 {
t.Errorf("Head status=%d; want=200", res.StatusCode)
}
}
// If a request has Expect:100-continue header, the request blocks sending body until the first response.
// Premature consumption of the request body should not be occurred.
func TestTransportExpect100Continue(t *testing.T) {
setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
switch req.URL.Path {
case "/100":
// This endpoint implicitly responds 100 Continue and reads body.
if _, err := io.Copy(io.Discard, req.Body); err != nil {
t.Error("Failed to read Body", err)
}
rw.WriteHeader(StatusOK)
case "/200":
// Go 1.5 adds Connection: close header if the client expect
// continue but not entire request body is consumed.
rw.WriteHeader(StatusOK)
case "/500":
rw.WriteHeader(StatusInternalServerError)
case "/keepalive":
// This hijacked endpoint responds error without Connection:close.
_, bufrw, err := rw.(Hijacker).Hijack()
if err != nil {
log.Fatal(err)
}
bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
bufrw.WriteString("Content-Length: 0\r\n\r\n")
bufrw.Flush()
case "/timeout":
// This endpoint tries to read body without 100 (Continue) response.
// After ExpectContinueTimeout, the reading will be started.
conn, bufrw, err := rw.(Hijacker).Hijack()
if err != nil {
log.Fatal(err)
}
if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
t.Error("Failed to read Body", err)
}
bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
bufrw.Flush()
conn.Close()
}
}))
defer ts.Close()
tests := []struct {
path string
body []byte
sent int
status int
}{
{path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent.
{path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent.
{path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent.
{path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent.
{path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent.
}
c := ts.Client()
for i, v := range tests {
tr := &Transport{
ExpectContinueTimeout: 2 * time.Second,
}
defer tr.CloseIdleConnections()
c.Transport = tr
body := bytes.NewReader(v.body)
req, err := NewRequest("PUT", ts.URL+v.path, body)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Expect", "100-continue")
req.ContentLength = int64(len(v.body))
resp, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
sent := len(v.body) - body.Len()
if v.status != resp.StatusCode {
t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
}
if v.sent != sent {
t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
}
}
}
func TestSOCKS5Proxy(t *testing.T) {
defer afterTest(t)
ch := make(chan string, 1)
l := newLocalListener(t)
defer l.Close()
defer close(ch)
proxy := func(t *testing.T) {
s, err := l.Accept()
if err != nil {
t.Errorf("socks5 proxy Accept(): %v", err)
return
}
defer s.Close()
var buf [22]byte
if _, err := io.ReadFull(s, buf[:3]); err != nil {
t.Errorf("socks5 proxy initial read: %v", err)
return
}
if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
return
}
if _, err := s.Write([]byte{5, 0}); err != nil {
t.Errorf("socks5 proxy initial write: %v", err)
return
}
if _, err := io.ReadFull(s, buf[:4]); err != nil {
t.Errorf("socks5 proxy second read: %v", err)
return
}
if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
return
}
var ipLen int
switch buf[3] {
case 1:
ipLen = net.IPv4len
case 4:
ipLen = net.IPv6len
default:
t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
return
}
if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
t.Errorf("socks5 proxy address read: %v", err)
return
}
ip := net.IP(buf[4 : ipLen+4])
port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
copy(buf[:3], []byte{5, 0, 0})
if _, err := s.Write(buf[:ipLen+6]); err != nil {
t.Errorf("socks5 proxy connect write: %v", err)
return
}
ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
// Implement proxying.
targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
targetConn, err := net.Dial("tcp", targetHost)
if err != nil {
t.Errorf("net.Dial failed")
return
}
go io.Copy(targetConn, s)
io.Copy(s, targetConn) // Wait for the client to close the socket.
targetConn.Close()
}
pu, err := url.Parse("socks5://" + l.Addr().String())
if err != nil {
t.Fatal(err)
}
sentinelHeader := "X-Sentinel"
sentinelValue := "12345"
h := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set(sentinelHeader, sentinelValue)
})
for _, useTLS := range []bool{false, true} {
t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
var ts *httptest.Server
if useTLS {
ts = httptest.NewTLSServer(h)
} else {
ts = httptest.NewServer(h)
}
go proxy(t)
c := ts.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
r, err := c.Head(ts.URL)
if err != nil {
t.Fatal(err)
}
if r.Header.Get(sentinelHeader) != sentinelValue {
t.Errorf("Failed to retrieve sentinel value")
}
var got string
select {
case got = <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timeout connecting to socks5 proxy")
}
ts.Close()
tsu, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
want := "proxy for " + tsu.Host
if got != want {
t.Errorf("got %q, want %q", got, want)
}
})
}
}
func TestTransportProxy(t *testing.T) {
defer afterTest(t)
testCases := []struct{ httpsSite, httpsProxy bool }{
{false, false},
{false, true},
{true, false},
{true, true},
}
for _, testCase := range testCases {
httpsSite := testCase.httpsSite
httpsProxy := testCase.httpsProxy
t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) {
siteCh := make(chan *Request, 1)
h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
siteCh <- r
})
proxyCh := make(chan *Request, 1)
h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
proxyCh <- r
// Implement an entire CONNECT proxy
if r.Method == "CONNECT" {
hijacker, ok := w.(Hijacker)
if !ok {
t.Errorf("hijack not allowed")
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
t.Errorf("hijacking failed")
return
}
res := &Response{
StatusCode: StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(Header),
}
targetConn, err := net.Dial("tcp", r.URL.Host)
if err != nil {
t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
return
}
if err := res.Write(clientConn); err != nil {
t.Errorf("Writing 200 OK failed: %v", err)
return
}
go io.Copy(targetConn, clientConn)
go func() {
io.Copy(clientConn, targetConn)
targetConn.Close()
}()
}
})
var ts *httptest.Server
if httpsSite {
ts = httptest.NewTLSServer(h1)
} else {
ts = httptest.NewServer(h1)
}
var proxy *httptest.Server
if httpsProxy {
proxy = httptest.NewTLSServer(h2)
} else {
proxy = httptest.NewServer(h2)
}
pu, err := url.Parse(proxy.URL)
if err != nil {
t.Fatal(err)
}
// If neither server is HTTPS or both are, then c may be derived from either.
// If only one server is HTTPS, c must be derived from that server in order
// to ensure that it is configured to use the fake root CA from testcert.go.
c := proxy.Client()
if httpsSite {
c = ts.Client()
}
c.Transport.(*Transport).Proxy = ProxyURL(pu)
if _, err := c.Head(ts.URL); err != nil {
t.Error(err)
}
var got *Request
select {
case got = <-proxyCh:
case <-time.After(5 * time.Second):
t.Fatal("timeout connecting to http proxy")
}
c.Transport.(*Transport).CloseIdleConnections()
ts.Close()
proxy.Close()
if httpsSite {
// First message should be a CONNECT, asking for a socket to the real server,
if got.Method != "CONNECT" {
t.Errorf("Wrong method for secure proxying: %q", got.Method)
}
gotHost := got.URL.Host
pu, err := url.Parse(ts.URL)
if err != nil {
t.Fatal("Invalid site URL")
}
if wantHost := pu.Host; gotHost != wantHost {
t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
}
// The next message on the channel should be from the site's server.
next := <-siteCh
if next.Method != "HEAD" {
t.Errorf("Wrong method at destination: %s", next.Method)
}
if nextURL := next.URL.String(); nextURL != "/" {
t.Errorf("Wrong URL at destination: %s", nextURL)
}
} else {
if got.Method != "HEAD" {
t.Errorf("Wrong method for destination: %q", got.Method)
}
gotURL := got.URL.String()
wantURL := ts.URL + "/"
if gotURL != wantURL {
t.Errorf("Got URL %q, want %q", gotURL, wantURL)
}
}
})
}
}
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
// when they're slow to reply to HTTPS CONNECT responses.
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
setParallel(t)
defer afterTest(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ln := newLocalListener(t)
defer ln.Close()
listenerDone := make(chan struct{})
go func() {
defer close(listenerDone)
c, err := ln.Accept()
if err != nil {
t.Errorf("Accept: %v", err)
return
}
defer c.Close()
// Read the CONNECT request
br := bufio.NewReader(c)
cr, err := ReadRequest(br)
if err != nil {
t.Errorf("proxy server failed to read CONNECT request")
return
}
if cr.Method != "CONNECT" {
t.Errorf("unexpected method %q", cr.Method)
return
}
// Now hang and never write a response; instead, cancel the request and wait
// for the client to close.
// (Prior to Issue 28012 being fixed, we never closed.)
cancel()
var buf [1]byte
_, err = br.Read(buf[:])
if err != io.EOF {
t.Errorf("proxy server Read err = %v; want EOF", err)
}
return
}()
c := &Client{
Transport: &Transport{
Proxy: func(*Request) (*url.URL, error) {
return url.Parse("http://" + ln.Addr().String())
},
},
}
req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
if err != nil {
t.Fatal(err)
}
_, err = c.Do(req)
if err == nil {
t.Errorf("unexpected Get success")
}
// Wait unconditionally for the listener goroutine to exit: this should never
// hang, so if it does we want a full goroutine dump — and that's exactly what
// the testing package will give us when the test run times out.
<-listenerDone
}
// Issue 16997: test transport dial preserves typed errors
func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
defer afterTest(t)
var errDial = errors.New("some dial error")
tr := &Transport{
Proxy: func(*Request) (*url.URL, error) {
return url.Parse("http://proxy.fake.tld/")
},
Dial: func(string, string) (net.Conn, error) {
return nil, errDial
},
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
req, _ := NewRequest("GET", "http://fake.tld", nil)
res, err := c.Do(req)
if err == nil {
res.Body.Close()
t.Fatal("wanted a non-nil error")
}
uerr, ok := err.(*url.Error)
if !ok {
t.Fatalf("got %T, want *url.Error", err)
}
oe, ok := uerr.Err.(*net.OpError)
if !ok {
t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
}
want := &net.OpError{
Op: "proxyconnect",
Net: "tcp",
Err: errDial, // original error, unwrapped.
}
if !reflect.DeepEqual(oe, want) {
t.Errorf("Got error %#v; want %#v", oe, want)
}
}
// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
//
// (A bug caused dialConn to instead write the per-request Proxy-Authorization
// header through to the shared Header instance, introducing a data race.)
func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
setParallel(t)
defer afterTest(t)
proxy := httptest.NewTLSServer(NotFoundHandler())
defer proxy.Close()
c := proxy.Client()
tr := c.Transport.(*Transport)
tr.Proxy = func(*Request) (*url.URL, error) {
u, _ := url.Parse(proxy.URL)
u.User = url.UserPassword("aladdin", "opensesame")
return u, nil
}
h := tr.ProxyConnectHeader
if h == nil {
h = make(Header)
}
tr.ProxyConnectHeader = h.Clone()
req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
if err != nil {
t.Fatal(err)
}
_, err = c.Do(req)
if err == nil {
t.Errorf("unexpected Get success")
}
if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
}
}
// TestTransportGzipRecursive sends a gzip quine and checks that the
// client gets the same value back. This is more cute than anything,
// but checks that we don't recurse forever, and checks that
// Content-Encoding is removed.
func TestTransportGzipRecursive(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "gzip")
w.Write(rgz)
}))
defer ts.Close()
c := ts.Client()
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(body, rgz) {
t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
body, rgz)
}
if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
t.Fatalf("Content-Encoding = %q; want %q", g, e)
}
}
// golang.org/issue/7750: request fails when server replies with
// a short gzip body
func TestTransportGzipShort(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "gzip")
w.Write([]byte{0x1f, 0x8b})
}))
defer ts.Close()
c := ts.Client()
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
if err == nil {
t.Fatal("Expect an error from reading a body.")
}
if err != io.ErrUnexpectedEOF {
t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
}
}
// Wait until number of goroutines is no greater than nmax, or time out.
func waitNumGoroutine(nmax int) int {
nfinal := runtime.NumGoroutine()
for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
time.Sleep(50 * time.Millisecond)
runtime.GC()
nfinal = runtime.NumGoroutine()
}
return nfinal
}
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
// Not parallel: counts goroutines
defer afterTest(t)
const numReq = 25
gotReqCh := make(chan bool, numReq)
unblockCh := make(chan bool, numReq)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
gotReqCh <- true
<-unblockCh
w.Header().Set("Content-Length", "0")
w.WriteHeader(204)
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
n0 := runtime.NumGoroutine()
didReqCh := make(chan bool, numReq)
failed := make(chan bool, numReq)
for i := 0; i < numReq; i++ {
go func() {
res, err := c.Get(ts.URL)
didReqCh <- true
if err != nil {
t.Logf("client fetch error: %v", err)
failed <- true
return
}
res.Body.Close()
}()
}
// Wait for all goroutines to be stuck in the Handler.
for i := 0; i < numReq; i++ {
select {
case <-gotReqCh:
// ok
case <-failed:
// Not great but not what we are testing:
// sometimes an overloaded system will fail to make all the connections.
}
}
nhigh := runtime.NumGoroutine()
// Tell all handlers to unblock and reply.
close(unblockCh)
// Wait for all HTTP clients to be done.
for i := 0; i < numReq; i++ {
<-didReqCh
}
tr.CloseIdleConnections()
nfinal := waitNumGoroutine(n0 + 5)
growth := nfinal - n0
// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
// Previously we were leaking one per numReq.
if int(growth) > 5 {
t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
t.Error("too many new goroutines")
}
}
// golang.org/issue/4531: Transport leaks goroutines when
// request.ContentLength is explicitly short
func TestTransportPersistConnLeakShortBody(t *testing.T) {
// Not parallel: measures goroutines.
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
n0 := runtime.NumGoroutine()
body := []byte("Hello")
for i := 0; i < 20; i++ {
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
req.ContentLength = int64(len(body) - 2) // explicitly short
_, err = c.Do(req)
if err == nil {
t.Fatal("Expect an error from writing too long of a body.")
}
}
nhigh := runtime.NumGoroutine()
tr.CloseIdleConnections()
nfinal := waitNumGoroutine(n0 + 5)
growth := nfinal - n0
// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
// Previously we were leaking one per numReq.
t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
if int(growth) > 5 {
t.Error("too many new goroutines")
}
}
// A countedConn is a net.Conn that decrements an atomic counter when finalized.
type countedConn struct {
net.Conn
}
// A countingDialer dials connections and counts the number that remain reachable.
type countingDialer struct {
dialer net.Dialer
mu sync.Mutex
total, live int64
}
func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
counted := new(countedConn)
counted.Conn = conn
d.mu.Lock()
defer d.mu.Unlock()
d.total++
d.live++
runtime.SetFinalizer(counted, d.decrement)
return counted, nil
}
func (d *countingDialer) decrement(*countedConn) {
d.mu.Lock()
defer d.mu.Unlock()
d.live--
}
func (d *countingDialer) Read() (total, live int64) {
d.mu.Lock()
defer d.mu.Unlock()
return d.total, d.live
}
func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Close every connection so that it cannot be kept alive.
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Errorf("Hijack failed unexpectedly: %v", err)
return
}
conn.Close()
}))
defer ts.Close()
var d countingDialer
c := ts.Client()
c.Transport.(*Transport).DialContext = d.DialContext
body := []byte("Hello")
for i := 0; ; i++ {
total, live := d.Read()
if live < total {
break
}
if i >= 1<<12 {
t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
}
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
_, err = c.Do(req)
if err == nil {
t.Fatal("expected broken connection")
}
runtime.GC()
}
}
type countedContext struct {
context.Context
}
type contextCounter struct {
mu sync.Mutex
live int64
}
func (cc *contextCounter) Track(ctx context.Context) context.Context {
counted := new(countedContext)
counted.Context = ctx
cc.mu.Lock()
defer cc.mu.Unlock()
cc.live++
runtime.SetFinalizer(counted, cc.decrement)
return counted
}
func (cc *contextCounter) decrement(*countedContext) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.live--
}
func (cc *contextCounter) Read() (live int64) {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.live
}
func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
if runtime.Compiler == "gccgo" {
t.Skip("fails with conservative stack GC")
}
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
runtime.Gosched()
w.WriteHeader(StatusOK)
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).MaxConnsPerHost = 1
ctx := context.Background()
body := []byte("Hello")
doPosts := func(cc *contextCounter) {
var wg sync.WaitGroup
for n := 64; n > 0; n-- {
wg.Add(1)
go func() {
defer wg.Done()
ctx := cc.Track(ctx)
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
if err != nil {
t.Error(err)
}
_, err = c.Do(req.WithContext(ctx))
if err != nil {
t.Errorf("Do failed with error: %v", err)
}
}()
}
wg.Wait()
}
var initialCC contextCounter
doPosts(&initialCC)
// flushCC exists only to put pressure on the GC to finalize the initialCC
// contexts: the flushCC allocations should eventually displace the initialCC
// allocations.
var flushCC contextCounter
for i := 0; ; i++ {
live := initialCC.Read()
if live == 0 {
break
}
if i >= 100 {
t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
}
doPosts(&flushCC)
runtime.GC()
}
}
// This used to crash; https://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) {
defer afterTest(t)
var tr *Transport
unblockCh := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockCh
tr.CloseIdleConnections()
}))
defer ts.Close()
c := ts.Client()
tr = c.Transport.(*Transport)
didreq := make(chan bool)
go func() {
res, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
} else {
res.Body.Close() // returns idle conn
}
didreq <- true
}()
unblockCh <- true
<-didreq
}
// Test that the transport doesn't close the TCP connection early,
// before the response body has been read. This was a regression
// which sadly lacked a triggering test. The large response body made
// the old race easier to trigger.
func TestIssue3644(t *testing.T) {
defer afterTest(t)
const numFoos = 5000
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
for i := 0; i < numFoos; i++ {
w.Write([]byte("foo "))
}
}))
defer ts.Close()
c := ts.Client()
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
bs, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if len(bs) != numFoos*len("foo ") {
t.Errorf("unexpected response length")
}
}
// Test that a client receives a server's reply, even if the server doesn't read
// the entire request body.
func TestIssue3595(t *testing.T) {
setParallel(t)
defer afterTest(t)
const deniedMsg = "sorry, denied."
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
Error(w, deniedMsg, StatusUnauthorized)
}))
defer ts.Close()
c := ts.Client()
res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
if err != nil {
t.Errorf("Post: %v", err)
return
}
got, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body ReadAll: %v", err)
}
if !strings.Contains(string(got), deniedMsg) {
t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
}
}
// From https://golang.org/issue/4454 ,
// "client fails to handle requests with no body and chunked encoding"
func TestChunkedNoContent(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.WriteHeader(StatusNoContent)
}))
defer ts.Close()
c := ts.Client()
for _, closeBody := range []bool{true, false} {
const n = 4
for i := 1; i <= n; i++ {
res, err := c.Get(ts.URL)
if err != nil {
t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
} else {
if closeBody {
res.Body.Close()
}
}
}
}
}
func TestTransportConcurrency(t *testing.T) {
// Not parallel: uses global test hooks.
defer afterTest(t)
maxProcs, numReqs := 16, 500
if testing.Short() {
maxProcs, numReqs = 4, 50
}
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%v", r.FormValue("echo"))
}))
defer ts.Close()
var wg sync.WaitGroup
wg.Add(numReqs)
// Due to the Transport's "socket late binding" (see
// idleConnCh in transport.go), the numReqs HTTP requests
// below can finish with a dial still outstanding. To keep
// the leak checker happy, keep track of pending dials and
// wait for them to finish (and be closed or returned to the
// idle pool) before we close idle connections.
SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
defer SetPendingDialHooks(nil, nil)
c := ts.Client()
reqs := make(chan string)
defer close(reqs)
for i := 0; i < maxProcs*2; i++ {
go func() {
for req := range reqs {
res, err := c.Get(ts.URL + "/?echo=" + req)
if err != nil {
t.Errorf("error on req %s: %v", req, err)
wg.Done()
continue
}
all, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("read error on req %s: %v", req, err)
wg.Done()
continue
}
if string(all) != req {
t.Errorf("body of req %s = %q; want %q", req, all, req)
}
res.Body.Close()
wg.Done()
}
}()
}
for i := 0; i < numReqs; i++ {
reqs <- fmt.Sprintf("request-%d", i)
}
wg.Wait()
}
func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
setParallel(t)
defer afterTest(t)
const debug = false
mux := NewServeMux()
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
io.Copy(w, neverEnding('a'))
})
ts := httptest.NewServer(mux)
defer ts.Close()
timeout := 100 * time.Millisecond
c := ts.Client()
c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
conn, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(timeout))
if debug {
conn = NewLoggingConn("client", conn)
}
return conn, nil
}
getFailed := false
nRuns := 5
if testing.Short() {
nRuns = 1
}
for i := 0; i < nRuns; i++ {
if debug {
println("run", i+1, "of", nRuns)
}
sres, err := c.Get(ts.URL + "/get")
if err != nil {
if !getFailed {
// Make the timeout longer, once.
getFailed = true
t.Logf("increasing timeout")
i--
timeout *= 10
continue
}
t.Errorf("Error issuing GET: %v", err)
break
}
_, err = io.Copy(io.Discard, sres.Body)
if err == nil {
t.Errorf("Unexpected successful copy")
break
}
}
if debug {
println("tests complete; waiting for handlers to finish")
}
}
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
setParallel(t)
defer afterTest(t)
const debug = false
mux := NewServeMux()
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
io.Copy(w, neverEnding('a'))
})
mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
defer r.Body.Close()
io.Copy(io.Discard, r.Body)
})
ts := httptest.NewServer(mux)
timeout := 100 * time.Millisecond
c := ts.Client()
c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
conn, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(timeout))
if debug {
conn = NewLoggingConn("client", conn)
}
return conn, nil
}
getFailed := false
nRuns := 5
if testing.Short() {
nRuns = 1
}
for i := 0; i < nRuns; i++ {
if debug {
println("run", i+1, "of", nRuns)
}
sres, err := c.Get(ts.URL + "/get")
if err != nil {
if !getFailed {
// Make the timeout longer, once.
getFailed = true
t.Logf("increasing timeout")
i--
timeout *= 10
continue
}
t.Errorf("Error issuing GET: %v", err)
break
}
req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
_, err = c.Do(req)
if err == nil {
sres.Body.Close()
t.Errorf("Unexpected successful PUT")
break
}
sres.Body.Close()
}
if debug {
println("tests complete; waiting for handlers to finish")
}
ts.Close()
}
func TestTransportResponseHeaderTimeout(t *testing.T) {
setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping timeout test in -short mode")
}
inHandler := make(chan bool, 1)
mux := NewServeMux()
mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
inHandler <- true
})
mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
inHandler <- true
time.Sleep(2 * time.Second)
})
ts := httptest.NewServer(mux)
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond
tests := []struct {
path string
want int
wantErr string
}{
{path: "/fast", want: 200},
{path: "/slow", wantErr: "timeout awaiting response headers"},
{path: "/fast", want: 200},
}
for i, tt := range tests {
req, _ := NewRequest("GET", ts.URL+tt.path, nil)
req = req.WithT(t)
res, err := c.Do(req)
select {
case <-inHandler:
case <-time.After(5 * time.Second):
t.Errorf("never entered handler for test index %d, %s", i, tt.path)
continue
}
if err != nil {
uerr, ok := err.(*url.Error)
if !ok {
t.Errorf("error is not an url.Error; got: %#v", err)
continue
}
nerr, ok := uerr.Err.(net.Error)
if !ok {
t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
continue
}
if !nerr.Timeout() {
t.Errorf("want timeout error; got: %q", nerr)
continue
}
if strings.Contains(err.Error(), tt.wantErr) {
continue
}
t.Errorf("%d. unexpected error: %v", i, err)
continue
}
if tt.wantErr != "" {
t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
continue
}
if res.StatusCode != tt.want {
t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
}
}
}
func TestTransportCancelRequest(t *testing.T) {
setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Hello")
w.(Flusher).Flush() // send headers and some body
<-unblockc
}))
defer ts.Close()
defer close(unblockc)
c := ts.Client()
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
go func() {
time.Sleep(1 * time.Second)
tr.CancelRequest(req)
}()
t0 := time.Now()
body, err := io.ReadAll(res.Body)
d := time.Since(t0)
if err != ExportErrRequestCanceled {
t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
}
if string(body) != "Hello" {
t.Errorf("Body = %q; want Hello", body)
}
if d < 500*time.Millisecond {
t.Errorf("expected ~1 second delay; got %v", d)
}
// Verify no outstanding requests after readLoop/writeLoop
// goroutines shut down.
for tries := 5; tries > 0; tries-- {
n := tr.NumPendingRequestsForTesting()
if n == 0 {
break
}
time.Sleep(100 * time.Millisecond)
if tries == 1 {
t.Errorf("pending requests = %d; want 0", n)
}
}
}
func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
}))
defer ts.Close()
defer close(unblockc)
c := ts.Client()
tr := c.Transport.(*Transport)
donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
go func() {
defer close(donec)
c.Do(req)
}()
start := time.Now()
timeout := 10 * time.Second
for time.Since(start) < timeout {
time.Sleep(100 * time.Millisecond)
tr.CancelRequest(req)
select {
case <-donec:
return
default:
}
}
t.Errorf("Do of canceled request has not returned after %v", timeout)
}
func TestTransportCancelRequestInDo(t *testing.T) {
testTransportCancelRequestInDo(t, nil)
}
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
}
func TestTransportCancelRequestInDial(t *testing.T) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
var logbuf bytes.Buffer
eventLog := log.New(&logbuf, "", 0)
unblockDial := make(chan bool)
defer close(unblockDial)
inDial := make(chan bool)
tr := &Transport{
Dial: func(network, addr string) (net.Conn, error) {
eventLog.Println("dial: blocking")
if !<-inDial {
return nil, errors.New("main Test goroutine exited")
}
<-unblockDial
return nil, errors.New("nope")
},
}
cl := &Client{Transport: tr}
gotres := make(chan bool)
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
go func() {
_, err := cl.Do(req)
eventLog.Printf("Get = %v", err)
gotres <- true
}()
select {
case inDial <- true:
case <-time.After(5 * time.Second):
close(inDial)
t.Fatal("timeout; never saw blocking dial")
}
eventLog.Printf("canceling")
tr.CancelRequest(req)
tr.CancelRequest(req) // used to panic on second call
select {
case <-gotres:
case <-time.After(5 * time.Second):
panic("hang. events are: " + logbuf.String())
}
got := logbuf.String()
want := `dial: blocking
canceling
Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
`
if got != want {
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
}
}
func TestCancelRequestWithChannel(t *testing.T) {
setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Hello")
w.(Flusher).Flush() // send headers and some body
<-unblockc
}))
defer ts.Close()
defer close(unblockc)
c := ts.Client()
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
ch := make(chan struct{})
req.Cancel = ch
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
go func() {
time.Sleep(1 * time.Second)
close(ch)
}()
t0 := time.Now()
body, err := io.ReadAll(res.Body)
d := time.Since(t0)
if err != ExportErrRequestCanceled {
t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
}
if string(body) != "Hello" {
t.Errorf("Body = %q; want Hello", body)
}
if d < 500*time.Millisecond {
t.Errorf("expected ~1 second delay; got %v", d)
}
// Verify no outstanding requests after readLoop/writeLoop
// goroutines shut down.
for tries := 5; tries > 0; tries-- {
n := tr.NumPendingRequestsForTesting()
if n == 0 {
break
}
time.Sleep(100 * time.Millisecond)
if tries == 1 {
t.Errorf("pending requests = %d; want 0", n)
}
}
}
func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
testCancelRequestWithChannelBeforeDo(t, false)
}
func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
testCancelRequestWithChannelBeforeDo(t, true)
}
func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
setParallel(t)
defer afterTest(t)
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
}))
defer ts.Close()
defer close(unblockc)
c := ts.Client()
req, _ := NewRequest("GET", ts.URL, nil)
if withCtx {
ctx, cancel := context.WithCancel(context.Background())
cancel()
req = req.WithContext(ctx)
} else {
ch := make(chan struct{})
req.Cancel = ch
close(ch)
}
_, err := c.Do(req)
if ue, ok := err.(*url.Error); ok {
err = ue.Err
}
if withCtx {
if err != context.Canceled {
t.Errorf("Do error = %v; want %v", err, context.Canceled)
}
} else {
if err == nil || !strings.Contains(err.Error(), "canceled") {
t.Errorf("Do error = %v; want cancellation", err)
}
}
}
// Issue 11020. The returned error message should be errRequestCanceled
func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
defer afterTest(t)
serverConnCh := make(chan net.Conn, 1)
tr := &Transport{
Dial: func(network, addr string) (net.Conn, error) {
cc, sc := net.Pipe()
serverConnCh <- sc
return cc, nil
},
}
defer tr.CloseIdleConnections()
errc := make(chan error, 1)
req, _ := NewRequest("GET", "http://example.com/", nil)
go func() {
_, err := tr.RoundTrip(req)
errc <- err
}()
sc := <-serverConnCh
verb := make([]byte, 3)
if _, err := io.ReadFull(sc, verb); err != nil {
t.Errorf("Error reading HTTP verb from server: %v", err)
}
if string(verb) != "GET" {
t.Errorf("server received %q; want GET", verb)
}
defer sc.Close()
tr.CancelRequest(req)
err := <-errc
if err == nil {
t.Fatalf("unexpected success from RoundTrip")
}
if err != ExportErrRequestCanceled {
t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
}
}
// golang.org/issue/3672 -- Client can't close HTTP stream
// Calling Close on a Response.Body used to just read until EOF.
// Now it actually closes the TCP connection.
func TestTransportCloseResponseBody(t *testing.T) {
defer afterTest(t)
writeErr := make(chan error, 1)
msg := []byte("young\n")
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
for {
_, err := w.Write(msg)
if err != nil {
writeErr <- err
return
}
w.(Flusher).Flush()
}
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
defer tr.CancelRequest(req)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
const repeats = 3
buf := make([]byte, len(msg)*repeats)
want := bytes.Repeat(msg, repeats)
_, err = io.ReadFull(res.Body, buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, want) {
t.Fatalf("read %q; want %q", buf, want)
}
didClose := make(chan error, 1)
go func() {
didClose <- res.Body.Close()
}()
select {
case err := <-didClose:
if err != nil {
t.Errorf("Close = %v", err)
}
case <-time.After(10 * time.Second):
t.Fatal("too long waiting for close")
}
select {
case err := <-writeErr:
if err == nil {
t.Errorf("expected non-nil write error")
}
case <-time.After(10 * time.Second):
t.Fatal("too long waiting for write error")
}
}
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {
res := &Response{
Status: "200 OK",
StatusCode: 200,
Header: make(Header),
Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
}
return res, nil
}
func TestTransportAltProto(t *testing.T) {
defer afterTest(t)
tr := &Transport{}
c := &Client{Transport: tr}
tr.RegisterProtocol("foo", fooProto{})
res, err := c.Get("foo://bar.com/path")
if err != nil {
t.Fatal(err)
}
bodyb, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
body := string(bodyb)
if e := "You wanted foo://bar.com/path"; body != e {
t.Errorf("got response %q, want %q", body, e)
}
}
func TestTransportNoHost(t *testing.T) {
defer afterTest(t)
tr := &Transport{}
_, err := tr.RoundTrip(&Request{
Header: make(Header),
URL: &url.URL{
Scheme: "http",
},
})
want := "http: no Host in request URL"
if got := fmt.Sprint(err); got != want {
t.Errorf("error = %v; want %q", err, want)
}
}
// Issue 13311
func TestTransportEmptyMethod(t *testing.T) {
req, _ := NewRequest("GET", "http://foo.com/", nil)
req.Method = "" // docs say "For client requests an empty string means GET"
got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(got), "GET ") {
t.Fatalf("expected substring 'GET '; got: %s", got)
}
}
func TestTransportSocketLateBinding(t *testing.T) {
setParallel(t)
defer afterTest(t)
mux := NewServeMux()
fooGate := make(chan bool, 1)
mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
w.Header().Set("foo-ipport", r.RemoteAddr)
w.(Flusher).Flush()
<-fooGate
})
mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
w.Header().Set("bar-ipport", r.RemoteAddr)
})
ts := httptest.NewServer(mux)
defer ts.Close()
dialGate := make(chan bool, 1)
c := ts.Client()
c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
if <-dialGate {
return net.Dial(n, addr)
}
return nil, errors.New("manually closed")
}
dialGate <- true // only allow one dial
fooRes, err := c.Get(ts.URL + "/foo")
if err != nil {
t.Fatal(err)
}
fooAddr := fooRes.Header.Get("foo-ipport")
if fooAddr == "" {
t.Fatal("No addr on /foo request")
}
time.AfterFunc(200*time.Millisecond, func() {
// let the foo response finish so we can use its
// connection for /bar
fooGate <- true
io.Copy(io.Discard, fooRes.Body)
fooRes.Body.Close()
})
barRes, err := c.Get(ts.URL + "/bar")
if err != nil {
t.Fatal(err)
}
barAddr := barRes.Header.Get("bar-ipport")
if barAddr != fooAddr {
t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
}
barRes.Body.Close()
dialGate <- false
}
// Issue 2184
func TestTransportReading100Continue(t *testing.T) {
defer afterTest(t)
const numReqs = 5
reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
defer w.Close()
defer r.Close()
br := bufio.NewReader(r)
n := 0
for {
n++
req, err := ReadRequest(br)
if err == io.EOF {
return
}
if err != nil {
t.Error(err)
return
}
slurp, err := io.ReadAll(req.Body)
if err != nil {
t.Errorf("Server request body slurp: %v", err)
return
}
id := req.Header.Get("Request-Id")
resCode := req.Header.Get("X-Want-Response-Code")
if resCode == "" {
resCode = "100 Continue"
if string(slurp) != reqBody(n) {
t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
}
}
body := fmt.Sprintf("Response number %d", n)
v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
Date: Thu, 28 Feb 2013 17:55:41 GMT
HTTP/1.1 200 OK
Content-Type: text/html
Echo-Request-Id: %s
Content-Length: %d
%s`, resCode, id, len(body), body), "\n", "\r\n", -1))
w.Write(v)
if id == reqID(numReqs) {
return
}
}
}
tr := &Transport{
Dial: func(n, addr string) (net.Conn, error) {
sr, sw := io.Pipe() // server read/write
cr, cw := io.Pipe() // client read/write
conn := &rwTestConn{
Reader: cr,
Writer: sw,
closeFunc: func() error {
sw.Close()
cw.Close()
return nil
},
}
go send100Response(cw, sr)
return conn, nil
},
DisableKeepAlives: false,
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
testResponse := func(req *Request, name string, wantCode int) {
t.Helper()
res, err := c.Do(req)
if err != nil {
t.Fatalf("%s: Do: %v", name, err)
}
if res.StatusCode != wantCode {
t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
}
if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
t.Errorf("%s: response id %q != request id %q", name, idBack, id)
}
_, err = io.ReadAll(res.Body)
if err != nil {
t.Fatalf("%s: Slurp error: %v", name, err)
}
}
// Few 100 responses, making sure we're not off-by-one.
for i := 1; i <= numReqs; i++ {
req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
req.Header.Set("Request-Id", reqID(i))
testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
}
}
// Issue 17739: the HTTP client must ignore any unknown 1xx
// informational responses before the actual response.
func TestTransportIgnore1xxResponses(t *testing.T) {
setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, buf, _ := w.(Hijacker).Hijack()
buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
buf.Flush()
conn.Close()
}))
defer cst.close()
cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
var got bytes.Buffer
req, _ := NewRequest("GET", cst.ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
return nil
},
}))
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
res.Write(&got)
want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
if got.String() != want {
t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want)
}
}
func TestTransportLimits1xxResponses(t *testing.T) {
setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, buf, _ := w.(Hijacker).Hijack()
for i := 0; i < 10; i++ {
buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
}
buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
buf.Flush()
conn.Close()
}))
defer cst.close()
cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
res, err := cst.c.Get(cst.ts.URL)
if res != nil {
defer res.Body.Close()
}
got := fmt.Sprint(err)
wantSub := "too many 1xx informational responses"
if !strings.Contains(got, wantSub) {
t.Errorf("Get error = %v; want substring %q", err, wantSub)
}
}
// Issue 26161: the HTTP client must treat 101 responses
// as the final response.
func TestTransportTreat101Terminal(t *testing.T) {
setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, buf, _ := w.(Hijacker).Hijack()
buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
buf.Flush()
conn.Close()
}))
defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
if res.StatusCode != StatusSwitchingProtocols {
t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
}
}
type proxyFromEnvTest struct {
req string // URL to fetch; blank means "http://example.com"
env string // HTTP_PROXY
httpsenv string // HTTPS_PROXY
noenv string // NO_PROXY
reqmeth string // REQUEST_METHOD
want string
wanterr error
}
func (t proxyFromEnvTest) String() string {
var buf bytes.Buffer
space := func() {
if buf.Len() > 0 {
buf.WriteByte(' ')
}
}
if t.env != "" {
fmt.Fprintf(&buf, "http_proxy=%q", t.env)
}
if t.httpsenv != "" {
space()
fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
}
if t.noenv != "" {
space()
fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
}
if t.reqmeth != "" {
space()
fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
}
req := "http://example.com"
if t.req != "" {
req = t.req
}
space()
fmt.Fprintf(&buf, "req=%q", req)
return strings.TrimSpace(buf.String())
}
var proxyFromEnvTests = []proxyFromEnvTest{
{env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
{env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
{env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
{env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
{env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
{env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
{env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
// Don't use secure for http
{req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
// Use secure for https.
{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
// Issue 16405: don't use HTTP_PROXY in a CGI environment,
// where HTTP_PROXY can be attacker-controlled.
{env: "http://10.1.2.3:8080", reqmeth: "POST",
want: "<nil>",
wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
{want: "<nil>"},
{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
{noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
}
func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
t.Helper()
reqURL := tt.req
if reqURL == "" {
reqURL = "http://example.com"
}
req, _ := NewRequest("GET", reqURL, nil)
url, err := proxyForRequest(req)
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
t.Errorf("%v: got error = %q, want %q", tt, g, e)
return
}
if got := fmt.Sprintf("%s", url); got != tt.want {
t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
}
}
func TestProxyFromEnvironment(t *testing.T) {
ResetProxyEnv()
defer ResetProxyEnv()
for _, tt := range proxyFromEnvTests {
testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
os.Setenv("HTTP_PROXY", tt.env)
os.Setenv("HTTPS_PROXY", tt.httpsenv)
os.Setenv("NO_PROXY", tt.noenv)
os.Setenv("REQUEST_METHOD", tt.reqmeth)
ResetCachedEnvironment()
return ProxyFromEnvironment(req)
})
}
}
func TestProxyFromEnvironmentLowerCase(t *testing.T) {
ResetProxyEnv()
defer ResetProxyEnv()
for _, tt := range proxyFromEnvTests {
testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
os.Setenv("http_proxy", tt.env)
os.Setenv("https_proxy", tt.httpsenv)
os.Setenv("no_proxy", tt.noenv)
os.Setenv("REQUEST_METHOD", tt.reqmeth)
ResetCachedEnvironment()
return ProxyFromEnvironment(req)
})
}
}
func TestIdleConnChannelLeak(t *testing.T) {
// Not parallel: uses global test hooks.
var mu sync.Mutex
var n int
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
n++
mu.Unlock()
}))
defer ts.Close()
const nReqs = 5
didRead := make(chan bool, nReqs)
SetReadLoopBeforeNextReadHook(func() { didRead <- true })
defer SetReadLoopBeforeNextReadHook(nil)
c := ts.Client()
tr := c.Transport.(*Transport)
tr.Dial = func(netw, addr string) (net.Conn, error) {
return net.Dial(netw, ts.Listener.Addr().String())
}
// First, without keep-alives.
for _, disableKeep := range []bool{true, false} {
tr.DisableKeepAlives = disableKeep
for i := 0; i < nReqs; i++ {
_, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
if err != nil {
t.Fatal(err)
}
// Note: no res.Body.Close is needed here, since the
// response Content-Length is zero. Perhaps the test
// should be more explicit and use a HEAD, but tests
// elsewhere guarantee that zero byte responses generate
// a "Content-Length: 0" instead of chunking.
}
// At this point, each of the 5 Transport.readLoop goroutines
// are scheduling noting that there are no response bodies (see
// earlier comment), and are then calling putIdleConn, which
// decrements this count. Usually that happens quickly, which is
// why this test has seemed to work for ages. But it's still
// racey: we have wait for them to finish first. See Issue 10427
for i := 0; i < nReqs; i++ {
<-didRead
}
if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
}
}
}
// Verify the status quo: that the Client.Post function coerces its
// body into a ReadCloser if it's a Closer, and that the Transport
// then closes it.
func TestTransportClosesRequestBody(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(io.Discard, r.Body)
}))
defer ts.Close()
c := ts.Client()
closes := 0
res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if closes != 1 {
t.Errorf("closes = %d; want 1", closes)
}
}
func TestTransportTLSHandshakeTimeout(t *testing.T) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping in short mode")
}
ln := newLocalListener(t)
defer ln.Close()
testdonec := make(chan struct{})
defer close(testdonec)
go func() {
c, err := ln.Accept()
if err != nil {
t.Error(err)
return
}
<-testdonec
c.Close()
}()
getdonec := make(chan struct{})
go func() {
defer close(getdonec)
tr := &Transport{
Dial: func(_, _ string) (net.Conn, error) {
return net.Dial("tcp", ln.Addr().String())
},
TLSHandshakeTimeout: 250 * time.Millisecond,
}
cl := &Client{Transport: tr}
_, err := cl.Get("https://dummy.tld/")
if err == nil {
t.Error("expected error")
return
}
ue, ok := err.(*url.Error)
if !ok {
t.Errorf("expected url.Error; got %#v", err)
return
}
ne, ok := ue.Err.(net.Error)
if !ok {
t.Errorf("expected net.Error; got %#v", err)
return
}
if !ne.Timeout() {
t.Errorf("expected timeout error; got %v", err)
}
if !strings.Contains(err.Error(), "handshake timeout") {
t.Errorf("expected 'handshake timeout' in error; got %v", err)
}
}()
select {
case <-getdonec:
case <-time.After(5 * time.Second):
t.Error("test timeout; TLS handshake hung?")
}
}
// Trying to repro golang.org/issue/3514
func TestTLSServerClosesConnection(t *testing.T) {
defer afterTest(t)
closedc := make(chan bool, 1)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
conn, _, _ := w.(Hijacker).Hijack()
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
conn.Close()
closedc <- true
return
}
fmt.Fprintf(w, "hello")
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
var nSuccess = 0
var errs []error
const trials = 20
for i := 0; i < trials; i++ {
tr.CloseIdleConnections()
res, err := c.Get(ts.URL + "/keep-alive-then-die")
if err != nil {
t.Fatal(err)
}
<-closedc
slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(slurp) != "foo" {
t.Errorf("Got %q, want foo", slurp)
}
// Now try again and see if we successfully
// pick a new connection.
res, err = c.Get(ts.URL + "/")
if err != nil {
errs = append(errs, err)
continue
}
slurp, err = io.ReadAll(res.Body)
if err != nil {
errs = append(errs, err)
continue
}
nSuccess++
}
if nSuccess > 0 {
t.Logf("successes = %d of %d", nSuccess, trials)
} else {
t.Errorf("All runs failed:")
}
for _, err := range errs {
t.Logf(" err: %v", err)
}
}
// byteFromChanReader is an io.Reader that reads a single byte at a
// time from the channel. When the channel is closed, the reader
// returns io.EOF.
type byteFromChanReader chan byte
func (c byteFromChanReader) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return
}
b, ok := <-c
if !ok {
return 0, io.EOF
}
p[0] = b
return 1, nil
}
// Verifies that the Transport doesn't reuse a connection in the case
// where the server replies before the request has been fully
// written. We still honor that reply (see TestIssue3595), but don't
// send future requests on the connection because it's then in a
// questionable state.
// golang.org/issue/7569
func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
setParallel(t)
defer afterTest(t)
var sconn struct {
sync.Mutex
c net.Conn
}
var getOkay bool
closeConn := func() {
sconn.Lock()
defer sconn.Unlock()
if sconn.c != nil {
sconn.c.Close()
sconn.c = nil
if !getOkay {
t.Logf("Closed server connection")
}
}
}
defer closeConn()
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method == "GET" {
io.WriteString(w, "bar")
return
}
conn, _, _ := w.(Hijacker).Hijack()
sconn.Lock()
sconn.c = conn
sconn.Unlock()
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
go io.Copy(io.Discard, conn)
}))
defer ts.Close()
c := ts.Client()
const bodySize = 256 << 10
finalBit := make(byteFromChanReader, 1)
req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
req.ContentLength = bodySize
res, err := c.Do(req)
if err := wantBody(res, err, "foo"); err != nil {
t.Errorf("POST response: %v", err)
}
donec := make(chan bool)
go func() {
defer close(donec)
res, err = c.Get(ts.URL)
if err := wantBody(res, err, "bar"); err != nil {
t.Errorf("GET response: %v", err)
return
}
getOkay = true // suppress test noise
}()
time.AfterFunc(5*time.Second, closeConn)
select {
case <-donec:
finalBit <- 'x' // unblock the writeloop of the first Post
close(finalBit)
case <-time.After(7 * time.Second):
t.Fatal("timeout waiting for GET request to finish")
}
}
// Tests that we don't leak Transport persistConn.readLoop goroutines
// when a server hangs up immediately after saying it would keep-alive.
func TestTransportIssue10457(t *testing.T) {
defer afterTest(t) // used to fail in goroutine leak check
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Send a response with no body, keep-alive
// (implicit), and then lie and immediately close the
// connection. This forces the Transport's readLoop to
// immediately Peek an io.EOF and get to the point
// that used to hang.
conn, _, _ := w.(Hijacker).Hijack()
conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
conn.Close()
}))
defer ts.Close()
c := ts.Client()
res, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
defer res.Body.Close()
// Just a sanity check that we at least get the response. The real
// test here is that the "defer afterTest" above doesn't find any
// leaked goroutines.
if got, want := res.Header.Get("Foo"), "Bar"; got != want {
t.Errorf("Foo header = %q; want %q", got, want)
}
}
type closerFunc func() error
func (f closerFunc) Close() error { return f() }
type writerFuncConn struct {
net.Conn
write func(p []byte) (n int, err error)
}
func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
// Issues 4677, 18241, and 17844. If we try to reuse a connection that the
// server is in the process of closing, we may end up successfully writing out
// our request (or a portion of our request) only to find a connection error
// when we try to read from (or finish writing to) the socket.
//
// NOTE: we resend a request only if:
// - we reused a keep-alive connection
// - we haven't yet received any header data
// - either we wrote no bytes to the server, or the request is idempotent
// This automatically prevents an infinite resend loop because we'll run out of
// the cached keep-alive connections eventually.
func TestRetryRequestsOnError(t *testing.T) {
newRequest := func(method, urlStr string, body io.Reader) *Request {
req, err := NewRequest(method, urlStr, body)
if err != nil {
t.Fatal(err)
}
return req
}
testCases := []struct {
name string
failureN int
failureErr error
// Note that we can't just re-use the Request object across calls to c.Do
// because we need to rewind Body between calls. (GetBody is only used to
// rewind Body on failure and redirects, not just because it's done.)
req func() *Request
reqString string
}{
{
name: "IdempotentNoBodySomeWritten",
// Believe that we've written some bytes to the server, so we know we're
// not just in the "retry when no bytes sent" case".
failureN: 1,
// Use the specific error that shouldRetryRequest looks for with idempotent requests.
failureErr: ExportErrServerClosedIdle,
req: func() *Request {
return newRequest("GET", "http://fake.golang", nil)
},
reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
},
{
name: "IdempotentGetBodySomeWritten",
// Believe that we've written some bytes to the server, so we know we're
// not just in the "retry when no bytes sent" case".
failureN: 1,
// Use the specific error that shouldRetryRequest looks for with idempotent requests.
failureErr: ExportErrServerClosedIdle,
req: func() *Request {
return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
},
reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
},
{
name: "NothingWrittenNoBody",
// It's key that we return 0 here -- that's what enables Transport to know
// that nothing was written, even though this is a non-idempotent request.
failureN: 0,
failureErr: errors.New("second write fails"),
req: func() *Request {
return newRequest("DELETE", "http://fake.golang", nil)
},
reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
},
{
name: "NothingWrittenGetBody",
// It's key that we return 0 here -- that's what enables Transport to know
// that nothing was written, even though this is a non-idempotent request.
failureN: 0,
failureErr: errors.New("second write fails"),
// Note that NewRequest will set up GetBody for strings.Reader, which is
// required for the retry to occur
req: func() *Request {
return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
},
reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defer afterTest(t)
var (
mu sync.Mutex
logbuf bytes.Buffer
)
logf := func(format string, args ...interface{}) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&logbuf, format, args...)
logbuf.WriteByte('\n')
}
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
logf("Handler")
w.Header().Set("X-Status", "ok")
}))
defer ts.Close()
var writeNumAtomic int32
c := ts.Client()
c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
logf("Dial")
c, err := net.Dial(network, ts.Listener.Addr().String())
if err != nil {
logf("Dial error: %v", err)
return nil, err
}
return &writerFuncConn{
Conn: c,
write: func(p []byte) (n int, err error) {
if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
logf("intentional write failure")
return tc.failureN, tc.failureErr
}
logf("Write(%q)", p)
return c.Write(p)
},
}, nil
}
SetRoundTripRetried(func() {
logf("Retried.")
})
defer SetRoundTripRetried(nil)
for i := 0; i < 3; i++ {
t0 := time.Now()
req := tc.req()
res, err := c.Do(req)
if err != nil {
if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 {
mu.Lock()
got := logbuf.String()
mu.Unlock()
t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
}
t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse)
}
res.Body.Close()
if res.Request != req {
t.Errorf("Response.Request != original request; want identical Request")
}
}
mu.Lock()
got := logbuf.String()
mu.Unlock()
want := fmt.Sprintf(`Dial
Write("%s")
Handler
intentional write failure
Retried.
Dial
Write("%s")
Handler
Write("%s")
Handler
`, tc.reqString, tc.reqString, tc.reqString)
if got != want {
t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
}
})
}
}
// Issue 6981
func TestTransportClosesBodyOnError(t *testing.T) {
setParallel(t)
defer afterTest(t)
readBody := make(chan error, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := io.ReadAll(r.Body)
readBody <- err
}))
defer ts.Close()
c := ts.Client()
fakeErr := errors.New("fake error")
didClose := make(chan bool, 1)
req, _ := NewRequest("POST", ts.URL, struct {
io.Reader
io.Closer
}{
io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
closerFunc(func() error {
select {
case didClose <- true:
default:
}
return nil
}),
})
res, err := c.Do(req)
if res != nil {
defer res.Body.Close()
}
if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
}
select {
case err := <-readBody:
if err == nil {
t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
}
case <-time.After(5 * time.Second):
t.Error("timeout waiting for server handler to complete")
}
select {
case <-didClose:
default:
t.Errorf("didn't see Body.Close")
}
}
func TestTransportDialTLS(t *testing.T) {
setParallel(t)
defer afterTest(t)
var mu sync.Mutex // guards following
var gotReq, didDial bool
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
mu.Lock()
didDial = true
mu.Unlock()
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
if err != nil {
return nil, err
}
return c, c.Handshake()
}
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if !didDial {
t.Error("didn't use dial hook")
}
}
func TestTransportDialContext(t *testing.T) {
setParallel(t)
defer afterTest(t)
var mu sync.Mutex // guards following
var gotReq bool
var receivedContext context.Context
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
receivedContext = ctx
mu.Unlock()
return net.Dial(netw, addr)
}
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), "some-key", "some-value")
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if receivedContext != ctx {
t.Error("didn't receive correct context")
}
}
func TestTransportDialTLSContext(t *testing.T) {
setParallel(t)
defer afterTest(t)
var mu sync.Mutex // guards following
var gotReq bool
var receivedContext context.Context
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
receivedContext = ctx
mu.Unlock()
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
if err != nil {
return nil, err
}
return c, c.Handshake()
}
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), "some-key", "some-value")
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if receivedContext != ctx {
t.Error("didn't receive correct context")
}
}
// Test for issue 8755
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
func TestRoundTripReturnsProxyError(t *testing.T) {
badProxy := func(*Request) (*url.URL, error) {
return nil, errors.New("errorMessage")
}
tr := &Transport{Proxy: badProxy}
req, _ := NewRequest("GET", "http://example.com", nil)
_, err := tr.RoundTrip(req)
if err == nil {
t.Error("Expected proxy error to be returned by RoundTrip")
}
}
// tests that putting an idle conn after a call to CloseIdleConns does return it
func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
tr := &Transport{}
wantIdle := func(when string, n int) bool {
got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
if got == n {
return true
}
t.Errorf("%s: idle conns = %d; want %d", when, got, n)
return false
}
wantIdle("start", 0)
if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("put failed")
}
if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("second put failed")
}
wantIdle("after put", 2)
tr.CloseIdleConnections()
if !tr.IsIdleForTesting() {
t.Error("should be idle after CloseIdleConnections")
}
wantIdle("after close idle", 0)
if tr.PutIdleTestConn("http", "example.com") {
t.Fatal("put didn't fail")
}
wantIdle("after second put", 0)
tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
if tr.IsIdleForTesting() {
t.Error("shouldn't be idle after QueueForIdleConnForTesting")
}
if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("after re-activation")
}
wantIdle("after final put", 1)
}
// Test for issue 34282
// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn
func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
tr := &Transport{}
wantIdle := func(when string, n int) bool {
got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
if got == n {
return true
}
t.Errorf("%s: idle conns = %d; want %d", when, got, n)
return false
}
wantIdle("start", 0)
alt := funcRoundTripper(func() {})
if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
t.Fatal("put failed")
}
wantIdle("after put", 1)
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
GotConn: func(httptrace.GotConnInfo) {
// tr.getConn should leave it for the HTTP/2 alt to call GotConn.
t.Error("GotConn called")
},
})
req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
_, err := tr.RoundTrip(req)
if err != errFakeRoundTrip {
t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
}
wantIdle("after round trip", 1)
}
func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
trFunc := func(tr *Transport) {
tr.MaxConnsPerHost = 1
tr.MaxIdleConnsPerHost = 1
tr.IdleConnTimeout = 10 * time.Millisecond
}
cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
defer cst.close()
if _, err := cst.c.Get(cst.ts.URL); err != nil {
t.Fatalf("got error: %s", err)
}
time.Sleep(100 * time.Millisecond)
got := make(chan error)
go func() {
if _, err := cst.c.Get(cst.ts.URL); err != nil {
got <- err
}
close(got)
}()
timeout := time.NewTimer(5 * time.Second)
defer timeout.Stop()
select {
case err := <-got:
if err != nil {
t.Fatalf("got error: %s", err)
}
case <-timeout.C:
t.Fatal("request never completed")
}
}
// This tests that a client requesting a content range won't also
// implicitly ask for gzip support. If they want that, they need to do it
// on their own.
// golang.org/issue/8923
func TestTransportRangeAndGzip(t *testing.T) {
defer afterTest(t)
reqc := make(chan *Request, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
reqc <- r
}))
defer ts.Close()
c := ts.Client()
req, _ := NewRequest("GET", ts.URL, nil)
req.Header.Set("Range", "bytes=7-11")
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
select {
case r := <-reqc:
if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
t.Error("Transport advertised gzip support in the Accept header")
}
if r.Header.Get("Range") == "" {
t.Error("no Range in request")
}
case <-time.After(10 * time.Second):
t.Fatal("timeout")
}
res.Body.Close()
}
// Test for issue 10474
func TestTransportResponseCancelRace(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// important that this response has a body.
var b [1024]byte
w.Write(b[:])
}))
defer ts.Close()
tr := ts.Client().Transport.(*Transport)
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
// If we do an early close, Transport just throws the connection away and
// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
// so read the body
if _, err := io.Copy(io.Discard, res.Body); err != nil {
t.Fatal(err)
}
req2, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
tr.CancelRequest(req)
res, err = tr.RoundTrip(req2)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
// Test for issue 19248: Content-Encoding's value is case insensitive.
func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
setParallel(t)
defer afterTest(t)
for _, ce := range []string{"gzip", "GZIP"} {
ce := ce
t.Run(ce, func(t *testing.T) {
const encodedString = "Hello Gopher"
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", ce)
gz := gzip.NewWriter(w)
gz.Write([]byte(encodedString))
gz.Close()
}))
defer ts.Close()
res, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(body) != encodedString {
t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
}
})
}
}
func TestTransportDialCancelRace(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
tr := ts.Client().Transport.(*Transport)
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
SetEnterRoundTripHook(func() {
tr.CancelRequest(req)
})
defer SetEnterRoundTripHook(nil)
res, err := tr.RoundTrip(req)
if err != ExportErrRequestCanceled {
t.Errorf("expected canceled request error; got %v", err)
if err == nil {
res.Body.Close()
}
}
}
// logWritesConn is a net.Conn that logs each Write call to writes
// and then proxies to w.
// It proxies Read calls to a reader it receives from rch.
type logWritesConn struct {
net.Conn // nil. crash on use.
w io.Writer
rch <-chan io.Reader
r io.Reader // nil until received by rch
mu sync.Mutex
writes []string
}
func (c *logWritesConn) Write(p []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.writes = append(c.writes, string(p))
return c.w.Write(p)
}
func (c *logWritesConn) Read(p []byte) (n int, err error) {
if c.r == nil {
c.r = <-c.rch
}
return c.r.Read(p)
}
func (c *logWritesConn) Close() error { return nil }
// Issue 6574
func TestTransportFlushesBodyChunks(t *testing.T) {
defer afterTest(t)
resBody := make(chan io.Reader, 1)
connr, connw := io.Pipe() // connection pipe pair
lw := &logWritesConn{
rch: resBody,
w: connw,
}
tr := &Transport{
Dial: func(network, addr string) (net.Conn, error) {
return lw, nil
},
}
bodyr, bodyw := io.Pipe() // body pipe pair
go func() {
defer bodyw.Close()
for i := 0; i < 3; i++ {
fmt.Fprintf(bodyw, "num%d\n", i)
}
}()
resc := make(chan *Response)
go func() {
req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
req.Header.Set("User-Agent", "x") // known value for test
res, err := tr.RoundTrip(req)
if err != nil {
t.Errorf("RoundTrip: %v", err)
close(resc)
return
}
resc <- res
}()
// Fully consume the request before checking the Write log vs. want.
req, err := ReadRequest(bufio.NewReader(connr))
if err != nil {
t.Fatal(err)
}
io.Copy(io.Discard, req.Body)
// Unblock the transport's roundTrip goroutine.
resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
res, ok := <-resc
if !ok {
return
}
defer res.Body.Close()
want := []string{
"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
"5\r\nnum0\n\r\n",
"5\r\nnum1\n\r\n",
"5\r\nnum2\n\r\n",
"0\r\n\r\n",
}
if !reflect.DeepEqual(lw.writes, want) {
t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
}
}
// Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
func TestTransportFlushesRequestHeader(t *testing.T) {
defer afterTest(t)
gotReq := make(chan struct{})
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
close(gotReq)
}))
defer cst.close()
pr, pw := io.Pipe()
req, err := NewRequest("POST", cst.ts.URL, pr)
if err != nil {
t.Fatal(err)
}
gotRes := make(chan struct{})
go func() {
defer close(gotRes)
res, err := cst.tr.RoundTrip(req)
if err != nil {
t.Error(err)
return
}
res.Body.Close()
}()
select {
case <-gotReq:
pw.Close()
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for handler to get request")
}
<-gotRes
}
// Issue 11745.
func TestTransportPrefersResponseOverWriteError(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
const contentLengthLimit = 1024 * 1024 // 1MB
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.ContentLength >= contentLengthLimit {
w.WriteHeader(StatusBadRequest)
r.Body.Close()
return
}
w.WriteHeader(StatusOK)
}))
defer ts.Close()
c := ts.Client()
fail := 0
count := 100
bigBody := strings.Repeat("a", contentLengthLimit*2)
for i := 0; i < count; i++ {
req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
if err != nil {
t.Fatal(err)
}
resp, err := c.Do(req)
if err != nil {
fail++
t.Logf("%d = %#v", i, err)
if ue, ok := err.(*url.Error); ok {
t.Logf("urlErr = %#v", ue.Err)
if ne, ok := ue.Err.(*net.OpError); ok {
t.Logf("netOpError = %#v", ne.Err)
}
}
} else {
resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("Expected status code 400, got %v", resp.Status)
}
}
}
if fail > 0 {
t.Errorf("Failed %v out of %v\n", fail, count)
}
}
func TestTransportAutomaticHTTP2(t *testing.T) {
testTransportAutoHTTP(t, &Transport{}, true)
}
func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
ForceAttemptHTTP2: true,
TLSClientConfig: new(tls.Config),
}, true)
}
// golang.org/issue/14391: also check DefaultTransport
func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
}
func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
}, false)
}
func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
TLSClientConfig: new(tls.Config),
}, false)
}
func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
ExpectContinueTimeout: 1 * time.Second,
}, true)
}
func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
var d net.Dialer
testTransportAutoHTTP(t, &Transport{
Dial: d.Dial,
}, false)
}
func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
var d net.Dialer
testTransportAutoHTTP(t, &Transport{
DialContext: d.DialContext,
}, false)
}
func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
DialTLS: func(network, addr string) (net.Conn, error) {
panic("unused")
},
}, false)
}
func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
CondSkipHTTP2(t)
_, err := tr.RoundTrip(new(Request))
if err == nil {
t.Error("expected error from RoundTrip")
}
if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
}
}
// Issue 13633: there was a race where we returned bodyless responses
// to callers before recycling the persistent connection, which meant
// a client doing two subsequent requests could end up on different
// connections. It's somewhat harmless but enough tests assume it's
// not true in order to test other things that it's worth fixing.
// Plus it's nice to be consistent and not have timing-dependent
// behavior.
func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
defer afterTest(t)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("X-Addr", r.RemoteAddr)
// Empty response body.
}))
defer cst.close()
n := 100
if testing.Short() {
n = 10
}
var firstAddr string
for i := 0; i < n; i++ {
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
log.Fatal(err)
}
addr := res.Header.Get("X-Addr")
if i == 0 {
firstAddr = addr
} else if addr != firstAddr {
t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
}
res.Body.Close()
}
}
// Issue 13839
func TestNoCrashReturningTransportAltConn(t *testing.T) {
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
t.Fatal(err)
}
ln := newLocalListener(t)
defer ln.Close()
var wg sync.WaitGroup
SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
defer SetPendingDialHooks(nil, nil)
testDone := make(chan struct{})
defer close(testDone)
go func() {
tln := tls.NewListener(ln, &tls.Config{
NextProtos: []string{"foo"},
Certificates: []tls.Certificate{cert},
})
sc, err := tln.Accept()
if err != nil {
t.Error(err)
return
}
if err := sc.(*tls.Conn).Handshake(); err != nil {
t.Error(err)
return
}
<-testDone
sc.Close()
}()
addr := ln.Addr().String()
req, _ := NewRequest("GET", "https://fake.tld/", nil)
cancel := make(chan struct{})
req.Cancel = cancel
doReturned := make(chan bool, 1)
madeRoundTripper := make(chan bool, 1)
tr := &Transport{
DisableKeepAlives: true,
TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
"foo": func(authority string, c *tls.Conn) RoundTripper {
madeRoundTripper <- true
return funcRoundTripper(func() {
t.Error("foo RoundTripper should not be called")
})
},
},
Dial: func(_, _ string) (net.Conn, error) {
panic("shouldn't be called")
},
DialTLS: func(_, _ string) (net.Conn, error) {
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"foo"},
})
if err != nil {
return nil, err
}
if err := tc.Handshake(); err != nil {
return nil, err
}
close(cancel)
<-doReturned
return tc, nil
},
}
c := &Client{Transport: tr}
_, err = c.Do(req)
if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
}
doReturned <- true
<-madeRoundTripper
wg.Wait()
}
func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
testTransportReuseConnection_Gzip(t, true)
}
func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
testTransportReuseConnection_Gzip(t, false)
}
// Make sure we re-use underlying TCP connection for gzipped responses too.
func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
setParallel(t)
defer afterTest(t)
addr := make(chan string, 2)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
addr <- r.RemoteAddr
w.Header().Set("Content-Encoding", "gzip")
if chunked {
w.(Flusher).Flush()
}
w.Write(rgz) // arbitrary gzip response
}))
defer ts.Close()
c := ts.Client()
for i := 0; i < 2; i++ {
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, len(rgz))
if n, err := io.ReadFull(res.Body, buf); err != nil {
t.Errorf("%d. ReadFull = %v, %v", i, n, err)
}
// Note: no res.Body.Close call. It should work without it,
// since the flate.Reader's internal buffering will hit EOF
// and that should be sufficient.
}
a1, a2 := <-addr, <-addr
if a1 != a2 {
t.Fatalf("didn't reuse connection")
}
}
func TestTransportResponseHeaderLength(t *testing.T) {
setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/long" {
w.Header().Set("Long", strings.Repeat("a", 1<<20))
}
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
if res, err := c.Get(ts.URL); err != nil {
t.Fatal(err)
} else {
res.Body.Close()
}
res, err := c.Get(ts.URL + "/long")
if err == nil {
defer res.Body.Close()
var n int64
for k, vv := range res.Header {
for _, v := range vv {
n += int64(len(k)) + int64(len(v))
}
}
t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
}
if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
t.Errorf("got error: %v; want %q", err, want)
}
}
func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) }
func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) }
// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) }
func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) }
func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
defer afterTest(t)
const resBody = "some body"
gotWroteReqEvent := make(chan struct{}, 500)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method == "GET" {
// Do nothing for the second request.
return
}
if _, err := io.ReadAll(r.Body); err != nil {
t.Error(err)
}
if !noHooks {
select {
case <-gotWroteReqEvent:
case <-time.After(5 * time.Second):
t.Error("timeout waiting for WroteRequest event")
}
}
io.WriteString(w, resBody)
}))
defer cst.close()
cst.tr.ExpectContinueTimeout = 1 * time.Second
var mu sync.Mutex // guards buf
var buf bytes.Buffer
logf := func(format string, args ...interface{}) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&buf, format, args...)
buf.WriteByte('\n')
}
addrStr := cst.ts.Listener.Addr().String()
ip, port, err := net.SplitHostPort(addrStr)
if err != nil {
t.Fatal(err)
}
// Install a fake DNS server.
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
if host != "dns-is-faked.golang" {
t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
return nil, nil
}
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
})
body := "some body"
req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
trace := &httptrace.ClientTrace{
GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
GotFirstResponseByte: func() { logf("first response byte") },
PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
ConnectDone: func(network, addr string, err error) {
if err != nil {
t.Errorf("ConnectDone: %v", err)
}
logf("ConnectDone: connected to %s %s = %v", network, addr, err)
},
WroteHeaderField: func(key string, value []string) {
logf("WroteHeaderField: %s: %v", key, value)
},
WroteHeaders: func() {
logf("WroteHeaders")
},
Wait100Continue: func() { logf("Wait100Continue") },
Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) {
logf("WroteRequest: %+v", e)
gotWroteReqEvent <- struct{}{}
},
}
if h2 {
trace.TLSHandshakeStart = func() { logf("tls handshake start") }
trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
}
}
if noHooks {
// zero out all func pointers, trying to get some path to crash
*trace = httptrace.ClientTrace{}
}
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
req.Header.Set("Expect", "100-continue")
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
logf("got roundtrip.response")
slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
logf("consumed body")
if string(slurp) != resBody || res.StatusCode != 200 {
t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
}
res.Body.Close()
if noHooks {
// Done at this point. Just testing a full HTTP
// requests can happen with a trace pointing to a zero
// ClientTrace, full of nil func pointers.
return
}
mu.Lock()
got := buf.String()
mu.Unlock()
wantOnce := func(sub string) {
if strings.Count(got, sub) != 1 {
t.Errorf("expected substring %q exactly once in output.", sub)
}
}
wantOnceOrMore := func(sub string) {
if strings.Count(got, sub) == 0 {
t.Errorf("expected substring %q at least once in output.", sub)
}
}
wantOnce("Getting conn for dns-is-faked.golang:" + port)
wantOnce("DNS start: {Host:dns-is-faked.golang}")
wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
wantOnce("got conn: {")
wantOnceOrMore("Connecting to tcp " + addrStr)
wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
wantOnce("Reused:false WasIdle:false IdleTime:0s")
wantOnce("first response byte")
if h2 {
wantOnce("tls handshake start")
wantOnce("tls handshake done")
} else {
wantOnce("PutIdleConn = <nil>")
wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
// WroteHeaderField hook is not yet implemented in h2.)
wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
}
wantOnce("WroteHeaders")
wantOnce("Wait100Continue")
wantOnce("Got100Continue")
wantOnce("WroteRequest: {Err:<nil>}")
if strings.Contains(got, " to udp ") {
t.Errorf("should not see UDP (DNS) connections")
}
if t.Failed() {
t.Errorf("Output:\n%s", got)
}
// And do a second request:
req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
res, err = cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 200 {
t.Fatal(res.Status)
}
res.Body.Close()
mu.Lock()
got = buf.String()
mu.Unlock()
sub := "Getting conn for dns-is-faked.golang:"
if gotn, want := strings.Count(got, sub), 2; gotn != want {
t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
}
}
func TestTransportEventTraceTLSVerify(t *testing.T) {
var mu sync.Mutex
var buf bytes.Buffer
logf := func(format string, args ...interface{}) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&buf, format, args...)
buf.WriteByte('\n')
}
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
t.Error("Unexpected request")
}))
defer ts.Close()
ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
logf("%s", p)
return len(p), nil
}), "", 0)
certpool := x509.NewCertPool()
certpool.AddCert(ts.Certificate())
c := &Client{Transport: &Transport{
TLSClientConfig: &tls.Config{
ServerName: "dns-is-faked.golang",
RootCAs: certpool,
},
}}
trace := &httptrace.ClientTrace{
TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
TLSHandshakeDone: func(s tls.ConnectionState, err error) {
logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
},
}
req, _ := NewRequest("GET", ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
_, err := c.Do(req)
if err == nil {
t.Error("Expected request to fail TLS verification")
}
mu.Lock()
got := buf.String()
mu.Unlock()
wantOnce := func(sub string) {
if strings.Count(got, sub) != 1 {
t.Errorf("expected substring %q exactly once in output.", sub)
}
}
wantOnce("TLSHandshakeStart")
wantOnce("TLSHandshakeDone")
wantOnce("err = x509: certificate is valid for example.com")
if t.Failed() {
t.Errorf("Output:\n%s", got)
}
}
var (
isDNSHijackedOnce sync.Once
isDNSHijacked bool
)
func skipIfDNSHijacked(t *testing.T) {
// Skip this test if the user is using a shady/ISP
// DNS server hijacking queries.
// See issues 16732, 16716.
isDNSHijackedOnce.Do(func() {
addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
isDNSHijacked = len(addrs) != 0
})
if isDNSHijacked {
t.Skip("skipping; test requires non-hijacking DNS server")
}
}
func TestTransportEventTraceRealDNS(t *testing.T) {
skipIfDNSHijacked(t)
defer afterTest(t)
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
var mu sync.Mutex // guards buf
var buf bytes.Buffer
logf := func(format string, args ...interface{}) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&buf, format, args...)
buf.WriteByte('\n')
}
req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
trace := &httptrace.ClientTrace{
DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
}
req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
resp, err := c.Do(req)
if err == nil {
resp.Body.Close()
t.Fatal("expected error during DNS lookup")
}
mu.Lock()
got := buf.String()
mu.Unlock()
wantSub := func(sub string) {
if !strings.Contains(got, sub) {
t.Errorf("expected substring %q in output.", sub)
}
}
wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
wantSub("DNSDone: {Addrs:[] Err:")
if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
t.Errorf("should not see Connect events")
}
if t.Failed() {
t.Errorf("Output:\n%s", got)
}
}
// Issue 14353: port can only contain digits.
func TestTransportRejectsAlphaPort(t *testing.T) {
res, err := Get("http://dummy.tld:123foo/bar")
if err == nil {
res.Body.Close()
t.Fatal("unexpected success")
}
ue, ok := err.(*url.Error)
if !ok {
t.Fatalf("got %#v; want *url.Error", err)
}
got := ue.Err.Error()
want := `invalid port ":123foo" after host`
if got != want {
t.Errorf("got error %q; want %q", got, want)
}
}
// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1
// connections. The http2 test is done in TestTransportEventTrace_h2
func TestTLSHandshakeTrace(t *testing.T) {
defer afterTest(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
var mu sync.Mutex
var start, done bool
trace := &httptrace.ClientTrace{
TLSHandshakeStart: func() {
mu.Lock()
defer mu.Unlock()
start = true
},
TLSHandshakeDone: func(s tls.ConnectionState, err error) {
mu.Lock()
defer mu.Unlock()
done = true
if err != nil {
t.Fatal("Expected error to be nil but was:", err)
}
},
}
c := ts.Client()
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal("Unable to construct test request:", err)
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
r, err := c.Do(req)
if err != nil {
t.Fatal("Unexpected error making request:", err)
}
r.Body.Close()
mu.Lock()
defer mu.Unlock()
if !start {
t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
}
if !done {
t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
}
}
func TestTransportMaxIdleConns(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// No body for convenience.
}))
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
tr.MaxIdleConns = 4
ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
})
hitHost := func(n int) {
req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
req = req.WithContext(ctx)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
for i := 0; i < 4; i++ {
hitHost(i)
}
want := []string{
"|http|host-0.dns-is-faked.golang:" + port,
"|http|host-1.dns-is-faked.golang:" + port,
"|http|host-2.dns-is-faked.golang:" + port,
"|http|host-3.dns-is-faked.golang:" + port,
}
if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
}
// Now hitting the 5th host should kick out the first host:
hitHost(4)
want = []string{
"|http|host-1.dns-is-faked.golang:" + port,
"|http|host-2.dns-is-faked.golang:" + port,
"|http|host-3.dns-is-faked.golang:" + port,
"|http|host-4.dns-is-faked.golang:" + port,
}
if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
}
}
func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) }
func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) }
func testTransportIdleConnTimeout(t *testing.T, h2 bool) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
const timeout = 1 * time.Second
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
// No body for convenience.
}))
defer cst.close()
tr := cst.tr
tr.IdleConnTimeout = timeout
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
idleConns := func() []string {
if h2 {
return tr.IdleConnStrsForTesting_h2()
} else {
return tr.IdleConnStrsForTesting()
}
}
var conn string
doReq := func(n int) {
req, _ := NewRequest("GET", cst.ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
PutIdleConn: func(err error) {
if err != nil {
t.Errorf("failed to keep idle conn: %v", err)
}
},
}))
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
conns := idleConns()
if len(conns) != 1 {
t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
}
if conn == "" {
conn = conns[0]
}
if conn != conns[0] {
t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n)
}
}
for i := 0; i < 3; i++ {
doReq(i)
time.Sleep(timeout / 2)
}
time.Sleep(timeout * 3 / 2)
if got := idleConns(); len(got) != 0 {
t.Errorf("idle conns = %q; want none", got)
}
}
// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
// HTTP/2 connection was established but its caller no longer
// wanted it. (Assuming the connection cache was enabled, which it is
// by default)
//
// This test reproduced the crash by setting the IdleConnTimeout low
// (to make the test reasonable) and then making a request which is
// canceled by the DialTLS hook, which then also waits to return the
// real connection until after the RoundTrip saw the error. Then we
// know the successful tls.Dial from DialTLS will need to go into the
// idle pool. Then we give it a of time to explode.
func TestIdleConnH2Crash(t *testing.T) {
setParallel(t)
cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
// nothing
}))
defer cst.close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sawDoErr := make(chan bool, 1)
testDone := make(chan struct{})
defer close(testDone)
cst.tr.IdleConnTimeout = 5 * time.Millisecond
cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := tls.Dial(network, addr, &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
})
if err != nil {
t.Error(err)
return nil, err
}
if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
c.Close()
return nil, errors.New("bogus")
}
cancel()
failTimer := time.NewTimer(5 * time.Second)
defer failTimer.Stop()
select {
case <-sawDoErr:
case <-testDone:
case <-failTimer.C:
t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail")
}
return c, nil
}
req, _ := NewRequest("GET", cst.ts.URL, nil)
req = req.WithContext(ctx)
res, err := cst.c.Do(req)
if err == nil {
res.Body.Close()
t.Fatal("unexpected success")
}
sawDoErr <- true
// Wait for the explosion.
time.Sleep(cst.tr.IdleConnTimeout * 10)
}
type funcConn struct {
net.Conn
read func([]byte) (int, error)
write func([]byte) (int, error)
}
func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
func (c funcConn) Close() error { return nil }
// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
// back to the caller.
func TestTransportReturnsPeekError(t *testing.T) {
errValue := errors.New("specific error value")
wrote := make(chan struct{})
var wroteOnce sync.Once
tr := &Transport{
Dial: func(network, addr string) (net.Conn, error) {
c := funcConn{
read: func([]byte) (int, error) {
<-wrote
return 0, errValue
},
write: func(p []byte) (int, error) {
wroteOnce.Do(func() { close(wrote) })
return len(p), nil
},
}
return c, nil
},
}
_, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
if err != errValue {
t.Errorf("error = %#v; want %v", err, errValue)
}
}
// Issue 13835: international domain names should work
func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) }
func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) }
func testTransportIDNA(t *testing.T, h2 bool) {
defer afterTest(t)
const uniDomain = "гофер.го"
const punyDomain = "xn--c1ae0ajs.xn--c1aw"
var port string
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
want := punyDomain + ":" + port
if r.Host != want {
t.Errorf("Host header = %q; want %q", r.Host, want)
}
if h2 {
if r.TLS == nil {
t.Errorf("r.TLS == nil")
} else if r.TLS.ServerName != punyDomain {
t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
}
}
w.Header().Set("Hit-Handler", "1")
}))
defer cst.close()
ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
// Install a fake DNS server.
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
if host != punyDomain {
t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
return nil, nil
}
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
})
req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
trace := &httptrace.ClientTrace{
GetConn: func(hostPort string) {
want := net.JoinHostPort(punyDomain, port)
if hostPort != want {
t.Errorf("getting conn for %q; want %q", hostPort, want)
}
},
DNSStart: func(e httptrace.DNSStartInfo) {
if e.Host != punyDomain {
t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
}
},
}
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
res, err := cst.tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
if res.Header.Get("Hit-Handler") != "1" {
out, err := httputil.DumpResponse(res, true)
if err != nil {
t.Fatal(err)
}
t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
}
}
// Issue 13290: send User-Agent in proxy CONNECT
func TestTransportProxyConnectHeader(t *testing.T) {
defer afterTest(t)
reqc := make(chan *Request, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "CONNECT" {
t.Errorf("method = %q; want CONNECT", r.Method)
}
reqc <- r
c, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Errorf("Hijack: %v", err)
return
}
c.Close()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
return url.Parse(ts.URL)
}
c.Transport.(*Transport).ProxyConnectHeader = Header{
"User-Agent": {"foo"},
"Other": {"bar"},
}
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
if err == nil {
res.Body.Close()
t.Errorf("unexpected success")
}
select {
case <-time.After(3 * time.Second):
t.Fatal("timeout")
case r := <-reqc:
if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
}
if got, want := r.Header.Get("Other"), "bar"; got != want {
t.Errorf("CONNECT request Other = %q; want %q", got, want)
}
}
}
func TestTransportProxyGetConnectHeader(t *testing.T) {
defer afterTest(t)
reqc := make(chan *Request, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "CONNECT" {
t.Errorf("method = %q; want CONNECT", r.Method)
}
reqc <- r
c, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Errorf("Hijack: %v", err)
return
}
c.Close()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
return url.Parse(ts.URL)
}
// These should be ignored:
c.Transport.(*Transport).ProxyConnectHeader = Header{
"User-Agent": {"foo"},
"Other": {"bar"},
}
c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
return Header{
"User-Agent": {"foo2"},
"Other": {"bar2"},
}, nil
}
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
if err == nil {
res.Body.Close()
t.Errorf("unexpected success")
}
select {
case <-time.After(3 * time.Second):
t.Fatal("timeout")
case r := <-reqc:
if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
}
if got, want := r.Header.Get("Other"), "bar2"; got != want {
t.Errorf("CONNECT request Other = %q; want %q", got, want)
}
}
}
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()
func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
fn()
return nil, errFakeRoundTrip
}
func wantBody(res *Response, err error, want string) error {
if err != nil {
return err
}
slurp, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("error reading body: %v", err)
}
if string(slurp) != want {
return fmt.Errorf("body = %q; want %q", slurp, want)
}
if err := res.Body.Close(); err != nil {
return fmt.Errorf("body Close = %v", err)
}
return nil
}
func newLocalListener(t *testing.T) net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
ln, err = net.Listen("tcp6", "[::1]:0")
}
if err != nil {
t.Fatal(err)
}
return ln
}
type countCloseReader struct {
n *int
io.Reader
}
func (cr countCloseReader) Close() error {
(*cr.n)++
return nil
}
// rgz is a gzip quine that uncompresses to itself.
var rgz = []byte{
0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
0x00, 0x00,
}
// Ensure that a missing status doesn't make the server panic
// See Issue https://golang.org/issues/21701
func TestMissingStatusNoPanic(t *testing.T) {
t.Parallel()
const want = "unknown status code"
ln := newLocalListener(t)
addr := ln.Addr().String()
done := make(chan bool)
fullAddrURL := fmt.Sprintf("http://%s", addr)
raw := "HTTP/1.1 400\r\n" +
"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
"Content-Type: text/html; charset=utf-8\r\n" +
"Content-Length: 10\r\n" +
"Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
"Vary: Accept-Encoding\r\n\r\n" +
"Aloha Olaa"
go func() {
defer close(done)
conn, _ := ln.Accept()
if conn != nil {
io.WriteString(conn, raw)
io.ReadAll(conn)
conn.Close()
}
}()
proxyURL, err := url.Parse(fullAddrURL)
if err != nil {
t.Fatalf("proxyURL: %v", err)
}
tr := &Transport{Proxy: ProxyURL(proxyURL)}
req, _ := NewRequest("GET", "https://golang.org/", nil)
res, err, panicked := doFetchCheckPanic(tr, req)
if panicked {
t.Error("panicked, expecting an error")
}
if res != nil && res.Body != nil {
io.Copy(io.Discard, res.Body)
res.Body.Close()
}
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("got=%v want=%q", err, want)
}
ln.Close()
<-done
}
func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
defer func() {
if r := recover(); r != nil {
panicked = true
}
}()
res, err = tr.RoundTrip(req)
return
}
// Issue 22330: do not allow the response body to be read when the status code
// forbids a response body.
func TestNoBodyOnChunked304Response(t *testing.T) {
defer afterTest(t)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, buf, _ := w.(Hijacker).Hijack()
buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
buf.Flush()
conn.Close()
}))
defer cst.close()
// Our test server above is sending back bogus data after the
// response (the "0\r\n\r\n" part), which causes the Transport
// code to log spam. Disable keep-alives so we never even try
// to reuse the connection.
cst.tr.DisableKeepAlives = true
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
if res.Body != NoBody {
t.Errorf("Unexpected body on 304 response")
}
}
type funcWriter func([]byte) (int, error)
func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
type doneContext struct {
context.Context
err error
}
func (doneContext) Done() <-chan struct{} {
c := make(chan struct{})
close(c)
return c
}
func (d doneContext) Err() error { return d.err }
// Issue 25852: Transport should check whether Context is done early.
func TestTransportCheckContextDoneEarly(t *testing.T) {
tr := &Transport{}
req, _ := NewRequest("GET", "http://fake.example/", nil)
wantErr := errors.New("some error")
req = req.WithContext(doneContext{context.Background(), wantErr})
_, err := tr.RoundTrip(req)
if err != wantErr {
t.Errorf("error = %v; want %v", err, wantErr)
}
}
// Issue 23399: verify that if a client request times out, the Transport's
// conn is closed so that it's not reused.
//
// This is the test variant that times out before the server replies with
// any response headers.
func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
setParallel(t)
defer afterTest(t)
inHandler := make(chan net.Conn, 1)
handlerReadReturned := make(chan bool, 1)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
inHandler <- conn
n, err := conn.Read([]byte{0})
if n != 0 || err != io.EOF {
t.Errorf("unexpected Read result: %v, %v", n, err)
}
handlerReadReturned <- true
}))
defer cst.close()
const timeout = 50 * time.Millisecond
cst.c.Timeout = timeout
_, err := cst.c.Get(cst.ts.URL)
if err == nil {
t.Fatal("unexpected Get succeess")
}
select {
case c := <-inHandler:
select {
case <-handlerReadReturned:
// Success.
return
case <-time.After(5 * time.Second):
t.Error("Handler's conn.Read seems to be stuck in Read")
c.Close() // close it to unblock Handler
}
case <-time.After(timeout * 10):
// If we didn't get into the Handler in 50ms, that probably means
// the builder was just slow and the Get failed in that time
// but never made it to the server. That's fine. We'll usually
// test the part above on faster machines.
t.Skip("skipping test on slow builder")
}
}
// Issue 23399: verify that if a client request times out, the Transport's
// conn is closed so that it's not reused.
//
// This is the test variant that has the server send response headers
// first, and time out during the write of the response body.
func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
setParallel(t)
defer afterTest(t)
inHandler := make(chan net.Conn, 1)
handlerResult := make(chan error, 1)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "100")
w.(Flusher).Flush()
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
conn.Write([]byte("foo"))
inHandler <- conn
n, err := conn.Read([]byte{0})
// The error should be io.EOF or "read tcp
// 127.0.0.1:35827->127.0.0.1:40290: read: connection
// reset by peer" depending on timing. Really we just
// care that it returns at all. But if it returns with
// data, that's weird.
if n != 0 || err == nil {
handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err)
return
}
handlerResult <- nil
}))
defer cst.close()
// Set Timeout to something very long but non-zero to exercise
// the codepaths that check for it. But rather than wait for it to fire
// (which would make the test slow), we send on the req.Cancel channel instead,
// which happens to exercise the same code paths.
cst.c.Timeout = time.Minute // just to be non-zero, not to hit it.
req, _ := NewRequest("GET", cst.ts.URL, nil)
cancel := make(chan struct{})
req.Cancel = cancel
res, err := cst.c.Do(req)
if err != nil {
select {
case <-inHandler:
t.Fatalf("Get error: %v", err)
default:
// Failed before entering handler. Ignore result.
t.Skip("skipping test on slow builder")
}
}
close(cancel)
got, err := io.ReadAll(res.Body)
if err == nil {
t.Fatalf("unexpected success; read %q, nil", got)
}
select {
case c := <-inHandler:
select {
case err := <-handlerResult:
if err != nil {
t.Errorf("handler: %v", err)
}
return
case <-time.After(5 * time.Second):
t.Error("Handler's conn.Read seems to be stuck in Read")
c.Close() // close it to unblock Handler
}
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}
func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
setParallel(t)
defer afterTest(t)
done := make(chan struct{})
defer close(done)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer conn.Close()
io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
bs := bufio.NewScanner(conn)
bs.Scan()
fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
<-done
}))
defer cst.close()
req, _ := NewRequest("GET", cst.ts.URL, nil)
req.Header.Set("Upgrade", "foo")
req.Header.Set("Connection", "upgrade")
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 101 {
t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
}
rwc, ok := res.Body.(io.ReadWriteCloser)
if !ok {
t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
}
defer rwc.Close()
bs := bufio.NewScanner(rwc)
if !bs.Scan() {
t.Fatalf("expected readable input")
}
if got, want := bs.Text(), "Some buffered data"; got != want {
t.Errorf("read %q; want %q", got, want)
}
io.WriteString(rwc, "echo\n")
if !bs.Scan() {
t.Fatalf("expected another line")
}
if got, want := bs.Text(), "ECHO"; got != want {
t.Errorf("read %q; want %q", got, want)
}
}
func TestTransportCONNECTBidi(t *testing.T) {
defer afterTest(t)
const target = "backend:443"
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "CONNECT" {
t.Errorf("unexpected method %q", r.Method)
w.WriteHeader(500)
return
}
if r.RequestURI != target {
t.Errorf("unexpected CONNECT target %q", r.RequestURI)
w.WriteHeader(500)
return
}
nc, brw, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer nc.Close()
nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
// Switch to a little protocol that capitalize its input lines:
for {
line, err := brw.ReadString('\n')
if err != nil {
if err != io.EOF {
t.Error(err)
}
return
}
io.WriteString(brw, strings.ToUpper(line))
brw.Flush()
}
}))
defer cst.close()
pr, pw := io.Pipe()
defer pw.Close()
req, err := NewRequest("CONNECT", cst.ts.URL, pr)
if err != nil {
t.Fatal(err)
}
req.URL.Opaque = target
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
if res.StatusCode != 200 {
t.Fatalf("status code = %d; want 200", res.StatusCode)
}
br := bufio.NewReader(res.Body)
for _, str := range []string{"foo", "bar", "baz"} {
fmt.Fprintf(pw, "%s\n", str)
got, err := br.ReadString('\n')
if err != nil {
t.Fatal(err)
}
got = strings.TrimSpace(got)
want := strings.ToUpper(str)
if got != want {
t.Fatalf("got %q; want %q", got, want)
}
}
}
func TestTransportRequestReplayable(t *testing.T) {
someBody := io.NopCloser(strings.NewReader(""))
tests := []struct {
name string
req *Request
want bool
}{
{
name: "GET",
req: &Request{Method: "GET"},
want: true,
},
{
name: "GET_http.NoBody",
req: &Request{Method: "GET", Body: NoBody},
want: true,
},
{
name: "GET_body",
req: &Request{Method: "GET", Body: someBody},
want: false,
},
{
name: "POST",
req: &Request{Method: "POST"},
want: false,
},
{
name: "POST_idempotency-key",
req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
want: true,
},
{
name: "POST_x-idempotency-key",
req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
want: true,
},
{
name: "POST_body",
req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.req.ExportIsReplayable()
if got != tt.want {
t.Errorf("replyable = %v; want %v", got, tt.want)
}
})
}
}
// testMockTCPConn is a mock TCP connection used to test that
// ReadFrom is called when sending the request body.
type testMockTCPConn struct {
*net.TCPConn
ReadFromCalled bool
}
func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
c.ReadFromCalled = true
return c.TCPConn.ReadFrom(r)
}
func TestTransportRequestWriteRoundTrip(t *testing.T) {
nBytes := int64(1 << 10)
newFileFunc := func() (r io.Reader, done func(), err error) {
f, err := os.CreateTemp("", "net-http-newfilefunc")
if err != nil {
return nil, nil, err
}
// Write some bytes to the file to enable reading.
if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
}
if _, err := f.Seek(0, 0); err != nil {
return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
}
done = func() {
f.Close()
os.Remove(f.Name())
}
return f, done, nil
}
newBufferFunc := func() (io.Reader, func(), error) {
return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
}
cases := []struct {
name string
readerFunc func() (io.Reader, func(), error)
contentLength int64
expectedReadFrom bool
}{
{
name: "file, length",
readerFunc: newFileFunc,
contentLength: nBytes,
expectedReadFrom: true,
},
{
name: "file, no length",
readerFunc: newFileFunc,
},
{
name: "file, negative length",
readerFunc: newFileFunc,
contentLength: -1,
},
{
name: "buffer",
contentLength: nBytes,
readerFunc: newBufferFunc,
},
{
name: "buffer, no length",
readerFunc: newBufferFunc,
},
{
name: "buffer, length -1",
contentLength: -1,
readerFunc: newBufferFunc,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
r, cleanup, err := tc.readerFunc()
if err != nil {
t.Fatal(err)
}
defer cleanup()
tConn := &testMockTCPConn{}
trFunc := func(tr *Transport) {
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
}
tConn.TCPConn = tcpConn
return tConn, nil
}
}
cst := newClientServerTest(
t,
h1Mode,
HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(io.Discard, r.Body)
r.Body.Close()
w.WriteHeader(200)
}),
trFunc,
)
defer cst.close()
req, err := NewRequest("PUT", cst.ts.URL, r)
if err != nil {
t.Fatal(err)
}
req.ContentLength = tc.contentLength
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("status code = %d; want 200", resp.StatusCode)
}
if !tConn.ReadFromCalled && tc.expectedReadFrom {
t.Fatalf("did not call ReadFrom")
}
if tConn.ReadFromCalled && !tc.expectedReadFrom {
t.Fatalf("ReadFrom was unexpectedly invoked")
}
})
}
}
func TestTransportClone(t *testing.T) {
tr := &Transport{
Proxy: func(*Request) (*url.URL, error) { panic("") },
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
Dial: func(network, addr string) (net.Conn, error) { panic("") },
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
TLSClientConfig: new(tls.Config),
TLSHandshakeTimeout: time.Second,
DisableKeepAlives: true,
DisableCompression: true,
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
IdleConnTimeout: time.Second,
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
ProxyConnectHeader: Header{},
GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
MaxResponseHeaderBytes: 1,
ForceAttemptHTTP2: true,
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
},
ReadBufferSize: 1,
WriteBufferSize: 1,
}
tr2 := tr.Clone()
rv := reflect.ValueOf(tr2).Elem()
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
sf := rt.Field(i)
if !token.IsExported(sf.Name) {
continue
}
if rv.Field(i).IsZero() {
t.Errorf("cloned field t2.%s is zero", sf.Name)
}
}
if _, ok := tr2.TLSNextProto["foo"]; !ok {
t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
}
// But test that a nil TLSNextProto is kept nil:
tr = new(Transport)
tr2 = tr.Clone()
if tr2.TLSNextProto != nil {
t.Errorf("Transport.TLSNextProto unexpected non-nil")
}
}
func TestIs408(t *testing.T) {
tests := []struct {
in string
want bool
}{
{"HTTP/1.0 408", true},
{"HTTP/1.1 408", true},
{"HTTP/1.8 408", true},
{"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
{"HTTP/1.1 408 ", true},
{"HTTP/1.1 40", false},
{"http/1.0 408", false},
{"HTTP/1-1 408", false},
}
for _, tt := range tests {
if got := Export_is408Message([]byte(tt.in)); got != tt.want {
t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
}
}
}
func TestTransportIgnores408(t *testing.T) {
// Not parallel. Relies on mutating the log package's global Output.
defer log.SetOutput(log.Writer())
var logout bytes.Buffer
log.SetOutput(&logout)
defer afterTest(t)
const target = "backend:443"
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
nc, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer nc.Close()
nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
}))
defer cst.close()
req, err := NewRequest("GET", cst.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
slurp, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatal(err)
}
if string(slurp) != "ok" {
t.Fatalf("got %q; want ok", slurp)
}
t0 := time.Now()
for i := 0; i < 50; i++ {
time.Sleep(time.Duration(i) * 5 * time.Millisecond)
if cst.tr.IdleConnKeyCountForTesting() == 0 {
if got := logout.String(); got != "" {
t.Fatalf("expected no log output; got: %s", got)
}
return
}
}
t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0))
}
func TestInvalidHeaderResponse(t *testing.T) {
setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, buf, _ := w.(Hijacker).Hijack()
buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
"Content-Type: text/html; charset=utf-8\r\n" +
"Content-Length: 0\r\n" +
"Foo : bar\r\n\r\n"))
buf.Flush()
conn.Close()
}))
defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
if v := res.Header.Get("Foo"); v != "" {
t.Errorf(`unexpected "Foo" header: %q`, v)
}
if v := res.Header.Get("Foo "); v != "bar" {
t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
}
}
type bodyCloser bool
func (bc *bodyCloser) Close() error {
*bc = true
return nil
}
func (bc *bodyCloser) Read(b []byte) (n int, err error) {
return 0, io.EOF
}
// Issue 35015: ensure that Transport closes the body on any error
// with an invalid request, as promised by Client.Do docs.
func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
t.Errorf("Should not have been invoked")
}))
defer cst.Close()
u, _ := url.Parse(cst.URL)
tests := []struct {
name string
req *Request
wantErr string
}{
{
name: "invalid method",
req: &Request{
Method: " ",
URL: u,
},
wantErr: "invalid method",
},
{
name: "nil URL",
req: &Request{
Method: "GET",
},
wantErr: "nil Request.URL",
},
{
name: "invalid header key",
req: &Request{
Method: "GET",
Header: Header{"💡": {"emoji"}},
URL: u,
},
wantErr: "invalid header field name",
},
{
name: "invalid header value",
req: &Request{
Method: "POST",
Header: Header{"key": {"\x19"}},
URL: u,
},
wantErr: "invalid header field value",
},
{
name: "non HTTP(s) scheme",
req: &Request{
Method: "POST",
URL: &url.URL{Scheme: "faux"},
},
wantErr: "unsupported protocol scheme",
},
{
name: "no Host in URL",
req: &Request{
Method: "POST",
URL: &url.URL{Scheme: "http"},
},
wantErr: "no Host",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var bc bodyCloser
req := tt.req
req.Body = &bc
_, err := DefaultClient.Do(tt.req)
if err == nil {
t.Fatal("Expected an error")
}
if !bc {
t.Fatal("Expected body to have been closed")
}
if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w)
}
})
}
}
// breakableConn is a net.Conn wrapper with a Write method
// that will fail when its brokenState is true.
type breakableConn struct {
net.Conn
*brokenState
}
type brokenState struct {
sync.Mutex
broken bool
}
func (w *breakableConn) Write(b []byte) (n int, err error) {
w.Lock()
defer w.Unlock()
if w.broken {
return 0, errors.New("some write error")
}
return w.Conn.Write(b)
}
// Issue 34978: don't cache a broken HTTP/2 connection
func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
defer cst.close()
var brokenState brokenState
const numReqs = 5
var numDials, gotConns uint32 // atomic
cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
atomic.AddUint32(&numDials, 1)
c, err := net.Dial(netw, addr)
if err != nil {
t.Errorf("unexpected Dial error: %v", err)
return nil, err
}
return &breakableConn{c, &brokenState}, err
}
for i := 1; i <= numReqs; i++ {
brokenState.Lock()
brokenState.broken = false
brokenState.Unlock()
// doBreak controls whether we break the TCP connection after the TLS
// handshake (before the HTTP/2 handshake). We test a few failures
// in a row followed by a final success.
doBreak := i != numReqs
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
atomic.AddUint32(&gotConns, 1)
},
TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
brokenState.Lock()
defer brokenState.Unlock()
if doBreak {
brokenState.broken = true
}
},
})
req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
_, err = cst.c.Do(req)
if doBreak != (err != nil) {
t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
}
}
if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
t.Errorf("GotConn calls = %v; want %v", got, want)
}
if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
t.Errorf("Dials = %v; want %v", got, want)
}
}
// Issue 34941
// When the client has too many concurrent requests on a single connection,
// http.http2noCachedConnError is reported on multiple requests. There should
// only be one decrement regardless of the number of failures.
func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
defer afterTest(t)
CondSkipHTTP2(t)
h := HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("foo"))
if err != nil {
t.Fatalf("Write: %v", err)
}
})
ts := httptest.NewUnstartedServer(h)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
tr.MaxConnsPerHost = 1
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
}
errCh := make(chan error, 300)
doReq := func() {
resp, err := c.Get(ts.URL)
if err != nil {
errCh <- fmt.Errorf("request failed: %v", err)
return
}
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
if err != nil {
errCh <- fmt.Errorf("read body failed: %v", err)
}
}
var wg sync.WaitGroup
for i := 0; i < 300; i++ {
wg.Add(1)
go func() {
defer wg.Done()
doReq()
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Errorf("error occurred: %v", err)
}
}
// Issue 36820
// Test that we use the older backward compatible cancellation protocol
// when a RoundTripper is registered via RegisterProtocol.
func TestAltProtoCancellation(t *testing.T) {
defer afterTest(t)
tr := &Transport{}
c := &Client{
Transport: tr,
Timeout: time.Millisecond,
}
tr.RegisterProtocol("timeout", timeoutProto{})
_, err := c.Get("timeout://bar.com/path")
if err == nil {
t.Error("request unexpectedly succeeded")
} else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
}
}
var timeoutProtoErr = errors.New("canceled as expected")
type timeoutProto struct{}
func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
select {
case <-req.Cancel:
return nil, timeoutProtoErr
case <-time.After(5 * time.Second):
return nil, errors.New("request was not canceled")
}
}
type roundTripFunc func(r *Request) (*Response, error)
func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
// Issue 32441: body is not reset after ErrSkipAltProtocol
func TestIssue32441(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
t.Error("body length is zero")
}
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
// Draining body to trigger failure condition on actual request to server.
if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
t.Error("body length is zero during round trip")
}
return nil, ErrSkipAltProtocol
}))
if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
t.Error(err)
}
}
// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
func TestTransportRejectsSignInContentLength(t *testing.T) {
cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "+3")
w.Write([]byte("abc"))
}))
defer cst.Close()
c := cst.Client()
res, err := c.Get(cst.URL)
if err == nil || res != nil {
t.Fatal("Expected a non-nil error and a nil http.Response")
}
if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
}
}
// dumpConn is a net.Conn which writes to Writer and reads from Reader
type dumpConn struct {
io.Writer
io.Reader
}
func (c *dumpConn) Close() error { return nil }
func (c *dumpConn) LocalAddr() net.Addr { return nil }
func (c *dumpConn) RemoteAddr() net.Addr { return nil }
func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
// delegateReader is a reader that delegates to another reader,
// once it arrives on a channel.
type delegateReader struct {
c chan io.Reader
r io.Reader // nil until received from c
}
func (r *delegateReader) Read(p []byte) (int, error) {
if r.r == nil {
var ok bool
if r.r, ok = <-r.c; !ok {
return 0, errors.New("delegate closed")
}
}
return r.r.Read(p)
}
func testTransportRace(req *Request) {
save := req.Body
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
dr := &delegateReader{c: make(chan io.Reader)}
t := &Transport{
Dial: func(net, addr string) (net.Conn, error) {
return &dumpConn{pw, dr}, nil
},
}
defer t.CloseIdleConnections()
quitReadCh := make(chan struct{})
// Wait for the request before replying with a dummy response:
go func() {
defer close(quitReadCh)
req, err := ReadRequest(bufio.NewReader(pr))
if err == nil {
// Ensure all the body is read; otherwise
// we'll get a partial dump.
io.Copy(io.Discard, req.Body)
req.Body.Close()
}
select {
case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
case quitReadCh <- struct{}{}:
// Ensure delegate is closed so Read doesn't block forever.
close(dr.c)
}
}()
t.RoundTrip(req)
// Ensure the reader returns before we reset req.Body to prevent
// a data race on req.Body.
pw.Close()
<-quitReadCh
req.Body = save
}
// Issue 37669
// Test that a cancellation doesn't result in a data race due to the writeLoop
// goroutine being left running, if the caller mutates the processed Request
// upon completion.
func TestErrorWriteLoopRace(t *testing.T) {
if testing.Short() {
return
}
t.Parallel()
for i := 0; i < 1000; i++ {
delay := time.Duration(mrand.Intn(5)) * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), delay)
defer cancel()
r := bytes.NewBuffer(make([]byte, 10000))
req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
if err != nil {
t.Fatal(err)
}
testTransportRace(req)
}
}
// Issue 41600
// Test that a new request which uses the connection of an active request
// cannot cause it to be canceled as well.
func TestCancelRequestWhenSharingConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) {
w.Header().Add("Content-Length", "0")
}))
defer ts.Close()
client := ts.Client()
transport := client.Transport.(*Transport)
transport.MaxIdleConns = 1
transport.MaxConnsPerHost = 1
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for ctx.Err() == nil {
reqctx, reqcancel := context.WithCancel(ctx)
go reqcancel()
req, _ := NewRequestWithContext(reqctx, "GET", ts.URL, nil)
res, err := client.Do(req)
if err == nil {
res.Body.Close()
}
}
}()
}
for ctx.Err() == nil {
req, _ := NewRequest("GET", ts.URL, nil)
if res, err := client.Do(req); err != nil {
t.Errorf("unexpected: %p %v", req, err)
break
} else {
res.Body.Close()
}
}
cancel()
wg.Wait()
}