Commit 3c247db0 authored by Eric Chiang's avatar Eric Chiang Committed by GitHub

Merge pull request #757 from ericchiang/constant-refresh-tokens

*: update refresh tokens instead of deleting and creating another
parents c66cce8b ed20fee2
...@@ -55,7 +55,7 @@ fmt: ...@@ -55,7 +55,7 @@ fmt:
@go fmt $(shell go list ./... | grep -v '/vendor/') @go fmt $(shell go list ./... | grep -v '/vendor/')
lint: lint:
@for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api'); do \ @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api' | grep -v '/server/internal'); do \
golint -set_exit_status $$package $$i || exit 1; \ golint -set_exit_status $$package $$i || exit 1; \
done done
...@@ -81,12 +81,15 @@ aci: clean-release _output/bin/dex _output/images/library-alpine-3.4.aci ...@@ -81,12 +81,15 @@ aci: clean-release _output/bin/dex _output/images/library-alpine-3.4.aci
docker-image: clean-release _output/bin/dex docker-image: clean-release _output/bin/dex
@sudo docker build -t $(DOCKER_IMAGE) . @sudo docker build -t $(DOCKER_IMAGE) .
.PHONY: grpc .PHONY: proto
grpc: api/api.pb.go proto: api/api.pb.go server/internal/types.pb.go
api/api.pb.go: api/api.proto bin/protoc bin/protoc-gen-go api/api.pb.go: api/api.proto bin/protoc bin/protoc-gen-go
@protoc --go_out=plugins=grpc:. api/*.proto @protoc --go_out=plugins=grpc:. api/*.proto
server/internal/types.pb.go: server/internal/types.proto bin/protoc bin/protoc-gen-go
@protoc --go_out=. server/internal/*.proto
bin/protoc: scripts/get-protoc bin/protoc: scripts/get-protoc
@./scripts/get-protoc bin/protoc @./scripts/get-protoc bin/protoc
......
...@@ -241,7 +241,7 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { ...@@ -241,7 +241,7 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
authCodeURL := "" authCodeURL := ""
scopes = append(scopes, "openid", "profile", "email") scopes = append(scopes, "openid", "profile", "email")
if r.FormValue("offline_acecss") != "yes" { if r.FormValue("offline_access") != "yes" {
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState) authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState)
} else if a.offlineAsScope { } else if a.offlineAsScope {
scopes = append(scopes, "offline_access") scopes = append(scopes, "offline_access")
...@@ -254,34 +254,42 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { ...@@ -254,34 +254,42 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
} }
func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
if errMsg := r.FormValue("error"); errMsg != "" {
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
return
}
if state := r.FormValue("state"); state != exampleAppState {
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
return
}
code := r.FormValue("code")
refresh := r.FormValue("refresh_token")
var ( var (
err error err error
token *oauth2.Token token *oauth2.Token
) )
oauth2Config := a.oauth2Config(nil) oauth2Config := a.oauth2Config(nil)
switch { switch r.Method {
case code != "": case "GET":
// Authorization redirect callback from OAuth2 auth flow.
if errMsg := r.FormValue("error"); errMsg != "" {
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
return
}
code := r.FormValue("code")
if code == "" {
http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest)
return
}
if state := r.FormValue("state"); state != exampleAppState {
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
return
}
token, err = oauth2Config.Exchange(a.ctx, code) token, err = oauth2Config.Exchange(a.ctx, code)
case refresh != "": case "POST":
// Form request from frontend to refresh a token.
refresh := r.FormValue("refresh_token")
if refresh == "" {
http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest)
return
}
t := &oauth2.Token{ t := &oauth2.Token{
RefreshToken: refresh, RefreshToken: refresh,
Expiry: time.Now().Add(-time.Hour), Expiry: time.Now().Add(-time.Hour),
} }
token, err = oauth2Config.TokenSource(r.Context(), t).Token() token, err = oauth2Config.TokenSource(r.Context(), t).Token()
default: default:
http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
return return
} }
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
var indexTmpl = template.Must(template.New("index.html").Parse(`<html> var indexTmpl = template.Must(template.New("index.html").Parse(`<html>
<body> <body>
<form action="/login"> <form action="/login" method="post">
<p> <p>
Authenticate for:<input type="text" name="cross_client" placeholder="list of client-ids"> Authenticate for:<input type="text" name="cross_client" placeholder="list of client-ids">
</p> </p>
...@@ -50,8 +50,13 @@ pre { ...@@ -50,8 +50,13 @@ pre {
<body> <body>
<p> Token: <pre><code>{{ .IDToken }}</code></pre></p> <p> Token: <pre><code>{{ .IDToken }}</code></pre></p>
<p> Claims: <pre><code>{{ .Claims }}</code></pre></p> <p> Claims: <pre><code>{{ .Claims }}</code></pre></p>
{{ if .RefreshToken }}
<p> Refresh Token: <pre><code>{{ .RefreshToken }}</code></pre></p> <p> Refresh Token: <pre><code>{{ .RefreshToken }}</code></pre></p>
<p><a href="{{ .RedirectURL }}?refresh_token={{ .RefreshToken }}">Redeem refresh token</a><p> <form action="{{ .RedirectURL }}" method="post">
<input type="hidden" name="refresh_token" value="{{ .RefreshToken }}">
<input type="submit" value="Redeem refresh token">
</form>
{{ end }}
</body> </body>
</html> </html>
`)) `))
......
...@@ -2,6 +2,7 @@ package server ...@@ -2,6 +2,7 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
...@@ -16,6 +17,7 @@ import ( ...@@ -16,6 +17,7 @@ import (
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/server/internal"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
...@@ -645,20 +647,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s ...@@ -645,20 +647,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
var refreshToken string var refreshToken string
if reqRefresh { if reqRefresh {
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
RefreshToken: storage.NewID(), ID: storage.NewID(),
Token: storage.NewID(),
ClientID: authCode.ClientID, ClientID: authCode.ClientID,
ConnectorID: authCode.ConnectorID, ConnectorID: authCode.ConnectorID,
Scopes: authCode.Scopes, Scopes: authCode.Scopes,
Claims: authCode.Claims, Claims: authCode.Claims,
Nonce: authCode.Nonce, Nonce: authCode.Nonce,
ConnectorData: authCode.ConnectorData, ConnectorData: authCode.ConnectorData,
CreatedAt: s.now(),
LastUsed: s.now(),
} }
token := &internal.RefreshToken{
RefreshId: refresh.ID,
Token: refresh.Token,
}
if refreshToken, err = internal.Marshal(token); err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := s.storage.CreateRefresh(refresh); err != nil { if err := s.storage.CreateRefresh(refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err) s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
refreshToken = refresh.RefreshToken
} }
s.writeAccessToken(w, idToken, refreshToken, expiry) s.writeAccessToken(w, idToken, refreshToken, expiry)
} }
...@@ -672,16 +686,37 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -672,16 +686,37 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
refresh, err := s.storage.GetRefresh(code) token := new(internal.RefreshToken)
if err != nil || refresh.ClientID != client.ID { if err := internal.Unmarshal(code, token); err != nil {
if err != storage.ErrNotFound { // For backward compatibility, assume the refresh_token is a raw refresh token ID
s.logger.Errorf("failed to get auth code: %v", err) // if it fails to decode.
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) //
} else { // Because refresh_token values that aren't unmarshable were generated by servers
// that don't have a Token value, we'll still reject any attempts to claim a
// refresh_token twice.
token = &internal.RefreshToken{RefreshId: code, Token: ""}
}
refresh, err := s.storage.GetRefresh(token.RefreshId)
if err != nil {
s.logger.Errorf("failed to get refresh token: %v", err)
if err == storage.ErrNotFound {
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
} else {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} }
return return
} }
if refresh.ClientID != client.ID {
s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return
}
if refresh.Token != token.Token {
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return
}
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original // Per the OAuth2 spec, if the client has omitted the scopes, default to the original
// authorized scopes. // authorized scopes.
...@@ -720,6 +755,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -720,6 +755,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
ident := connector.Identity{
UserID: refresh.Claims.UserID,
Username: refresh.Claims.Username,
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData,
}
// Can the connector refresh the identity? If so, attempt to refresh the data // Can the connector refresh the identity? If so, attempt to refresh the data
// in the connector. // in the connector.
...@@ -727,52 +770,63 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -727,52 +770,63 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
// TODO(ericchiang): We may want a strict mode where connectors that don't implement // TODO(ericchiang): We may want a strict mode where connectors that don't implement
// this interface can't perform refreshing. // this interface can't perform refreshing.
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
ident := connector.Identity{ newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
UserID: refresh.Claims.UserID,
Username: refresh.Claims.Username,
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData,
}
ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
if err != nil { if err != nil {
s.logger.Errorf("failed to refresh identity: %v", err) s.logger.Errorf("failed to refresh identity: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
ident = newIdent
}
// Update the claims of the refresh token. claims := storage.Claims{
// UserID: ident.UserID,
// UserID intentionally ignored for now. Username: ident.Username,
refresh.Claims.Username = ident.Username Email: ident.Email,
refresh.Claims.Email = ident.Email EmailVerified: ident.EmailVerified,
refresh.Claims.EmailVerified = ident.EmailVerified Groups: ident.Groups,
refresh.Claims.Groups = ident.Groups
refresh.ConnectorData = ident.ConnectorData
} }
idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce) idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Errorf("failed to create ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
// Refresh tokens are claimed exactly once. Delete the current token and newToken := &internal.RefreshToken{
// create a new one. RefreshId: refresh.ID,
if err := s.storage.DeleteRefresh(code); err != nil { Token: storage.NewID(),
s.logger.Errorf("failed to delete auth code: %v", err) }
rawNewToken, err := internal.Marshal(newToken)
if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
refresh.RefreshToken = storage.NewID()
if err := s.storage.CreateRefresh(refresh); err != nil { updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
s.logger.Errorf("failed to create refresh token: %v", err) if old.Token != refresh.Token {
return old, errors.New("refresh token claimed twice")
}
old.Token = newToken.Token
// Update the claims of the refresh token.
//
// UserID intentionally ignored for now.
old.Claims.Username = ident.Username
old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
old.LastUsed = s.now()
return old, nil
}
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
s.logger.Errorf("failed to update refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry) s.writeAccessToken(w, idToken, rawNewToken, expiry)
} }
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) { func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
......
package internal
import (
"encoding/base64"
"github.com/golang/protobuf/proto"
)
// Marshal converts a protobuf message to a URL legal string.
func Marshal(message proto.Message) (string, error) {
data, err := proto.Marshal(message)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(data), nil
}
// Unmarshal decodes a protobuf message.
func Unmarshal(s string, message proto.Message) error {
data, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return err
}
return proto.Unmarshal(data, message)
}
// Code generated by protoc-gen-go.
// source: server/internal/types.proto
// DO NOT EDIT!
/*
Package internal is a generated protocol buffer package.
Package internal holds protobuf types used by the server
It is generated from these files:
server/internal/types.proto
It has these top-level messages:
RefreshToken
*/
package internal
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// RefreshToken is a message that holds refresh token data used by dex.
type RefreshToken struct {
RefreshId string `protobuf:"bytes,1,opt,name=refresh_id,json=refreshId" json:"refresh_id,omitempty"`
Token string `protobuf:"bytes,2,opt,name=token" json:"token,omitempty"`
}
func (m *RefreshToken) Reset() { *m = RefreshToken{} }
func (m *RefreshToken) String() string { return proto.CompactTextString(m) }
func (*RefreshToken) ProtoMessage() {}
func (*RefreshToken) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func init() {
proto.RegisterType((*RefreshToken)(nil), "internal.RefreshToken")
}
func init() { proto.RegisterFile("server/internal/types.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 112 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x92, 0x2e, 0x4e, 0x2d, 0x2a,
0x4b, 0x2d, 0xd2, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, 0x2f, 0xa9, 0x2c, 0x48,
0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0x89, 0x2a, 0x39, 0x73, 0xf1, 0x04,
0xa5, 0xa6, 0x15, 0xa5, 0x16, 0x67, 0x84, 0xe4, 0x67, 0xa7, 0xe6, 0x09, 0xc9, 0x72, 0x71, 0x15,
0x41, 0xf8, 0xf1, 0x99, 0x29, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x9c, 0x50, 0x11, 0xcf,
0x14, 0x21, 0x11, 0x2e, 0xd6, 0x12, 0x90, 0x3a, 0x09, 0x26, 0xb0, 0x0c, 0x84, 0x93, 0xc4, 0x06,
0x36, 0xd5, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x9b, 0xd0, 0x5a, 0x1d, 0x74, 0x00, 0x00, 0x00,
}
syntax = "proto3";
// Package internal holds protobuf types used by the server
package internal;
// RefreshToken is a message that holds refresh token data used by dex.
message RefreshToken {
string refresh_id = 1;
string token = 2;
}
...@@ -237,6 +237,10 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -237,6 +237,10 @@ func TestOAuth2CodeFlow(t *testing.T) {
if token.RefreshToken == newToken.RefreshToken { if token.RefreshToken == newToken.RefreshToken {
return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
} }
if _, err := config.TokenSource(ctx, token).Token(); err == nil {
return errors.New("was able to redeem the same refresh token twice")
}
return nil return nil
}, },
}, },
......
...@@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) { ...@@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
id := storage.NewID() id := storage.NewID()
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
RefreshToken: id, ID: id,
ClientID: "client_id", Token: "bar",
ConnectorID: "client_secret", Nonce: "foo",
Scopes: []string{"openid", "email", "profile"}, ClientID: "client_id",
ConnectorID: "client_secret",
Scopes: []string{"openid", "email", "profile"},
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{ Claims: storage.Claims{
UserID: "1", UserID: "1",
Username: "jane", Username: "jane",
...@@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { ...@@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
getAndCompare(id, refresh) getAndCompare(id, refresh)
updatedAt := time.Now().UTC().Round(time.Millisecond)
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "spam"
r.LastUsed = updatedAt
return r, nil
}
if err := s.UpdateRefreshToken(id, updater); err != nil {
t.Errorf("failed to udpate refresh token: %v", err)
}
refresh.Token = "spam"
refresh.LastUsed = updatedAt
getAndCompare(id, refresh)
if err := s.DeleteRefresh(id); err != nil { if err := s.DeleteRefresh(id); err != nil {
t.Fatalf("failed to delete refresh request: %v", err) t.Fatalf("failed to delete refresh request: %v", err)
} }
......
...@@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error { ...@@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error {
} }
func (cli *client) CreateRefresh(r storage.RefreshToken) error { func (cli *client) CreateRefresh(r storage.RefreshToken) error {
refresh := RefreshToken{ return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
TypeMeta: k8sapi.TypeMeta{
Kind: kindRefreshToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: r.RefreshToken,
Namespace: cli.namespace,
},
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: fromStorageClaims(r.Claims),
ConnectorData: r.ConnectorData,
}
return cli.post(resourceRefreshToken, refresh)
} }
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
...@@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) { ...@@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) {
} }
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) { func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
var r RefreshToken r, err := cli.getRefreshToken(id)
if err := cli.get(resourceRefreshToken, id, &r); err != nil { if err != nil {
return storage.RefreshToken{}, err return storage.RefreshToken{}, err
} }
return storage.RefreshToken{ return toStorageRefreshToken(r), nil
RefreshToken: r.ObjectMeta.Name, }
ClientID: r.ClientID,
ConnectorID: r.ConnectorID, func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
Scopes: r.Scopes, err = cli.get(resourceRefreshToken, id, &r)
Nonce: r.Nonce, return
Claims: toStorageClaims(r.Claims),
ConnectorData: r.ConnectorData,
}, nil
} }
func (cli *client) ListClients() ([]storage.Client, error) { func (cli *client) ListClients() ([]storage.Client, error) {
...@@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error { ...@@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error {
return cli.delete(resourcePassword, p.ObjectMeta.Name) return cli.delete(resourcePassword, p.ObjectMeta.Name)
} }
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
r, err := cli.getRefreshToken(id)
if err != nil {
return err
}
updated, err := updater(toStorageRefreshToken(r))
if err != nil {
return err
}
updated.ID = id
newToken := cli.fromStorageRefreshToken(updated)
newToken.ObjectMeta = r.ObjectMeta
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
}
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
c, err := cli.getClient(id) c, err := cli.getClient(id)
if err != nil { if err != nil {
......
...@@ -362,9 +362,14 @@ type RefreshToken struct { ...@@ -362,9 +362,14 @@ type RefreshToken struct {
k8sapi.TypeMeta `json:",inline"` k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"` k8sapi.ObjectMeta `json:"metadata,omitempty"`
CreatedAt time.Time
LastUsed time.Time
ClientID string `json:"clientID"` ClientID string `json:"clientID"`
Scopes []string `json:"scopes,omitempty"` Scopes []string `json:"scopes,omitempty"`
Token string `json:"token,omitempty"`
Nonce string `json:"nonce,omitempty"` Nonce string `json:"nonce,omitempty"`
Claims Claims `json:"claims,omitempty"` Claims Claims `json:"claims,omitempty"`
...@@ -379,6 +384,43 @@ type RefreshList struct { ...@@ -379,6 +384,43 @@ type RefreshList struct {
RefreshTokens []RefreshToken `json:"items"` RefreshTokens []RefreshToken `json:"items"`
} }
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
return storage.RefreshToken{
ID: r.ObjectMeta.Name,
Token: r.Token,
CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed,
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
ConnectorData: r.ConnectorData,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: toStorageClaims(r.Claims),
}
}
func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
return RefreshToken{
TypeMeta: k8sapi.TypeMeta{
Kind: kindRefreshToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: r.ID,
Namespace: cli.namespace,
},
Token: r.Token,
CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed,
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
ConnectorData: r.ConnectorData,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: fromStorageClaims(r.Claims),
}
}
// Keys is a mirrored struct from storage with JSON struct tags and Kubernetes // Keys is a mirrored struct from storage with JSON struct tags and Kubernetes
// type metadata. // type metadata.
type Keys struct { type Keys struct {
......
...@@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { ...@@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
s.tx(func() { s.tx(func() {
if _, ok := s.refreshTokens[r.RefreshToken]; ok { if _, ok := s.refreshTokens[r.ID]; ok {
err = storage.ErrAlreadyExists err = storage.ErrAlreadyExists
} else { } else {
s.refreshTokens[r.RefreshToken] = r s.refreshTokens[r.ID] = r
} }
}) })
return return
...@@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor ...@@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor
}) })
return return
} }
func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
s.tx(func() {
r, ok := s.refreshTokens[id]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.refreshTokens[id] = r
}
})
return
}
...@@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { ...@@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
id, client_id, scopes, nonce, id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified, claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data connector_id, connector_data,
token, created_at, last_used
) )
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
`, `,
r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups), encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
) )
if err != nil { if err != nil {
return fmt.Errorf("insert refresh_token: %v", err) return fmt.Errorf("insert refresh_token: %v", err)
...@@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { ...@@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
return nil return nil
} }
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
client_id = $1,
scopes = $2,
nonce = $3,
claims_user_id = $4,
claims_username = $5,
claims_email = $6,
claims_email_verified = $7,
claims_groups = $8,
connector_id = $9,
connector_data = $10,
token = $11,
created_at = $12,
last_used = $13
`,
r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
)
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
})
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return scanRefresh(c.QueryRow(` return getRefresh(c, id)
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select select
id, client_id, scopes, nonce, id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified, claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data connector_id, connector_data,
token, created_at, last_used
from refresh_token where id = $1; from refresh_token where id = $1;
`, id)) `, id))
} }
...@@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { ...@@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
id, client_id, scopes, nonce, id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified, claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data connector_id, connector_data,
token, created_at, last_used
from refresh_token; from refresh_token;
`) `)
if err != nil { if err != nil {
...@@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { ...@@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
func scanRefresh(s scanner) (r storage.RefreshToken, err error) { func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan( err = s.Scan(
&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, &r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, &r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups), decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData, &r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
......
...@@ -155,4 +155,14 @@ var migrations = []migration{ ...@@ -155,4 +155,14 @@ var migrations = []migration{
); );
`, `,
}, },
{
stmt: `
alter table refresh_token
add column token text not null default '';
alter table refresh_token
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
alter table refresh_token
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
`,
},
} }
...@@ -94,6 +94,7 @@ type Storage interface { ...@@ -94,6 +94,7 @@ type Storage interface {
UpdateClient(id string, updater func(old Client) (Client, error)) error UpdateClient(id string, updater func(old Client) (Client, error)) error
UpdateKeys(updater func(old Keys) (Keys, error)) error UpdateKeys(updater func(old Keys) (Keys, error)) error
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests. // GarbageCollect deletes all expired AuthCodes and AuthRequests.
...@@ -216,8 +217,15 @@ type AuthCode struct { ...@@ -216,8 +217,15 @@ type AuthCode struct {
// RefreshToken is an OAuth2 refresh token which allows a client to request new // RefreshToken is an OAuth2 refresh token which allows a client to request new
// tokens on the end user's behalf. // tokens on the end user's behalf.
type RefreshToken struct { type RefreshToken struct {
// The actual refresh token. ID string
RefreshToken string
// A single token that's rotated every time the refresh token is refreshed.
//
// May be empty.
Token string
CreatedAt time.Time
LastUsed time.Time
// Client this refresh token is valid for. // Client this refresh token is valid for.
ClientID string ClientID string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment