Commit a846016c authored by Evan Cordell's avatar Evan Cordell

Merge pull request #442 from ecordell/client-manager

Adds client manager
parents b0f17c94 73d9742c
// package admin provides an implementation of the API described in auth/schema/adminschema. // Package admin provides an implementation of the API described in auth/schema/adminschema.
package admin package admin
import ( import (
"net/http" "net/http"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
)
var (
ClientIDGenerator = oidc.GenClientID
) )
// AdminAPI provides the logic necessary to implement the Admin API. // AdminAPI provides the logic necessary to implement the Admin API.
type AdminAPI struct { type AdminAPI struct {
userManager *manager.UserManager userManager *usermanager.UserManager
userRepo user.UserRepo userRepo user.UserRepo
passwordInfoRepo user.PasswordInfoRepo passwordInfoRepo user.PasswordInfoRepo
clientRepo client.ClientRepo clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
localConnectorID string localConnectorID string
} }
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *manager.UserManager, localConnectorID string) *AdminAPI { func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, localConnectorID string) *AdminAPI {
if localConnectorID == "" { if localConnectorID == "" {
panic("must specify non-blank localConnectorID") panic("must specify non-blank localConnectorID")
} }
...@@ -34,6 +30,7 @@ func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRe ...@@ -34,6 +30,7 @@ func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRe
userRepo: userRepo, userRepo: userRepo,
passwordInfoRepo: pwiRepo, passwordInfoRepo: pwiRepo,
clientRepo: clientRepo, clientRepo: clientRepo,
clientManager: clientManager,
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
} }
} }
...@@ -141,14 +138,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem ...@@ -141,14 +138,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
} }
// metadata is guaranteed to have at least one redirect_uri by earlier validation. // metadata is guaranteed to have at least one redirect_uri by earlier validation.
id, err := ClientIDGenerator(cli.Metadata.RedirectURIs[0].Host) creds, err := a.clientManager.New(cli)
if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err)
}
cli.Credentials.ID = id
creds, err := a.clientRepo.New(cli)
if err != nil { if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err) return adminschema.ClientCreateResponse{}, mapError(err)
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
...@@ -17,6 +18,7 @@ type testFixtures struct { ...@@ -17,6 +18,7 @@ type testFixtures struct {
ur user.UserRepo ur user.UserRepo
pwr user.PasswordInfoRepo pwr user.PasswordInfoRepo
cr client.ClientRepo cr client.ClientRepo
cm *clientmanager.ClientManager
mgr *manager.UserManager mgr *manager.UserManager
adAPI *AdminAPI adAPI *AdminAPI
} }
...@@ -71,7 +73,8 @@ func makeTestFixtures() *testFixtures { ...@@ -71,7 +73,8 @@ func makeTestFixtures() *testFixtures {
}() }()
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{}) f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, "local") f.cm = clientmanager.NewClientManager(f.cr, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, f.cm, "local")
return f return f
} }
......
package client package client
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"net/url" "net/url"
"reflect" "reflect"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
...@@ -17,6 +21,24 @@ var ( ...@@ -17,6 +21,24 @@ var (
ErrorNotFound = errors.New("no data found") ErrorNotFound = errors.New("no data found")
) )
const (
bcryptHashCost = 10
)
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
secretBytes, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
return nil, err
}
hashed, err := bcrypt.GenerateFromPassword([]byte(
secretBytes),
bcryptHashCost)
if err != nil {
return nil, err
}
return hashed, nil
}
type Client struct { type Client struct {
Credentials oidc.ClientCredentials Credentials oidc.ClientCredentials
Metadata oidc.ClientMetadata Metadata oidc.ClientMetadata
...@@ -24,30 +46,20 @@ type Client struct { ...@@ -24,30 +46,20 @@ type Client struct {
} }
type ClientRepo interface { type ClientRepo interface {
Get(clientID string) (Client, error) Get(tx repo.Transaction, clientID string) (Client, error)
// Metadata returns one matching ClientMetadata if the given client
// exists, otherwise nil. The returned error will be non-nil only
// if the repo was unable to determine client existence.
Metadata(clientID string) (*oidc.ClientMetadata, error)
// Authenticate asserts that a client with the given ID exists and // GetSecret returns the (base64 encoded) hashed client secret
// that the provided secret matches. If either of these assertions GetSecret(tx repo.Transaction, clientID string) ([]byte, error)
// fail, (false, nil) will be returned. Only if the repo is unable
// to make these assertions will a non-nil error be returned.
Authenticate(creds oidc.ClientCredentials) (bool, error)
// All returns all registered Clients // All returns all registered Clients
All() ([]Client, error) All(tx repo.Transaction) ([]Client, error)
// New registers a Client with the repo. // New registers a Client with the repo.
// An unused ID must be provided. A corresponding secret will be returned // An unused ID must be provided. A corresponding secret will be returned
// in a ClientCredentials struct along with the provided ID. // in a ClientCredentials struct along with the provided ID.
New(client Client) (*oidc.ClientCredentials, error) New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
SetDexAdmin(clientID string, isAdmin bool) error
IsDexAdmin(clientID string) (bool, error) Update(tx repo.Transaction, client Client) error
} }
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise. // ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
......
...@@ -34,7 +34,7 @@ var ( ...@@ -34,7 +34,7 @@ var (
badSecretClient = `{ badSecretClient = `{
"id": "my_id", "id": "my_id",
"secret": "` + "****" + `", "secret": "` + "" + `",
"redirectURLs": ["https://client.example.com"] "redirectURLs": ["https://client.example.com"]
}` }`
...@@ -64,7 +64,7 @@ func TestClientsFromReader(t *testing.T) { ...@@ -64,7 +64,7 @@ func TestClientsFromReader(t *testing.T) {
{ {
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "my_id", ID: "my_id",
Secret: "my_secret", Secret: goodSecret1,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
...@@ -80,7 +80,7 @@ func TestClientsFromReader(t *testing.T) { ...@@ -80,7 +80,7 @@ func TestClientsFromReader(t *testing.T) {
{ {
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "my_id", ID: "my_id",
Secret: "my_secret", Secret: goodSecret1,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
...@@ -91,7 +91,7 @@ func TestClientsFromReader(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestClientsFromReader(t *testing.T) {
{ {
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "my_other_id", ID: "my_other_id",
Secret: "my_other_secret", Secret: goodSecret2,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
...@@ -101,7 +101,8 @@ func TestClientsFromReader(t *testing.T) { ...@@ -101,7 +101,8 @@ func TestClientsFromReader(t *testing.T) {
}, },
}, },
}, },
}, { },
{
json: "[" + badURLClient + "]", json: "[" + badURLClient + "]",
wantErr: true, wantErr: true,
}, },
......
package manager
import (
"encoding/base64"
"fmt"
"errors"
"github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
"golang.org/x/crypto/bcrypt"
)
const (
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
)
type SecretGenerator func() ([]byte, error)
func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func CompareHashAndPassword(hashedPassword, password []byte) error {
if len(password) > maxSecretLength {
return errors.New("password length greater than max secret length")
}
return bcrypt.CompareHashAndPassword(hashedPassword, password)
}
// ClientManager performs client-related "business-logic" functions on client and related objects.
// This is in contrast to the Repos which perform little more than CRUD operations.
type ClientManager struct {
clientRepo client.ClientRepo
begin repo.TransactionFactory
secretGenerator SecretGenerator
clientIDGenerator func(string) (string, error)
}
type ManagerOptions struct {
SecretGenerator func() ([]byte, error)
ClientIDGenerator func(string) (string, error)
}
func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *ClientManager {
if options.SecretGenerator == nil {
options.SecretGenerator = DefaultSecretGenerator
}
if options.ClientIDGenerator == nil {
options.ClientIDGenerator = oidc.GenClientID
}
return &ClientManager{
clientRepo: clientRepo,
begin: txnFactory,
secretGenerator: options.SecretGenerator,
clientIDGenerator: options.ClientIDGenerator,
}
}
func NewClientManagerFromClients(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, clients []client.Client, options ManagerOptions) (*ClientManager, error) {
clientManager := NewClientManager(clientRepo, txnFactory, options)
tx, err := clientManager.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
for _, c := range clients {
if c.Credentials.Secret == "" {
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
}
cli, err := clientManager.generateClientCredentials(c)
if err != nil {
return nil, err
}
_, err = clientRepo.New(tx, cli)
if err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return clientManager, nil
}
func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) {
tx, err := m.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
c, err := m.generateClientCredentials(cli)
if err != nil {
return nil, err
}
creds := c.Credentials
// Save Client
_, err = m.clientRepo.New(tx, c)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
// Returns creds with unhashed secret
return &creds, nil
}
func (m *ClientManager) Get(id string) (client.Client, error) {
return m.clientRepo.Get(nil, id)
}
func (m *ClientManager) All() ([]client.Client, error) {
return m.clientRepo.All(nil)
}
func (m *ClientManager) Metadata(clientID string) (*oidc.ClientMetadata, error) {
c, err := m.clientRepo.Get(nil, clientID)
if err != nil {
return nil, err
}
return &c.Metadata, nil
}
func (m *ClientManager) IsDexAdmin(clientID string) (bool, error) {
c, err := m.clientRepo.Get(nil, clientID)
if err != nil {
return false, err
}
return c.Admin, nil
}
func (m *ClientManager) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := m.begin()
if err != nil {
return err
}
defer tx.Rollback()
c, err := m.clientRepo.Get(tx, clientID)
if err != nil {
return err
}
c.Admin = isAdmin
err = m.clientRepo.Update(tx, c)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (m *ClientManager) Authenticate(creds oidc.ClientCredentials) (bool, error) {
clientSecret, err := m.clientRepo.GetSecret(nil, creds.ID)
if err != nil || clientSecret == nil {
return false, nil
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
return false, nil
}
ok := CompareHashAndPassword(clientSecret, dec) == nil
return ok, nil
}
func (m *ClientManager) generateClientCredentials(cli client.Client) (client.Client, error) {
// Generate Client ID
if len(cli.Metadata.RedirectURIs) < 1 {
return cli, errors.New("no client redirect url given")
}
clientID, err := m.clientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
if err != nil {
return cli, err
}
// Generate Secret
secret, err := m.secretGenerator()
if err != nil {
return cli, err
}
clientSecret := base64.URLEncoding.EncodeToString(secret)
cli.Credentials = oidc.ClientCredentials{
ID: clientID,
Secret: clientSecret,
}
return cli, nil
}
package manager
import (
"encoding/base64"
"fmt"
"net/url"
"testing"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/oidc"
)
type testFixtures struct {
clientRepo client.ClientRepo
mgr *ClientManager
}
var (
goodSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
)
func makeTestFixtures() *testFixtures {
f := &testFixtures{}
dbMap := db.NewMemDB()
clients := []client.Client{
{
Credentials: oidc.ClientCredentials{
ID: "client.example.com",
Secret: goodSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "http", Host: "client.example.com", Path: "/"},
},
},
Admin: true,
},
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
f.clientRepo = db.NewClientRepo(dbMap)
clientManager, err := NewClientManagerFromClients(f.clientRepo, db.TransactionFactory(dbMap), clients, ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client manager: " + err.Error())
}
f.mgr = clientManager
return f
}
func TestMetadata(t *testing.T) {
tests := []struct {
clientID string
uri string
wantErr bool
}{
{
clientID: "client.example.com",
uri: "http://client.example.com/",
wantErr: false,
},
}
for i, tt := range tests {
f := makeTestFixtures()
md, err := f.mgr.Metadata(tt.clientID)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
continue
}
if md.RedirectURIs[0].String() != tt.uri {
t.Errorf("case %d: manager.Metadata.RedirectURIs: want=%q got=%q", i, tt.uri, md.RedirectURIs[0].String())
continue
}
}
}
func TestIsDexAdmin(t *testing.T) {
tests := []struct {
clientID string
isAdmin bool
wantErr bool
}{
{
clientID: "client.example.com",
isAdmin: true,
wantErr: false,
},
}
for i, tt := range tests {
f := makeTestFixtures()
admin, err := f.mgr.IsDexAdmin(tt.clientID)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
continue
}
if admin != tt.isAdmin {
t.Errorf("case %d: manager.Admin want=%t got=%t", i, tt.isAdmin, admin)
continue
}
}
}
func TestSetDexAdmin(t *testing.T) {
f := makeTestFixtures()
err := f.mgr.SetDexAdmin("client.example.com", false)
if err != nil {
t.Errorf("unexpected err: %v", err)
}
admin, _ := f.mgr.IsDexAdmin("client.example.com")
if admin {
t.Errorf("expected admin to be false")
}
}
func TestAuthenticate(t *testing.T) {
f := makeTestFixtures()
cm := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com", Path: "/cb"},
},
}
cli := client.Client{
Metadata: cm,
}
cc, err := f.mgr.New(cli)
if err != nil {
t.Fatalf(err.Error())
}
ok, err := f.mgr.Authenticate(*cc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if !ok {
t.Fatalf("Authentication failed for good creds")
}
creds := []oidc.ClientCredentials{
//completely made up
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
// good client ID, bad secret
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
// bad client ID, good secret
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
// good client ID, secret with some fluff on the end
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
}
for i, c := range creds {
ok, err := f.mgr.Authenticate(c)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok {
t.Errorf("case %d: authentication succeeded for bad creds", i)
}
}
}
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/coreos/dex/admin" "github.com/coreos/dex/admin"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
pflag "github.com/coreos/dex/pkg/flag" pflag "github.com/coreos/dex/pkg/flag"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
...@@ -119,8 +120,9 @@ func main() { ...@@ -119,8 +120,9 @@ func main() {
clientRepo := db.NewClientRepo(dbc) clientRepo := db.NewClientRepo(dbc)
userManager := manager.NewUserManager(userRepo, userManager := manager.NewUserManager(userRepo,
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, *localConnectorID) adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, clientManager, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...) kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil { if err != nil {
log.Fatalf(err.Error()) log.Fatalf(err.Error())
......
...@@ -2,6 +2,7 @@ package main ...@@ -2,6 +2,7 @@ package main
import ( import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
...@@ -14,15 +15,15 @@ func newDBDriver(dsn string) (driver, error) { ...@@ -14,15 +15,15 @@ func newDBDriver(dsn string) (driver, error) {
} }
drv := &dbDriver{ drv := &dbDriver{
ciRepo: db.NewClientRepo(dbc),
cfgRepo: db.NewConnectorConfigRepo(dbc), cfgRepo: db.NewConnectorConfigRepo(dbc),
ciManager: manager.NewClientManager(db.NewClientRepo(dbc), db.TransactionFactory(dbc), manager.ManagerOptions{}),
} }
return drv, nil return drv, nil
} }
type dbDriver struct { type dbDriver struct {
ciRepo client.ClientRepo ciManager *manager.ClientManager
cfgRepo *db.ConnectorConfigRepo cfgRepo *db.ConnectorConfigRepo
} }
...@@ -30,18 +31,10 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials, ...@@ -30,18 +31,10 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials,
if err := meta.Valid(); err != nil { if err := meta.Valid(); err != nil {
return nil, err return nil, err
} }
cli := client.Client{
clientID, err := oidc.GenClientID(meta.RedirectURIs[0].Host)
if err != nil {
return nil, err
}
return d.ciRepo.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: clientID,
},
Metadata: meta, Metadata: meta,
}) }
return d.ciManager.New(cli)
} }
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) { func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {
......
...@@ -2,7 +2,6 @@ package db ...@@ -2,7 +2,6 @@ package db
import ( import (
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -10,24 +9,15 @@ import ( ...@@ -10,24 +9,15 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
) )
const ( const (
clientTableName = "client_identity" clientTableName = "client_identity"
bcryptHashCost = 10
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
// postgres error codes // postgres error codes
pgErrorCodeUniqueViolation = "23505" // unique_violation pgErrorCodeUniqueViolation = "23505" // unique_violation
) )
...@@ -42,17 +32,10 @@ func init() { ...@@ -42,17 +32,10 @@ func init() {
} }
func newClientModel(cli client.Client) (*clientModel, error) { func newClientModel(cli client.Client) (*clientModel, error) {
secretBytes, err := base64.URLEncoding.DecodeString(cli.Credentials.Secret) hashed, err := client.HashSecret(cli.Credentials)
if err != nil {
return nil, err
}
hashed, err := bcrypt.GenerateFromPassword([]byte(
secretBytes),
bcryptHashCost)
if err != nil { if err != nil {
return nil, err return nil, err
} }
bmeta, err := json.Marshal(&cli.Metadata) bmeta, err := json.Marshal(&cli.Metadata)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -92,56 +75,20 @@ func (m *clientModel) Client() (*client.Client, error) { ...@@ -92,56 +75,20 @@ func (m *clientModel) Client() (*client.Client, error) {
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo { func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
return newClientRepo(dbm) return newClientRepo(dbm)
}
func NewClientRepoWithSecretGenerator(dbm *gorp.DbMap, secGen SecretGenerator) client.ClientRepo {
rep := newClientRepo(dbm)
rep.secretGenerator = secGen
return rep
} }
func newClientRepo(dbm *gorp.DbMap) *clientRepo { func newClientRepo(dbm *gorp.DbMap) *clientRepo {
return &clientRepo{ return &clientRepo{
db: &db{dbm}, db: &db{dbm},
secretGenerator: DefaultSecretGenerator,
}
}
func NewClientRepoFromClients(dbm *gorp.DbMap, clients []client.Client) (client.ClientRepo, error) {
repo := newClientRepo(dbm)
tx, err := repo.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
exec := repo.executor(tx)
for _, c := range clients {
if c.Credentials.Secret == "" {
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
}
cm, err := newClientModel(c)
if err != nil {
return nil, err
}
err = exec.Insert(cm)
if err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
} }
return repo, nil
} }
type clientRepo struct { type clientRepo struct {
*db *db
secretGenerator SecretGenerator
} }
func (r *clientRepo) Get(clientID string) (client.Client, error) { func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
m, err := r.executor(nil).Get(clientModel{}, clientID) m, err := r.executor(tx).Get(clientModel{}, clientID)
if err == sql.ErrNoRows || m == nil { if err == sql.ErrNoRows || m == nil {
return client.Client{}, client.ErrorNotFound return client.Client{}, client.ErrorNotFound
} }
...@@ -163,82 +110,28 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) { ...@@ -163,82 +110,28 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) {
return *ci, nil return *ci, nil
} }
func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) { func (r *clientRepo) GetSecret(tx repo.Transaction, clientID string) ([]byte, error) {
c, err := r.Get(clientID) m, err := r.getModel(tx, clientID)
if err != nil { if err != nil || m == nil {
return nil, err return nil, err
} }
return m.Secret, nil
return &c.Metadata, nil
} }
func (r *clientRepo) IsDexAdmin(clientID string) (bool, error) { func (r *clientRepo) Update(tx repo.Transaction, cli client.Client) error {
m, err := r.executor(nil).Get(clientModel{}, clientID) if cli.Credentials.ID == "" {
if m == nil || err != nil { return client.ErrorNotFound
return false, err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return false, errors.New("unrecognized model")
} }
// make sure this client exists already
return cim.DexAdmin, nil _, err := r.get(tx, cli.Credentials.ID)
}
func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := r.begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() err = r.update(tx, cli)
exec := r.executor(tx)
m, err := exec.Get(clientModel{}, clientID)
if m == nil || err != nil {
return err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return errors.New("unrecognized model")
}
cim.DexAdmin = isAdmin
_, err = exec.Update(cim)
if err != nil { if err != nil {
return err return err
} }
return nil
return tx.Commit()
}
func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
m, err := r.executor(nil).Get(clientModel{}, creds.ID)
if m == nil || err != nil {
return false, err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return false, errors.New("unrecognized model")
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
return false, nil
}
if len(dec) > maxSecretLength {
return false, nil
}
ok = bcrypt.CompareHashAndPassword(cim.Secret, dec) == nil
return ok, nil
} }
var alreadyExistsCheckers []func(err error) bool var alreadyExistsCheckers []func(err error) bool
...@@ -260,26 +153,14 @@ func isAlreadyExistsErr(err error) bool { ...@@ -260,26 +153,14 @@ func isAlreadyExistsErr(err error) bool {
return false return false
} }
type SecretGenerator func() ([]byte, error) func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
secret, err := r.secretGenerator()
if err != nil {
return nil, err
}
cli.Credentials.Secret = base64.URLEncoding.EncodeToString(secret)
cim, err := newClientModel(cli) cim, err := newClientModel(cli)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := r.executor(nil).Insert(cim); err != nil { if err := r.executor(tx).Insert(cim); err != nil {
if isAlreadyExistsErr(err) { if isAlreadyExistsErr(err) {
err = errors.New("client ID already exists") err = errors.New("client ID already exists")
} }
...@@ -294,10 +175,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) { ...@@ -294,10 +175,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
return &cc, nil return &cc, nil
} }
func (r *clientRepo) All() ([]client.Client, error) { func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
qt := r.quote(clientTableName) qt := r.quote(clientTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt) q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.executor(nil).Select(&clientModel{}, q) objs, err := r.executor(tx).Select(&clientModel{}, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -317,3 +198,47 @@ func (r *clientRepo) All() ([]client.Client, error) { ...@@ -317,3 +198,47 @@ func (r *clientRepo) All() ([]client.Client, error) {
} }
return cs, nil return cs, nil
} }
func (r *clientRepo) get(tx repo.Transaction, clientID string) (client.Client, error) {
cm, err := r.getModel(tx, clientID)
if err != nil {
return client.Client{}, err
}
cli, err := cm.Client()
if err != nil {
return client.Client{}, err
}
return *cli, nil
}
func (r *clientRepo) getModel(tx repo.Transaction, clientID string) (*clientModel, error) {
ex := r.executor(tx)
m, err := ex.Get(clientModel{}, clientID)
if err != nil {
return nil, err
}
if m == nil {
return nil, client.ErrorNotFound
}
cm, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return nil, errors.New("unrecognized model")
}
return cm, nil
}
func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
ex := r.executor(tx)
cm, err := newClientModel(cli)
if err != nil {
return err
}
_, err = ex.Update(cm)
return err
}
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
...@@ -191,7 +192,7 @@ func TestDBClientRepoMetadata(t *testing.T) { ...@@ -191,7 +192,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
}, },
} }
_, err := r.New(client.Client{ _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "foo", ID: "foo",
}, },
...@@ -201,20 +202,22 @@ func TestDBClientRepoMetadata(t *testing.T) { ...@@ -201,20 +202,22 @@ func TestDBClientRepoMetadata(t *testing.T) {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
got, err := r.Metadata("foo") got, err := r.Get(nil, "foo")
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
if diff := pretty.Compare(cm, *got); diff != "" { if diff := pretty.Compare(cm, got.Metadata); diff != "" {
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff) t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
} }
} }
func TestDBClientRepoMetadataNoExist(t *testing.T) { func TestDBClientRepoMetadataNoExist(t *testing.T) {
r := db.NewClientRepo(connect(t)) c := connect(t)
r := db.NewClientRepo(c)
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{})
got, err := r.Metadata("noexist") got, err := m.Metadata("noexist")
if err != client.ErrorNotFound { if err != client.ErrorNotFound {
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err) t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
} }
...@@ -232,7 +235,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) { ...@@ -232,7 +235,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
}, },
} }
if _, err := r.New(client.Client{ if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "foo", ID: "foo",
}, },
...@@ -247,7 +250,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) { ...@@ -247,7 +250,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
}, },
} }
if _, err := r.New(client.Client{ if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "foo", ID: "foo",
}, },
...@@ -261,7 +264,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) { ...@@ -261,7 +264,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
for _, admin := range []bool{true, false} { for _, admin := range []bool{true, false} {
r := db.NewClientRepo(connect(t)) r := db.NewClientRepo(connect(t))
if _, err := r.New(client.Client{ if _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "foo", ID: "foo",
}, },
...@@ -275,15 +278,15 @@ func TestDBClientRepoNewAdmin(t *testing.T) { ...@@ -275,15 +278,15 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
t.Fatalf("expected non-nil error: %v", err) t.Fatalf("expected non-nil error: %v", err)
} }
gotAdmin, err := r.IsDexAdmin("foo") gotAdmin, err := r.Get(nil, "foo")
if err != nil { if err != nil {
t.Fatalf("expected non-nil error") t.Fatalf("expected non-nil error")
} }
if gotAdmin != admin { if gotAdmin.Admin != admin {
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin) t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
} }
cli, err := r.Get("foo") cli, err := r.Get(nil, "foo")
if err != nil { if err != nil {
t.Fatalf("expected non-nil error") t.Fatalf("expected non-nil error")
} }
...@@ -294,29 +297,35 @@ func TestDBClientRepoNewAdmin(t *testing.T) { ...@@ -294,29 +297,35 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
} }
func TestDBClientRepoAuthenticate(t *testing.T) { func TestDBClientRepoAuthenticate(t *testing.T) {
r := db.NewClientRepo(connect(t)) c := connect(t)
r := db.NewClientRepo(c)
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
cm := oidc.ClientMetadata{ cm := oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"}, url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
}, },
} }
cli := client.Client{
cc, err := r.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: "baz",
},
Metadata: cm, Metadata: cm,
}) }
cc, err := m.New(cli)
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
if cc.ID != "baz" { if cc.ID != "127.0.0.1:5556" {
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID) t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
} }
ok, err := r.Authenticate(*cc) ok, err := m.Authenticate(*cc)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} else if !ok { } else if !ok {
...@@ -337,7 +346,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) { ...@@ -337,7 +346,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)}, oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
} }
for i, c := range creds { for i, c := range creds {
ok, err := r.Authenticate(c) ok, err := m.Authenticate(c)
if err != nil { if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok { } else if ok {
...@@ -355,7 +364,7 @@ func TestDBClientAll(t *testing.T) { ...@@ -355,7 +364,7 @@ func TestDBClientAll(t *testing.T) {
}, },
} }
_, err := r.New(client.Client{ _, err := r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "foo", ID: "foo",
}, },
...@@ -365,7 +374,7 @@ func TestDBClientAll(t *testing.T) { ...@@ -365,7 +374,7 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
got, err := r.All() got, err := r.All(nil)
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
...@@ -383,7 +392,7 @@ func TestDBClientAll(t *testing.T) { ...@@ -383,7 +392,7 @@ func TestDBClientAll(t *testing.T) {
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"}, url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
}, },
} }
_, err = r.New(client.Client{ _, err = r.New(nil, client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "bar", ID: "bar",
}, },
...@@ -393,7 +402,7 @@ func TestDBClientAll(t *testing.T) { ...@@ -393,7 +402,7 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
got, err = r.All() got, err = r.All(nil)
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
......
...@@ -3,14 +3,10 @@ package repo ...@@ -3,14 +3,10 @@ package repo
import ( import (
"encoding/base64" "encoding/base64"
"net/url" "net/url"
"os"
"testing"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
) )
var ( var (
...@@ -47,95 +43,3 @@ var ( ...@@ -47,95 +43,3 @@ var (
}, },
} }
) )
func newClientRepo(t *testing.T) client.ClientRepo {
dsn := os.Getenv("DEX_TEST_DSN")
var dbMap *gorp.DbMap
if dsn == "" {
dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
}
repo, err := db.NewClientRepoFromClients(dbMap, testClients)
if err != nil {
t.Fatalf("failed to create client repo from clients: %v", err)
}
return repo
}
func TestGetSetAdminClient(t *testing.T) {
startAdmins := []string{"client2"}
tests := []struct {
// client ID
cid string
// initial state of client
wantAdmin bool
// final state of client
setAdmin bool
wantErr bool
}{
{
cid: "client1",
wantAdmin: false,
setAdmin: true,
},
{
cid: "client1",
wantAdmin: false,
setAdmin: false,
},
{
cid: "client2",
wantAdmin: true,
setAdmin: true,
},
{
cid: "client2",
wantAdmin: true,
setAdmin: false,
},
}
Tests:
for i, tt := range tests {
repo := newClientRepo(t)
for _, cid := range startAdmins {
err := repo.SetDexAdmin(cid, true)
if err != nil {
t.Errorf("case %d: failed to set dex admin: %v", i, err)
continue Tests
}
}
gotAdmin, err := repo.IsDexAdmin(tt.cid)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if gotAdmin != tt.wantAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
}
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
gotAdmin, err = repo.IsDexAdmin(tt.cid)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if gotAdmin != tt.setAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
}
}
}
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
...@@ -27,7 +28,7 @@ func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients ...@@ -27,7 +28,7 @@ func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err) t.Fatalf("Unable to add users: %v", err)
} }
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil { if _, err := manager.NewClientManagerFromClients(db.NewClientRepo(dbMap), db.TransactionFactory(dbMap), clients, manager.ManagerOptions{}); err != nil {
t.Fatalf("Unable to add clients: %v", err) t.Fatalf("Unable to add clients: %v", err)
} }
return db.NewRefreshTokenRepo(dbMap) return db.NewRefreshTokenRepo(dbMap)
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/coreos/dex/admin" "github.com/coreos/dex/admin"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
...@@ -87,12 +88,16 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures { ...@@ -87,12 +88,16 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures {
secGen := func() ([]byte, error) { secGen := func() ([]byte, error) {
return []byte(fmt.Sprintf("client_%v", cliCount)), nil return []byte(fmt.Sprintf("client_%v", cliCount)), nil
} }
cr := db.NewClientRepoWithSecretGenerator(dbMap, secGen) cr := db.NewClientRepo(dbMap)
clientIDGenerator := func(hostport string) (string, error) {
return fmt.Sprintf("client_%v", hostport), nil
}
cm := manager.NewClientManager(cr, db.TransactionFactory(dbMap), manager.ManagerOptions{SecretGenerator: secGen, ClientIDGenerator: clientIDGenerator})
f.cr = cr f.cr = cr
f.ur = ur f.ur = ur
f.pwr = pwr f.pwr = pwr
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, "local") f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, cm, "local")
f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret) f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret)
f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler()) f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler())
f.hc = &http.Client{ f.hc = &http.Client{
...@@ -268,14 +273,6 @@ func TestCreateAdmin(t *testing.T) { ...@@ -268,14 +273,6 @@ func TestCreateAdmin(t *testing.T) {
} }
func TestCreateClient(t *testing.T) { func TestCreateClient(t *testing.T) {
oldGen := admin.ClientIDGenerator
admin.ClientIDGenerator = func(hostport string) (string, error) {
return fmt.Sprintf("client_%v", hostport), nil
}
defer func() {
admin.ClientIDGenerator = oldGen
}()
mustParseURL := func(s string) *url.URL { mustParseURL := func(s string) *url.URL {
u, err := url.Parse(s) u, err := url.Parse(s)
if err != nil { if err != nil {
...@@ -402,7 +399,7 @@ func TestCreateClient(t *testing.T) { ...@@ -402,7 +399,7 @@ func TestCreateClient(t *testing.T) {
t.Errorf("case %d: Compare(want, got) = %v", i, diff) t.Errorf("case %d: Compare(want, got) = %v", i, diff)
} }
repoClient, err := f.cr.Get(resp.Client.Id) repoClient, err := f.cr.Get(nil, resp.Client.Id)
if err != nil { if err != nil {
t.Errorf("case %d: Unexpected error getting client: %v", i, err) t.Errorf("case %d: Unexpected error getting client: %v", i, err)
} }
......
...@@ -14,9 +14,10 @@ import ( ...@@ -14,9 +14,10 @@ import (
func TestClientCreate(t *testing.T) { func TestClientCreate(t *testing.T) {
ci := client.Client{ ci := client.Client{
// Credentials are for reference, they are actually generated by the client manager
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "authn.example.com",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
...@@ -73,7 +74,7 @@ func TestClientCreate(t *testing.T) { ...@@ -73,7 +74,7 @@ func TestClientCreate(t *testing.T) {
t.Error("Expected non-empty Client Secret") t.Error("Expected non-empty Client Secret")
} }
meta, err := srv.ClientRepo.Metadata(newClient.Id) meta, err := srv.ClientManager.Metadata(newClient.Id)
if err != nil { if err != nil {
t.Errorf("Error looking up client metadata: %v", err) t.Errorf("Error looking up client metadata: %v", err)
} else if meta == nil { } else if meta == nil {
......
...@@ -22,9 +22,10 @@ var ( ...@@ -22,9 +22,10 @@ var (
clock = clockwork.NewFakeClock() clock = clockwork.NewFakeClock()
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"} testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
testClientID = "XXX" testClientID = "client.example.com"
testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy")) testClientSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"} testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
testBadRedirectURL = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"}
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"} testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
testPrivKey, _ = key.GeneratePrivateKey() testPrivKey, _ = key.GeneratePrivateKey()
) )
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
...@@ -35,7 +36,15 @@ func mockServer(cis []client.Client) (*server.Server, error) { ...@@ -35,7 +36,15 @@ func mockServer(cis []client.Client) (*server.Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientRepo, err := db.NewClientRepoFromClients(dbMap, cis)
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -45,6 +54,7 @@ func mockServer(cis []client.Client) (*server.Server, error) { ...@@ -45,6 +54,7 @@ func mockServer(cis []client.Client) (*server.Server, error) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
ClientRepo: clientRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
SessionManager: sm, SessionManager: sm,
} }
...@@ -82,15 +92,21 @@ func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, is ...@@ -82,15 +92,21 @@ func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, is
expectedSub, expectedName = user.ID, user.DisplayName expectedSub, expectedName = user.ID, user.DisplayName
} }
if aud := claims["aud"].(string); aud != ci.Credentials.ID { if aud, ok := claims["aud"].(string); !ok {
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", ci.Credentials.ID)
} else if aud != ci.Credentials.ID {
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID) return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
} }
if sub := claims["sub"].(string); sub != expectedSub { if sub, ok := claims["sub"].(string); !ok {
return fmt.Errorf("unexpected claim value for sub, got=nil, want=%v", expectedSub)
} else if sub != expectedSub {
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub) return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
} }
if name := claims["name"].(string); name != expectedName { if name, ok := claims["name"].(string); !ok {
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", expectedName)
} else if name != expectedName {
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName) return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
} }
...@@ -117,17 +133,34 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -117,17 +133,34 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
ID: "local", ID: "local",
} }
validRedirURL := url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: validRedirURL.Host,
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
}, },
} }
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbMap := db.NewMemDB() dbMap := db.NewMemDB()
cir, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}) clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: " + err.Error()) t.Fatalf("Failed to create client identity manager: " + err.Error())
} }
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo}) passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
if err != nil { if err != nil {
...@@ -164,7 +197,8 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -164,7 +197,8 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
IssuerURL: issuerURL, IssuerURL: issuerURL,
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: cir, ClientRepo: clientRepo,
ClientManager: clientManager,
Templates: template.New(connector.LoginPageTemplateName), Templates: template.New(connector.LoginPageTemplateName),
Connectors: []connector.Connector{}, Connectors: []connector.Connector{},
UserRepo: userRepo, UserRepo: userRepo,
...@@ -188,7 +222,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -188,7 +222,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
HTTPClient: sClient, HTTPClient: sClient,
ProviderConfig: pcfg, ProviderConfig: pcfg,
Credentials: ci.Credentials, Credentials: ci.Credentials,
RedirectURL: "http://client.example.com", RedirectURL: validRedirURL.String(),
KeySet: *ks, KeySet: *ks,
} }
...@@ -263,10 +297,20 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -263,10 +297,20 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
} }
func TestHTTPClientCredsToken(t *testing.T) { func TestHTTPClientCredsToken(t *testing.T) {
validRedirURL := url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: validRedirURL.Host,
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
}, },
} }
cis := []client.Client{ci} cis := []client.Client{ci}
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
...@@ -79,7 +80,7 @@ var ( ...@@ -79,7 +80,7 @@ var (
}, },
} }
userBadClientID = "ZZZ" userBadClientID = testBadRedirectURL.Host
userGoodToken = makeUserToken(testIssuerURL, userGoodToken = makeUserToken(testIssuerURL,
"ID-1", testClientID, time.Hour*1, testPrivKey) "ID-1", testClientID, time.Hour*1, testPrivKey)
...@@ -101,8 +102,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -101,8 +102,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f := &userAPITestFixtures{} f := &userAPITestFixtures{}
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords) dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
cir := func() client.ClientRepo { clients := []client.Client{
repo, err := db.NewClientRepoFromClients(dbMap, []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: testClientID, ID: testClientID,
...@@ -121,18 +121,23 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -121,18 +121,23 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
testRedirectURL, testBadRedirectURL,
}, },
}, },
}, },
}) }
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte(testClientSecret), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
panic("Failed to create client identity repo: " + err.Error()) panic("Failed to create client identity manager: " + err.Error())
} }
return repo clientManager.SetDexAdmin(testClientID, true)
}()
cir.SetDexAdmin(testClientID, true)
noop := func() error { return nil } noop := func() error { return nil }
...@@ -153,8 +158,9 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -153,8 +158,9 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.emailer = &testEmailer{} f.emailer = &testEmailer{}
um.Clock = clock um.Clock = clock
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir) api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager)
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler()) f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
f.trans = &tokenHandlerTransport{ f.trans = &tokenHandlerTransport{
...@@ -536,7 +542,7 @@ func TestCreateUser(t *testing.T) { ...@@ -536,7 +542,7 @@ func TestCreateUser(t *testing.T) {
wantEmalier := testEmailer{ wantEmalier := testEmailer{
cantEmail: tt.cantEmail, cantEmail: tt.cantEmail,
lastEmail: tt.req.User.Email, lastEmail: tt.req.User.Email,
lastClientID: "XXX", lastClientID: testClientID,
lastWasInvite: true, lastWasInvite: true,
lastRedirectURL: *urlParsed, lastRedirectURL: *urlParsed,
} }
...@@ -799,7 +805,7 @@ func TestResendEmailInvitation(t *testing.T) { ...@@ -799,7 +805,7 @@ func TestResendEmailInvitation(t *testing.T) {
wantEmalier := testEmailer{ wantEmalier := testEmailer{
cantEmail: tt.cantEmail, cantEmail: tt.cantEmail,
lastEmail: strings.ToLower(tt.email), lastEmail: strings.ToLower(tt.email),
lastClientID: "XXX", lastClientID: testClientID,
lastWasInvite: true, lastWasInvite: true,
lastRedirectURL: *urlParsed, lastRedirectURL: *urlParsed,
} }
......
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/coreos/dex/client" "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
type clientTokenMiddleware struct { type clientTokenMiddleware struct {
issuerURL string issuerURL string
ciRepo client.ClientRepo ciManager *manager.ClientManager
keysFunc func() ([]key.PublicKey, error) keysFunc func() ([]key.PublicKey, error)
next http.Handler next http.Handler
} }
...@@ -30,8 +30,8 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request ...@@ -30,8 +30,8 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return return
} }
if c.ciRepo == nil { if c.ciManager == nil {
log.Errorf("Misconfigured clientTokenMiddleware, ClientRepo is not set") log.Errorf("Misconfigured clientTokenMiddleware, ClientManager is not set")
respondError() respondError()
return return
} }
...@@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request ...@@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return return
} }
md, err := c.ciRepo.Metadata(clientID) md, err := c.ciManager.Metadata(clientID)
if md == nil || err != nil { if md == nil || err != nil {
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err) log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
respondError() respondError()
......
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -10,6 +9,7 @@ import ( ...@@ -10,6 +9,7 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
...@@ -25,22 +25,23 @@ func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -25,22 +25,23 @@ func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func TestClientToken(t *testing.T) { func TestClientToken(t *testing.T) {
now := time.Now() now := time.Now()
tomorrow := now.Add(24 * time.Hour) tomorrow := now.Add(24 * time.Hour)
validClientID := "valid-client" clientMetadata := oidc.ClientMetadata{
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: validClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
{Scheme: "https", Host: "authn.example.com", Path: "/callback"}, {Scheme: "https", Host: "authn.example.com", Path: "/callback"},
}, },
},
} }
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{})
cli := client.Client{
Metadata: clientMetadata,
}
creds, err := clientManager.New(cli)
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
validClientID := creds.ID
privKey, err := key.GeneratePrivateKey() privKey, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
...@@ -65,63 +66,63 @@ func TestClientToken(t *testing.T) { ...@@ -65,63 +66,63 @@ func TestClientToken(t *testing.T) {
tests := []struct { tests := []struct {
keys []key.PublicKey keys []key.PublicKey
repo client.ClientRepo manager *clientmanager.ClientManager
header string header string
wantCode int wantCode int
}{ }{
// valid token // valid token
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusOK, wantCode: http.StatusOK,
}, },
// invalid token // invalid token
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", invalidJWT), header: fmt.Sprintf("BEARER %s", invalidJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// empty header // empty header
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: "", header: "",
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// unparsable token // unparsable token
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: "BEARER xxx", header: "BEARER xxx",
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// no verification keys // no verification keys
{ {
keys: []key.PublicKey{}, keys: []key.PublicKey{},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// nil repo // nil repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: nil, manager: nil,
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// empty repo // empty repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: db.NewClientRepo(db.NewMemDB()), manager: clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}),
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// client not in repo // client not in repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)), header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
...@@ -131,7 +132,7 @@ func TestClientToken(t *testing.T) { ...@@ -131,7 +132,7 @@ func TestClientToken(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
mw := &clientTokenMiddleware{ mw := &clientTokenMiddleware{
issuerURL: validIss, issuerURL: validIss,
ciRepo: tt.repo, ciManager: tt.manager,
keysFunc: func() ([]key.PublicKey, error) { keysFunc: func() ([]key.PublicKey, error) {
return tt.keys, nil return tt.keys, nil
}, },
......
...@@ -39,18 +39,10 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR ...@@ -39,18 +39,10 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
} }
// metadata is guarenteed to have at least one redirect_uri by earlier validation. // metadata is guarenteed to have at least one redirect_uri by earlier validation.
id, err := oidc.GenClientID(clientMetadata.RedirectURIs[0].Host) cli := client.Client{
if err != nil {
log.Errorf("Faild to create client ID: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
}
creds, err := s.ClientRepo.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: id,
},
Metadata: clientMetadata, Metadata: clientMetadata,
}) }
creds, err := s.ClientManager.New(cli)
if err != nil { if err != nil {
log.Errorf("Failed to create new client identity: %v", err) log.Errorf("Failed to create new client identity: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata") return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
......
...@@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) { ...@@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
return fmt.Errorf("no client id in registration response") return fmt.Errorf("no client id in registration response")
} }
metadata, err := fixtures.clientRepo.Metadata(r.ClientID) metadata, err := fixtures.clientManager.Metadata(r.ClientID)
if err != nil { if err != nil {
return fmt.Errorf("failed to lookup client id after creation") return fmt.Errorf("failed to lookup client id after creation")
} }
......
...@@ -6,21 +6,20 @@ import ( ...@@ -6,21 +6,20 @@ import (
"net/http" "net/http"
"path" "path"
"github.com/coreos/dex/client" "github.com/coreos/dex/client/manager"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc"
) )
type clientResource struct { type clientResource struct {
repo client.ClientRepo manager *manager.ClientManager
} }
func registerClientResource(prefix string, repo client.ClientRepo) (string, http.Handler) { func registerClientResource(prefix string, manager *manager.ClientManager) (string, http.Handler) {
mux := http.NewServeMux() mux := http.NewServeMux()
c := &clientResource{ c := &clientResource{
repo: repo, manager: manager,
} }
relPath := "clients" relPath := "clients"
absPath := path.Join(prefix, relPath) absPath := path.Join(prefix, relPath)
...@@ -41,7 +40,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -41,7 +40,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) { func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
cs, err := c.repo.All() cs, err := c.manager.All()
if err != nil { if err != nil {
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients")) writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
return return
...@@ -88,16 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) { ...@@ -88,16 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error())) writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
return return
} }
creds, err := c.manager.New(ci)
clientID, err := oidc.GenClientID(ci.Metadata.RedirectURIs[0].Host)
if err != nil {
log.Errorf("Failed generating ID for new client: %v", err)
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "unable to generate client ID"))
return
}
ci.Credentials.ID = clientID
creds, err := c.repo.New(ci)
if err != nil { if err != nil {
log.Errorf("Failed creating client: %v", err) log.Errorf("Failed creating client: %v", err)
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"testing" "testing"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
...@@ -28,8 +29,10 @@ func makeBody(s string) io.ReadCloser { ...@@ -28,8 +29,10 @@ func makeBody(s string) io.ReadCloser {
func TestCreateInvalidRequest(t *testing.T) { func TestCreateInvalidRequest(t *testing.T) {
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"} u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
h := http.Header{"Content-Type": []string{"application/json"}} h := http.Header{"Content-Type": []string{"application/json"}}
repo := db.NewClientRepo(db.NewMemDB()) dbm := db.NewMemDB()
res := &clientResource{repo: repo} repo := db.NewClientRepo(dbm)
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
res := &clientResource{manager: manager}
tests := []struct { tests := []struct {
req *http.Request req *http.Request
wantCode int wantCode int
...@@ -119,8 +122,10 @@ func TestCreateInvalidRequest(t *testing.T) { ...@@ -119,8 +122,10 @@ func TestCreateInvalidRequest(t *testing.T) {
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
repo := db.NewClientRepo(db.NewMemDB()) dbm := db.NewMemDB()
res := &clientResource{repo: repo} repo := db.NewClientRepo(dbm)
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
res := &clientResource{manager: manager}
tests := [][]string{ tests := [][]string{
[]string{"http://example.com"}, []string{"http://example.com"},
[]string{"https://example.com"}, []string{"https://example.com"},
...@@ -190,7 +195,7 @@ func TestList(t *testing.T) { ...@@ -190,7 +195,7 @@ func TestList(t *testing.T) {
{ {
cs: []client.Client{ cs: []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")}, Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
...@@ -200,7 +205,7 @@ func TestList(t *testing.T) { ...@@ -200,7 +205,7 @@ func TestList(t *testing.T) {
}, },
want: []*schema.Client{ want: []*schema.Client{
&schema.Client{ &schema.Client{
Id: "foo", Id: "example.com",
RedirectURIs: []string{"http://example.com"}, RedirectURIs: []string{"http://example.com"},
}, },
}, },
...@@ -209,7 +214,7 @@ func TestList(t *testing.T) { ...@@ -209,7 +214,7 @@ func TestList(t *testing.T) {
{ {
cs: []client.Client{ cs: []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")}, Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
...@@ -217,21 +222,21 @@ func TestList(t *testing.T) { ...@@ -217,21 +222,21 @@ func TestList(t *testing.T) {
}, },
}, },
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")}, Credentials: oidc.ClientCredentials{ID: "example2.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"}, url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
}, },
}, },
}, },
}, },
want: []*schema.Client{ want: []*schema.Client{
&schema.Client{ &schema.Client{
Id: "biz", Id: "example2.com",
RedirectURIs: []string{"https://example.com/one/two/three"}, RedirectURIs: []string{"https://example2.com/one/two/three"},
}, },
&schema.Client{ &schema.Client{
Id: "foo", Id: "example.com",
RedirectURIs: []string{"http://example.com"}, RedirectURIs: []string{"http://example.com"},
}, },
}, },
...@@ -239,12 +244,20 @@ func TestList(t *testing.T) { ...@@ -239,12 +244,20 @@ func TestList(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), tt.cs) dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), tt.cs, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err) t.Fatalf("Failed to create client identity manager: %v", err)
continue continue
} }
res := &clientResource{repo: repo} res := &clientResource{manager: clientManager}
r, err := http.NewRequest("GET", "http://example.com/clients", nil) r, err := http.NewRequest("GET", "http://example.com/clients", nil)
if err != nil { if err != nil {
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
...@@ -114,9 +115,11 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { ...@@ -114,9 +115,11 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
if err != nil { if err != nil {
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
} }
ciRepo, err := db.NewClientRepoFromClients(dbMap, clients)
if err != nil { clientRepo := db.NewClientRepo(dbMap)
return fmt.Errorf("failed to create client identity repo: %v", err)
for _, c := range clients {
clientRepo.New(nil, c)
} }
f, err := os.Open(cfg.ConnectorsFile) f, err := os.Open(cfg.ConnectorsFile)
...@@ -155,7 +158,12 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { ...@@ -155,7 +158,12 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
txnFactory := db.TransactionFactory(dbMap) txnFactory := db.TransactionFactory(dbMap)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientRepo = ciRepo clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{})
if err != nil {
return fmt.Errorf("Failed to create client identity manager: %v", err)
}
srv.ClientRepo = clientRepo
srv.ClientManager = clientManager
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
srv.UserRepo = userRepo srv.UserRepo = userRepo
...@@ -253,11 +261,13 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { ...@@ -253,11 +261,13 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
clientManager := clientmanager.NewClientManager(ciRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc) refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := sessionmanager.NewSessionManager(sRepo, skRepo) sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientRepo = ciRepo srv.ClientRepo = ciRepo
srv.ClientManager = clientManager
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
srv.UserRepo = userRepo srv.UserRepo = userRepo
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
...@@ -28,7 +29,7 @@ func handleVerifyEmailResendFunc( ...@@ -28,7 +29,7 @@ func handleVerifyEmailResendFunc(
srvKeysFunc func() ([]key.PublicKey, error), srvKeysFunc func() ([]key.PublicKey, error),
emailer *useremail.UserEmailer, emailer *useremail.UserEmailer,
userRepo user.UserRepo, userRepo user.UserRepo,
clientRepo client.ClientRepo) http.HandlerFunc { clientManager *clientmanager.ClientManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
var params struct { var params struct {
...@@ -57,7 +58,7 @@ func handleVerifyEmailResendFunc( ...@@ -57,7 +58,7 @@ func handleVerifyEmailResendFunc(
return return
} }
cm, err := clientRepo.Metadata(clientID) cm, err := clientManager.Metadata(clientID)
if err == client.ErrorNotFound { if err == client.ErrorNotFound {
log.Errorf("No such client: %v", err) log.Errorf("No such client: %v", err)
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
......
...@@ -130,7 +130,7 @@ func TestHandleVerifyEmailResend(t *testing.T) { ...@@ -130,7 +130,7 @@ func TestHandleVerifyEmailResend(t *testing.T) {
keysFunc, keysFunc,
f.srv.UserEmailer, f.srv.UserEmailer,
f.userRepo, f.userRepo,
f.clientRepo) f.clientManager)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := "http://example.com" u := "http://example.com"
......
This diff is collapsed.
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
sessionmanager "github.com/coreos/dex/session/manager" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
...@@ -29,7 +30,7 @@ type SendResetPasswordEmailHandler struct { ...@@ -29,7 +30,7 @@ type SendResetPasswordEmailHandler struct {
tpl *template.Template tpl *template.Template
emailer *useremail.UserEmailer emailer *useremail.UserEmailer
sm *sessionmanager.SessionManager sm *sessionmanager.SessionManager
cr client.ClientRepo cm *clientmanager.ClientManager
} }
func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
...@@ -128,7 +129,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red ...@@ -128,7 +129,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
return url.URL{}, false return url.URL{}, false
} }
cm, err := h.cr.Metadata(clientID) cm, err := h.cm.Metadata(clientID)
if err != nil || cm == nil { if err != nil || cm == nil {
log.Errorf("Error getting ClientMetadata: %v", err) log.Errorf("Error getting ClientMetadata: %v", err)
return url.URL{}, false return url.URL{}, false
......
...@@ -253,7 +253,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) { ...@@ -253,7 +253,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
t.Fatalf("case %d: could not make test fixtures: %v", i, err) t.Fatalf("case %d: could not make test fixtures: %v", i, err)
} }
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"}) _, err = f.srv.NewSession("local", testClientID, "", f.redirectURL, "", true, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("case %d: could not create new session: %v", i, err) t.Fatalf("case %d: could not create new session: %v", i, err)
} }
...@@ -267,7 +267,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) { ...@@ -267,7 +267,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
tpl: f.srv.SendResetPasswordEmailTemplate, tpl: f.srv.SendResetPasswordEmailTemplate,
emailer: f.srv.UserEmailer, emailer: f.srv.UserEmailer,
sm: f.sessionManager, sm: f.sessionManager,
cr: f.clientRepo, cm: f.clientManager,
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
......
...@@ -295,7 +295,7 @@ func TestHandleRegister(t *testing.T) { ...@@ -295,7 +295,7 @@ func TestHandleRegister(t *testing.T) {
}) })
} }
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"}) key, err := f.srv.NewSession(tt.connID, testClientID, "", f.redirectURL, "", true, []string{"openid"})
t.Logf("case %d: key for NewSession: %v", i, key) t.Logf("case %d: key for NewSession: %v", i, key)
if tt.attachRemote { if tt.attachRemote {
......
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
...@@ -72,6 +73,7 @@ type Server struct { ...@@ -72,6 +73,7 @@ type Server struct {
Connectors []connector.Connector Connectors []connector.Connector
UserRepo user.UserRepo UserRepo user.UserRepo
UserManager *usermanager.UserManager UserManager *usermanager.UserManager
ClientManager *clientmanager.ClientManager
PasswordInfoRepo user.PasswordInfoRepo PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer UserEmailer *useremail.UserEmailer
...@@ -213,13 +215,13 @@ func (s *Server) HTTPHandler() http.Handler { ...@@ -213,13 +215,13 @@ func (s *Server) HTTPHandler() http.Handler {
s.KeyManager.PublicKeys, s.KeyManager.PublicKeys,
s.UserEmailer, s.UserEmailer,
s.UserRepo, s.UserRepo,
s.ClientRepo))) s.ClientManager)))
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{ mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
tpl: s.SendResetPasswordEmailTemplate, tpl: s.SendResetPasswordEmailTemplate,
emailer: s.UserEmailer, emailer: s.UserEmailer,
sm: s.SessionManager, sm: s.SessionManager,
cr: s.ClientRepo, cm: s.ClientManager,
}) })
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{ mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
...@@ -256,11 +258,11 @@ func (s *Server) HTTPHandler() http.Handler { ...@@ -256,11 +258,11 @@ func (s *Server) HTTPHandler() http.Handler {
apiBasePath := path.Join(httpPathAPI, APIVersion) apiBasePath := path.Join(httpPathAPI, APIVersion)
registerDiscoveryResource(apiBasePath, mux) registerDiscoveryResource(apiBasePath, mux)
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientRepo) clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientManager)
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.dbMap, s.UserManager, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientRepo).HTTPHandler() handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler()
mux.Handle(apiBasePath+"/", handler) mux.Handle(apiBasePath+"/", handler)
...@@ -271,14 +273,14 @@ func (s *Server) HTTPHandler() http.Handler { ...@@ -271,14 +273,14 @@ func (s *Server) HTTPHandler() http.Handler {
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler { func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
return &clientTokenMiddleware{ return &clientTokenMiddleware{
issuerURL: s.IssuerURL.String(), issuerURL: s.IssuerURL.String(),
ciRepo: s.ClientRepo, ciManager: s.ClientManager,
keysFunc: s.KeyManager.PublicKeys, keysFunc: s.KeyManager.PublicKeys,
next: handler, next: handler,
} }
} }
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) { func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
return s.ClientRepo.Metadata(clientID) return s.ClientManager.Metadata(clientID)
} }
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
...@@ -365,9 +367,9 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { ...@@ -365,9 +367,9 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
} }
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)
} }
if !ok { if !ok {
...@@ -397,7 +399,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro ...@@ -397,7 +399,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
} }
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) { func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
ok, err := s.ClientRepo.Authenticate(creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
...@@ -466,7 +468,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo ...@@ -466,7 +468,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
} }
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) { func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
...@@ -26,7 +27,7 @@ const ( ...@@ -26,7 +27,7 @@ const (
var ( var (
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
testClientID = "XXX" testClientID = "client.example.com"
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
...@@ -79,6 +80,7 @@ type testFixtures struct { ...@@ -79,6 +80,7 @@ type testFixtures struct {
emailer *email.TemplatizedEmailer emailer *email.TemplatizedEmailer
redirectURL url.URL redirectURL url.URL
clientRepo client.ClientRepo clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
} }
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc { func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
...@@ -123,7 +125,7 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -123,7 +125,7 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc() sessionManager.GenerateCode = sequentialGenerateCodeFunc()
...@@ -136,11 +138,11 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -136,11 +138,11 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
clientRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ clients := []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
...@@ -148,11 +150,19 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -148,11 +150,19 @@ func makeTestFixtures() (*testFixtures, error) {
}, },
}, },
}, },
}) }
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
return nil, err return nil, err
} }
km := key.NewPrivateKeyManager() km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute))) err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
if err != nil { if err != nil {
...@@ -173,7 +183,8 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -173,7 +183,8 @@ func makeTestFixtures() (*testFixtures, error) {
Templates: tpl, Templates: tpl,
UserRepo: userRepo, UserRepo: userRepo,
PasswordInfoRepo: pwRepo, PasswordInfoRepo: pwRepo,
UserManager: manager, UserManager: userManager,
ClientManager: clientManager,
KeyManager: km, KeyManager: km,
} }
...@@ -207,5 +218,6 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -207,5 +218,6 @@ func makeTestFixtures() (*testFixtures, error) {
sessionManager: sessionManager, sessionManager: sessionManager,
emailer: emailer, emailer: emailer,
clientRepo: clientRepo, clientRepo: clientRepo,
clientManager: clientManager,
}, nil }, nil
} }
...@@ -11,12 +11,12 @@ import ( ...@@ -11,12 +11,12 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/coreos/dex/client" clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/api" "github.com/coreos/dex/user/api"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
const ( const (
...@@ -38,16 +38,16 @@ var ( ...@@ -38,16 +38,16 @@ var (
type UserMgmtServer struct { type UserMgmtServer struct {
api *api.UsersAPI api *api.UsersAPI
jwtvFactory JWTVerifierFactory jwtvFactory JWTVerifierFactory
um *manager.UserManager um *usermanager.UserManager
cir client.ClientRepo cm *clientmanager.ClientManager
} }
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientRepo) *UserMgmtServer { func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager) *UserMgmtServer {
return &UserMgmtServer{ return &UserMgmtServer{
api: userMgmtAPI, api: userMgmtAPI,
jwtvFactory: jwtvFactory, jwtvFactory: jwtvFactory,
um: um, um: um,
cir: cir, cm: cm,
} }
} }
...@@ -295,7 +295,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred ...@@ -295,7 +295,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred
return api.Creds{}, err return api.Creds{}, err
} }
isAdmin, err := s.cir.IsDexAdmin(clientID) isAdmin, err := s.cm.IsDexAdmin(clientID)
if err != nil { if err != nil {
log.Errorf("userMgmtServer: GetCreds err: %q", err) log.Errorf("userMgmtServer: GetCreds err: %q", err)
return api.Creds{}, err return api.Creds{}, err
......
...@@ -18,7 +18,7 @@ if [ ! -d $GOPATH/pkg ]; then ...@@ -18,7 +18,7 @@ if [ ! -d $GOPATH/pkg ]; then
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests." echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
fi fi
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin" TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin client client/manager"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override # user has not provided PKG override
......
...@@ -9,15 +9,13 @@ import ( ...@@ -9,15 +9,13 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db" clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
var ( var (
...@@ -88,9 +86,9 @@ func (e Error) Error() string { ...@@ -88,9 +86,9 @@ func (e Error) Error() string {
// calling User. It is assumed that the clientID has already validated as an // calling User. It is assumed that the clientID has already validated as an
// admin app before calling. // admin app before calling.
type UsersAPI struct { type UsersAPI struct {
manager *manager.UserManager userManager *usermanager.UserManager
localConnectorID string localConnectorID string
clientRepo client.ClientRepo clientManager *clientmanager.ClientManager
refreshRepo refresh.RefreshTokenRepo refreshRepo refresh.RefreshTokenRepo
emailer Emailer emailer Emailer
} }
...@@ -105,11 +103,11 @@ type Creds struct { ...@@ -105,11 +103,11 @@ type Creds struct {
} }
// TODO(ericchiang): Don't pass a dbMap. See #385. // TODO(ericchiang): Don't pass a dbMap. See #385.
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI { func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{ return &UsersAPI{
manager: userManager, userManager: userManager,
refreshRepo: db.NewRefreshTokenRepo(dbMap), refreshRepo: refreshRepo,
clientRepo: db.NewClientRepo(dbMap), clientManager: clientManager,
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
emailer: emailer, emailer: emailer,
} }
...@@ -122,7 +120,7 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) { ...@@ -122,7 +120,7 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
return schema.User{}, ErrorUnauthorized return schema.User{}, ErrorUnauthorized
} }
usr, err := u.manager.Get(id) usr, err := u.userManager.Get(id)
if err != nil { if err != nil {
return schema.User{}, mapError(err) return schema.User{}, mapError(err)
...@@ -137,7 +135,7 @@ func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema ...@@ -137,7 +135,7 @@ func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema
return schema.UserDisableResponse{}, ErrorUnauthorized return schema.UserDisableResponse{}, ErrorUnauthorized
} }
if err := u.manager.Disable(userID, disable); err != nil { if err := u.userManager.Disable(userID, disable); err != nil {
return schema.UserDisableResponse{}, mapError(err) return schema.UserDisableResponse{}, mapError(err)
} }
...@@ -157,7 +155,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s ...@@ -157,7 +155,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
metadata, err := u.clientRepo.Metadata(creds.ClientID) metadata, err := u.clientManager.Metadata(creds.ClientID)
if err != nil { if err != nil {
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
...@@ -167,12 +165,12 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s ...@@ -167,12 +165,12 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, ErrorInvalidRedirectURL return schema.UserCreateResponse{}, ErrorInvalidRedirectURL
} }
id, err := u.manager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID) id, err := u.userManager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
if err != nil { if err != nil {
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
userUser, err := u.manager.Get(id) userUser, err := u.userManager.Get(id)
if err != nil { if err != nil {
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
...@@ -202,7 +200,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur ...@@ -202,7 +200,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
} }
metadata, err := u.clientRepo.Metadata(creds.ClientID) metadata, err := u.clientManager.Metadata(creds.ClientID)
if err != nil { if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err) return schema.ResendEmailInvitationResponse{}, mapError(err)
} }
...@@ -213,7 +211,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur ...@@ -213,7 +211,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
} }
// Retrieve user to check if it's already created // Retrieve user to check if it's already created
userUser, err := u.manager.Get(userID) userUser, err := u.userManager.Get(userID)
if err != nil { if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err) return schema.ResendEmailInvitationResponse{}, mapError(err)
} }
...@@ -251,7 +249,7 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string) ...@@ -251,7 +249,7 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
return nil, "", ErrorMaxResultsTooHigh return nil, "", ErrorMaxResultsTooHigh
} }
users, tok, err := u.manager.List(user.UserFilter{}, maxResults, nextPageToken) users, tok, err := u.userManager.List(user.UserFilter{}, maxResults, nextPageToken)
if err != nil { if err != nil {
return nil, "", mapError(err) return nil, "", mapError(err)
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
...@@ -51,13 +52,14 @@ func (t *testEmailer) sendEmail(email string, redirectURL url.URL, clientID stri ...@@ -51,13 +52,14 @@ func (t *testEmailer) sendEmail(email string, redirectURL url.URL, clientID stri
var ( var (
clock = clockwork.NewFakeClock() clock = clockwork.NewFakeClock()
goodClientID = "client.example.com"
goodCreds = Creds{ goodCreds = Creds{
User: user.User{ User: user.User{
ID: "ID-1", ID: "ID-1",
Admin: true, Admin: true,
}, },
ClientID: "XXX", ClientID: goodClientID,
} }
badCreds = Creds{ badCreds = Creds{
...@@ -72,7 +74,7 @@ var ( ...@@ -72,7 +74,7 @@ var (
Admin: true, Admin: true,
Disabled: true, Disabled: true,
}, },
ClientID: "XXX", ClientID: goodClientID,
} }
resetPasswordURL = url.URL{ resetPasswordURL = url.URL{
...@@ -82,7 +84,7 @@ var ( ...@@ -82,7 +84,7 @@ var (
validRedirURL = url.URL{ validRedirURL = url.URL{
Scheme: "http", Scheme: "http",
Host: "client.example.com", Host: goodClientID,
Path: "/callback", Path: "/callback",
} }
) )
...@@ -158,8 +160,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -158,8 +160,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
mgr.Clock = clock mgr.Clock = clock
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: goodClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
...@@ -167,8 +169,17 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -167,8 +169,17 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}, },
}, },
} }
if _, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}); err != nil {
panic("Failed to create client repo: " + err.Error()) clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client manager: " + err.Error())
} }
// Used in TestRevokeRefreshToken test. // Used in TestRevokeRefreshToken test.
...@@ -176,8 +187,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -176,8 +187,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
clientID string clientID string
userID string userID string
}{ }{
{"XXX", "ID-1"}, {goodClientID, "ID-1"},
{"XXX", "ID-2"}, {goodClientID, "ID-2"},
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens { for _, token := range refreshTokens {
...@@ -187,7 +198,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -187,7 +198,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
} }
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(dbMap, mgr, emailer, "local") api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local")
return api, emailer return api, emailer
} }
...@@ -582,8 +593,8 @@ func TestRevokeRefreshToken(t *testing.T) { ...@@ -582,8 +593,8 @@ func TestRevokeRefreshToken(t *testing.T) {
before []string // clientIDs expected before the change. before []string // clientIDs expected before the change.
after []string // clientIDs expected after the change. after []string // clientIDs expected after the change.
}{ }{
{"ID-1", "XXX", []string{"XXX"}, []string{}}, {"ID-1", goodClientID, []string{goodClientID}, []string{}},
{"ID-2", "XXX", []string{"XXX"}, []string{}}, {"ID-2", goodClientID, []string{goodClientID}, []string{}},
} }
api, _ := makeTestFixtures() api, _ := makeTestFixtures()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment