Commit 5abc7633 authored by bobbyrullo's avatar bobbyrullo

Merge pull request #87 from bobbyrullo/keyspace

Base64 Encode secrets, and allow >1 of them
parents 8cfffcc9 d0c199b6
......@@ -28,7 +28,10 @@ func init() {
func main() {
fs := flag.NewFlagSet("dex-overlord", flag.ExitOnError)
secret := fs.String("key-secret", "", "symmetric key used to encrypt/decrypt signing key data in DB")
keySecrets := pflag.NewBase64List(32)
fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.")
dbURL := fs.String("db-url", "", "DSN-formatted database connection string")
dbMigrate := fs.Bool("db-migrate", true, "perform database migrations when starting up overlord. This includes the initial DB objects creation.")
......@@ -59,10 +62,6 @@ func main() {
log.EnableTimestamps()
}
if len(*secret) == 0 {
log.Fatalf("--key-secret unset")
}
adminURL, err := url.Parse(*adminListen)
if err != nil {
log.Fatalf("Unable to use --admin-listen flag: %v", err)
......@@ -96,11 +95,32 @@ func main() {
userManager := user.NewManager(userRepo,
pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *secret)
kRepo, err := db.NewPrivateKeySetRepo(dbc, keySecrets.BytesSlice()...)
if err != nil {
log.Fatalf(err.Error())
}
var sleep time.Duration
for {
var done bool
_, err := kRepo.Get()
switch err {
case nil:
done = true
case key.ErrorNoKeys:
done = true
case db.ErrorCannotDecryptKeys:
log.Fatalf("Cannot decrypt keys using any of the given key secrets. The key secrets must be changed to include one that can decrypt the existing keys, or the existing keys must be deleted.")
}
if done {
break
}
sleep = ptime.ExpBackoff(sleep, time.Minute)
log.Errorf("Unable to get keys from repository, retrying in %v: %v", sleep, err)
time.Sleep(sleep)
}
krot := key.NewPrivateKeyRotator(kRepo, *keyPeriod)
s := server.NewAdminServer(adminAPI, krot)
h := s.HTTPHandler()
......
......@@ -41,7 +41,10 @@ func main() {
// ignored if --no-db is set
dbURL := fs.String("db-url", "", "DSN-formatted database connection string")
keySecret := fs.String("key-secret", "", "symmetric key used to encrypt/decrypt signing key data in DB")
keySecrets := pflag.NewBase64List(32)
fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.")
dbMaxIdleConns := fs.Int("db-max-idle-conns", 0, "maximum number of connections in the idle connection pool")
dbMaxOpenConns := fs.Int("db-max-open-conns", 0, "maximum number of open connections to the database")
......@@ -109,7 +112,7 @@ func main() {
MaxOpenConnections: *dbMaxOpenConns,
}
scfg.StateConfig = &server.MultiServerConfig{
KeySecret: *keySecret,
KeySecrets: keySecrets.BytesSlice(),
DatabaseConfig: dbCfg,
}
}
......
......@@ -18,6 +18,10 @@ const (
keyTableName = "key"
)
var (
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
)
func init() {
register(table{
name: keyTableName,
......@@ -85,15 +89,16 @@ type privateKeySetBlob struct {
Value []byte `db:"value"`
}
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
bsecret := []byte(secret)
if len(bsecret) != 32 {
return nil, errors.New("expected 32-byte secret")
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) {
for i, secret := range secrets {
if len(secret) != 32 {
return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
}
}
r := &PrivateKeySetRepo{
dbMap: dbm,
secret: []byte(secret),
secrets: secrets,
}
return r, nil
......@@ -101,7 +106,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, e
type PrivateKeySetRepo struct {
dbMap *gorp.DbMap
secret []byte
secrets [][]byte
}
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
......@@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
return err
}
v, err := pcrypto.AESEncrypt(j, r.secret)
v, err := pcrypto.AESEncrypt(j, r.active())
if err != nil {
return err
}
......@@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
return nil, errors.New("unable to cast to KeySet")
}
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
var pks *key.PrivateKeySet
for _, secret := range r.secrets {
var j []byte
j, err = pcrypto.AESDecrypt(b.Value, secret)
if err != nil {
return nil, errors.New("unable to decrypt key set")
continue
}
var m privateKeySetModel
if err := json.Unmarshal(j, &m); err != nil {
return nil, err
if err = json.Unmarshal(j, &m); err != nil {
continue
}
pks, err := m.PrivateKeySet()
pks, err = m.PrivateKeySet()
if err != nil {
return nil, err
continue
}
break
}
if err != nil {
return nil, ErrorCannotDecryptKeys
}
return key.KeySet(pks), nil
}
func (r *PrivateKeySetRepo) active() []byte {
return r.secrets[0]
}
......@@ -5,7 +5,7 @@ import (
)
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
_, err := NewPrivateKeySetRepo(nil, "sharks")
_, err := NewPrivateKeySetRepo(nil, []byte("sharks"))
if err == nil {
t.Fatalf("Expected non-nil error")
}
......
......@@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) {
}
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl")
s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
s2 := []byte("oooooooooooooooooooooooooooooooo")
s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
keys := []*key.PrivateKey{}
for i := 0; i < 2; i++ {
k, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf(err.Error())
t.Fatalf("Unable to generate RSA key: %v", err)
}
keys = append(keys, k)
}
ks := key.NewPrivateKeySet(
[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
tests := []struct {
setSecrets [][]byte
getSecrets [][]byte
wantErr bool
}{
{
// same secrets used to encrypt, decrypt
setSecrets: [][]byte{s1, s2},
getSecrets: [][]byte{s1, s2},
},
{
// setSecrets got rotated, but getSecrets didn't yet.
setSecrets: [][]byte{s2, s3},
getSecrets: [][]byte{s1, s2},
},
{
// getSecrets doesn't have s3
setSecrets: [][]byte{s3},
getSecrets: [][]byte{s1, s2},
wantErr: true,
},
}
k1, err := key.GeneratePrivateKey()
for i, tt := range tests {
setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...)
if err != nil {
t.Fatalf("Unable to generate RSA key: %v", err)
t.Fatalf(err.Error())
}
k2, err := key.GeneratePrivateKey()
getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...)
if err != nil {
t.Fatalf("Unable to generate RSA key: %v", err)
t.Fatalf(err.Error())
}
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute))
if err := r.Set(ks); err != nil {
t.Fatalf("Unexpected error: %v", err)
if err := setRepo.Set(ks); err != nil {
t.Fatalf("case %d: Unexpected error: %v", i, err)
}
got, err := r.Get()
got, err := getRepo.Get()
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want err, got nil", i)
}
continue
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
t.Fatalf("case %d: Unexpected error: %v", i, err)
}
if diff := pretty.Compare(ks, got); diff != "" {
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
}
}
}
......
......@@ -84,11 +84,13 @@ func AESDecrypt(ciphertext, key []byte) ([]byte, error) {
}
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(ciphertext, ciphertext)
if len(ciphertext)%aes.BlockSize != 0 {
plaintext := make([]byte, len(ciphertext))
mode.CryptBlocks(plaintext, ciphertext)
if len(plaintext)%aes.BlockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of the block size")
}
return unpad(ciphertext)
return unpad(plaintext)
}
package flag
import (
"encoding/base64"
"fmt"
"strings"
)
// Base64 implements flag.Value, and is used to populate []byte values from baes64 encoded strings.
type Base64 struct {
val []byte
len int
}
// NewBase64 returns a Base64 which accepts values which decode to len byte strings.
func NewBase64(len int) *Base64 {
return &Base64{
len: len,
}
}
func (f *Base64) String() string {
return base64.StdEncoding.EncodeToString(f.val)
}
// Set will set the []byte value of the Base64 to the base64 decoded values of the string, returning an error if it cannot be decoded or is of the wrong length.
func (f *Base64) Set(s string) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return err
}
if len(b) != f.len {
return fmt.Errorf("expected %d-byte secret", f.len)
}
f.val = b
return nil
}
// Bytes returns the set []byte value.
// If no value has been set, a nil []byte is returned.
func (f *Base64) Bytes() []byte {
return f.val
}
// NewBase64List returns a Base64List which accepts a comma-separated list of strings which must decode to len byte strings.
func NewBase64List(len int) *Base64List {
return &Base64List{
len: len,
}
}
// Base64List implements flag.Value and is used to populate [][]byte values from a comma-separated list of base64 encoded strings.
type Base64List struct {
val [][]byte
len int
}
// Set will set the [][]byte value of the Base64List to the base64 decoded values of the comma-separated strings, returning an error on the first error it encounters.
func (f *Base64List) Set(ss string) error {
if ss == "" {
return nil
}
for i, s := range strings.Split(ss, ",") {
b64 := NewBase64(f.len)
err := b64.Set(s)
if err != nil {
return fmt.Errorf("error decoding string %d: %q", i, err)
}
f.val = append(f.val, b64.Bytes())
}
return nil
}
func (f *Base64List) String() string {
ss := []string{}
for _, b := range f.val {
ss = append(ss, base64.StdEncoding.EncodeToString(b))
}
return strings.Join(ss, ",")
}
func (f *Base64List) BytesSlice() [][]byte {
return f.val
}
package flag
import (
"encoding/base64"
"strings"
"testing"
"github.com/kylelemons/godebug/pretty"
)
func TestBase64(t *testing.T) {
toB64 := func(b []byte) string {
return base64.StdEncoding.EncodeToString(b)
}
tests := []struct {
s string
l int
b []byte
wantError bool
}{
{
s: toB64([]byte("123456")),
l: 6,
b: []byte("123456"),
},
{
s: toB64([]byte("123456")),
l: 5,
wantError: true,
},
{
s: "not base64",
l: 5,
wantError: true,
},
}
for i, tt := range tests {
b64 := NewBase64(tt.l)
err := b64.Set(tt.s)
if tt.wantError {
if err == nil {
t.Errorf("case %d: want err, got nil", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected error %q", i, err)
}
if diff := pretty.Compare(tt.b, b64.Bytes()); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i,
diff)
}
if b64.String() != tt.s {
t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s)
}
}
}
func TestBase64List(t *testing.T) {
// toCSB64 == to comma separated base 64
toCSB64 := func(bb ...[]byte) string {
ss := []string{}
for _, b := range bb {
ss = append(ss, base64.StdEncoding.EncodeToString(b))
}
return strings.Join(ss, ",")
}
b123 := []byte("123456")
b567 := []byte("567890")
bShort := []byte("1234")
tests := []struct {
s string
l int
bb [][]byte
wantError bool
}{
{
s: toCSB64(b123, b567),
l: 6,
bb: [][]byte{b123, b567},
},
{
s: toCSB64(b123),
l: 6,
bb: [][]byte{b123},
},
{
s: "",
l: 6,
bb: [][]byte{},
},
{
s: toCSB64(b123, bShort),
l: 6,
wantError: true,
},
{
s: toCSB64(bShort, b123),
l: 6,
wantError: true,
},
}
for i, tt := range tests {
b64 := NewBase64List(tt.l)
err := b64.Set(tt.s)
if tt.wantError {
if err == nil {
t.Errorf("case %d: want err, got nil", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected error %q", i, err)
}
if diff := pretty.Compare(tt.bb, b64.BytesSlice()); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i,
diff)
}
if b64.String() != tt.s {
t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s)
}
}
}
......@@ -44,7 +44,7 @@ type SingleServerConfig struct {
}
type MultiServerConfig struct {
KeySecret string
KeySecrets [][]byte
DatabaseConfig db.Config
}
......@@ -141,7 +141,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
}
func (cfg *MultiServerConfig) Configure(srv *Server) error {
if cfg.KeySecret == "" {
if len(cfg.KeySecrets) == 0 {
return errors.New("missing key secret")
}
......@@ -154,7 +154,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
return fmt.Errorf("unable to initialize database connection: %v", err)
}
kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.KeySecret)
kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.KeySecrets...)
if err != nil {
return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment