Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dex
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
go
dex
Commits
bfd63b75
Commit
bfd63b75
authored
Feb 09, 2016
by
Eric Chiang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
db: add sqlite3 support
parent
8f16279f
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
345 additions
and
145 deletions
+345
-145
client.go
db/client.go
+25
-18
conn.go
db/conn.go
+31
-16
connector_config.go
db/connector_config.go
+7
-19
key.go
db/key.go
+2
-3
migrate.go
db/migrate.go
+43
-19
migrate_sqlite3.go
db/migrate_sqlite3.go
+73
-0
password.go
db/password.go
+5
-16
refresh.go
db/refresh.go
+4
-11
session.go
db/session.go
+2
-3
session_key.go
db/session_key.go
+4
-5
transaction.go
db/transaction.go
+33
-0
translate.go
db/translate/translate.go
+68
-0
translate_test.go
db/translate/translate_test.go
+28
-0
user.go
db/user.go
+20
-35
No files found.
db/client.go
View file @
bfd63b75
...
...
@@ -11,6 +11,7 @@ import (
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client"
...
...
@@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
}
func
NewClientIdentityRepoFromClients
(
dbm
*
gorp
.
DbMap
,
clients
[]
oidc
.
ClientIdentity
)
(
client
.
ClientIdentityRepo
,
error
)
{
repo
:=
NewClientIdentityRepo
(
dbm
)
.
(
*
clientIdentityRepo
)
tx
,
err
:=
dbm
.
Begin
()
if
err
!=
nil
{
return
nil
,
err
}
defer
tx
.
Rollback
()
for
_
,
c
:=
range
clients
{
dec
,
err
:=
base64
.
URLEncoding
.
DecodeString
(
c
.
Credentials
.
Secret
)
if
err
!=
nil
{
return
nil
,
err
}
cm
,
err
:=
newClientIdentityModel
(
c
.
Credentials
.
ID
,
dec
,
&
c
.
Metadata
)
if
err
!=
nil
{
return
nil
,
err
}
err
=
repo
.
dbMap
.
Insert
(
cm
)
err
=
tx
.
Insert
(
cm
)
if
err
!=
nil
{
return
nil
,
err
}
}
return
repo
,
nil
if
err
:=
tx
.
Commit
();
err
!=
nil
{
return
nil
,
err
}
return
NewClientIdentityRepo
(
dbm
),
nil
}
type
clientIdentityRepo
struct
{
...
...
@@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
if
err
!=
nil
{
return
err
}
defer
tx
.
Rollback
()
m
,
err
:=
r
.
dbMap
.
Get
(
clientIdentityModel
{},
clientID
)
m
,
err
:=
tx
.
Get
(
clientIdentityModel
{},
clientID
)
if
m
==
nil
||
err
!=
nil
{
rollback
(
tx
)
return
err
...
...
@@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
cim
,
ok
:=
m
.
(
*
clientIdentityModel
)
if
!
ok
{
rollback
(
tx
)
log
.
Errorf
(
"expected clientIdentityModel but found %v"
,
reflect
.
TypeOf
(
m
))
return
errors
.
New
(
"unrecognized model"
)
}
cim
.
DexAdmin
=
isAdmin
_
,
err
=
r
.
dbMap
.
Update
(
cim
)
_
,
err
=
tx
.
Update
(
cim
)
if
err
!=
nil
{
rollback
(
tx
)
return
err
}
err
=
tx
.
Commit
()
if
err
!=
nil
{
rollback
(
tx
)
return
err
}
return
nil
return
tx
.
Commit
()
}
func
(
r
*
clientIdentityRepo
)
Authenticate
(
creds
oidc
.
ClientCredentials
)
(
bool
,
error
)
{
...
...
@@ -223,8 +223,15 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
}
if
err
:=
r
.
dbMap
.
Insert
(
cim
);
err
!=
nil
{
if
perr
,
ok
:=
err
.
(
*
pq
.
Error
);
ok
&&
perr
.
Code
==
pgErrorCodeUniqueViolation
{
err
=
errors
.
New
(
"client ID already exists"
)
switch
sqlErr
:=
err
.
(
type
)
{
case
*
pq
.
Error
:
if
sqlErr
.
Code
==
pgErrorCodeUniqueViolation
{
err
=
errors
.
New
(
"client ID already exists"
)
}
case
*
sqlite3
.
Error
:
if
sqlErr
.
ExtendedCode
==
sqlite3
.
ErrConstraintUnique
{
err
=
errors
.
New
(
"client ID already exists"
)
}
}
return
nil
,
err
...
...
@@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
}
func
(
r
*
clientIdentityRepo
)
All
()
([]
oidc
.
ClientIdentity
,
error
)
{
qt
:=
pq
.
QuoteIdentifier
(
clientIdentityTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
clientIdentityTableName
)
q
:=
fmt
.
Sprintf
(
"SELECT * FROM %s"
,
qt
)
objs
,
err
:=
r
.
dbMap
.
Select
(
&
clientIdentityModel
{},
q
)
if
err
!=
nil
{
...
...
db/conn.go
View file @
bfd63b75
...
...
@@ -4,13 +4,16 @@ import (
"database/sql"
"errors"
"fmt"
"
strings
"
"
net/url
"
"github.com/go-gorp/gorp"
_
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
// Import database drivers
_
"github.com/lib/pq"
_
"github.com/mattn/go-sqlite3"
)
type
table
struct
{
...
...
@@ -43,23 +46,36 @@ type Config struct {
}
func
NewConnection
(
cfg
Config
)
(
*
gorp
.
DbMap
,
error
)
{
if
!
strings
.
HasPrefix
(
cfg
.
DSN
,
"postgres://"
)
{
return
nil
,
errors
.
New
(
"unrecognized database driver"
)
}
db
,
err
:=
sql
.
Open
(
"postgres"
,
cfg
.
DSN
)
u
,
err
:=
url
.
Parse
(
cfg
.
DSN
)
if
err
!=
nil
{
return
nil
,
err
return
nil
,
fmt
.
Errorf
(
"parse DSN: %v"
,
err
)
}
db
.
SetMaxIdleConns
(
cfg
.
MaxIdleConnections
)
db
.
SetMaxOpenConns
(
cfg
.
MaxOpenConnections
)
dbm
:=
gorp
.
DbMap
{
Db
:
db
,
Dialect
:
gorp
.
PostgresDialect
{},
var
(
db
*
sql
.
DB
dialect
gorp
.
Dialect
)
switch
u
.
Scheme
{
case
"postgres"
:
db
,
err
=
sql
.
Open
(
"postgres"
,
cfg
.
DSN
)
if
err
!=
nil
{
return
nil
,
err
}
db
.
SetMaxIdleConns
(
cfg
.
MaxIdleConnections
)
db
.
SetMaxOpenConns
(
cfg
.
MaxOpenConnections
)
dialect
=
gorp
.
PostgresDialect
{}
case
"sqlite3"
:
db
,
err
=
sql
.
Open
(
"sqlite3"
,
u
.
Host
)
if
err
!=
nil
{
return
nil
,
err
}
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
dialect
=
gorp
.
SqliteDialect
{}
default
:
return
nil
,
errors
.
New
(
"unrecognized database driver"
)
}
dbm
:=
gorp
.
DbMap
{
Db
:
db
,
Dialect
:
dialect
}
for
_
,
t
:=
range
tables
{
tm
:=
dbm
.
AddTableWithName
(
t
.
model
,
t
.
name
)
.
SetKeys
(
t
.
autoinc
,
t
.
pkey
...
)
for
_
,
unique
:=
range
t
.
unique
{
...
...
@@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
cm
.
SetUnique
(
true
)
}
}
return
&
dbm
,
nil
}
...
...
db/connector_config.go
View file @
bfd63b75
...
...
@@ -7,7 +7,6 @@ import (
"fmt"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
...
...
@@ -69,7 +68,7 @@ type ConnectorConfigRepo struct {
}
func
(
r
*
ConnectorConfigRepo
)
All
()
([]
connector
.
ConnectorConfig
,
error
)
{
qt
:=
pq
.
QuoteIdentifier
(
connectorConfigTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
connectorConfigTableName
)
q
:=
fmt
.
Sprintf
(
"SELECT * FROM %s"
,
qt
)
objs
,
err
:=
r
.
dbMap
.
Select
(
&
connectorConfigModel
{},
q
)
if
err
!=
nil
{
...
...
@@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
}
func
(
r
*
ConnectorConfigRepo
)
GetConnectorByID
(
tx
repo
.
Transaction
,
id
string
)
(
connector
.
ConnectorConfig
,
error
)
{
qt
:=
pq
.
QuoteIdentifier
(
connectorConfigTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
connectorConfigTableName
)
q
:=
fmt
.
Sprintf
(
"SELECT * FROM %s WHERE id = $1"
,
qt
)
var
c
connectorConfigModel
if
err
:=
r
.
executor
(
tx
)
.
SelectOne
(
&
c
,
q
,
id
);
err
!=
nil
{
if
err
:=
executor
(
r
.
dbMap
,
tx
)
.
SelectOne
(
&
c
,
q
,
id
);
err
!=
nil
{
if
err
==
sql
.
ErrNoRows
{
return
nil
,
connector
.
ErrorNotFound
}
...
...
@@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
if
err
!=
nil
{
return
err
}
defer
tx
.
Rollback
()
qt
:=
pq
.
QuoteIdentifier
(
connectorConfigTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
connectorConfigTableName
)
q
:=
fmt
.
Sprintf
(
"DELETE FROM %s"
,
qt
)
if
_
,
err
=
r
.
dbMap
.
Exec
(
q
);
err
!=
nil
{
if
_
,
err
=
tx
.
Exec
(
q
);
err
!=
nil
{
return
err
}
if
err
=
r
.
dbMap
.
Insert
(
insert
...
);
err
!=
nil
{
if
err
=
tx
.
Insert
(
insert
...
);
err
!=
nil
{
return
fmt
.
Errorf
(
"DB insert failed %#v: %v"
,
insert
,
err
)
}
return
tx
.
Commit
()
}
func
(
r
*
ConnectorConfigRepo
)
executor
(
tx
repo
.
Transaction
)
gorp
.
SqlExecutor
{
if
tx
==
nil
{
return
r
.
dbMap
}
gorpTx
,
ok
:=
tx
.
(
*
gorp
.
Transaction
)
if
!
ok
{
panic
(
"wrong kind of transaction passed to a DB repo"
)
}
return
gorpTx
}
db/key.go
View file @
bfd63b75
...
...
@@ -8,7 +8,6 @@ import (
"time"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
pcrypto
"github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/key"
...
...
@@ -114,7 +113,7 @@ type PrivateKeySetRepo struct {
}
func
(
r
*
PrivateKeySetRepo
)
Set
(
ks
key
.
KeySet
)
error
{
qt
:=
pq
.
QuoteIdentifier
(
keyTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
keyTableName
)
_
,
err
:=
r
.
dbMap
.
Exec
(
fmt
.
Sprintf
(
"DELETE FROM %s"
,
qt
))
if
err
!=
nil
{
return
err
...
...
@@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
}
func
(
r
*
PrivateKeySetRepo
)
Get
()
(
key
.
KeySet
,
error
)
{
qt
:=
pq
.
QuoteIdentifier
(
keyTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
keyTableName
)
objs
,
err
:=
r
.
dbMap
.
Select
(
&
privateKeySetBlob
{},
fmt
.
Sprintf
(
"SELECT * FROM %s"
,
qt
))
if
err
!=
nil
{
return
nil
,
err
...
...
db/migrate.go
View file @
bfd63b75
package
db
import
(
"errors"
"fmt"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
migrate
"github.com/rubenv/sql-migrate"
"github.com/rubenv/sql-migrate"
"github.com/coreos/dex/db/migrations"
)
const
(
migrationDialect
=
"postgres"
migrationTable
=
"dex_migrations"
migrationDir
=
"db/migrations"
migrationTable
=
"dex_migrations"
migrationDir
=
"db/migrations"
)
func
init
()
{
...
...
@@ -21,32 +20,57 @@ func init() {
}
func
MigrateToLatest
(
dbMap
*
gorp
.
DbMap
)
(
int
,
error
)
{
source
:=
getSource
()
return
migrate
.
Exec
(
dbMap
.
Db
,
migrationDialect
,
source
,
migrate
.
Up
)
source
,
dialect
,
err
:=
migrationSource
(
dbMap
)
if
err
!=
nil
{
return
0
,
err
}
return
migrate
.
Exec
(
dbMap
.
Db
,
dialect
,
source
,
migrate
.
Up
)
}
func
MigrateMaxMigrations
(
dbMap
*
gorp
.
DbMap
,
max
int
)
(
int
,
error
)
{
source
:=
getSource
()
return
migrate
.
ExecMax
(
dbMap
.
Db
,
migrationDialect
,
source
,
migrate
.
Up
,
max
)
source
,
dialect
,
err
:=
migrationSource
(
dbMap
)
if
err
!=
nil
{
return
0
,
err
}
return
migrate
.
ExecMax
(
dbMap
.
Db
,
dialect
,
source
,
migrate
.
Up
,
max
)
}
func
GetPlannedMigrations
(
dbMap
*
gorp
.
DbMap
)
([]
*
migrate
.
PlannedMigration
,
error
)
{
migrations
,
_
,
err
:=
migrate
.
PlanMigration
(
dbMap
.
Db
,
migrationDialect
,
getSource
(),
migrate
.
Up
,
0
)
source
,
dialect
,
err
:=
migrationSource
(
dbMap
)
if
err
!=
nil
{
return
nil
,
err
}
migrations
,
_
,
err
:=
migrate
.
PlanMigration
(
dbMap
.
Db
,
dialect
,
source
,
migrate
.
Up
,
0
)
return
migrations
,
err
}
func
DropMigrationsTable
(
dbMap
*
gorp
.
DbMap
)
error
{
qt
:=
pq
.
QuoteIdentifier
(
migrationTable
)
_
,
err
:=
dbMap
.
Exec
(
fmt
.
Sprintf
(
"drop table if exists %s ;"
,
qt
)
)
qt
:=
fmt
.
Sprintf
(
"DROP TABLE IF EXISTS %s;"
,
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
migrationTable
)
)
_
,
err
:=
dbMap
.
Exec
(
qt
)
return
err
}
func
getSource
()
migrate
.
MigrationSource
{
return
&
migrate
.
AssetMigrationSource
{
Dir
:
migrationDir
,
Asset
:
migrations
.
Asset
,
AssetDir
:
migrations
.
AssetDir
,
func
migrationSource
(
dbMap
*
gorp
.
DbMap
)
(
src
migrate
.
MigrationSource
,
dialect
string
,
err
error
)
{
switch
dbMap
.
Dialect
.
(
type
)
{
case
gorp
.
PostgresDialect
:
src
=
&
migrate
.
AssetMigrationSource
{
Dir
:
migrationDir
,
Asset
:
migrations
.
Asset
,
AssetDir
:
migrations
.
AssetDir
,
}
return
src
,
"postgres"
,
nil
case
gorp
.
SqliteDialect
:
src
=
&
migrate
.
MemoryMigrationSource
{
Migrations
:
[]
*
migrate
.
Migration
{
{
Id
:
"dex.sql"
,
Up
:
[]
string
{
sqlite3Migration
},
},
},
}
return
src
,
"sqlite3"
,
nil
default
:
return
nil
,
""
,
errors
.
New
(
"unsupported migration driver"
)
}
}
db/migrate_sqlite3.go
0 → 100644
View file @
bfd63b75
package
db
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
const
sqlite3Migration
=
`
CREATE TABLE authd_user (
id text NOT NULL UNIQUE,
email text,
email_verified integer,
display_name text,
admin integer,
created_at bigint,
disabled integer
);
CREATE TABLE client_identity (
id text NOT NULL UNIQUE,
secret blob,
metadata text,
dex_admin integer
);
CREATE TABLE connector_config (
id text NOT NULL UNIQUE,
type text,
config text
);
CREATE TABLE key (
value blob
);
CREATE TABLE password_info (
user_id text NOT NULL UNIQUE,
password text,
password_expires bigint
);
CREATE TABLE refresh_token (
id integer PRIMARY KEY,
payload_hash blob,
user_id text,
client_id text
);
CREATE TABLE remote_identity_mapping (
connector_id text NOT NULL,
user_id text,
remote_id text NOT NULL
);
CREATE TABLE session (
id text NOT NULL UNIQUE,
state text,
created_at bigint,
expires_at bigint,
client_id text,
client_state text,
redirect_url text,
identity text,
connector_id text,
user_id text,
register integer,
nonce text,
scope text
);
CREATE TABLE session_key (
key text NOT NULL UNIQUE,
session_id text,
expires_at bigint,
stale integer
);
`
db/password.go
View file @
bfd63b75
...
...
@@ -5,10 +5,11 @@ import (
"reflect"
"time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
"github.com/go-gorp/gorp"
)
const
(
...
...
@@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
return
nil
}
func
(
r
*
passwordInfoRepo
)
executor
(
tx
repo
.
Transaction
)
gorp
.
SqlExecutor
{
if
tx
==
nil
{
return
r
.
dbMap
}
gorpTx
,
ok
:=
tx
.
(
*
gorp
.
Transaction
)
if
!
ok
{
panic
(
"wrong kind of transaction passed to a DB repo"
)
}
return
gorpTx
}
func
(
r
*
passwordInfoRepo
)
get
(
tx
repo
.
Transaction
,
id
string
)
(
user
.
PasswordInfo
,
error
)
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
m
,
err
:=
ex
.
Get
(
passwordInfoModel
{},
id
)
if
err
!=
nil
{
...
...
@@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf
}
func
(
r
*
passwordInfoRepo
)
insert
(
tx
repo
.
Transaction
,
pw
user
.
PasswordInfo
)
error
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
pm
,
err
:=
newPasswordInfoModel
(
&
pw
)
if
err
!=
nil
{
return
err
...
...
@@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err
}
func
(
r
*
passwordInfoRepo
)
update
(
tx
repo
.
Transaction
,
pw
user
.
PasswordInfo
)
error
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
pm
,
err
:=
newPasswordInfoModel
(
&
pw
)
if
err
!=
nil
{
return
err
...
...
db/refresh.go
View file @
bfd63b75
...
...
@@ -8,10 +8,11 @@ import (
"strconv"
"strings"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
)
const
(
...
...
@@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return
nil
}
func
(
r
*
refreshTokenRepo
)
executor
(
tx
*
gorp
.
Transaction
)
gorp
.
SqlExecutor
{
if
tx
==
nil
{
return
r
.
dbMap
}
return
tx
}
func
(
r
*
refreshTokenRepo
)
get
(
tx
*
gorp
.
Transaction
,
tokenID
int64
)
(
*
refreshTokenModel
,
error
)
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
result
,
err
:=
ex
.
Get
(
refreshTokenModel
{},
tokenID
)
if
err
!=
nil
{
return
nil
,
err
...
...
db/session.go
View file @
bfd63b75
...
...
@@ -11,7 +11,6 @@ import (
"github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
...
...
@@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
}
func
(
r
*
SessionRepo
)
purge
()
error
{
qt
:=
pq
.
QuoteIdentifier
(
sessionTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
sessionTableName
)
q
:=
fmt
.
Sprintf
(
"DELETE FROM %s WHERE expires_at < $1 OR state = $2"
,
qt
)
res
,
err
:=
r
.
dbMap
.
Exec
(
q
,
r
.
clock
.
Now
()
.
Unix
(),
string
(
session
.
SessionStateDead
))
res
,
err
:=
executor
(
r
.
dbMap
,
nil
)
.
Exec
(
q
,
r
.
clock
.
Now
()
.
Unix
(),
string
(
session
.
SessionStateDead
))
if
err
!=
nil
{
return
err
}
...
...
db/session_key.go
View file @
bfd63b75
...
...
@@ -8,7 +8,6 @@ import (
"github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
...
...
@@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
return
""
,
errors
.
New
(
"invalid session key"
)
}
qt
:=
pq
.
QuoteIdentifier
(
sessionKeyTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
sessionKeyTableName
)
q
:=
fmt
.
Sprintf
(
"UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3"
,
qt
)
res
,
err
:=
r
.
dbMap
.
Exec
(
q
,
true
,
key
,
false
)
res
,
err
:=
executor
(
r
.
dbMap
,
nil
)
.
Exec
(
q
,
true
,
key
,
false
)
if
err
!=
nil
{
return
""
,
err
}
...
...
@@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
}
func
(
r
*
SessionKeyRepo
)
purge
()
error
{
qt
:=
pq
.
QuoteIdentifier
(
sessionKeyTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
sessionKeyTableName
)
q
:=
fmt
.
Sprintf
(
"DELETE FROM %s WHERE stale = $1 OR expires_at < $2"
,
qt
)
res
,
err
:=
r
.
dbMap
.
Exec
(
q
,
true
,
r
.
clock
.
Now
()
.
Unix
())
res
,
err
:=
executor
(
r
.
dbMap
,
nil
)
.
Exec
(
q
,
true
,
r
.
clock
.
Now
()
.
Unix
())
if
err
!=
nil
{
return
err
}
...
...
db/transaction.go
0 → 100644
View file @
bfd63b75
package
db
import
(
"github.com/go-gorp/gorp"
"github.com/coreos/dex/db/translate"
"github.com/coreos/dex/repo"
)
func
executor
(
dbMap
*
gorp
.
DbMap
,
tx
repo
.
Transaction
)
gorp
.
SqlExecutor
{
var
exec
gorp
.
SqlExecutor
if
tx
==
nil
{
exec
=
dbMap
}
else
{
gorpTx
,
ok
:=
tx
.
(
*
gorp
.
Transaction
)
if
!
ok
{
panic
(
"wrong kind of transaction passed to a DB repo"
)
}
// Check if the underlying value of the pointer is nil.
// This is not caught by the initial comparison (tx == nil).
if
gorpTx
==
nil
{
exec
=
dbMap
}
else
{
exec
=
gorpTx
}
}
if
_
,
ok
:=
dbMap
.
Dialect
.
(
gorp
.
SqliteDialect
);
ok
{
exec
=
translate
.
NewExecutor
(
exec
,
translate
.
PostgresToSQLite
)
}
return
exec
}
db/translate/translate.go
0 → 100644
View file @
bfd63b75
/*
Package translate implements translation of driver specific SQL queries.
*/
package
translate
import
(
"database/sql"
"regexp"
"github.com/go-gorp/gorp"
)
var
(
bindRegexp
=
regexp
.
MustCompile
(
`\$\d+`
)
trueRegexp
=
regexp
.
MustCompile
(
`\btrue\b`
)
)
// PostgresToSQLite implements translation of the pq driver to sqlite3.
func
PostgresToSQLite
(
query
string
)
string
{
query
=
bindRegexp
.
ReplaceAllString
(
query
,
"?"
)
query
=
trueRegexp
.
ReplaceAllString
(
query
,
"1"
)
return
query
}
func
NewExecutor
(
exec
gorp
.
SqlExecutor
,
translate
func
(
string
)
string
)
gorp
.
SqlExecutor
{
return
&
executor
{
exec
,
translate
}
}
type
executor
struct
{
gorp
.
SqlExecutor
Translate
func
(
string
)
string
}
func
(
e
*
executor
)
Exec
(
query
string
,
args
...
interface
{})
(
sql
.
Result
,
error
)
{
return
e
.
SqlExecutor
.
Exec
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
Select
(
i
interface
{},
query
string
,
args
...
interface
{})
([]
interface
{},
error
)
{
return
e
.
SqlExecutor
.
Select
(
i
,
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectInt
(
query
string
,
args
...
interface
{})
(
int64
,
error
)
{
return
e
.
SqlExecutor
.
SelectInt
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectNullInt
(
query
string
,
args
...
interface
{})
(
sql
.
NullInt64
,
error
)
{
return
e
.
SqlExecutor
.
SelectNullInt
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectFloat
(
query
string
,
args
...
interface
{})
(
float64
,
error
)
{
return
e
.
SqlExecutor
.
SelectFloat
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectNullFloat
(
query
string
,
args
...
interface
{})
(
sql
.
NullFloat64
,
error
)
{
return
e
.
SqlExecutor
.
SelectNullFloat
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectStr
(
query
string
,
args
...
interface
{})
(
string
,
error
)
{
return
e
.
SqlExecutor
.
SelectStr
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectNullStr
(
query
string
,
args
...
interface
{})
(
sql
.
NullString
,
error
)
{
return
e
.
SqlExecutor
.
SelectNullStr
(
e
.
Translate
(
query
),
args
...
)
}
func
(
e
*
executor
)
SelectOne
(
holder
interface
{},
query
string
,
args
...
interface
{})
error
{
return
e
.
SqlExecutor
.
SelectOne
(
holder
,
e
.
Translate
(
query
),
args
...
)
}
db/translate/translate_test.go
0 → 100644
View file @
bfd63b75
package
translate
import
"testing"
func
TestPostgresToSQLite
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
query
string
want
string
}{
{
"SELECT * FROM foo"
,
"SELECT * FROM foo"
},
{
"SELECT * FROM %s"
,
"SELECT * FROM %s"
},
{
"SELECT * FROM foo WHERE is_admin=true"
,
"SELECT * FROM foo WHERE is_admin=1"
},
{
"SELECT * FROM foo WHERE is_admin=true;"
,
"SELECT * FROM foo WHERE is_admin=1;"
},
{
"SELECT * FROM foo WHERE is_admin=$10"
,
"SELECT * FROM foo WHERE is_admin=?"
},
{
"SELECT * FROM foo WHERE is_admin=$10;"
,
"SELECT * FROM foo WHERE is_admin=?;"
},
{
"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;"
,
"SELECT * FROM foo WHERE name=? AND is_admin=?;"
},
{
"$1"
,
"?"
},
{
"$"
,
"$"
},
}
for
_
,
tt
:=
range
tests
{
got
:=
PostgresToSQLite
(
tt
.
query
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"PostgresToSQLite(%q): want=%q, got=%q"
,
tt
.
query
,
tt
.
want
,
got
)
}
}
}
db/user.go
View file @
bfd63b75
...
...
@@ -8,7 +8,6 @@ import (
"time"
"github.com/go-gorp/gorp"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
...
...
@@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
return
user
.
ErrorInvalidID
}
qt
:=
pq
.
QuoteIdentifier
(
userTableName
)
ex
:=
r
.
executor
(
tx
)
result
,
err
:=
ex
.
Exec
(
fmt
.
Sprintf
(
"UPDATE %s SET disabled = $
2 WHERE id = $1"
,
qt
),
userID
,
disable
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
userTableName
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
result
,
err
:=
ex
.
Exec
(
fmt
.
Sprintf
(
"UPDATE %s SET disabled = $
1 WHERE id = $2;"
,
qt
),
disable
,
userID
)
if
err
!=
nil
{
return
err
}
...
...
@@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
return
err
}
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
deleted
,
err
:=
ex
.
Delete
(
rim
)
if
err
!=
nil
{
...
...
@@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
}
func
(
r
*
userRepo
)
GetRemoteIdentities
(
tx
repo
.
Transaction
,
userID
string
)
([]
user
.
RemoteIdentity
,
error
)
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
if
userID
==
""
{
return
nil
,
user
.
ErrorInvalidID
}
qt
:=
pq
.
QuoteIdentifier
(
remoteIdentityMappingTableName
)
rims
,
err
:=
ex
.
Select
(
&
remoteIdentityMappingModel
{},
fmt
.
Sprintf
(
"select * from %s where user_id = $1"
,
qt
),
userID
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
remoteIdentityMappingTableName
)
rims
,
err
:=
ex
.
Select
(
&
remoteIdentityMappingModel
{},
fmt
.
Sprintf
(
"SELECT * FROM %s WHERE user_id = $1"
,
qt
),
userID
)
if
err
!=
nil
{
if
err
!=
sql
.
ErrNoRows
{
...
...
@@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
}
func
(
r
*
userRepo
)
GetAdminCount
(
tx
repo
.
Transaction
)
(
int
,
error
)
{
qt
:=
pq
.
QuoteIdentifier
(
userTableName
)
ex
:=
r
.
executor
(
tx
)
i
,
err
:=
ex
.
SelectInt
(
fmt
.
Sprintf
(
"SELECT count(*) FROM %s
where admin=true
"
,
qt
))
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
userTableName
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
i
,
err
:=
ex
.
SelectInt
(
fmt
.
Sprintf
(
"SELECT count(*) FROM %s
WHERE admin=true;
"
,
qt
))
return
int
(
i
),
err
}
...
...
@@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
if
err
!=
nil
{
return
nil
,
""
,
err
}
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
qt
:=
pq
.
QuoteIdentifier
(
userTableName
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
userTableName
)
// Ask for one more than needed so we know if there's more results, and
// hence, whether a nextPageToken is necessary.
ums
,
err
:=
ex
.
Select
(
&
userModel
{},
fmt
.
Sprintf
(
"SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 "
,
qt
),
maxResults
+
1
,
offset
)
ums
,
err
:=
ex
.
Select
(
&
userModel
{},
fmt
.
Sprintf
(
"SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2"
,
qt
),
maxResults
+
1
,
offset
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
...
...
@@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
}
func
(
r
*
userRepo
)
executor
(
tx
repo
.
Transaction
)
gorp
.
SqlExecutor
{
if
tx
==
nil
{
return
r
.
dbMap
}
gorpTx
,
ok
:=
tx
.
(
*
gorp
.
Transaction
)
if
!
ok
{
panic
(
"wrong kind of transaction passed to a DB repo"
)
}
return
gorpTx
}
func
(
r
*
userRepo
)
insert
(
tx
repo
.
Transaction
,
usr
user
.
User
)
error
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
um
,
err
:=
newUserModel
(
&
usr
)
if
err
!=
nil
{
return
err
...
...
@@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
}
func
(
r
*
userRepo
)
update
(
tx
repo
.
Transaction
,
usr
user
.
User
)
error
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
um
,
err
:=
newUserModel
(
&
usr
)
if
err
!=
nil
{
return
err
...
...
@@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
}
func
(
r
*
userRepo
)
get
(
tx
repo
.
Transaction
,
userID
string
)
(
user
.
User
,
error
)
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
m
,
err
:=
ex
.
Get
(
userModel
{},
userID
)
if
err
!=
nil
{
...
...
@@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
}
func
(
r
*
userRepo
)
getUserIDForRemoteIdentity
(
tx
repo
.
Transaction
,
ri
user
.
RemoteIdentity
)
(
string
,
error
)
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
m
,
err
:=
ex
.
Get
(
remoteIdentityMappingModel
{},
ri
.
ConnectorID
,
ri
.
ID
)
if
err
!=
nil
{
...
...
@@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
}
func
(
r
*
userRepo
)
getByEmail
(
tx
repo
.
Transaction
,
email
string
)
(
user
.
User
,
error
)
{
qt
:=
pq
.
QuoteIdentifier
(
userTableName
)
ex
:=
r
.
executor
(
tx
)
qt
:=
r
.
dbMap
.
Dialect
.
QuotedTableForQuery
(
""
,
userTableName
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
var
um
userModel
err
:=
ex
.
SelectOne
(
&
um
,
fmt
.
Sprintf
(
"select * from %s where email = $1"
,
qt
),
email
)
...
...
@@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err
}
func
(
r
*
userRepo
)
insertRemoteIdentity
(
tx
repo
.
Transaction
,
userID
string
,
ri
user
.
RemoteIdentity
)
error
{
ex
:=
r
.
executor
(
tx
)
ex
:=
executor
(
r
.
dbMap
,
tx
)
rim
,
err
:=
newRemoteIdentityMappingModel
(
userID
,
ri
)
if
err
!=
nil
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment