Commit 3df1db18 authored by rithu john's avatar rithu john

storage: Surface "already exists" errors.

parent 7e9dc836
...@@ -53,8 +53,10 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap ...@@ -53,8 +53,10 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
LogoURL: req.Client.LogoUrl, LogoURL: req.Client.LogoUrl,
} }
if err := d.s.CreateClient(c); err != nil { if err := d.s.CreateClient(c); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreateClientResp{AlreadyExists: true}, nil
}
d.logger.Errorf("api: failed to create client: %v", err) d.logger.Errorf("api: failed to create client: %v", err)
// TODO(ericchiang): Surface "already exists" errors.
return nil, fmt.Errorf("create client: %v", err) return nil, fmt.Errorf("create client: %v", err)
} }
...@@ -109,6 +111,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) ...@@ -109,6 +111,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
UserID: req.Password.UserId, UserID: req.Password.UserId,
} }
if err := d.s.CreatePassword(p); err != nil { if err := d.s.CreatePassword(p); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreatePasswordResp{AlreadyExists: true}, nil
}
d.logger.Errorf("api: failed to create password: %v", err) d.logger.Errorf("api: failed to create password: %v", err)
return nil, fmt.Errorf("create password: %v", err) return nil, fmt.Errorf("create password: %v", err)
} }
......
...@@ -37,10 +37,18 @@ func TestPassword(t *testing.T) { ...@@ -37,10 +37,18 @@ func TestPassword(t *testing.T) {
Password: &p, Password: &p,
} }
if _, err := serv.CreatePassword(ctx, &createReq); err != nil { if resp, err := serv.CreatePassword(ctx, &createReq); err != nil || resp.AlreadyExists {
if resp.AlreadyExists {
t.Fatalf("Unable to create password since %s already exists", createReq.Password.Email)
}
t.Fatalf("Unable to create password: %v", err) t.Fatalf("Unable to create password: %v", err)
} }
// Attempt to create a password that already exists.
if resp, _ := serv.CreatePassword(ctx, &createReq); !resp.AlreadyExists {
t.Fatalf("Created password %s twice", createReq.Password.Email)
}
updateReq := api.UpdatePasswordReq{ updateReq := api.UpdatePasswordReq{
Email: "test@example.com", Email: "test@example.com",
NewUsername: "test1", NewUsername: "test1",
......
...@@ -70,6 +70,15 @@ func mustBeErrNotFound(t *testing.T, kind string, err error) { ...@@ -70,6 +70,15 @@ func mustBeErrNotFound(t *testing.T, kind string, err error) {
} }
} }
func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
switch {
case err == nil:
t.Errorf("attempting to create an existing %s should return an error", kind)
case err != storage.ErrAlreadyExists:
t.Errorf("creating an existing %s expected storage.ErrAlreadyExists, got %v", kind, err)
}
}
func testAuthRequestCRUD(t *testing.T, s storage.Storage) { func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
a := storage.AuthRequest{ a := storage.AuthRequest{
ID: storage.NewID(), ID: storage.NewID(),
...@@ -98,6 +107,11 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { ...@@ -98,6 +107,11 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
if err := s.CreateAuthRequest(a); err != nil { if err := s.CreateAuthRequest(a); err != nil {
t.Fatalf("failed creating auth request: %v", err) t.Fatalf("failed creating auth request: %v", err)
} }
// Attempt to create same AuthRequest twice.
err := s.CreateAuthRequest(a)
mustBeErrAlreadyExists(t, "auth request", err)
if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
old.Claims = identity old.Claims = identity
old.ConnectorID = "connID" old.ConnectorID = "connID"
...@@ -138,6 +152,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { ...@@ -138,6 +152,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth code: %v", err) t.Fatalf("failed creating auth code: %v", err)
} }
// Attempt to create same AuthCode twice.
err := s.CreateAuthCode(a)
mustBeErrAlreadyExists(t, "auth code", err)
got, err := s.GetAuthCode(a.ID) got, err := s.GetAuthCode(a.ID)
if err != nil { if err != nil {
t.Fatalf("failed to get auth req: %v", err) t.Fatalf("failed to get auth req: %v", err)
...@@ -174,6 +192,10 @@ func testClientCRUD(t *testing.T, s storage.Storage) { ...@@ -174,6 +192,10 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("create client: %v", err) t.Fatalf("create client: %v", err)
} }
// Attempt to create same Client twice.
err = s.CreateClient(c)
mustBeErrAlreadyExists(t, "client", err)
getAndCompare := func(id string, want storage.Client) { getAndCompare := func(id string, want storage.Client) {
gc, err := s.GetClient(id) gc, err := s.GetClient(id)
if err != nil { if err != nil {
...@@ -230,6 +252,10 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { ...@@ -230,6 +252,10 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("create refresh token: %v", err) t.Fatalf("create refresh token: %v", err)
} }
// Attempt to create same Refresh Token twice.
err := s.CreateRefresh(refresh)
mustBeErrAlreadyExists(t, "refresh token", err)
getAndCompare := func(id string, want storage.RefreshToken) { getAndCompare := func(id string, want storage.RefreshToken) {
gr, err := s.GetRefresh(id) gr, err := s.GetRefresh(id)
if err != nil { if err != nil {
...@@ -261,9 +287,8 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { ...@@ -261,9 +287,8 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed to delete refresh request: %v", err) t.Fatalf("failed to delete refresh request: %v", err)
} }
if _, err := s.GetRefresh(id); err != storage.ErrNotFound { _, err = s.GetRefresh(id)
t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err) mustBeErrNotFound(t, "refresh token", err)
}
} }
type byEmail []storage.Password type byEmail []storage.Password
...@@ -289,6 +314,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { ...@@ -289,6 +314,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("create password token: %v", err) t.Fatalf("create password token: %v", err)
} }
// Attempt to create same Password twice.
err = s.CreatePassword(password)
mustBeErrAlreadyExists(t, "password", err)
getAndCompare := func(id string, want storage.Password) { getAndCompare := func(id string, want storage.Password) {
gr, err := s.GetPassword(id) gr, err := s.GetPassword(id)
if err != nil { if err != nil {
...@@ -335,9 +364,8 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { ...@@ -335,9 +364,8 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed to delete password: %v", err) t.Fatalf("failed to delete password: %v", err)
} }
if _, err := s.GetPassword(password.Email); err != storage.ErrNotFound { _, err = s.GetPassword(password.Email)
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err) mustBeErrNotFound(t, "password", err)
}
} }
...@@ -354,6 +382,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { ...@@ -354,6 +382,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("create offline session: %v", err) t.Fatalf("create offline session: %v", err)
} }
// Attempt to create same OfflineSession twice.
err := s.CreateOfflineSessions(session)
mustBeErrAlreadyExists(t, "offline session", err)
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) { getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
gr, err := s.GetOfflineSessions(userID, connID) gr, err := s.GetOfflineSessions(userID, connID)
if err != nil { if err != nil {
...@@ -389,9 +421,8 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { ...@@ -389,9 +421,8 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed to delete offline session: %v", err) t.Fatalf("failed to delete offline session: %v", err)
} }
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound { _, err = s.GetOfflineSessions(session.UserID, session.ConnID)
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err) mustBeErrNotFound(t, "offline session", err)
}
} }
......
...@@ -8,6 +8,13 @@ import ( ...@@ -8,6 +8,13 @@ import (
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
)
const (
// postgres error codes
pgErrUniqueViolation = "23505" // unique_violation
) )
// SQLite3 options for creating an SQL db. // SQLite3 options for creating an SQL db.
...@@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) { ...@@ -35,7 +42,16 @@ func (s *SQLite3) open(logger logrus.FieldLogger) (*conn, error) {
// doesn't support this, so limit the number of connections to 1. // doesn't support this, so limit the number of connections to 1.
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
} }
c := &conn{db, flavorSQLite3, logger}
errCheck := func(err error) bool {
sqlErr, ok := err.(sqlite3.Error)
if !ok {
return false
}
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
}
c := &conn{db, flavorSQLite3, logger, errCheck}
if _, err := c.migrate(); err != nil { if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err) return nil, fmt.Errorf("failed to perform migrations: %v", err)
} }
...@@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) { ...@@ -114,7 +130,16 @@ func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
c := &conn{db, flavorPostgres, logger}
errCheck := func(err error) bool {
sqlErr, ok := err.(*pq.Error)
if !ok {
return false
}
return sqlErr.Code == pgErrUniqueViolation
}
c := &conn{db, flavorPostgres, logger, errCheck}
if _, err := c.migrate(); err != nil { if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err) return nil, fmt.Errorf("failed to perform migrations: %v", err)
} }
......
...@@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { ...@@ -125,6 +125,9 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
a.Expiry, a.Expiry,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth request: %v", err) return fmt.Errorf("insert auth request: %v", err)
} }
return nil return nil
...@@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error { ...@@ -212,7 +215,14 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.Expiry, a.ConnectorID, a.ConnectorData, a.Expiry,
) )
return err
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth code: %v", err)
}
return nil
} }
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
...@@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { ...@@ -256,6 +266,9 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
r.Token, r.CreatedAt, r.LastUsed, r.Token, r.CreatedAt, r.LastUsed,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh_token: %v", err) return fmt.Errorf("insert refresh_token: %v", err)
} }
return nil return nil
...@@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error { ...@@ -477,6 +490,9 @@ func (c *conn) CreateClient(cli storage.Client) error {
cli.Public, cli.Name, cli.LogoURL, cli.Public, cli.Name, cli.LogoURL,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert client: %v", err) return fmt.Errorf("insert client: %v", err)
} }
return nil return nil
...@@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error { ...@@ -544,6 +560,9 @@ func (c *conn) CreatePassword(p storage.Password) error {
p.Email, p.Hash, p.Username, p.UserID, p.Email, p.Hash, p.Username, p.UserID,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert password: %v", err) return fmt.Errorf("insert password: %v", err)
} }
return nil return nil
...@@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { ...@@ -636,6 +655,9 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
s.UserID, s.ConnID, encoder(s.Refresh), s.UserID, s.ConnID, encoder(s.Refresh),
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert offline session: %v", err) return fmt.Errorf("insert offline session: %v", err)
} }
return nil return nil
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
sqlite3 "github.com/mattn/go-sqlite3"
) )
func TestMigrate(t *testing.T) { func TestMigrate(t *testing.T) {
...@@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) { ...@@ -21,7 +22,15 @@ func TestMigrate(t *testing.T) {
Level: logrus.DebugLevel, Level: logrus.DebugLevel,
} }
c := &conn{db, flavorSQLite3, logger} errCheck := func(err error) bool {
sqlErr, ok := err.(sqlite3.Error)
if !ok {
return false
}
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
}
c := &conn{db, flavorSQLite3, logger, errCheck}
for _, want := range []int{len(migrations), 0} { for _, want := range []int{len(migrations), 0} {
got, err := c.migrate() got, err := c.migrate()
if err != nil { if err != nil {
......
...@@ -131,9 +131,10 @@ func (c *conn) translateArgs(args []interface{}) []interface{} { ...@@ -131,9 +131,10 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
// conn is the main database connection. // conn is the main database connection.
type conn struct { type conn struct {
db *sql.DB db *sql.DB
flavor flavor flavor flavor
logger logrus.FieldLogger logger logrus.FieldLogger
alreadyExistsCheck func(err error) bool
} }
func (c *conn) Close() error { func (c *conn) Close() error {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment