Commit 17e503f7 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: prevent Server reuse after a Shutdown

Fixes #20239

Change-Id: Icb021daad82e6905f536e4ef09ab219500b08167
Reviewed-on: https://go-review.googlesource.com/81778
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarIan Lance Taylor <iant@golang.org>
parent 4ba55273
......@@ -5980,6 +5980,21 @@ func TestServerCloseListenerOnce(t *testing.T) {
}
}
// Issue 20239: don't block in Serve if Shutdown is called first.
func TestServerShutdownThenServe(t *testing.T) {
var srv Server
cl := &countCloseListener{Listener: nil}
srv.Shutdown(context.Background())
got := srv.Serve(cl)
if got != ErrServerClosed {
t.Errorf("Serve err = %v; want ErrServerClosed", got)
}
nclose := atomic.LoadInt32(&cl.closes)
if nclose != 1 {
t.Errorf("Close calls = %v; want 1", nclose)
}
}
// Issue 23351: document and test behavior of ServeMux with ports
func TestStripPortFromHost(t *testing.T) {
mux := NewServeMux()
......
......@@ -2541,6 +2541,7 @@ func (s *Server) closeDoneChanLocked() {
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
atomic.StoreInt32(&srv.inShutdown, 1)
srv.mu.Lock()
defer srv.mu.Unlock()
srv.closeDoneChanLocked()
......@@ -2578,9 +2579,11 @@ var shutdownPollInterval = 500 * time.Millisecond
// separately notify such long-lived connections of shutdown and wait
// for them to close, if desired. See RegisterOnShutdown for a way to
// register shutdown notification functions.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
atomic.AddInt32(&srv.inShutdown, 1)
defer atomic.AddInt32(&srv.inShutdown, -1)
atomic.StoreInt32(&srv.inShutdown, 1)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
......@@ -2727,6 +2730,9 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
// If srv.Addr is blank, ":http" is used.
// ListenAndServe always returns a non-nil error.
func (srv *Server) ListenAndServe() error {
if srv.shuttingDown() {
return ErrServerClosed
}
addr := srv.Addr
if addr == "" {
addr = ":http"
......@@ -2775,8 +2781,8 @@ var ErrServerClosed = errors.New("http: Server closed")
// srv.TLSConfig is non-nil and doesn't include the string "h2" in
// Config.NextProtos, HTTP/2 support is not enabled.
//
// Serve always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
// Serve always returns a non-nil error and closes l.
// After Shutdown or Close, the returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
if fn := testHookServerServe; fn != nil {
fn(srv, l) // call hook with unwrapped listener
......@@ -2785,15 +2791,19 @@ func (srv *Server) Serve(l net.Listener) error {
l = &onceCloseListener{Listener: l}
defer l.Close()
var tempDelay time.Duration // how long to sleep on accept failure
if err := srv.setupHTTP2_Serve(); err != nil {
return err
}
srv.trackListener(&l, true)
serveDone := make(chan struct{})
defer close(serveDone)
if !srv.trackListener(&l, true) {
return ErrServerClosed
}
defer srv.trackListener(&l, false)
var tempDelay time.Duration // how long to sleep on accept failure
baseCtx := context.Background() // base is always background, per Issue 16220
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
for {
......@@ -2877,13 +2887,18 @@ func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
func (s *Server) trackListener(ln *net.Listener, add bool) {
//
// It reports whether the server is still up (not Shutdown or Closed).
func (s *Server) trackListener(ln *net.Listener, add bool) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
s.listeners = make(map[*net.Listener]struct{})
}
if add {
if s.shuttingDown() {
return false
}
// If the *Server is being reused after a previous
// Close or Shutdown, reset its doneChan:
if len(s.listeners) == 0 && len(s.activeConn) == 0 {
......@@ -2893,6 +2908,7 @@ func (s *Server) trackListener(ln *net.Listener, add bool) {
} else {
delete(s.listeners, ln)
}
return true
}
func (s *Server) trackConn(c *conn, add bool) {
......@@ -2927,6 +2943,8 @@ func (s *Server) doKeepAlives() bool {
}
func (s *Server) shuttingDown() bool {
// TODO: replace inShutdown with the existing atomicBool type;
// see https://github.com/golang/go/issues/20239#issuecomment-381434582
return atomic.LoadInt32(&s.inShutdown) != 0
}
......@@ -3055,6 +3073,9 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
//
// ListenAndServeTLS always returns a non-nil error.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
if srv.shuttingDown() {
return ErrServerClosed
}
addr := srv.Addr
if addr == "" {
addr = ":https"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment