Commit c58dd948 authored by Eric Chiang's avatar Eric Chiang Committed by GitHub

Merge pull request #749 from ericchiang/postgres-timezones

storage: fix postgres timezone handling
parents c7aa1548 fd20b213
......@@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"PasswordCRUD", testPasswordCRUD},
{"KeysCRUD", testKeysCRUD},
{"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones},
})
}
......@@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
}
func testGC(t *testing.T, s storage.Storage) {
n := time.Now().UTC()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
}
pst, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
t.Fatal(err)
}
expiry := time.Now().In(est)
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: n.Add(time.Second),
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
......@@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth code: %v", err)
}
if _, err := s.GarbageCollect(n); err != nil {
t.Errorf("garbage collection failed: %v", err)
}
if _, err := s.GetAuthCode(c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
}
if _, err := s.GetAuthCode(c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}
}
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthCodes != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
......@@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) {
State: "bar",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: n,
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
......@@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth request: %v", err)
}
if _, err := s.GarbageCollect(n); err != nil {
t.Errorf("garbage collection failed: %v", err)
}
if _, err := s.GetAuthRequest(a.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
}
if _, err := s.GetAuthRequest(a.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}
}
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
......@@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
}
// Create an expiry with timezone info. Only expect backends to be
// accurate to the millisecond
expiry := time.Now().In(est).Round(time.Millisecond)
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(c.ID)
if err != nil {
t.Fatalf("failed to get auth code: %v", err)
}
// Ensure that if the resulting time is converted to the same
// timezone, it's the same value. We DO NOT expect timezones
// to be preserved.
gotTime := got.Expiry.In(est)
wantTime := expiry
if !gotTime.Equal(wantTime) {
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
}
}
......@@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) {
_, err := c.Exec(`
create table if not exists migrations (
num integer not null,
at timestamp not null
at timestamptz not null
);
`)
if err != nil {
......@@ -100,7 +100,7 @@ var migrations = []migration{
connector_id text not null,
connector_data bytea,
expiry timestamp not null
expiry timestamptz not null
);
create table auth_code (
......@@ -119,7 +119,7 @@ var migrations = []migration{
connector_id text not null,
connector_data bytea,
expiry timestamp not null
expiry timestamptz not null
);
create table refresh_token (
......@@ -151,7 +151,7 @@ var migrations = []migration{
verification_keys bytea not null, -- JSON array
signing_key bytea not null, -- JSON object
signing_key_pub bytea not null, -- JSON object
next_rotation timestamp not null
next_rotation timestamptz not null
);
`,
},
......
......@@ -4,6 +4,7 @@ package sql
import (
"database/sql"
"regexp"
"time"
"github.com/Sirupsen/logrus"
"github.com/cockroachdb/cockroach-go/crdb"
......@@ -28,6 +29,9 @@ type flavor struct {
//
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
// Does the flavor support timezones?
supportsTimezones bool
}
// A regexp with a replacement string.
......@@ -69,6 +73,8 @@ var (
}
return tx.Commit()
},
supportsTimezones: true,
}
flavorSQLite3 = flavor{
......@@ -80,7 +86,7 @@ var (
{matchLiteral("boolean"), "integer"},
// Translate other types.
{matchLiteral("bytea"), "blob"},
// {matchLiteral("timestamp"), "integer"},
{matchLiteral("timestamptz"), "timestamp"},
// SQLite doesn't have a "now()" method, replace with "date('now')"
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
},
......@@ -107,6 +113,22 @@ func (f flavor) translate(query string) string {
return query
}
// translateArgs translates query parameters that may be unique to
// a specific SQL flavor. For example, standardizing "time.Time"
// types to UTC for clients that don't provide timezone support.
func (c *conn) translateArgs(args []interface{}) []interface{} {
if c.flavor.supportsTimezones {
return args
}
for i, arg := range args {
if t, ok := arg.(time.Time); ok {
args[i] = t.UTC()
}
}
return args
}
// conn is the main database connection.
type conn struct {
db *sql.DB
......@@ -122,17 +144,17 @@ func (c *conn) Close() error {
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
query = c.flavor.translate(query)
return c.db.Exec(query, args...)
return c.db.Exec(query, c.translateArgs(args)...)
}
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = c.flavor.translate(query)
return c.db.Query(query, args...)
return c.db.Query(query, c.translateArgs(args)...)
}
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
query = c.flavor.translate(query)
return c.db.QueryRow(query, args...)
return c.db.QueryRow(query, c.translateArgs(args)...)
}
// ExecTx runs a method which operates on a transaction.
......@@ -163,15 +185,15 @@ type trans struct {
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
query = t.c.flavor.translate(query)
return t.tx.Exec(query, args...)
return t.tx.Exec(query, t.c.translateArgs(args)...)
}
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = t.c.flavor.translate(query)
return t.tx.Query(query, args...)
return t.tx.Query(query, t.c.translateArgs(args)...)
}
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
query = t.c.flavor.translate(query)
return t.tx.QueryRow(query, args...)
return t.tx.QueryRow(query, t.c.translateArgs(args)...)
}
......@@ -44,11 +44,9 @@ type GCResult struct {
AuthCodes int64
}
// Storage is the storage interface used by the server. Implementations, at minimum
// require compare-and-swap atomic actions.
//
// Implementations are expected to perform their own garbage collection of
// expired objects (expect keys, which are handled by the server).
// Storage is the storage interface used by the server. Implementations are
// required to be able to perform atomic compare-and-swap updates and either
// support timezones or standardize on UTC.
type Storage interface {
Close() error
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment