Commit c8feb5c3 authored by Bobby Rullo's avatar Bobby Rullo

db: PrivateKeySetRepo now takes >1 secrets

The first secret is used to encrypt, the rest are for decryption; if the
first doesn't work, the rest are tried in order.

The makes it possible to rotate keys.
parent 72c3b0c3
......@@ -18,6 +18,10 @@ const (
keyTableName = "key"
)
var (
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
)
func init() {
register(table{
name: keyTableName,
......@@ -85,15 +89,16 @@ type privateKeySetBlob struct {
Value []byte `db:"value"`
}
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
bsecret := []byte(secret)
if len(bsecret) != 32 {
return nil, errors.New("expected 32-byte secret")
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) {
for i, secret := range secrets {
if len(secret) != 32 {
return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
}
}
r := &PrivateKeySetRepo{
dbMap: dbm,
secret: []byte(secret),
secrets: secrets,
}
return r, nil
......@@ -101,7 +106,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, e
type PrivateKeySetRepo struct {
dbMap *gorp.DbMap
secret []byte
secrets [][]byte
}
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
......@@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
return err
}
v, err := pcrypto.AESEncrypt(j, r.secret)
v, err := pcrypto.AESEncrypt(j, r.active())
if err != nil {
return err
}
......@@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
return nil, errors.New("unable to cast to KeySet")
}
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
var pks *key.PrivateKeySet
for _, secret := range r.secrets {
var j []byte
j, err = pcrypto.AESDecrypt(b.Value, secret)
if err != nil {
return nil, errors.New("unable to decrypt key set")
continue
}
var m privateKeySetModel
if err := json.Unmarshal(j, &m); err != nil {
return nil, err
if err = json.Unmarshal(j, &m); err != nil {
continue
}
pks, err := m.PrivateKeySet()
pks, err = m.PrivateKeySet()
if err != nil {
return nil, err
continue
}
break
}
if err != nil {
return nil, ErrorCannotDecryptKeys
}
return key.KeySet(pks), nil
}
func (r *PrivateKeySetRepo) active() []byte {
return r.secrets[0]
}
......@@ -5,7 +5,7 @@ import (
)
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
_, err := NewPrivateKeySetRepo(nil, "sharks")
_, err := NewPrivateKeySetRepo(nil, []byte("sharks"))
if err == nil {
t.Fatalf("Expected non-nil error")
}
......
......@@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) {
}
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl")
s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
s2 := []byte("oooooooooooooooooooooooooooooooo")
s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
keys := []*key.PrivateKey{}
for i := 0; i < 2; i++ {
k, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf(err.Error())
t.Fatalf("Unable to generate RSA key: %v", err)
}
keys = append(keys, k)
}
ks := key.NewPrivateKeySet(
[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
tests := []struct {
setSecrets [][]byte
getSecrets [][]byte
wantErr bool
}{
{
// same secrets used to encrypt, decrypt
setSecrets: [][]byte{s1, s2},
getSecrets: [][]byte{s1, s2},
},
{
// setSecrets got rotated, but getSecrets didn't yet.
setSecrets: [][]byte{s2, s3},
getSecrets: [][]byte{s1, s2},
},
{
// getSecrets doesn't have s3
setSecrets: [][]byte{s3},
getSecrets: [][]byte{s1, s2},
wantErr: true,
},
}
k1, err := key.GeneratePrivateKey()
for i, tt := range tests {
setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...)
if err != nil {
t.Fatalf("Unable to generate RSA key: %v", err)
t.Fatalf(err.Error())
}
k2, err := key.GeneratePrivateKey()
getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...)
if err != nil {
t.Fatalf("Unable to generate RSA key: %v", err)
t.Fatalf(err.Error())
}
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute))
if err := r.Set(ks); err != nil {
t.Fatalf("Unexpected error: %v", err)
if err := setRepo.Set(ks); err != nil {
t.Fatalf("case %d: Unexpected error: %v", i, err)
}
got, err := r.Get()
got, err := getRepo.Get()
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want err, got nil", i)
}
continue
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
t.Fatalf("case %d: Unexpected error: %v", i, err)
}
if diff := pretty.Compare(ks, got); diff != "" {
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment