diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 1baec23d57..6075397a26 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -93,9 +93,6 @@ func NewUnstartedServer(handler http.Handler) *Server { return &Server{ Listener: newLocalListener(), Config: &http.Server{Handler: handler}, - client: &http.Client{ - Transport: &http.Transport{}, - }, } } @@ -104,6 +101,9 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } s.URL = "http://" + s.Listener.Addr().String() s.wrap() s.goServe() @@ -118,6 +118,9 @@ func (s *Server) StartTLS() { if s.URL != "" { panic("Server already started") } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index d97cec5fdd..8ab50cdb0a 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -12,8 +12,48 @@ import ( "testing" ) +type newServerFunc func(http.Handler) *Server + +var newServers = map[string]newServerFunc{ + "NewServer": NewServer, + "NewTLSServer": NewTLSServer, + + // The manual variants of newServer create a Server manually by only filling + // in the exported fields of Server. + "NewServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.Start() + return ts + }, + "NewTLSServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.StartTLS() + return ts + }, +} + func TestServer(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, name := range []string{"NewServer", "NewServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("Server", func(t *testing.T) { testServer(t, newServer) }) + t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) }) + t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) }) + t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) }) + t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) }) + }) + } + for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) }) + t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) }) + }) + } +} + +func testServer(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) defer ts.Close() @@ -32,8 +72,8 @@ func TestServer(t *testing.T) { } // Issue 12781 -func TestGetAfterClose(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testGetAfterClose(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) @@ -58,8 +98,8 @@ func TestGetAfterClose(t *testing.T) { } } -func TestServerCloseBlocking(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testServerCloseBlocking(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) dial := func() net.Conn { @@ -87,9 +127,9 @@ func TestServerCloseBlocking(t *testing.T) { } // Issue 14290 -func TestServerCloseClientConnections(t *testing.T) { +func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) { var s *Server - s = NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.CloseClientConnections() })) defer s.Close() @@ -102,8 +142,8 @@ func TestServerCloseClientConnections(t *testing.T) { // Tests that the Server.Client method works and returns an http.Client that can hit // NewTLSServer without cert warnings. -func TestServerClient(t *testing.T) { - ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testServerClient(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) defer ts.Close() @@ -124,8 +164,8 @@ func TestServerClient(t *testing.T) { // Tests that the Server.Client.Transport interface is implemented // by a *http.Transport. -func TestServerClientTransportType(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testServerClientTransportType(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) defer ts.Close() client := ts.Client() @@ -136,8 +176,8 @@ func TestServerClientTransportType(t *testing.T) { // Tests that the TLS Server.Client.Transport interface is implemented // by a *http.Transport. -func TestTLSServerClientTransportType(t *testing.T) { - ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) defer ts.Close() client := ts.Client()