Commit beab8eb8 authored by Andrew Gerrand's avatar Andrew Gerrand

go.net: add netutil package with LimitListener

Update golang/go#6012

R=golang-dev, dsymonds, rsc
CC=golang-dev
https://golang.org/cl/12727043
parent 77895031
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package netutil provides network utility functions, complementing the more
// common ones in the net package.
package netutil
import "net"
// LimitListener returns a Listener that accepts at most n simultaneous
// connections from the provided Listener.
func LimitListener(l net.Listener, n int) net.Listener {
ch := make(chan struct{}, n)
for i := 0; i < n; i++ {
ch <- struct{}{}
}
return &limitListener{l, ch}
}
type limitListener struct {
net.Listener
ch chan struct{}
}
func (l *limitListener) Accept() (net.Conn, error) {
<-l.ch
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &limitListenerConn{c, l.ch}, nil
}
type limitListenerConn struct {
net.Conn
ch chan<- struct{}
}
func (l *limitListenerConn) Close() error {
err := l.Conn.Close()
l.ch <- struct{}{}
return err
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netutil
import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestLimitListener(t *testing.T) {
const (
max = 5
num = 200
)
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
l = LimitListener(l, max)
var open int32
go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if n := atomic.AddInt32(&open, 1); n > max {
t.Errorf("%d open connections, want <= %d", n, max)
}
defer atomic.AddInt32(&open, -1)
time.Sleep(10 * time.Millisecond)
fmt.Fprint(w, "some body")
}))
var wg sync.WaitGroup
var failed int32
for i := 0; i < num; i++ {
wg.Add(1)
go func() {
defer wg.Done()
r, err := http.Get("http://" + l.Addr().String())
if err != nil {
t.Logf("Get: %v", err)
atomic.AddInt32(&failed, 1)
return
}
defer r.Body.Close()
io.Copy(ioutil.Discard, r.Body)
}()
}
wg.Wait()
// We expect some Gets to fail as the kernel's accept queue is filled,
// but most should succeed.
if failed >= num/2 {
t.Errorf("too many Gets failed")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment