Commit 2d5fb0b4 authored by Eric Chiang's avatar Eric Chiang

Merge pull request #316 from fnordahl/issue/309-implement-connection-pooling

connector_ldap: Implement connection pooling for LDAP connections
parents 4440b3a0 3077979a
......@@ -167,6 +167,8 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio
* skipCertVerification: a `boolean`. Skip server certificate chain verification.
* maxIdleConn: a `integer`. Maximum number of idle LDAP Connections to keep in connection pool. Default: `5`
* baseDN: a `string`. Base DN from which Bind DN is built and searches are based.
* nameAttribute: a `string`. Attribute to map to Name. Default: `cn`
......
......@@ -12,6 +12,7 @@ import (
"net/url"
"path"
"strings"
"sync"
"time"
"github.com/coreos/dex/pkg/log"
......@@ -40,6 +41,7 @@ type LDAPConnectorConfig struct {
KeyFile string `json:"keyFile"`
CaFile string `json:"caFile"`
SkipCertVerification bool `json:"skipCertVerification"`
MaxIdleConn int `json:"maxIdleConn"`
BaseDN string `json:"baseDN"`
NameAttribute string `json:"nameAttribute"`
EmailAttribute string `json:"emailAttribute"`
......@@ -81,6 +83,8 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
const defaultEmailAttribute = "mail"
const defaultBindTemplate = "uid=%u,%b"
const defaultSearchScope = ldap.ScopeWholeSubtree
const defaultMaxIdleConns = 5
const defaultPoolCheckTimer = 7200 * time.Second
if cfg.UseTLS && cfg.UseSSL {
return nil, fmt.Errorf("Invalid configuration. useTLS and useSSL are mutual exclusive.")
......@@ -154,11 +158,22 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
tlsConfig.Certificates = []tls.Certificate{cert}
}
maxIdleConn := defaultMaxIdleConns
if cfg.MaxIdleConn > 0 {
maxIdleConn = cfg.MaxIdleConn
}
ldapPool := &LDAPPool{
MaxIdleConn: maxIdleConn,
PoolCheckTimer: defaultPoolCheckTimer,
ServerHost: cfg.ServerHost,
ServerPort: cfg.ServerPort,
UseTLS: cfg.UseTLS,
UseSSL: cfg.UseSSL,
TLSConfig: tlsConfig,
}
idp := &LDAPIdentityProvider{
serverHost: cfg.ServerHost,
serverPort: cfg.ServerPort,
useTLS: cfg.UseTLS,
useSSL: cfg.UseSSL,
baseDN: cfg.BaseDN,
nameAttribute: nameAttribute,
emailAttribute: emailAttribute,
......@@ -168,7 +183,7 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
searchBindDN: cfg.SearchBindDN,
searchBindPw: cfg.SearchBindPw,
bindTemplate: bindTemplate,
tlsConfig: tlsConfig,
ldapPool: ldapPool,
}
idpc := &LDAPConnector{
......@@ -188,9 +203,9 @@ func (c *LDAPConnector) ID() string {
}
func (c *LDAPConnector) Healthy() error {
ldapConn, err := c.idp.LDAPConnect()
ldapConn, err := c.idp.ldapPool.Acquire()
if err == nil {
ldapConn.Close()
c.idp.ldapPool.Put(ldapConn)
}
return err
}
......@@ -210,18 +225,145 @@ func (c *LDAPConnector) Register(mux *http.ServeMux, errorURL url.URL) {
}
func (c *LDAPConnector) Sync() chan struct{} {
return make(chan struct{})
stop := make(chan struct{})
go func() {
for {
select {
case <-time.After(c.idp.ldapPool.PoolCheckTimer):
alive, killed := c.idp.ldapPool.CheckConnections()
if alive > 0 {
log.Infof("Connector ID=%v idle_conns=%v", c.id, alive)
}
if killed > 0 {
log.Warningf("Connector ID=%v closed %v dead connections.", c.id, killed)
}
case <-stop:
return
}
}
}()
return stop
}
func (c *LDAPConnector) TrustedEmailProvider() bool {
return c.trustedEmailProvider
}
// A LDAPPool is a Connection Pool for LDAP connections
// Initialize exported fields and use Acquire() to get a connection.
// Use Put() to put it back into the pool.
type LDAPPool struct {
m sync.Mutex
conns map[*ldap.Conn]struct{}
MaxIdleConn int
PoolCheckTimer time.Duration
ServerHost string
ServerPort uint16
UseTLS bool
UseSSL bool
TLSConfig *tls.Config
}
// Acquire removes and returns a random connection from the pool. A new connection is returned
// if there are no connections available in the pool.
func (p *LDAPPool) Acquire() (*ldap.Conn, error) {
conn := p.removeRandomConn()
if conn != nil {
return conn, nil
}
return p.ldapConnect()
}
// Put makes a connection ready for re-use and puts it back into the pool. If the connection
// cannot be reused it is discarded. If there already are MaxIdleConn connections in the pool
// the connection is discarded.
func (p *LDAPPool) Put(c *ldap.Conn) {
p.m.Lock()
if p.conns == nil {
// First call to Put, initialize map
p.conns = make(map[*ldap.Conn]struct{})
}
if len(p.conns)+1 > p.MaxIdleConn {
p.m.Unlock()
c.Close()
return
}
p.m.Unlock()
// drop to anonymous bind
err := c.Bind("", "")
if err != nil {
// unsupported or disallowed, throw away connection
log.Warningf("Unable to re-use LDAP Connection after failure to bind anonymously: %v", err)
c.Close()
return
}
p.m.Lock()
p.conns[c] = struct{}{}
p.m.Unlock()
}
// removeConn attempts to remove the provided connection from the pool. If removeConn returns false
// another routine is using the connection and the caller should discard the pointer.
func (p *LDAPPool) removeConn(conn *ldap.Conn) bool {
p.m.Lock()
_, ok := p.conns[conn]
delete(p.conns, conn)
p.m.Unlock()
return ok
}
// removeRandomConn attempts to remove a random connection from the pool. If removeRandomConn
// returns nil the pool is empty.
func (p *LDAPPool) removeRandomConn() *ldap.Conn {
p.m.Lock()
defer p.m.Unlock()
for conn := range p.conns {
delete(p.conns, conn)
return conn
}
return nil
}
// CheckConnections attempts to iterate over all the connections in the pool and check wheter
// they are alive or not. Live connections are put back into the pool, dead ones are discarded.
func (p *LDAPPool) CheckConnections() (int, int) {
var conns []*ldap.Conn
var alive, killed int
// Get snapshot of connection-map while holding Lock
p.m.Lock()
for conn := range p.conns {
conns = append(conns, conn)
}
p.m.Unlock()
// Iterate over snapshot, Get and ping connections.
// Put live connections back into pool, Close dead ones.
for _, conn := range conns {
ok := p.removeConn(conn)
if ok {
err := ldapPing(conn)
if err == nil {
p.Put(conn)
alive++
} else {
conn.Close()
killed++
}
}
}
return alive, killed
}
func ldapPing(conn *ldap.Conn) error {
// Query root DSE
s := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{}, nil)
_, err := conn.Search(s)
return err
}
type LDAPIdentityProvider struct {
serverHost string
serverPort uint16
useTLS bool
useSSL bool
baseDN string
nameAttribute string
emailAttribute string
......@@ -231,26 +373,26 @@ type LDAPIdentityProvider struct {
searchBindDN string
searchBindPw string
bindTemplate string
tlsConfig *tls.Config
ldapPool *LDAPPool
}
func (m *LDAPIdentityProvider) LDAPConnect() (*ldap.Conn, error) {
func (p *LDAPPool) ldapConnect() (*ldap.Conn, error) {
var err error
var ldapConn *ldap.Conn
log.Debugf("LDAPConnect()")
if m.useSSL {
ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", m.serverHost, m.serverPort), m.tlsConfig)
if p.UseSSL {
ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort), p.TLSConfig)
if err != nil {
return nil, err
}
} else {
ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", m.serverHost, m.serverPort))
ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort))
if err != nil {
return nil, err
}
if m.useTLS {
err = ldapConn.StartTLS(m.tlsConfig)
if p.UseTLS {
err = ldapConn.StartTLS(p.TLSConfig)
if err != nil {
return nil, err
}
......@@ -273,11 +415,11 @@ func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identi
var bindDN, ldapUid, ldapName, ldapEmail string
var ldapConn *ldap.Conn
ldapConn, err = m.LDAPConnect()
ldapConn, err = m.ldapPool.Acquire()
if err != nil {
return nil, err
}
defer ldapConn.Close()
defer m.ldapPool.Put(ldapConn)
if m.searchBeforeAuth {
err = ldapConn.Bind(m.searchBindDN, m.searchBindPw)
......@@ -307,16 +449,11 @@ func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identi
ldapName = sr.Entries[0].GetAttributeValue(m.nameAttribute)
ldapEmail = sr.Entries[0].GetAttributeValue(m.emailAttribute)
// drop to anonymous bind, prepare for bind as user
err = ldapConn.Bind("", "")
// prepare LDAP connection for bind as user
m.ldapPool.Put(ldapConn)
ldapConn, err = m.ldapPool.Acquire()
if err != nil {
// unsupported or disallowed, reconnect
log.Warningf("Re-connecting to LDAP Server after failure to bind anonymously: %v", err)
ldapConn.Close()
ldapConn, err = m.LDAPConnect()
if err != nil {
return nil, err
}
return nil, err
}
} else {
bindDN = m.ParseString(m.bindTemplate, username)
......
......@@ -7,9 +7,12 @@ import (
"net/url"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/coreos/dex/connector"
"golang.org/x/net/context"
"gopkg.in/ldap.v2"
)
......@@ -123,3 +126,57 @@ func TestConnectorLDAPHealthy(t *testing.T) {
}
}
}
func TestLDAPPoolHighWatermarkAndLockContention(t *testing.T) {
server := ldapServer(t)
ldapPool := &connector.LDAPPool{
MaxIdleConn: 30,
ServerHost: server.Host,
ServerPort: server.Port,
UseTLS: false,
UseSSL: false,
}
// Excercise pool operations with MaxIdleConn + 10 concurrent goroutines.
// We are testing both pool high watermark code and lock contention
numRoutines := ldapPool.MaxIdleConn + 10
var wg sync.WaitGroup
wg.Add(numRoutines)
ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
for i := 0; i < numRoutines; i++ {
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
ldapConn, err := ldapPool.Acquire()
if err != nil {
t.Errorf("Unable to acquire LDAP Connection: %v", err)
}
s := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{}, nil)
_, err = ldapConn.Search(s)
if err != nil {
t.Errorf("Search request failed. Dead/invalid LDAP connection from pool?: %v", err)
ldapConn.Close()
} else {
ldapPool.Put(ldapConn)
}
_, _ = ldapPool.CheckConnections()
}
}
}()
}
// Wait for all operations to complete and check status.
// There should be MaxIdleConn connections in the pool. This confirms:
// 1. The tests was indeed executed concurrently
// 2. Even though we ran more routines than the configured MaxIdleConn the high
// watermark code did its job and closed surplus connections
wg.Wait()
alive, killed := ldapPool.CheckConnections()
if alive < ldapPool.MaxIdleConn {
t.Errorf("expected %v connections, got alive=%v killed=%v", ldapPool.MaxIdleConn, alive, killed)
}
}
......@@ -65,6 +65,7 @@ imports:
version: dfe268fd2bb5c793f4c083803609fce9806c6f80
subpackages:
- html
- context
- html/atom
- name: google.golang.org/api
version: d3edb0282bde692467788c50070a9211afe75cf3
......
......@@ -60,6 +60,7 @@ import:
version: dfe268fd2bb5c793f4c083803609fce9806c6f80
subpackages:
- html
- context
- package: google.golang.org/api
version: d3edb0282bde692467788c50070a9211afe75cf3
subpackages:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment