diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index a686fd0de0..404cca0825 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6725,28 +6725,3 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) } } - -// Issue 48642: close accepted connection -func TestServerCloseAccepted(t *testing.T) { - closed := 0 - conn := &rwTestConn{ - closeFunc: func() error { - closed++ - return nil - }, - } - ln := &oneConnListener{conn: conn} - var srv Server - // Use ConnContext to close server after connection is accepted but before it is tracked - srv.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - srv.Close() - return ctx - } - got := srv.Serve(ln) - if got != ErrServerClosed { - t.Errorf("Serve err = %v; want ErrServerClosed", got) - } - if closed != 1 { - t.Errorf("Connection expected to be closed") - } -} diff --git a/src/net/http/server.go b/src/net/http/server.go index 34d6ec828b..d44b0fb256 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1754,10 +1754,10 @@ const ( func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { srv := c.server switch state { + case StateNew: + srv.trackConn(c, true) case StateHijacked, StateClosed: - srv.mu.Lock() - delete(srv.activeConn, c) - srv.mu.Unlock() + srv.trackConn(c, false) } if state > 0xff || state < 0 { panic("internal error") @@ -3068,10 +3068,6 @@ func (srv *Server) Serve(l net.Listener) error { } tempDelay = 0 c := srv.newConn(rw) - if !srv.trackConn(c) { - rw.Close() - return ErrServerClosed - } c.setState(c.rwc, StateNew, runHooks) // before Serve can return go c.serve(connCtx) } @@ -3143,19 +3139,17 @@ func (s *Server) trackListener(ln *net.Listener, add bool) bool { return true } -// trackConn adds a connection to the set of tracked connections. -// It reports whether the server is still up (not Shutdown or Closed). -func (s *Server) trackConn(c *conn) bool { +func (s *Server) trackConn(c *conn, add bool) { s.mu.Lock() defer s.mu.Unlock() - if s.shuttingDown() { - return false - } if s.activeConn == nil { s.activeConn = make(map[*conn]struct{}) } - s.activeConn[c] = struct{}{} - return true + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } } func (s *Server) idleTimeout() time.Duration {