Commit 53fc330e authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Server.Close & Server.Shutdown for forced & graceful shutdown

Also updates x/net/http2 to git rev 541150 for:

   http2: add support for graceful shutdown of Server
   https://golang.org/cl/32412

   http2: make http2.Server access http1's Server via an interface check
   https://golang.org/cl/32417

Fixes #4674
Fixes #9478

Change-Id: I8021a18dee0ef2fe3946ac1776d2b10d3d429052
Reviewed-on: https://go-review.googlesource.com/32329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 2d4d22af
...@@ -87,6 +87,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { ...@@ -87,6 +87,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
return return
} }
func (t *Transport) IdleConnKeyCountForTesting() int {
t.idleMu.Lock()
defer t.idleMu.Unlock()
return len(t.idleConn)
}
func (t *Transport) IdleConnStrsForTesting() []string { func (t *Transport) IdleConnStrsForTesting() []string {
var ret []string var ret []string
t.idleMu.Lock() t.idleMu.Lock()
......
...@@ -2982,10 +2982,6 @@ func (s *http2Server) maxConcurrentStreams() uint32 { ...@@ -2982,10 +2982,6 @@ func (s *http2Server) maxConcurrentStreams() uint32 {
return http2defaultMaxStreams return http2defaultMaxStreams
} }
// List of funcs for ConfigureServer to run. Both h1 and h2 are guaranteed
// to be non-nil.
var http2configServerFuncs []func(h1 *Server, h2 *http2Server) error
// ConfigureServer adds HTTP/2 support to a net/http Server. // ConfigureServer adds HTTP/2 support to a net/http Server.
// //
// The configuration conf may be nil. // The configuration conf may be nil.
...@@ -3512,6 +3508,11 @@ func (sc *http2serverConn) serve() { ...@@ -3512,6 +3508,11 @@ func (sc *http2serverConn) serve() {
sc.idleTimerCh = sc.idleTimer.C sc.idleTimerCh = sc.idleTimer.C
} }
var gracefulShutdownCh <-chan struct{}
if sc.hs != nil {
gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs)
}
go sc.readFrames() go sc.readFrames()
settingsTimer := time.NewTimer(http2firstSettingsTimeout) settingsTimer := time.NewTimer(http2firstSettingsTimeout)
...@@ -3539,6 +3540,9 @@ func (sc *http2serverConn) serve() { ...@@ -3539,6 +3540,9 @@ func (sc *http2serverConn) serve() {
case <-settingsTimer.C: case <-settingsTimer.C:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
return return
case <-gracefulShutdownCh:
gracefulShutdownCh = nil
sc.goAwayIn(http2ErrCodeNo, 0)
case <-sc.shutdownTimerCh: case <-sc.shutdownTimerCh:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return return
...@@ -3548,6 +3552,10 @@ func (sc *http2serverConn) serve() { ...@@ -3548,6 +3552,10 @@ func (sc *http2serverConn) serve() {
case fn := <-sc.testHookCh: case fn := <-sc.testHookCh:
fn(loopNum) fn(loopNum)
} }
if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
return
}
} }
} }
...@@ -3803,7 +3811,7 @@ func (sc *http2serverConn) scheduleFrameWrite() { ...@@ -3803,7 +3811,7 @@ func (sc *http2serverConn) scheduleFrameWrite() {
sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
continue continue
} }
if !sc.inGoAway { if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
if wr, ok := sc.writeSched.Pop(); ok { if wr, ok := sc.writeSched.Pop(); ok {
sc.startFrameWrite(wr) sc.startFrameWrite(wr)
continue continue
...@@ -3821,14 +3829,23 @@ func (sc *http2serverConn) scheduleFrameWrite() { ...@@ -3821,14 +3829,23 @@ func (sc *http2serverConn) scheduleFrameWrite() {
func (sc *http2serverConn) goAway(code http2ErrCode) { func (sc *http2serverConn) goAway(code http2ErrCode) {
sc.serveG.check() sc.serveG.check()
if sc.inGoAway { var forceCloseIn time.Duration
return
}
if code != http2ErrCodeNo { if code != http2ErrCodeNo {
sc.shutDownIn(250 * time.Millisecond) forceCloseIn = 250 * time.Millisecond
} else { } else {
sc.shutDownIn(1 * time.Second) forceCloseIn = 1 * time.Second
}
sc.goAwayIn(code, forceCloseIn)
}
func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) {
sc.serveG.check()
if sc.inGoAway {
return
}
if forceCloseIn != 0 {
sc.shutDownIn(forceCloseIn)
} }
sc.inGoAway = true sc.inGoAway = true
sc.needToSendGoAway = true sc.needToSendGoAway = true
...@@ -5264,6 +5281,31 @@ var http2badTrailer = map[string]bool{ ...@@ -5264,6 +5281,31 @@ var http2badTrailer = map[string]bool{
"Www-Authenticate": true, "Www-Authenticate": true,
} }
// h1ServerShutdownChan returns a channel that will be closed when the
// provided *http.Server wants to shut down.
//
// This is a somewhat hacky way to get at http1 innards. It works
// when the http2 code is bundled into the net/http package in the
// standard library. The alternatives ended up making the cmd/go tool
// depend on http Servers. This is the lightest option for now.
// This is tested via the TestServeShutdown* tests in net/http.
func http2h1ServerShutdownChan(hs *Server) <-chan struct{} {
if fn := http2testh1ServerShutdownChan; fn != nil {
return fn(hs)
}
var x interface{} = hs
type I interface {
getDoneChan() <-chan struct{}
}
if hs, ok := x.(I); ok {
return hs.getDoneChan()
}
return nil
}
// optional test hook for h1ServerShutdownChan.
var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{}
const ( const (
// transportDefaultConnFlow is how many connection-level flow control // transportDefaultConnFlow is how many connection-level flow control
// tokens we give the server at start-up, past the default 64k. // tokens we give the server at start-up, past the default 64k.
......
...@@ -12,8 +12,13 @@ import ( ...@@ -12,8 +12,13 @@ import (
"os/exec" "os/exec"
"reflect" "reflect"
"testing" "testing"
"time"
) )
func init() {
shutdownPollInterval = 5 * time.Millisecond
}
func TestForeachHeaderElement(t *testing.T) { func TestForeachHeaderElement(t *testing.T) {
tests := []struct { tests := []struct {
in string in string
......
...@@ -4832,3 +4832,96 @@ func TestServerIdleTimeout(t *testing.T) { ...@@ -4832,3 +4832,96 @@ func TestServerIdleTimeout(t *testing.T) {
t.Fatal("copy byte succeeded; want err") t.Fatal("copy byte succeeded; want err")
} }
} }
func get(t *testing.T, c *Client, url string) string {
res, err := c.Get(url)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
return string(slurp)
}
// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
// currently-open connections.
func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
if runtime.GOOS == "nacl" {
t.Skip("skipping on nacl; see golang.org/issue/17695")
}
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, r.RemoteAddr)
}))
defer ts.Close()
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
get := func() string { return get(t, c, ts.URL) }
a1, a2 := get(), get()
if a1 != a2 {
t.Fatal("expected first two requests on same connection")
}
var idle0 int
if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
idle0 = tr.IdleConnKeyCountForTesting()
return idle0 == 1
}) {
t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0)
}
ts.Config.SetKeepAlivesEnabled(false)
var idle1 int
if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
idle1 = tr.IdleConnKeyCountForTesting()
return idle1 == 0
}) {
t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
}
a3 := get()
if a3 == a2 {
t.Fatal("expected third request on new connection")
}
}
func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) }
func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) }
func testServerShutdown(t *testing.T, h2 bool) {
defer afterTest(t)
var doShutdown func() // set later
var shutdownRes = make(chan error, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
go doShutdown()
// Shutdown is graceful, so it should not interrupt
// this in-flight response. Add a tiny sleep here to
// increase the odds of a failure if shutdown has
// bugs.
time.Sleep(20 * time.Millisecond)
io.WriteString(w, r.RemoteAddr)
}))
defer cst.close()
doShutdown = func() {
shutdownRes <- cst.ts.Config.Shutdown(context.Background())
}
get(t, cst.c, cst.ts.URL) // calls t.Fail on failure
if err := <-shutdownRes; err != nil {
t.Fatalf("Shutdown: %v", err)
}
res, err := cst.c.Get(cst.ts.URL)
if err == nil {
res.Body.Close()
t.Fatal("second request should fail. server should be shut down")
}
}
...@@ -248,6 +248,8 @@ type conn struct { ...@@ -248,6 +248,8 @@ type conn struct {
curReq atomic.Value // of *response (which has a Request in it) curReq atomic.Value // of *response (which has a Request in it)
curState atomic.Value // of ConnectionState
// mu guards hijackedv // mu guards hijackedv
mu sync.Mutex mu sync.Mutex
...@@ -1586,11 +1588,30 @@ func validNPN(proto string) bool { ...@@ -1586,11 +1588,30 @@ func validNPN(proto string) bool {
} }
func (c *conn) setState(nc net.Conn, state ConnState) { func (c *conn) setState(nc net.Conn, state ConnState) {
if hook := c.server.ConnState; hook != nil { srv := c.server
switch state {
case StateNew:
srv.trackConn(c, true)
case StateHijacked, StateClosed:
srv.trackConn(c, false)
}
c.curState.Store(connStateInterface[state])
if hook := srv.ConnState; hook != nil {
hook(nc, state) hook(nc, state)
} }
} }
// connStateInterface is an array of the interface{} versions of
// ConnState values, so we can use them in atomic.Values later without
// paying the cost of shoving their integers in an interface{}.
var connStateInterface = [...]interface{}{
StateNew: StateNew,
StateActive: StateActive,
StateIdle: StateIdle,
StateHijacked: StateHijacked,
StateClosed: StateClosed,
}
// badRequestError is a literal string (used by in the server in HTML, // badRequestError is a literal string (used by in the server in HTML,
// unescaped) to tell the user why their request was bad. It should // unescaped) to tell the user why their request was bad. It should
// be plain text without user info or other embedded errors. // be plain text without user info or other embedded errors.
...@@ -2247,8 +2268,120 @@ type Server struct { ...@@ -2247,8 +2268,120 @@ type Server struct {
ErrorLog *log.Logger ErrorLog *log.Logger
disableKeepAlives int32 // accessed atomically. disableKeepAlives int32 // accessed atomically.
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoOnce sync.Once // guards setupHTTP2_* init
nextProtoErr error // result of http2.ConfigureServer if used nextProtoErr error // result of http2.ConfigureServer if used
mu sync.Mutex
listeners map[net.Listener]struct{}
activeConn map[*conn]struct{}
doneChan chan struct{}
}
func (s *Server) getDoneChan() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
return s.getDoneChanLocked()
}
func (s *Server) getDoneChanLocked() chan struct{} {
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}
func (s *Server) closeDoneChanLocked() {
ch := s.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.
close(ch)
}
}
// Close immediately closes all active net.Listeners and connections,
// regardless of their state. For a graceful shutdown, use Shutdown.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Lock()
s.closeDoneChanLocked()
err := s.closeListenersLocked()
for c := range s.activeConn {
c.rwc.Close()
delete(s.activeConn, c)
}
return err
}
// shutdownPollInterval is how often we poll for quiescence
// during Server.Shutdown. This is lower during tests, to
// speed up tests.
// Ideally we could find a solution that doesn't involve polling,
// but which also doesn't have a high runtime cost (and doesn't
// involve any contentious mutexes), but that is left as an
// exercise for the reader.
var shutdownPollInterval = 500 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// then the context's error is returned.
func (s *Server) Shutdown(ctx context.Context) error {
atomic.AddInt32(&s.inShutdown, 1)
defer atomic.AddInt32(&s.inShutdown, -1)
s.mu.Lock()
lnerr := s.closeListenersLocked()
s.closeDoneChanLocked()
s.mu.Unlock()
ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
for {
if s.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// closeIdleConns closes all idle connections and reports whether the
// server is quiescent.
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.activeConn {
st, ok := c.curState.Load().(ConnState)
if !ok || st != StateIdle {
quiescent = false
continue
}
c.rwc.Close()
delete(s.activeConn, c)
}
return quiescent
}
func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
if cerr := ln.Close(); cerr != nil && err == nil {
err = cerr
}
delete(s.listeners, ln)
}
return err
} }
// A ConnState represents the state of a client connection to a server. // A ConnState represents the state of a client connection to a server.
...@@ -2361,6 +2494,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { ...@@ -2361,6 +2494,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
} }
var ErrServerClosed = errors.New("http: Server closed")
// Serve accepts incoming connections on the Listener l, creating a // Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and // new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them. // then call srv.Handler to reply to them.
...@@ -2370,7 +2505,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { ...@@ -2370,7 +2505,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
// srv.TLSConfig is non-nil and doesn't include the string "h2" in // srv.TLSConfig is non-nil and doesn't include the string "h2" in
// Config.NextProtos, HTTP/2 support is not enabled. // Config.NextProtos, HTTP/2 support is not enabled.
// //
// Serve always returns a non-nil error. // Serve always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error { func (srv *Server) Serve(l net.Listener) error {
defer l.Close() defer l.Close()
if fn := testHookServerServe; fn != nil { if fn := testHookServerServe; fn != nil {
...@@ -2382,12 +2518,20 @@ func (srv *Server) Serve(l net.Listener) error { ...@@ -2382,12 +2518,20 @@ func (srv *Server) Serve(l net.Listener) error {
return err return err
} }
srv.trackListener(l, true)
defer srv.trackListener(l, false)
baseCtx := context.Background() // base is always background, per Issue 16220 baseCtx := context.Background() // base is always background, per Issue 16220
ctx := context.WithValue(baseCtx, ServerContextKey, srv) ctx := context.WithValue(baseCtx, ServerContextKey, srv)
ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr())
for { for {
rw, e := l.Accept() rw, e := l.Accept()
if e != nil { if e != nil {
select {
case <-srv.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() { if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 { if tempDelay == 0 {
tempDelay = 5 * time.Millisecond tempDelay = 5 * time.Millisecond
...@@ -2410,6 +2554,37 @@ func (srv *Server) Serve(l net.Listener) error { ...@@ -2410,6 +2554,37 @@ func (srv *Server) Serve(l net.Listener) error {
} }
} }
func (s *Server) trackListener(ln net.Listener, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
s.listeners = make(map[net.Listener]struct{})
}
if add {
// If the *Server is being reused after a previous
// Close or Shutdown, reset its doneChan:
if len(s.listeners) == 0 && len(s.activeConn) == 0 {
s.doneChan = nil
}
s.listeners[ln] = struct{}{}
} else {
delete(s.listeners, ln)
}
}
func (s *Server) trackConn(c *conn, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.activeConn == nil {
s.activeConn = make(map[*conn]struct{})
}
if add {
s.activeConn[c] = struct{}{}
} else {
delete(s.activeConn, c)
}
}
func (s *Server) idleTimeout() time.Duration { func (s *Server) idleTimeout() time.Duration {
if s.IdleTimeout != 0 { if s.IdleTimeout != 0 {
return s.IdleTimeout return s.IdleTimeout
...@@ -2425,7 +2600,11 @@ func (s *Server) readHeaderTimeout() time.Duration { ...@@ -2425,7 +2600,11 @@ func (s *Server) readHeaderTimeout() time.Duration {
} }
func (s *Server) doKeepAlives() bool { func (s *Server) doKeepAlives() bool {
return atomic.LoadInt32(&s.disableKeepAlives) == 0 return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
}
func (s *Server) shuttingDown() bool {
return atomic.LoadInt32(&s.inShutdown) != 0
} }
// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
...@@ -2435,9 +2614,21 @@ func (s *Server) doKeepAlives() bool { ...@@ -2435,9 +2614,21 @@ func (s *Server) doKeepAlives() bool {
func (srv *Server) SetKeepAlivesEnabled(v bool) { func (srv *Server) SetKeepAlivesEnabled(v bool) {
if v { if v {
atomic.StoreInt32(&srv.disableKeepAlives, 0) atomic.StoreInt32(&srv.disableKeepAlives, 0)
} else { return
atomic.StoreInt32(&srv.disableKeepAlives, 1)
} }
atomic.StoreInt32(&srv.disableKeepAlives, 1)
// Close idle HTTP/1 conns:
srv.closeIdleConns()
// Close HTTP/2 conns, as soon as they become idle, but reset
// the chan so future conns (if the listener is still active)
// still work and don't get a GOAWAY immediately, before their
// first request:
srv.mu.Lock()
defer srv.mu.Unlock()
srv.closeDoneChanLocked() // closes http2 conns
srv.doneChan = nil
} }
func (s *Server) logf(format string, args ...interface{}) { func (s *Server) logf(format string, args ...interface{}) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment