Commit a846016c authored by Evan Cordell's avatar Evan Cordell

Merge pull request #442 from ecordell/client-manager

Adds client manager
parents b0f17c94 73d9742c
// package admin provides an implementation of the API described in auth/schema/adminschema.
// Package admin provides an implementation of the API described in auth/schema/adminschema.
package admin
import (
"net/http"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
var (
ClientIDGenerator = oidc.GenClientID
usermanager "github.com/coreos/dex/user/manager"
)
// AdminAPI provides the logic necessary to implement the Admin API.
type AdminAPI struct {
userManager *manager.UserManager
userManager *usermanager.UserManager
userRepo user.UserRepo
passwordInfoRepo user.PasswordInfoRepo
clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
localConnectorID string
}
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *manager.UserManager, localConnectorID string) *AdminAPI {
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, localConnectorID string) *AdminAPI {
if localConnectorID == "" {
panic("must specify non-blank localConnectorID")
}
......@@ -34,6 +30,7 @@ func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRe
userRepo: userRepo,
passwordInfoRepo: pwiRepo,
clientRepo: clientRepo,
clientManager: clientManager,
localConnectorID: localConnectorID,
}
}
......@@ -141,14 +138,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
}
// metadata is guaranteed to have at least one redirect_uri by earlier validation.
id, err := ClientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err)
}
cli.Credentials.ID = id
creds, err := a.clientRepo.New(cli)
creds, err := a.clientManager.New(cli)
if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err)
}
......
......@@ -4,6 +4,7 @@ import (
"testing"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema"
......@@ -17,6 +18,7 @@ type testFixtures struct {
ur user.UserRepo
pwr user.PasswordInfoRepo
cr client.ClientRepo
cm *clientmanager.ClientManager
mgr *manager.UserManager
adAPI *AdminAPI
}
......@@ -71,7 +73,8 @@ func makeTestFixtures() *testFixtures {
}()
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, "local")
f.cm = clientmanager.NewClientManager(f.cr, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, f.cm, "local")
return f
}
......
package client
import (
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/url"
"reflect"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
)
......@@ -17,6 +21,24 @@ var (
ErrorNotFound = errors.New("no data found")
)
const (
bcryptHashCost = 10
)
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
secretBytes, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
return nil, err
}
hashed, err := bcrypt.GenerateFromPassword([]byte(
secretBytes),
bcryptHashCost)
if err != nil {
return nil, err
}
return hashed, nil
}
type Client struct {
Credentials oidc.ClientCredentials
Metadata oidc.ClientMetadata
......@@ -24,30 +46,20 @@ type Client struct {
}
type ClientRepo interface {
Get(clientID string) (Client, error)
Get(tx repo.Transaction, clientID string) (Client, error)
// Metadata returns one matching ClientMetadata if the given client
// exists, otherwise nil. The returned error will be non-nil only
// if the repo was unable to determine client existence.
Metadata(clientID string) (*oidc.ClientMetadata, error)
// Authenticate asserts that a client with the given ID exists and
// that the provided secret matches. If either of these assertions
// fail, (false, nil) will be returned. Only if the repo is unable
// to make these assertions will a non-nil error be returned.
Authenticate(creds oidc.ClientCredentials) (bool, error)
// GetSecret returns the (base64 encoded) hashed client secret
GetSecret(tx repo.Transaction, clientID string) ([]byte, error)
// All returns all registered Clients
All() ([]Client, error)
All(tx repo.Transaction) ([]Client, error)
// New registers a Client with the repo.
// An unused ID must be provided. A corresponding secret will be returned
// in a ClientCredentials struct along with the provided ID.
New(client Client) (*oidc.ClientCredentials, error)
SetDexAdmin(clientID string, isAdmin bool) error
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
IsDexAdmin(clientID string) (bool, error)
Update(tx repo.Transaction, client Client) error
}
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
......
......@@ -34,7 +34,7 @@ var (
badSecretClient = `{
"id": "my_id",
"secret": "` + "****" + `",
"secret": "` + "" + `",
"redirectURLs": ["https://client.example.com"]
}`
......@@ -64,7 +64,7 @@ func TestClientsFromReader(t *testing.T) {
{
Credentials: oidc.ClientCredentials{
ID: "my_id",
Secret: "my_secret",
Secret: goodSecret1,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
......@@ -80,7 +80,7 @@ func TestClientsFromReader(t *testing.T) {
{
Credentials: oidc.ClientCredentials{
ID: "my_id",
Secret: "my_secret",
Secret: goodSecret1,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
......@@ -91,7 +91,7 @@ func TestClientsFromReader(t *testing.T) {
{
Credentials: oidc.ClientCredentials{
ID: "my_other_id",
Secret: "my_other_secret",
Secret: goodSecret2,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
......@@ -101,7 +101,8 @@ func TestClientsFromReader(t *testing.T) {
},
},
},
}, {
},
{
json: "[" + badURLClient + "]",
wantErr: true,
},
......
package manager
import (
"encoding/base64"
"fmt"
"errors"
"github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
"golang.org/x/crypto/bcrypt"
)
const (
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
)
type SecretGenerator func() ([]byte, error)
func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func CompareHashAndPassword(hashedPassword, password []byte) error {
if len(password) > maxSecretLength {
return errors.New("password length greater than max secret length")
}
return bcrypt.CompareHashAndPassword(hashedPassword, password)
}
// ClientManager performs client-related "business-logic" functions on client and related objects.
// This is in contrast to the Repos which perform little more than CRUD operations.
type ClientManager struct {
clientRepo client.ClientRepo
begin repo.TransactionFactory
secretGenerator SecretGenerator
clientIDGenerator func(string) (string, error)
}
type ManagerOptions struct {
SecretGenerator func() ([]byte, error)
ClientIDGenerator func(string) (string, error)
}
func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *ClientManager {
if options.SecretGenerator == nil {
options.SecretGenerator = DefaultSecretGenerator
}
if options.ClientIDGenerator == nil {
options.ClientIDGenerator = oidc.GenClientID
}
return &ClientManager{
clientRepo: clientRepo,
begin: txnFactory,
secretGenerator: options.SecretGenerator,
clientIDGenerator: options.ClientIDGenerator,
}
}
func NewClientManagerFromClients(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, clients []client.Client, options ManagerOptions) (*ClientManager, error) {
clientManager := NewClientManager(clientRepo, txnFactory, options)
tx, err := clientManager.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
for _, c := range clients {
if c.Credentials.Secret == "" {
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
}
cli, err := clientManager.generateClientCredentials(c)
if err != nil {
return nil, err
}
_, err = clientRepo.New(tx, cli)
if err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return clientManager, nil
}
func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) {
tx, err := m.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
c, err := m.generateClientCredentials(cli)
if err != nil {
return nil, err
}
creds := c.Credentials
// Save Client
_, err = m.clientRepo.New(tx, c)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
// Returns creds with unhashed secret
return &creds, nil
}
func (m *ClientManager) Get(id string) (client.Client, error) {
return m.clientRepo.Get(nil, id)
}
func (m *ClientManager) All() ([]client.Client, error) {
return m.clientRepo.All(nil)
}
func (m *ClientManager) Metadata(clientID string) (*oidc.ClientMetadata, error) {
c, err := m.clientRepo.Get(nil, clientID)
if err != nil {
return nil, err
}
return &c.Metadata, nil
}
func (m *ClientManager) IsDexAdmin(clientID string) (bool, error) {
c, err := m.clientRepo.Get(nil, clientID)
if err != nil {
return false, err
}
return c.Admin, nil
}
func (m *ClientManager) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := m.begin()
if err != nil {
return err
}
defer tx.Rollback()
c, err := m.clientRepo.Get(tx, clientID)
if err != nil {
return err
}
c.Admin = isAdmin
err = m.clientRepo.Update(tx, c)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (m *ClientManager) Authenticate(creds oidc.ClientCredentials) (bool, error) {
clientSecret, err := m.clientRepo.GetSecret(nil, creds.ID)
if err != nil || clientSecret == nil {
return false, nil
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
return false, nil
}
ok := CompareHashAndPassword(clientSecret, dec) == nil
return ok, nil
}
func (m *ClientManager) generateClientCredentials(cli client.Client) (client.Client, error) {
// Generate Client ID
if len(cli.Metadata.RedirectURIs) < 1 {
return cli, errors.New("no client redirect url given")
}
clientID, err := m.clientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
if err != nil {
return cli, err
}
// Generate Secret
secret, err := m.secretGenerator()
if err != nil {
return cli, err
}
clientSecret := base64.URLEncoding.EncodeToString(secret)
cli.Credentials = oidc.ClientCredentials{
ID: clientID,
Secret: clientSecret,
}
return cli, nil
}
package manager
import (
"encoding/base64"
"fmt"
"net/url"
"testing"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/oidc"
)
type testFixtures struct {
clientRepo client.ClientRepo
mgr *ClientManager
}
var (
goodSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
)
func makeTestFixtures() *testFixtures {
f := &testFixtures{}
dbMap := db.NewMemDB()
clients := []client.Client{
{
Credentials: oidc.ClientCredentials{
ID: "client.example.com",
Secret: goodSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "http", Host: "client.example.com", Path: "/"},
},
},
Admin: true,
},
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
f.clientRepo = db.NewClientRepo(dbMap)
clientManager, err := NewClientManagerFromClients(f.clientRepo, db.TransactionFactory(dbMap), clients, ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client manager: " + err.Error())
}
f.mgr = clientManager
return f
}
func TestMetadata(t *testing.T) {
tests := []struct {
clientID string
uri string
wantErr bool
}{
{
clientID: "client.example.com",
uri: "http://client.example.com/",
wantErr: false,
},
}
for i, tt := range tests {
f := makeTestFixtures()
md, err := f.mgr.Metadata(tt.clientID)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
continue
}
if md.RedirectURIs[0].String() != tt.uri {
t.Errorf("case %d: manager.Metadata.RedirectURIs: want=%q got=%q", i, tt.uri, md.RedirectURIs[0].String())
continue
}
}
}
func TestIsDexAdmin(t *testing.T) {
tests := []struct {
clientID string
isAdmin bool
wantErr bool
}{
{
clientID: "client.example.com",
isAdmin: true,
wantErr: false,
},
}
for i, tt := range tests {
f := makeTestFixtures()
admin, err := f.mgr.IsDexAdmin(tt.clientID)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
continue
}
if admin != tt.isAdmin {
t.Errorf("case %d: manager.Admin want=%t got=%t", i, tt.isAdmin, admin)
continue
}
}
}
func TestSetDexAdmin(t *testing.T) {
f := makeTestFixtures()
err := f.mgr.SetDexAdmin("client.example.com", false)
if err != nil {
t.Errorf("unexpected err: %v", err)
}
admin, _ := f.mgr.IsDexAdmin("client.example.com")
if admin {
t.Errorf("expected admin to be false")
}
}
func TestAuthenticate(t *testing.T) {
f := makeTestFixtures()
cm := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com", Path: "/cb"},
},
}
cli := client.Client{
Metadata: cm,
}
cc, err := f.mgr.New(cli)
if err != nil {
t.Fatalf(err.Error())
}
ok, err := f.mgr.Authenticate(*cc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if !ok {
t.Fatalf("Authentication failed for good creds")
}
creds := []oidc.ClientCredentials{
//completely made up
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
// good client ID, bad secret
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
// bad client ID, good secret
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
// good client ID, secret with some fluff on the end
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
}
for i, c := range creds {
ok, err := f.mgr.Authenticate(c)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok {
t.Errorf("case %d: authentication succeeded for bad creds", i)
}
}
}
......@@ -15,6 +15,7 @@ import (
"github.com/go-gorp/gorp"
"github.com/coreos/dex/admin"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
pflag "github.com/coreos/dex/pkg/flag"
"github.com/coreos/dex/pkg/log"
......@@ -119,8 +120,9 @@ func main() {
clientRepo := db.NewClientRepo(dbc)
userManager := manager.NewUserManager(userRepo,
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, *localConnectorID)
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, clientManager, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil {
log.Fatalf(err.Error())
......
......@@ -2,6 +2,7 @@ package main
import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/oidc"
......@@ -14,34 +15,26 @@ func newDBDriver(dsn string) (driver, error) {
}
drv := &dbDriver{
ciRepo: db.NewClientRepo(dbc),
cfgRepo: db.NewConnectorConfigRepo(dbc),
cfgRepo: db.NewConnectorConfigRepo(dbc),
ciManager: manager.NewClientManager(db.NewClientRepo(dbc), db.TransactionFactory(dbc), manager.ManagerOptions{}),
}
return drv, nil
}
type dbDriver struct {
ciRepo client.ClientRepo
cfgRepo *db.ConnectorConfigRepo
ciManager *manager.ClientManager
cfgRepo *db.ConnectorConfigRepo
}
func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
if err := meta.Valid(); err != nil {
return nil, err
}
clientID, err := oidc.GenClientID(meta.RedirectURIs[0].Host)
if err != nil {
return nil, err
}
return d.ciRepo.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: clientID,
},
cli := client.Client{
Metadata: meta,
})
}
return d.ciManager.New(cli)
}
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {
......
......@@ -2,7 +2,6 @@ package db
import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
......@@ -10,24 +9,15 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
)
const (
clientTableName = "client_identity"
bcryptHashCost = 10
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
// postgres error codes
pgErrorCodeUniqueViolation = "23505" // unique_violation
)
......@@ -42,17 +32,10 @@ func init() {
}
func newClientModel(cli client.Client) (*clientModel, error) {
secretBytes, err := base64.URLEncoding.DecodeString(cli.Credentials.Secret)
if err != nil {
return nil, err
}
hashed, err := bcrypt.GenerateFromPassword([]byte(
secretBytes),
bcryptHashCost)
hashed, err := client.HashSecret(cli.Credentials)
if err != nil {
return nil, err
}
bmeta, err := json.Marshal(&cli.Metadata)
if err != nil {
return nil, err
......@@ -92,56 +75,20 @@ func (m *clientModel) Client() (*client.Client, error) {
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
return newClientRepo(dbm)
}
func NewClientRepoWithSecretGenerator(dbm *gorp.DbMap, secGen SecretGenerator) client.ClientRepo {
rep := newClientRepo(dbm)
rep.secretGenerator = secGen
return rep
}
func newClientRepo(dbm *gorp.DbMap) *clientRepo {
return &clientRepo{
db: &db{dbm},
secretGenerator: DefaultSecretGenerator,
}
}
func NewClientRepoFromClients(dbm *gorp.DbMap, clients []client.Client) (client.ClientRepo, error) {
repo := newClientRepo(dbm)
tx, err := repo.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
exec := repo.executor(tx)
for _, c := range clients {
if c.Credentials.Secret == "" {
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
}
cm, err := newClientModel(c)
if err != nil {
return nil, err
}
err = exec.Insert(cm)
if err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
db: &db{dbm},
}
return repo, nil
}
type clientRepo struct {
*db
secretGenerator SecretGenerator
}
func (r *clientRepo) Get(clientID string) (client.Client, error) {
m, err := r.executor(nil).Get(clientModel{}, clientID)
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
m, err := r.executor(tx).Get(clientModel{}, clientID)
if err == sql.ErrNoRows || m == nil {
return client.Client{}, client.ErrorNotFound
}
......@@ -163,82 +110,28 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) {
return *ci, nil
}
func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
c, err := r.Get(clientID)
if err != nil {
func (r *clientRepo) GetSecret(tx repo.Transaction, clientID string) ([]byte, error) {
m, err := r.getModel(tx, clientID)
if err != nil || m == nil {
return nil, err
}
return &c.Metadata, nil
return m.Secret, nil
}
func (r *clientRepo) IsDexAdmin(clientID string) (bool, error) {
m, err := r.executor(nil).Get(clientModel{}, clientID)
if m == nil || err != nil {
return false, err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return false, errors.New("unrecognized model")
func (r *clientRepo) Update(tx repo.Transaction, cli client.Client) error {
if cli.Credentials.ID == "" {
return client.ErrorNotFound
}
return cim.DexAdmin, nil
}
func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := r.begin()
// make sure this client exists already
_, err := r.get(tx, cli.Credentials.ID)
if err != nil {
return err
}
defer tx.Rollback()
exec := r.executor(tx)
m, err := exec.Get(clientModel{}, clientID)
if m == nil || err != nil {
return err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return errors.New("unrecognized model")
}
cim.DexAdmin = isAdmin
_, err = exec.Update(cim)
err = r.update(tx, cli)
if err != nil {
return err
}
return tx.Commit()
}
func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
m, err := r.executor(nil).Get(clientModel{}, creds.ID)
if m == nil || err != nil {
return false, err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return false, errors.New("unrecognized model")
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
return false, nil
}
if len(dec) > maxSecretLength {
return false, nil
}
ok = bcrypt.CompareHashAndPassword(cim.Secret, dec) == nil
return ok, nil
return nil
}
var alreadyExistsCheckers []func(err error) bool
......@@ -260,26 +153,14 @@ func isAlreadyExistsErr(err error) bool {
return false
}
type SecretGenerator func() ([]byte, error)
func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
secret, err := r.secretGenerator()
if err != nil {
return nil, err
}
cli.Credentials.Secret = base64.URLEncoding.EncodeToString(secret)
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
cim, err := newClientModel(cli)
if err != nil {
return nil, err
}
if err := r.executor(nil).Insert(cim); err != nil {
if err := r.executor(tx).Insert(cim); err != nil {
if isAlreadyExistsErr(err) {
err = errors.New("client ID already exists")
}
......@@ -294,10 +175,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
return &cc, nil
}
func (r *clientRepo) All() ([]client.Client, error) {
func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
qt := r.quote(clientTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.executor(nil).Select(&clientModel{}, q)
objs, err := r.executor(tx).Select(&clientModel{}, q)
if err != nil {
return nil, err
}
......@@ -317,3 +198,47 @@ func (r *clientRepo) All() ([]client.Client, error) {
}
return cs, nil
}
func (r *clientRepo) get(tx repo.Transaction, clientID string) (client.Client, error) {
cm, err := r.getModel(tx, clientID)
if err != nil {
return client.Client{}, err
}
cli, err := cm.Client()
if err != nil {
return client.Client{}, err
}
return *cli, nil
}
func (r *clientRepo) getModel(tx repo.Transaction, clientID string) (*clientModel, error) {
ex := r.executor(tx)
m, err := ex.Get(clientModel{}, clientID)
if err != nil {
return nil, err
}
if m == nil {
return nil, client.ErrorNotFound
}
cm, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return nil, errors.New("unrecognized model")
}
return cm, nil
}
func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
ex := r.executor(tx)
cm, err := newClientModel(cli)
if err != nil {
return err
}
_, err = ex.Update(cm)
return err
}
......@@ -14,6 +14,7 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/session"
......@@ -191,7 +192,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
},
}
_, err := r.New(client.Client{
_, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
......@@ -201,20 +202,22 @@ func TestDBClientRepoMetadata(t *testing.T) {
t.Fatalf(err.Error())
}
got, err := r.Metadata("foo")
got, err := r.Get(nil, "foo")
if err != nil {
t.Fatalf(err.Error())
}
if diff := pretty.Compare(cm, *got); diff != "" {
if diff := pretty.Compare(cm, got.Metadata); diff != "" {
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
}
}
func TestDBClientRepoMetadataNoExist(t *testing.T) {
r := db.NewClientRepo(connect(t))
c := connect(t)
r := db.NewClientRepo(c)
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{})
got, err := r.Metadata("noexist")
got, err := m.Metadata("noexist")
if err != client.ErrorNotFound {
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
}
......@@ -232,7 +235,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
},
}
if _, err := r.New(client.Client{
if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
......@@ -247,7 +250,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
},
}
if _, err := r.New(client.Client{
if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
......@@ -261,7 +264,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
for _, admin := range []bool{true, false} {
r := db.NewClientRepo(connect(t))
if _, err := r.New(client.Client{
if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
......@@ -275,15 +278,15 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
t.Fatalf("expected non-nil error: %v", err)
}
gotAdmin, err := r.IsDexAdmin("foo")
gotAdmin, err := r.Get(nil, "foo")
if err != nil {
t.Fatalf("expected non-nil error")
}
if gotAdmin != admin {
if gotAdmin.Admin != admin {
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
}
cli, err := r.Get("foo")
cli, err := r.Get(nil, "foo")
if err != nil {
t.Fatalf("expected non-nil error")
}
......@@ -294,29 +297,35 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
}
func TestDBClientRepoAuthenticate(t *testing.T) {
r := db.NewClientRepo(connect(t))
c := connect(t)
r := db.NewClientRepo(c)
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
cm := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
},
}
cc, err := r.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: "baz",
},
cli := client.Client{
Metadata: cm,
})
}
cc, err := m.New(cli)
if err != nil {
t.Fatalf(err.Error())
}
if cc.ID != "baz" {
if cc.ID != "127.0.0.1:5556" {
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
}
ok, err := r.Authenticate(*cc)
ok, err := m.Authenticate(*cc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if !ok {
......@@ -337,7 +346,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
}
for i, c := range creds {
ok, err := r.Authenticate(c)
ok, err := m.Authenticate(c)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok {
......@@ -355,7 +364,7 @@ func TestDBClientAll(t *testing.T) {
},
}
_, err := r.New(client.Client{
_, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "foo",
},
......@@ -365,7 +374,7 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf(err.Error())
}
got, err := r.All()
got, err := r.All(nil)
if err != nil {
t.Fatalf(err.Error())
}
......@@ -383,7 +392,7 @@ func TestDBClientAll(t *testing.T) {
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
},
}
_, err = r.New(client.Client{
_, err = r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: "bar",
},
......@@ -393,7 +402,7 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf(err.Error())
}
got, err = r.All()
got, err = r.All(nil)
if err != nil {
t.Fatalf(err.Error())
}
......
......@@ -3,14 +3,10 @@ package repo
import (
"encoding/base64"
"net/url"
"os"
"testing"
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
)
var (
......@@ -47,95 +43,3 @@ var (
},
}
)
func newClientRepo(t *testing.T) client.ClientRepo {
dsn := os.Getenv("DEX_TEST_DSN")
var dbMap *gorp.DbMap
if dsn == "" {
dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
}
repo, err := db.NewClientRepoFromClients(dbMap, testClients)
if err != nil {
t.Fatalf("failed to create client repo from clients: %v", err)
}
return repo
}
func TestGetSetAdminClient(t *testing.T) {
startAdmins := []string{"client2"}
tests := []struct {
// client ID
cid string
// initial state of client
wantAdmin bool
// final state of client
setAdmin bool
wantErr bool
}{
{
cid: "client1",
wantAdmin: false,
setAdmin: true,
},
{
cid: "client1",
wantAdmin: false,
setAdmin: false,
},
{
cid: "client2",
wantAdmin: true,
setAdmin: true,
},
{
cid: "client2",
wantAdmin: true,
setAdmin: false,
},
}
Tests:
for i, tt := range tests {
repo := newClientRepo(t)
for _, cid := range startAdmins {
err := repo.SetDexAdmin(cid, true)
if err != nil {
t.Errorf("case %d: failed to set dex admin: %v", i, err)
continue Tests
}
}
gotAdmin, err := repo.IsDexAdmin(tt.cid)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if gotAdmin != tt.wantAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
}
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
gotAdmin, err = repo.IsDexAdmin(tt.cid)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if gotAdmin != tt.setAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
}
}
}
......@@ -12,6 +12,7 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/user"
......@@ -27,7 +28,7 @@ func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err)
}
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
if _, err := manager.NewClientManagerFromClients(db.NewClientRepo(dbMap), db.TransactionFactory(dbMap), clients, manager.ManagerOptions{}); err != nil {
t.Fatalf("Unable to add clients: %v", err)
}
return db.NewRefreshTokenRepo(dbMap)
......
......@@ -14,6 +14,7 @@ import (
"github.com/coreos/dex/admin"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/server"
......@@ -87,12 +88,16 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures {
secGen := func() ([]byte, error) {
return []byte(fmt.Sprintf("client_%v", cliCount)), nil
}
cr := db.NewClientRepoWithSecretGenerator(dbMap, secGen)
cr := db.NewClientRepo(dbMap)
clientIDGenerator := func(hostport string) (string, error) {
return fmt.Sprintf("client_%v", hostport), nil
}
cm := manager.NewClientManager(cr, db.TransactionFactory(dbMap), manager.ManagerOptions{SecretGenerator: secGen, ClientIDGenerator: clientIDGenerator})
f.cr = cr
f.ur = ur
f.pwr = pwr
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, "local")
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, cm, "local")
f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret)
f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler())
f.hc = &http.Client{
......@@ -268,14 +273,6 @@ func TestCreateAdmin(t *testing.T) {
}
func TestCreateClient(t *testing.T) {
oldGen := admin.ClientIDGenerator
admin.ClientIDGenerator = func(hostport string) (string, error) {
return fmt.Sprintf("client_%v", hostport), nil
}
defer func() {
admin.ClientIDGenerator = oldGen
}()
mustParseURL := func(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
......@@ -402,7 +399,7 @@ func TestCreateClient(t *testing.T) {
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
}
repoClient, err := f.cr.Get(resp.Client.Id)
repoClient, err := f.cr.Get(nil, resp.Client.Id)
if err != nil {
t.Errorf("case %d: Unexpected error getting client: %v", i, err)
}
......
......@@ -14,9 +14,10 @@ import (
func TestClientCreate(t *testing.T) {
ci := client.Client{
// Credentials are for reference, they are actually generated by the client manager
Credentials: oidc.ClientCredentials{
ID: "72de74a9",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
ID: "authn.example.com",
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
......@@ -73,7 +74,7 @@ func TestClientCreate(t *testing.T) {
t.Error("Expected non-empty Client Secret")
}
meta, err := srv.ClientRepo.Metadata(newClient.Id)
meta, err := srv.ClientManager.Metadata(newClient.Id)
if err != nil {
t.Errorf("Error looking up client metadata: %v", err)
} else if meta == nil {
......
......@@ -22,9 +22,10 @@ var (
clock = clockwork.NewFakeClock()
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
testClientID = "XXX"
testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy"))
testClientID = "client.example.com"
testClientSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
testBadRedirectURL = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"}
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
testPrivKey, _ = key.GeneratePrivateKey()
)
......
......@@ -10,6 +10,7 @@ import (
"time"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http"
......@@ -35,7 +36,15 @@ func mockServer(cis []client.Client) (*server.Server, error) {
if err != nil {
return nil, err
}
clientRepo, err := db.NewClientRepoFromClients(dbMap, cis)
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
return nil, err
}
......@@ -45,6 +54,7 @@ func mockServer(cis []client.Client) (*server.Server, error) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
ClientRepo: clientRepo,
ClientManager: clientManager,
SessionManager: sm,
}
......@@ -82,15 +92,21 @@ func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, is
expectedSub, expectedName = user.ID, user.DisplayName
}
if aud := claims["aud"].(string); aud != ci.Credentials.ID {
if aud, ok := claims["aud"].(string); !ok {
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", ci.Credentials.ID)
} else if aud != ci.Credentials.ID {
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
}
if sub := claims["sub"].(string); sub != expectedSub {
if sub, ok := claims["sub"].(string); !ok {
return fmt.Errorf("unexpected claim value for sub, got=nil, want=%v", expectedSub)
} else if sub != expectedSub {
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
}
if name := claims["name"].(string); name != expectedName {
if name, ok := claims["name"].(string); !ok {
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", expectedName)
} else if name != expectedName {
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
}
......@@ -117,17 +133,34 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
ID: "local",
}
validRedirURL := url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: "72de74a9",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
ID: validRedirURL.Host,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbMap := db.NewMemDB()
cir, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci})
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity repo: " + err.Error())
t.Fatalf("Failed to create client identity manager: " + err.Error())
}
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
if err != nil {
......@@ -164,7 +197,8 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
IssuerURL: issuerURL,
KeyManager: km,
SessionManager: sm,
ClientRepo: cir,
ClientRepo: clientRepo,
ClientManager: clientManager,
Templates: template.New(connector.LoginPageTemplateName),
Connectors: []connector.Connector{},
UserRepo: userRepo,
......@@ -188,7 +222,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
HTTPClient: sClient,
ProviderConfig: pcfg,
Credentials: ci.Credentials,
RedirectURL: "http://client.example.com",
RedirectURL: validRedirURL.String(),
KeySet: *ks,
}
......@@ -263,10 +297,20 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
}
func TestHTTPClientCredsToken(t *testing.T) {
validRedirURL := url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: "72de74a9",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
ID: validRedirURL.Host,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
}
cis := []client.Client{ci}
......
......@@ -18,6 +18,7 @@ import (
"google.golang.org/api/googleapi"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/server"
......@@ -79,7 +80,7 @@ var (
},
}
userBadClientID = "ZZZ"
userBadClientID = testBadRedirectURL.Host
userGoodToken = makeUserToken(testIssuerURL,
"ID-1", testClientID, time.Hour*1, testPrivKey)
......@@ -101,38 +102,42 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f := &userAPITestFixtures{}
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
cir := func() client.ClientRepo {
repo, err := db.NewClientRepoFromClients(dbMap, []client.Client{
client.Client{
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: testClientSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
},
clients := []client.Client{
client.Client{
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: testClientSecret,
},
client.Client{
Credentials: oidc.ClientCredentials{
ID: userBadClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
},
},
client.Client{
Credentials: oidc.ClientCredentials{
ID: userBadClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testBadRedirectURL,
},
},
})
if err != nil {
panic("Failed to create client identity repo: " + err.Error())
}
return repo
}()
cir.SetDexAdmin(testClientID, true)
},
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte(testClientSecret), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client identity manager: " + err.Error())
}
clientManager.SetDexAdmin(testClientID, true)
noop := func() error { return nil }
......@@ -153,8 +158,9 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.emailer = &testEmailer{}
um.Clock = clock
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager)
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
f.trans = &tokenHandlerTransport{
......@@ -536,7 +542,7 @@ func TestCreateUser(t *testing.T) {
wantEmalier := testEmailer{
cantEmail: tt.cantEmail,
lastEmail: tt.req.User.Email,
lastClientID: "XXX",
lastClientID: testClientID,
lastWasInvite: true,
lastRedirectURL: *urlParsed,
}
......@@ -799,7 +805,7 @@ func TestResendEmailInvitation(t *testing.T) {
wantEmalier := testEmailer{
cantEmail: tt.cantEmail,
lastEmail: strings.ToLower(tt.email),
lastClientID: "XXX",
lastClientID: testClientID,
lastWasInvite: true,
lastRedirectURL: *urlParsed,
}
......
......@@ -5,7 +5,7 @@ import (
"fmt"
"net/http"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
......@@ -14,7 +14,7 @@ import (
type clientTokenMiddleware struct {
issuerURL string
ciRepo client.ClientRepo
ciManager *manager.ClientManager
keysFunc func() ([]key.PublicKey, error)
next http.Handler
}
......@@ -30,8 +30,8 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
if c.ciRepo == nil {
log.Errorf("Misconfigured clientTokenMiddleware, ClientRepo is not set")
if c.ciManager == nil {
log.Errorf("Misconfigured clientTokenMiddleware, ClientManager is not set")
respondError()
return
}
......@@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
md, err := c.ciRepo.Metadata(clientID)
md, err := c.ciManager.Metadata(clientID)
if md == nil || err != nil {
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
respondError()
......
package server
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
......@@ -10,6 +9,7 @@ import (
"time"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
......@@ -25,22 +25,23 @@ func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func TestClientToken(t *testing.T) {
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
validClientID := "valid-client"
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: validClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
},
clientMetadata := oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
},
}
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{})
cli := client.Client{
Metadata: clientMetadata,
}
creds, err := clientManager.New(cli)
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
t.Fatalf("Failed to create client: %v", err)
}
validClientID := creds.ID
privKey, err := key.GeneratePrivateKey()
if err != nil {
......@@ -65,63 +66,63 @@ func TestClientToken(t *testing.T) {
tests := []struct {
keys []key.PublicKey
repo client.ClientRepo
manager *clientmanager.ClientManager
header string
wantCode int
}{
// valid token
{
keys: []key.PublicKey{pubKey},
repo: repo,
manager: clientManager,
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusOK,
},
// invalid token
{
keys: []key.PublicKey{pubKey},
repo: repo,
manager: clientManager,
header: fmt.Sprintf("BEARER %s", invalidJWT),
wantCode: http.StatusUnauthorized,
},
// empty header
{
keys: []key.PublicKey{pubKey},
repo: repo,
manager: clientManager,
header: "",
wantCode: http.StatusUnauthorized,
},
// unparsable token
{
keys: []key.PublicKey{pubKey},
repo: repo,
manager: clientManager,
header: "BEARER xxx",
wantCode: http.StatusUnauthorized,
},
// no verification keys
{
keys: []key.PublicKey{},
repo: repo,
manager: clientManager,
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized,
},
// nil repo
{
keys: []key.PublicKey{pubKey},
repo: nil,
manager: nil,
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized,
},
// empty repo
{
keys: []key.PublicKey{pubKey},
repo: db.NewClientRepo(db.NewMemDB()),
manager: clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}),
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized,
},
// client not in repo
{
keys: []key.PublicKey{pubKey},
repo: repo,
manager: clientManager,
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
wantCode: http.StatusUnauthorized,
},
......@@ -131,7 +132,7 @@ func TestClientToken(t *testing.T) {
w := httptest.NewRecorder()
mw := &clientTokenMiddleware{
issuerURL: validIss,
ciRepo: tt.repo,
ciManager: tt.manager,
keysFunc: func() ([]key.PublicKey, error) {
return tt.keys, nil
},
......
......@@ -39,18 +39,10 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
}
// metadata is guarenteed to have at least one redirect_uri by earlier validation.
id, err := oidc.GenClientID(clientMetadata.RedirectURIs[0].Host)
if err != nil {
log.Errorf("Faild to create client ID: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
}
creds, err := s.ClientRepo.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: id,
},
cli := client.Client{
Metadata: clientMetadata,
})
}
creds, err := s.ClientManager.New(cli)
if err != nil {
log.Errorf("Failed to create new client identity: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
......
......@@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
return fmt.Errorf("no client id in registration response")
}
metadata, err := fixtures.clientRepo.Metadata(r.ClientID)
metadata, err := fixtures.clientManager.Metadata(r.ClientID)
if err != nil {
return fmt.Errorf("failed to lookup client id after creation")
}
......
......@@ -6,21 +6,20 @@ import (
"net/http"
"path"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc"
)
type clientResource struct {
repo client.ClientRepo
manager *manager.ClientManager
}
func registerClientResource(prefix string, repo client.ClientRepo) (string, http.Handler) {
func registerClientResource(prefix string, manager *manager.ClientManager) (string, http.Handler) {
mux := http.NewServeMux()
c := &clientResource{
repo: repo,
manager: manager,
}
relPath := "clients"
absPath := path.Join(prefix, relPath)
......@@ -41,7 +40,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
cs, err := c.repo.All()
cs, err := c.manager.All()
if err != nil {
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
return
......@@ -88,16 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
return
}
clientID, err := oidc.GenClientID(ci.Metadata.RedirectURIs[0].Host)
if err != nil {
log.Errorf("Failed generating ID for new client: %v", err)
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "unable to generate client ID"))
return
}
ci.Credentials.ID = clientID
creds, err := c.repo.New(ci)
creds, err := c.manager.New(ci)
if err != nil {
log.Errorf("Failed creating client: %v", err)
......
......@@ -15,6 +15,7 @@ import (
"testing"
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc"
......@@ -28,8 +29,10 @@ func makeBody(s string) io.ReadCloser {
func TestCreateInvalidRequest(t *testing.T) {
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
h := http.Header{"Content-Type": []string{"application/json"}}
repo := db.NewClientRepo(db.NewMemDB())
res := &clientResource{repo: repo}
dbm := db.NewMemDB()
repo := db.NewClientRepo(dbm)
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
res := &clientResource{manager: manager}
tests := []struct {
req *http.Request
wantCode int
......@@ -119,8 +122,10 @@ func TestCreateInvalidRequest(t *testing.T) {
}
func TestCreate(t *testing.T) {
repo := db.NewClientRepo(db.NewMemDB())
res := &clientResource{repo: repo}
dbm := db.NewMemDB()
repo := db.NewClientRepo(dbm)
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
res := &clientResource{manager: manager}
tests := [][]string{
[]string{"http://example.com"},
[]string{"https://example.com"},
......@@ -190,7 +195,7 @@ func TestList(t *testing.T) {
{
cs: []client.Client{
client.Client{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"},
......@@ -200,7 +205,7 @@ func TestList(t *testing.T) {
},
want: []*schema.Client{
&schema.Client{
Id: "foo",
Id: "example.com",
RedirectURIs: []string{"http://example.com"},
},
},
......@@ -209,7 +214,7 @@ func TestList(t *testing.T) {
{
cs: []client.Client{
client.Client{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"},
......@@ -217,21 +222,21 @@ func TestList(t *testing.T) {
},
},
client.Client{
Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")},
Credentials: oidc.ClientCredentials{ID: "example2.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"},
url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
},
},
},
},
want: []*schema.Client{
&schema.Client{
Id: "biz",
RedirectURIs: []string{"https://example.com/one/two/three"},
Id: "example2.com",
RedirectURIs: []string{"https://example2.com/one/two/three"},
},
&schema.Client{
Id: "foo",
Id: "example.com",
RedirectURIs: []string{"http://example.com"},
},
},
......@@ -239,12 +244,20 @@ func TestList(t *testing.T) {
}
for i, tt := range tests {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), tt.cs)
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), tt.cs, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
t.Fatalf("Failed to create client identity manager: %v", err)
continue
}
res := &clientResource{repo: repo}
res := &clientResource{manager: clientManager}
r, err := http.NewRequest("GET", "http://example.com/clients", nil)
if err != nil {
......
......@@ -17,6 +17,7 @@ import (
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
......@@ -114,9 +115,11 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
if err != nil {
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
}
ciRepo, err := db.NewClientRepoFromClients(dbMap, clients)
if err != nil {
return fmt.Errorf("failed to create client identity repo: %v", err)
clientRepo := db.NewClientRepo(dbMap)
for _, c := range clients {
clientRepo.New(nil, c)
}
f, err := os.Open(cfg.ConnectorsFile)
......@@ -155,7 +158,12 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
txnFactory := db.TransactionFactory(dbMap)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientRepo = ciRepo
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{})
if err != nil {
return fmt.Errorf("Failed to create client identity manager: %v", err)
}
srv.ClientRepo = clientRepo
srv.ClientManager = clientManager
srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo
srv.UserRepo = userRepo
......@@ -253,11 +261,13 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
clientManager := clientmanager.NewClientManager(ciRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientRepo = ciRepo
srv.ClientManager = clientManager
srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo
srv.UserRepo = userRepo
......
......@@ -12,6 +12,7 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
......@@ -28,7 +29,7 @@ func handleVerifyEmailResendFunc(
srvKeysFunc func() ([]key.PublicKey, error),
emailer *useremail.UserEmailer,
userRepo user.UserRepo,
clientRepo client.ClientRepo) http.HandlerFunc {
clientManager *clientmanager.ClientManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
var params struct {
......@@ -57,7 +58,7 @@ func handleVerifyEmailResendFunc(
return
}
cm, err := clientRepo.Metadata(clientID)
cm, err := clientManager.Metadata(clientID)
if err == client.ErrorNotFound {
log.Errorf("No such client: %v", err)
writeAPIError(w, http.StatusBadRequest,
......
......@@ -130,7 +130,7 @@ func TestHandleVerifyEmailResend(t *testing.T) {
keysFunc,
f.srv.UserEmailer,
f.userRepo,
f.clientRepo)
f.clientManager)
w := httptest.NewRecorder()
u := "http://example.com"
......
This diff is collapsed.
......@@ -8,6 +8,7 @@ import (
"github.com/coreos/go-oidc/key"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
......@@ -29,7 +30,7 @@ type SendResetPasswordEmailHandler struct {
tpl *template.Template
emailer *useremail.UserEmailer
sm *sessionmanager.SessionManager
cr client.ClientRepo
cm *clientmanager.ClientManager
}
func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
......@@ -128,7 +129,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
return url.URL{}, false
}
cm, err := h.cr.Metadata(clientID)
cm, err := h.cm.Metadata(clientID)
if err != nil || cm == nil {
log.Errorf("Error getting ClientMetadata: %v", err)
return url.URL{}, false
......
......@@ -253,7 +253,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
t.Fatalf("case %d: could not make test fixtures: %v", i, err)
}
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"})
_, err = f.srv.NewSession("local", testClientID, "", f.redirectURL, "", true, []string{"openid"})
if err != nil {
t.Fatalf("case %d: could not create new session: %v", i, err)
}
......@@ -267,7 +267,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
tpl: f.srv.SendResetPasswordEmailTemplate,
emailer: f.srv.UserEmailer,
sm: f.sessionManager,
cr: f.clientRepo,
cm: f.clientManager,
}
w := httptest.NewRecorder()
......
......@@ -295,7 +295,7 @@ func TestHandleRegister(t *testing.T) {
})
}
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"})
key, err := f.srv.NewSession(tt.connID, testClientID, "", f.redirectURL, "", true, []string{"openid"})
t.Logf("case %d: key for NewSession: %v", i, key)
if tt.attachRemote {
......
......@@ -19,6 +19,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
......@@ -72,6 +73,7 @@ type Server struct {
Connectors []connector.Connector
UserRepo user.UserRepo
UserManager *usermanager.UserManager
ClientManager *clientmanager.ClientManager
PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer
......@@ -213,13 +215,13 @@ func (s *Server) HTTPHandler() http.Handler {
s.KeyManager.PublicKeys,
s.UserEmailer,
s.UserRepo,
s.ClientRepo)))
s.ClientManager)))
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
tpl: s.SendResetPasswordEmailTemplate,
emailer: s.UserEmailer,
sm: s.SessionManager,
cr: s.ClientRepo,
cm: s.ClientManager,
})
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
......@@ -256,11 +258,11 @@ func (s *Server) HTTPHandler() http.Handler {
apiBasePath := path.Join(httpPathAPI, APIVersion)
registerDiscoveryResource(apiBasePath, mux)
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientRepo)
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientManager)
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.dbMap, s.UserManager, s.UserEmailer, s.localConnectorID)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientRepo).HTTPHandler()
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler()
mux.Handle(apiBasePath+"/", handler)
......@@ -271,14 +273,14 @@ func (s *Server) HTTPHandler() http.Handler {
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
return &clientTokenMiddleware{
issuerURL: s.IssuerURL.String(),
ciRepo: s.ClientRepo,
ciManager: s.ClientManager,
keysFunc: s.KeyManager.PublicKeys,
next: handler,
}
}
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
return s.ClientRepo.Metadata(clientID)
return s.ClientManager.Metadata(clientID)
}
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
......@@ -365,9 +367,9 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
}
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(creds)
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
......@@ -397,7 +399,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
}
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
ok, err := s.ClientRepo.Authenticate(creds)
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
......@@ -466,7 +468,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
}
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(creds)
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
......
This diff is collapsed.
......@@ -10,6 +10,7 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
......@@ -26,7 +27,7 @@ const (
var (
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
testClientID = "XXX"
testClientID = "client.example.com"
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
......@@ -79,6 +80,7 @@ type testFixtures struct {
emailer *email.TemplatizedEmailer
redirectURL url.URL
clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
}
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
......@@ -123,7 +125,7 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err
}
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
userManager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
......@@ -136,11 +138,11 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err
}
clientRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
clients := []client.Client{
client.Client{
Credentials: oidc.ClientCredentials{
ID: "XXX",
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
ID: testClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
......@@ -148,11 +150,19 @@ func makeTestFixtures() (*testFixtures, error) {
},
},
},
})
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
return nil, err
}
km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
if err != nil {
......@@ -173,7 +183,8 @@ func makeTestFixtures() (*testFixtures, error) {
Templates: tpl,
UserRepo: userRepo,
PasswordInfoRepo: pwRepo,
UserManager: manager,
UserManager: userManager,
ClientManager: clientManager,
KeyManager: km,
}
......@@ -207,5 +218,6 @@ func makeTestFixtures() (*testFixtures, error) {
sessionManager: sessionManager,
emailer: emailer,
clientRepo: clientRepo,
clientManager: clientManager,
}, nil
}
......@@ -11,12 +11,12 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/julienschmidt/httprouter"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/api"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
const (
......@@ -38,16 +38,16 @@ var (
type UserMgmtServer struct {
api *api.UsersAPI
jwtvFactory JWTVerifierFactory
um *manager.UserManager
cir client.ClientRepo
um *usermanager.UserManager
cm *clientmanager.ClientManager
}
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientRepo) *UserMgmtServer {
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager) *UserMgmtServer {
return &UserMgmtServer{
api: userMgmtAPI,
jwtvFactory: jwtvFactory,
um: um,
cir: cir,
cm: cm,
}
}
......@@ -295,7 +295,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred
return api.Creds{}, err
}
isAdmin, err := s.cir.IsDexAdmin(clientID)
isAdmin, err := s.cm.IsDexAdmin(clientID)
if err != nil {
log.Errorf("userMgmtServer: GetCreds err: %q", err)
return api.Creds{}, err
......
......@@ -18,7 +18,7 @@ if [ ! -d $GOPATH/pkg ]; then
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
fi
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin"
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin client client/manager"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override
......
......@@ -9,15 +9,13 @@ import (
"net/url"
"time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
var (
......@@ -88,9 +86,9 @@ func (e Error) Error() string {
// calling User. It is assumed that the clientID has already validated as an
// admin app before calling.
type UsersAPI struct {
manager *manager.UserManager
userManager *usermanager.UserManager
localConnectorID string
clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
refreshRepo refresh.RefreshTokenRepo
emailer Emailer
}
......@@ -105,11 +103,11 @@ type Creds struct {
}
// TODO(ericchiang): Don't pass a dbMap. See #385.
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{
manager: userManager,
refreshRepo: db.NewRefreshTokenRepo(dbMap),
clientRepo: db.NewClientRepo(dbMap),
userManager: userManager,
refreshRepo: refreshRepo,
clientManager: clientManager,
localConnectorID: localConnectorID,
emailer: emailer,
}
......@@ -122,7 +120,7 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
return schema.User{}, ErrorUnauthorized
}
usr, err := u.manager.Get(id)
usr, err := u.userManager.Get(id)
if err != nil {
return schema.User{}, mapError(err)
......@@ -137,7 +135,7 @@ func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema
return schema.UserDisableResponse{}, ErrorUnauthorized
}
if err := u.manager.Disable(userID, disable); err != nil {
if err := u.userManager.Disable(userID, disable); err != nil {
return schema.UserDisableResponse{}, mapError(err)
}
......@@ -157,7 +155,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, mapError(err)
}
metadata, err := u.clientRepo.Metadata(creds.ClientID)
metadata, err := u.clientManager.Metadata(creds.ClientID)
if err != nil {
return schema.UserCreateResponse{}, mapError(err)
}
......@@ -167,12 +165,12 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, ErrorInvalidRedirectURL
}
id, err := u.manager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
id, err := u.userManager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
if err != nil {
return schema.UserCreateResponse{}, mapError(err)
}
userUser, err := u.manager.Get(id)
userUser, err := u.userManager.Get(id)
if err != nil {
return schema.UserCreateResponse{}, mapError(err)
}
......@@ -202,7 +200,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
}
metadata, err := u.clientRepo.Metadata(creds.ClientID)
metadata, err := u.clientManager.Metadata(creds.ClientID)
if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err)
}
......@@ -213,7 +211,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
}
// Retrieve user to check if it's already created
userUser, err := u.manager.Get(userID)
userUser, err := u.userManager.Get(userID)
if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err)
}
......@@ -251,7 +249,7 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
return nil, "", ErrorMaxResultsTooHigh
}
users, tok, err := u.manager.List(user.UserFilter{}, maxResults, nextPageToken)
users, tok, err := u.userManager.List(user.UserFilter{}, maxResults, nextPageToken)
if err != nil {
return nil, "", mapError(err)
}
......
......@@ -12,6 +12,7 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema"
......@@ -50,14 +51,15 @@ func (t *testEmailer) sendEmail(email string, redirectURL url.URL, clientID stri
}
var (
clock = clockwork.NewFakeClock()
clock = clockwork.NewFakeClock()
goodClientID = "client.example.com"
goodCreds = Creds{
User: user.User{
ID: "ID-1",
Admin: true,
},
ClientID: "XXX",
ClientID: goodClientID,
}
badCreds = Creds{
......@@ -72,7 +74,7 @@ var (
Admin: true,
Disabled: true,
},
ClientID: "XXX",
ClientID: goodClientID,
}
resetPasswordURL = url.URL{
......@@ -82,7 +84,7 @@ var (
validRedirURL = url.URL{
Scheme: "http",
Host: "client.example.com",
Host: goodClientID,
Path: "/callback",
}
)
......@@ -158,8 +160,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
mgr.Clock = clock
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: "XXX",
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
ID: goodClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
......@@ -167,8 +169,17 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
},
},
}
if _, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}); err != nil {
panic("Failed to create client repo: " + err.Error())
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client manager: " + err.Error())
}
// Used in TestRevokeRefreshToken test.
......@@ -176,8 +187,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
clientID string
userID string
}{
{"XXX", "ID-1"},
{"XXX", "ID-2"},
{goodClientID, "ID-1"},
{goodClientID, "ID-2"},
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens {
......@@ -187,7 +198,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}
emailer := &testEmailer{}
api := NewUsersAPI(dbMap, mgr, emailer, "local")
api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local")
return api, emailer
}
......@@ -582,8 +593,8 @@ func TestRevokeRefreshToken(t *testing.T) {
before []string // clientIDs expected before the change.
after []string // clientIDs expected after the change.
}{
{"ID-1", "XXX", []string{"XXX"}, []string{}},
{"ID-2", "XXX", []string{"XXX"}, []string{}},
{"ID-1", goodClientID, []string{goodClientID}, []string{}},
{"ID-2", goodClientID, []string{goodClientID}, []string{}},
}
api, _ := makeTestFixtures()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment