Commit bfd63b75 authored by Eric Chiang's avatar Eric Chiang

db: add sqlite3 support

parent 8f16279f
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
...@@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo { ...@@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
} }
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) { func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo) tx, err := dbm.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
for _, c := range clients { for _, c := range clients {
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret) dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata) cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = repo.dbMap.Insert(cm) err = tx.Insert(cm)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return repo, nil if err := tx.Commit(); err != nil {
return nil, err
}
return NewClientIdentityRepo(dbm), nil
} }
type clientIdentityRepo struct { type clientIdentityRepo struct {
...@@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { ...@@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
m, err := r.dbMap.Get(clientIdentityModel{}, clientID) m, err := tx.Get(clientIdentityModel{}, clientID)
if m == nil || err != nil { if m == nil || err != nil {
rollback(tx) rollback(tx)
return err return err
...@@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { ...@@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
cim, ok := m.(*clientIdentityModel) cim, ok := m.(*clientIdentityModel)
if !ok { if !ok {
rollback(tx)
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m)) log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
return errors.New("unrecognized model") return errors.New("unrecognized model")
} }
cim.DexAdmin = isAdmin cim.DexAdmin = isAdmin
_, err = r.dbMap.Update(cim) _, err = tx.Update(cim)
if err != nil { if err != nil {
rollback(tx)
return err
}
err = tx.Commit()
if err != nil {
rollback(tx)
return err return err
} }
return nil return tx.Commit()
} }
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) { func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
...@@ -223,8 +223,15 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli ...@@ -223,8 +223,15 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
} }
if err := r.dbMap.Insert(cim); err != nil { if err := r.dbMap.Insert(cim); err != nil {
if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation { switch sqlErr := err.(type) {
err = errors.New("client ID already exists") case *pq.Error:
if sqlErr.Code == pgErrorCodeUniqueViolation {
err = errors.New("client ID already exists")
}
case *sqlite3.Error:
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
err = errors.New("client ID already exists")
}
} }
return nil, err return nil, err
...@@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli ...@@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
} }
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) { func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
qt := pq.QuoteIdentifier(clientIdentityTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt) q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&clientIdentityModel{}, q) objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
if err != nil { if err != nil {
......
...@@ -4,13 +4,16 @@ import ( ...@@ -4,13 +4,16 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strings" "net/url"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
_ "github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
// Import database drivers
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
) )
type table struct { type table struct {
...@@ -43,23 +46,36 @@ type Config struct { ...@@ -43,23 +46,36 @@ type Config struct {
} }
func NewConnection(cfg Config) (*gorp.DbMap, error) { func NewConnection(cfg Config) (*gorp.DbMap, error) {
if !strings.HasPrefix(cfg.DSN, "postgres://") { u, err := url.Parse(cfg.DSN)
return nil, errors.New("unrecognized database driver")
}
db, err := sql.Open("postgres", cfg.DSN)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("parse DSN: %v", err)
} }
var (
db.SetMaxIdleConns(cfg.MaxIdleConnections) db *sql.DB
db.SetMaxOpenConns(cfg.MaxOpenConnections) dialect gorp.Dialect
)
dbm := gorp.DbMap{ switch u.Scheme {
Db: db, case "postgres":
Dialect: gorp.PostgresDialect{}, db, err = sql.Open("postgres", cfg.DSN)
if err != nil {
return nil, err
}
db.SetMaxIdleConns(cfg.MaxIdleConnections)
db.SetMaxOpenConns(cfg.MaxOpenConnections)
dialect = gorp.PostgresDialect{}
case "sqlite3":
db, err = sql.Open("sqlite3", u.Host)
if err != nil {
return nil, err
}
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
dialect = gorp.SqliteDialect{}
default:
return nil, errors.New("unrecognized database driver")
} }
dbm := gorp.DbMap{Db: db, Dialect: dialect}
for _, t := range tables { for _, t := range tables {
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...) tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
for _, unique := range t.unique { for _, unique := range t.unique {
...@@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) { ...@@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
cm.SetUnique(true) cm.SetUnique(true)
} }
} }
return &dbm, nil return &dbm, nil
} }
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
...@@ -69,7 +68,7 @@ type ConnectorConfigRepo struct { ...@@ -69,7 +68,7 @@ type ConnectorConfigRepo struct {
} }
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt) q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&connectorConfigModel{}, q) objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
if err != nil { if err != nil {
...@@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { ...@@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
} }
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) { func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt) q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil { if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, connector.ErrorNotFound return nil, connector.ErrorNotFound
} }
...@@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { ...@@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback()
qt := pq.QuoteIdentifier(connectorConfigTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
q := fmt.Sprintf("DELETE FROM %s", qt) q := fmt.Sprintf("DELETE FROM %s", qt)
if _, err = r.dbMap.Exec(q); err != nil { if _, err = tx.Exec(q); err != nil {
return err return err
} }
if err = r.dbMap.Insert(insert...); err != nil { if err = tx.Insert(insert...); err != nil {
return fmt.Errorf("DB insert failed %#v: %v", insert, err) return fmt.Errorf("DB insert failed %#v: %v", insert, err)
} }
return tx.Commit() return tx.Commit()
} }
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq"
pcrypto "github.com/coreos/dex/pkg/crypto" pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
...@@ -114,7 +113,7 @@ type PrivateKeySetRepo struct { ...@@ -114,7 +113,7 @@ type PrivateKeySetRepo struct {
} }
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
qt := pq.QuoteIdentifier(keyTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt)) _, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
if err != nil { if err != nil {
return err return err
...@@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { ...@@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
} }
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) { func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
qt := pq.QuoteIdentifier(keyTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt)) objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
if err != nil { if err != nil {
return nil, err return nil, err
......
package db package db
import ( import (
"errors"
"fmt" "fmt"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq" "github.com/rubenv/sql-migrate"
migrate "github.com/rubenv/sql-migrate"
"github.com/coreos/dex/db/migrations" "github.com/coreos/dex/db/migrations"
) )
const ( const (
migrationDialect = "postgres" migrationTable = "dex_migrations"
migrationTable = "dex_migrations" migrationDir = "db/migrations"
migrationDir = "db/migrations"
) )
func init() { func init() {
...@@ -21,32 +20,57 @@ func init() { ...@@ -21,32 +20,57 @@ func init() {
} }
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) { func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
source := getSource() source, dialect, err := migrationSource(dbMap)
if err != nil {
return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up) return 0, err
}
return migrate.Exec(dbMap.Db, dialect, source, migrate.Up)
} }
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) { func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
source := getSource() source, dialect, err := migrationSource(dbMap)
if err != nil {
return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max) return 0, err
}
return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max)
} }
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) { func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0) source, dialect, err := migrationSource(dbMap)
if err != nil {
return nil, err
}
migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0)
return migrations, err return migrations, err
} }
func DropMigrationsTable(dbMap *gorp.DbMap) error { func DropMigrationsTable(dbMap *gorp.DbMap) error {
qt := pq.QuoteIdentifier(migrationTable) qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable))
_, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt)) _, err := dbMap.Exec(qt)
return err return err
} }
func getSource() migrate.MigrationSource { func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) {
return &migrate.AssetMigrationSource{ switch dbMap.Dialect.(type) {
Dir: migrationDir, case gorp.PostgresDialect:
Asset: migrations.Asset, src = &migrate.AssetMigrationSource{
AssetDir: migrations.AssetDir, Dir: migrationDir,
Asset: migrations.Asset,
AssetDir: migrations.AssetDir,
}
return src, "postgres", nil
case gorp.SqliteDialect:
src = &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "dex.sql",
Up: []string{sqlite3Migration},
},
},
}
return src, "sqlite3", nil
default:
return nil, "", errors.New("unsupported migration driver")
} }
} }
package db
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
const sqlite3Migration = `
CREATE TABLE authd_user (
id text NOT NULL UNIQUE,
email text,
email_verified integer,
display_name text,
admin integer,
created_at bigint,
disabled integer
);
CREATE TABLE client_identity (
id text NOT NULL UNIQUE,
secret blob,
metadata text,
dex_admin integer
);
CREATE TABLE connector_config (
id text NOT NULL UNIQUE,
type text,
config text
);
CREATE TABLE key (
value blob
);
CREATE TABLE password_info (
user_id text NOT NULL UNIQUE,
password text,
password_expires bigint
);
CREATE TABLE refresh_token (
id integer PRIMARY KEY,
payload_hash blob,
user_id text,
client_id text
);
CREATE TABLE remote_identity_mapping (
connector_id text NOT NULL,
user_id text,
remote_id text NOT NULL
);
CREATE TABLE session (
id text NOT NULL UNIQUE,
state text,
created_at bigint,
expires_at bigint,
client_id text,
client_state text,
redirect_url text,
identity text,
connector_id text,
user_id text,
register integer,
nonce text,
scope text
);
CREATE TABLE session_key (
key text NOT NULL UNIQUE,
session_id text,
expires_at bigint,
stale integer
);
`
...@@ -5,10 +5,11 @@ import ( ...@@ -5,10 +5,11 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/go-gorp/gorp"
) )
const ( const (
...@@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err ...@@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
return nil return nil
} }
func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) { func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
m, err := ex.Get(passwordInfoModel{}, id) m, err := ex.Get(passwordInfoModel{}, id)
if err != nil { if err != nil {
...@@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf ...@@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf
} }
func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error { func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
pm, err := newPasswordInfoModel(&pw) pm, err := newPasswordInfoModel(&pw)
if err != nil { if err != nil {
return err return err
...@@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err ...@@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err
} }
func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error { func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
pm, err := newPasswordInfoModel(&pw) pm, err := newPasswordInfoModel(&pw)
if err != nil { if err != nil {
return err return err
......
...@@ -8,10 +8,11 @@ import ( ...@@ -8,10 +8,11 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
) )
const ( const (
...@@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error { ...@@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return nil return nil
} }
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
return tx
}
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) { func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
result, err := ex.Get(refreshTokenModel{}, tokenID) result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
...@@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error { ...@@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
} }
func (r *SessionRepo) purge() error { func (r *SessionRepo) purge() error {
qt := pq.QuoteIdentifier(sessionTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt) q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead)) res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
if err != nil { if err != nil {
return err return err
} }
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
...@@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) { ...@@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
return "", errors.New("invalid session key") return "", errors.New("invalid session key")
} }
qt := pq.QuoteIdentifier(sessionKeyTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt) q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
res, err := r.dbMap.Exec(q, true, key, false) res, err := executor(r.dbMap, nil).Exec(q, true, key, false)
if err != nil { if err != nil {
return "", err return "", err
} }
...@@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) { ...@@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
} }
func (r *SessionKeyRepo) purge() error { func (r *SessionKeyRepo) purge() error {
qt := pq.QuoteIdentifier(sessionKeyTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt) q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix()) res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix())
if err != nil { if err != nil {
return err return err
} }
......
package db
import (
"github.com/go-gorp/gorp"
"github.com/coreos/dex/db/translate"
"github.com/coreos/dex/repo"
)
func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor {
var exec gorp.SqlExecutor
if tx == nil {
exec = dbMap
} else {
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
// Check if the underlying value of the pointer is nil.
// This is not caught by the initial comparison (tx == nil).
if gorpTx == nil {
exec = dbMap
} else {
exec = gorpTx
}
}
if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok {
exec = translate.NewExecutor(exec, translate.PostgresToSQLite)
}
return exec
}
/*
Package translate implements translation of driver specific SQL queries.
*/
package translate
import (
"database/sql"
"regexp"
"github.com/go-gorp/gorp"
)
var (
bindRegexp = regexp.MustCompile(`\$\d+`)
trueRegexp = regexp.MustCompile(`\btrue\b`)
)
// PostgresToSQLite implements translation of the pq driver to sqlite3.
func PostgresToSQLite(query string) string {
query = bindRegexp.ReplaceAllString(query, "?")
query = trueRegexp.ReplaceAllString(query, "1")
return query
}
func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
return &executor{exec, translate}
}
type executor struct {
gorp.SqlExecutor
Translate func(string) string
}
func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) {
return e.SqlExecutor.Exec(e.Translate(query), args...)
}
func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return e.SqlExecutor.Select(i, e.Translate(query), args...)
}
func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) {
return e.SqlExecutor.SelectInt(e.Translate(query), args...)
}
func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
return e.SqlExecutor.SelectNullInt(e.Translate(query), args...)
}
func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) {
return e.SqlExecutor.SelectFloat(e.Translate(query), args...)
}
func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...)
}
func (e *executor) SelectStr(query string, args ...interface{}) (string, error) {
return e.SqlExecutor.SelectStr(e.Translate(query), args...)
}
func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
return e.SqlExecutor.SelectNullStr(e.Translate(query), args...)
}
func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error {
return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...)
}
package translate
import "testing"
func TestPostgresToSQLite(t *testing.T) {
tests := []struct {
query string
want string
}{
{"SELECT * FROM foo", "SELECT * FROM foo"},
{"SELECT * FROM %s", "SELECT * FROM %s"},
{"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"},
{"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"},
{"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"},
{"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"},
{"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"},
{"$1", "?"},
{"$", "$"},
}
for _, tt := range tests {
got := PostgresToSQLite(tt.query)
if got != tt.want {
t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got)
}
}
}
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
...@@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err ...@@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
return user.ErrorInvalidID return user.ErrorInvalidID
} }
qt := pq.QuoteIdentifier(userTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
ex := r.executor(tx) ex := executor(r.dbMap, tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable) result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
if err != nil { if err != nil {
return err return err
} }
...@@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid ...@@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
return err return err
} }
ex := r.executor(tx) ex := executor(r.dbMap, tx)
deleted, err := ex.Delete(rim) deleted, err := ex.Delete(rim)
if err != nil { if err != nil {
...@@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid ...@@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
} }
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) { func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
if userID == "" { if userID == "" {
return nil, user.ErrorInvalidID return nil, user.ErrorInvalidID
} }
qt := pq.QuoteIdentifier(remoteIdentityMappingTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName)
rims, err := ex.Select(&remoteIdentityMappingModel{}, rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
fmt.Sprintf("select * from %s where user_id = $1", qt), userID)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
...@@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us ...@@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
} }
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) { func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
qt := pq.QuoteIdentifier(userTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
ex := r.executor(tx) ex := executor(r.dbMap, tx)
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt)) i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
return int(i), err return int(i), err
} }
...@@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults ...@@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ex := r.executor(tx) ex := executor(r.dbMap, tx)
qt := pq.QuoteIdentifier(userTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
// Ask for one more than needed so we know if there's more results, and // Ask for one more than needed so we know if there's more results, and
// hence, whether a nextPageToken is necessary. // hence, whether a nextPageToken is necessary.
ums, err := ex.Select(&userModel{}, ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
...@@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults ...@@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
} }
func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
um, err := newUserModel(&usr) um, err := newUserModel(&usr)
if err != nil { if err != nil {
return err return err
...@@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { ...@@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
} }
func (r *userRepo) update(tx repo.Transaction, usr user.User) error { func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
um, err := newUserModel(&usr) um, err := newUserModel(&usr)
if err != nil { if err != nil {
return err return err
...@@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error { ...@@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
} }
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) { func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
m, err := ex.Get(userModel{}, userID) m, err := ex.Get(userModel{}, userID)
if err != nil { if err != nil {
...@@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) { ...@@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
} }
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) { func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID) m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
if err != nil { if err != nil {
...@@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot ...@@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
} }
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) { func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
qt := pq.QuoteIdentifier(userTableName) qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
ex := r.executor(tx) ex := executor(r.dbMap, tx)
var um userModel var um userModel
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email) err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
...@@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err ...@@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err
} }
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error { func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
ex := r.executor(tx) ex := executor(r.dbMap, tx)
rim, err := newRemoteIdentityMappingModel(userID, ri) rim, err := newRemoteIdentityMappingModel(userID, ri)
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment