Commit bfd63b75 authored by Eric Chiang's avatar Eric Chiang

db: add sqlite3 support

parent 8f16279f
......@@ -11,6 +11,7 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client"
......@@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
}
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo)
tx, err := dbm.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
for _, c := range clients {
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
if err != nil {
return nil, err
}
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
if err != nil {
return nil, err
}
err = repo.dbMap.Insert(cm)
err = tx.Insert(cm)
if err != nil {
return nil, err
}
}
return repo, nil
if err := tx.Commit(); err != nil {
return nil, err
}
return NewClientIdentityRepo(dbm), nil
}
type clientIdentityRepo struct {
......@@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
if err != nil {
return err
}
defer tx.Rollback()
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
m, err := tx.Get(clientIdentityModel{}, clientID)
if m == nil || err != nil {
rollback(tx)
return err
......@@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
cim, ok := m.(*clientIdentityModel)
if !ok {
rollback(tx)
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
return errors.New("unrecognized model")
}
cim.DexAdmin = isAdmin
_, err = r.dbMap.Update(cim)
_, err = tx.Update(cim)
if err != nil {
rollback(tx)
return err
}
err = tx.Commit()
if err != nil {
rollback(tx)
return err
}
return nil
return tx.Commit()
}
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
......@@ -223,8 +223,15 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
}
if err := r.dbMap.Insert(cim); err != nil {
if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation {
err = errors.New("client ID already exists")
switch sqlErr := err.(type) {
case *pq.Error:
if sqlErr.Code == pgErrorCodeUniqueViolation {
err = errors.New("client ID already exists")
}
case *sqlite3.Error:
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
err = errors.New("client ID already exists")
}
}
return nil, err
......@@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
}
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
qt := pq.QuoteIdentifier(clientIdentityTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
if err != nil {
......
......@@ -4,13 +4,16 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"net/url"
"github.com/go-gorp/gorp"
_ "github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
// Import database drivers
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type table struct {
......@@ -43,23 +46,36 @@ type Config struct {
}
func NewConnection(cfg Config) (*gorp.DbMap, error) {
if !strings.HasPrefix(cfg.DSN, "postgres://") {
return nil, errors.New("unrecognized database driver")
}
db, err := sql.Open("postgres", cfg.DSN)
u, err := url.Parse(cfg.DSN)
if err != nil {
return nil, err
return nil, fmt.Errorf("parse DSN: %v", err)
}
db.SetMaxIdleConns(cfg.MaxIdleConnections)
db.SetMaxOpenConns(cfg.MaxOpenConnections)
dbm := gorp.DbMap{
Db: db,
Dialect: gorp.PostgresDialect{},
var (
db *sql.DB
dialect gorp.Dialect
)
switch u.Scheme {
case "postgres":
db, err = sql.Open("postgres", cfg.DSN)
if err != nil {
return nil, err
}
db.SetMaxIdleConns(cfg.MaxIdleConnections)
db.SetMaxOpenConns(cfg.MaxOpenConnections)
dialect = gorp.PostgresDialect{}
case "sqlite3":
db, err = sql.Open("sqlite3", u.Host)
if err != nil {
return nil, err
}
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
dialect = gorp.SqliteDialect{}
default:
return nil, errors.New("unrecognized database driver")
}
dbm := gorp.DbMap{Db: db, Dialect: dialect}
for _, t := range tables {
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
for _, unique := range t.unique {
......@@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
cm.SetUnique(true)
}
}
return &dbm, nil
}
......
......@@ -7,7 +7,6 @@ import (
"fmt"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
......@@ -69,7 +68,7 @@ type ConnectorConfigRepo struct {
}
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
if err != nil {
......@@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
}
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil {
if err == sql.ErrNoRows {
return nil, connector.ErrorNotFound
}
......@@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
if err != nil {
return err
}
defer tx.Rollback()
qt := pq.QuoteIdentifier(connectorConfigTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
q := fmt.Sprintf("DELETE FROM %s", qt)
if _, err = r.dbMap.Exec(q); err != nil {
if _, err = tx.Exec(q); err != nil {
return err
}
if err = r.dbMap.Insert(insert...); err != nil {
if err = tx.Insert(insert...); err != nil {
return fmt.Errorf("DB insert failed %#v: %v", insert, err)
}
return tx.Commit()
}
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
......@@ -8,7 +8,6 @@ import (
"time"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/key"
......@@ -114,7 +113,7 @@ type PrivateKeySetRepo struct {
}
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
qt := pq.QuoteIdentifier(keyTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
if err != nil {
return err
......@@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
}
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
qt := pq.QuoteIdentifier(keyTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
if err != nil {
return nil, err
......
package db
import (
"errors"
"fmt"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
migrate "github.com/rubenv/sql-migrate"
"github.com/rubenv/sql-migrate"
"github.com/coreos/dex/db/migrations"
)
const (
migrationDialect = "postgres"
migrationTable = "dex_migrations"
migrationDir = "db/migrations"
migrationTable = "dex_migrations"
migrationDir = "db/migrations"
)
func init() {
......@@ -21,32 +20,57 @@ func init() {
}
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
source := getSource()
return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up)
source, dialect, err := migrationSource(dbMap)
if err != nil {
return 0, err
}
return migrate.Exec(dbMap.Db, dialect, source, migrate.Up)
}
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
source := getSource()
return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max)
source, dialect, err := migrationSource(dbMap)
if err != nil {
return 0, err
}
return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max)
}
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0)
source, dialect, err := migrationSource(dbMap)
if err != nil {
return nil, err
}
migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0)
return migrations, err
}
func DropMigrationsTable(dbMap *gorp.DbMap) error {
qt := pq.QuoteIdentifier(migrationTable)
_, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt))
qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable))
_, err := dbMap.Exec(qt)
return err
}
func getSource() migrate.MigrationSource {
return &migrate.AssetMigrationSource{
Dir: migrationDir,
Asset: migrations.Asset,
AssetDir: migrations.AssetDir,
func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) {
switch dbMap.Dialect.(type) {
case gorp.PostgresDialect:
src = &migrate.AssetMigrationSource{
Dir: migrationDir,
Asset: migrations.Asset,
AssetDir: migrations.AssetDir,
}
return src, "postgres", nil
case gorp.SqliteDialect:
src = &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "dex.sql",
Up: []string{sqlite3Migration},
},
},
}
return src, "sqlite3", nil
default:
return nil, "", errors.New("unsupported migration driver")
}
}
package db
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
const sqlite3Migration = `
CREATE TABLE authd_user (
id text NOT NULL UNIQUE,
email text,
email_verified integer,
display_name text,
admin integer,
created_at bigint,
disabled integer
);
CREATE TABLE client_identity (
id text NOT NULL UNIQUE,
secret blob,
metadata text,
dex_admin integer
);
CREATE TABLE connector_config (
id text NOT NULL UNIQUE,
type text,
config text
);
CREATE TABLE key (
value blob
);
CREATE TABLE password_info (
user_id text NOT NULL UNIQUE,
password text,
password_expires bigint
);
CREATE TABLE refresh_token (
id integer PRIMARY KEY,
payload_hash blob,
user_id text,
client_id text
);
CREATE TABLE remote_identity_mapping (
connector_id text NOT NULL,
user_id text,
remote_id text NOT NULL
);
CREATE TABLE session (
id text NOT NULL UNIQUE,
state text,
created_at bigint,
expires_at bigint,
client_id text,
client_state text,
redirect_url text,
identity text,
connector_id text,
user_id text,
register integer,
nonce text,
scope text
);
CREATE TABLE session_key (
key text NOT NULL UNIQUE,
session_id text,
expires_at bigint,
stale integer
);
`
......@@ -5,10 +5,11 @@ import (
"reflect"
"time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
"github.com/go-gorp/gorp"
)
const (
......@@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
return nil
}
func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
m, err := ex.Get(passwordInfoModel{}, id)
if err != nil {
......@@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf
}
func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
pm, err := newPasswordInfoModel(&pw)
if err != nil {
return err
......@@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err
}
func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
pm, err := newPasswordInfoModel(&pw)
if err != nil {
return err
......
......@@ -8,10 +8,11 @@ import (
"strconv"
"strings"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
)
const (
......@@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return nil
}
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
return tx
}
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
......
......@@ -11,7 +11,6 @@ import (
"github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
......@@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
}
func (r *SessionRepo) purge() error {
qt := pq.QuoteIdentifier(sessionTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
if err != nil {
return err
}
......
......@@ -8,7 +8,6 @@ import (
"github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
......@@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
return "", errors.New("invalid session key")
}
qt := pq.QuoteIdentifier(sessionKeyTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
res, err := r.dbMap.Exec(q, true, key, false)
res, err := executor(r.dbMap, nil).Exec(q, true, key, false)
if err != nil {
return "", err
}
......@@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
}
func (r *SessionKeyRepo) purge() error {
qt := pq.QuoteIdentifier(sessionKeyTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix())
res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix())
if err != nil {
return err
}
......
package db
import (
"github.com/go-gorp/gorp"
"github.com/coreos/dex/db/translate"
"github.com/coreos/dex/repo"
)
func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor {
var exec gorp.SqlExecutor
if tx == nil {
exec = dbMap
} else {
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
// Check if the underlying value of the pointer is nil.
// This is not caught by the initial comparison (tx == nil).
if gorpTx == nil {
exec = dbMap
} else {
exec = gorpTx
}
}
if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok {
exec = translate.NewExecutor(exec, translate.PostgresToSQLite)
}
return exec
}
/*
Package translate implements translation of driver specific SQL queries.
*/
package translate
import (
"database/sql"
"regexp"
"github.com/go-gorp/gorp"
)
var (
bindRegexp = regexp.MustCompile(`\$\d+`)
trueRegexp = regexp.MustCompile(`\btrue\b`)
)
// PostgresToSQLite implements translation of the pq driver to sqlite3.
func PostgresToSQLite(query string) string {
query = bindRegexp.ReplaceAllString(query, "?")
query = trueRegexp.ReplaceAllString(query, "1")
return query
}
func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
return &executor{exec, translate}
}
type executor struct {
gorp.SqlExecutor
Translate func(string) string
}
func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) {
return e.SqlExecutor.Exec(e.Translate(query), args...)
}
func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return e.SqlExecutor.Select(i, e.Translate(query), args...)
}
func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) {
return e.SqlExecutor.SelectInt(e.Translate(query), args...)
}
func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
return e.SqlExecutor.SelectNullInt(e.Translate(query), args...)
}
func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) {
return e.SqlExecutor.SelectFloat(e.Translate(query), args...)
}
func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...)
}
func (e *executor) SelectStr(query string, args ...interface{}) (string, error) {
return e.SqlExecutor.SelectStr(e.Translate(query), args...)
}
func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
return e.SqlExecutor.SelectNullStr(e.Translate(query), args...)
}
func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error {
return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...)
}
package translate
import "testing"
func TestPostgresToSQLite(t *testing.T) {
tests := []struct {
query string
want string
}{
{"SELECT * FROM foo", "SELECT * FROM foo"},
{"SELECT * FROM %s", "SELECT * FROM %s"},
{"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"},
{"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"},
{"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"},
{"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"},
{"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"},
{"$1", "?"},
{"$", "$"},
}
for _, tt := range tests {
got := PostgresToSQLite(tt.query)
if got != tt.want {
t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got)
}
}
}
......@@ -8,7 +8,6 @@ import (
"time"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
......@@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
return user.ErrorInvalidID
}
qt := pq.QuoteIdentifier(userTableName)
ex := r.executor(tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
ex := executor(r.dbMap, tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
if err != nil {
return err
}
......@@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
return err
}
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
deleted, err := ex.Delete(rim)
if err != nil {
......@@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
}
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
if userID == "" {
return nil, user.ErrorInvalidID
}
qt := pq.QuoteIdentifier(remoteIdentityMappingTableName)
rims, err := ex.Select(&remoteIdentityMappingModel{},
fmt.Sprintf("select * from %s where user_id = $1", qt), userID)
qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName)
rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
if err != nil {
if err != sql.ErrNoRows {
......@@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
}
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
qt := pq.QuoteIdentifier(userTableName)
ex := r.executor(tx)
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt))
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
ex := executor(r.dbMap, tx)
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
return int(i), err
}
......@@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
if err != nil {
return nil, "", err
}
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
qt := pq.QuoteIdentifier(userTableName)
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
// Ask for one more than needed so we know if there's more results, and
// hence, whether a nextPageToken is necessary.
ums, err := ex.Select(&userModel{},
fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset)
ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
if err != nil {
return nil, "", err
}
......@@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
}
func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
um, err := newUserModel(&usr)
if err != nil {
return err
......@@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
}
func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
um, err := newUserModel(&usr)
if err != nil {
return err
......@@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
}
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
m, err := ex.Get(userModel{}, userID)
if err != nil {
......@@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
}
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
if err != nil {
......@@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
}
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
qt := pq.QuoteIdentifier(userTableName)
ex := r.executor(tx)
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
ex := executor(r.dbMap, tx)
var um userModel
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
......@@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err
}
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
ex := r.executor(tx)
ex := executor(r.dbMap, tx)
rim, err := newRemoteIdentityMappingModel(userID, ri)
if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment