fix(GODT-3102): Distinguish Vault Decryption from Serialization Errors

Rather than returning whether the vault was corrupt or not return the
error which caused the vault to be considered as corrupt.
This commit is contained in:
Leander Beernaert 2023-11-30 08:31:14 +01:00
parent 7a1c7e8743
commit 1b22c32ef9
10 changed files with 50 additions and 37 deletions

View File

@ -42,7 +42,7 @@ func TestMigratePrefsToVaultWithKeys(t *testing.T) {
// Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
// load the old prefs file.
configDir := filepath.Join("testdata", "with_keys")
@ -63,7 +63,7 @@ func TestMigratePrefsToVaultWithoutKeys(t *testing.T) {
// Create a new vault.
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
// load the old prefs file.
configDir := filepath.Join("testdata", "without_keys")
@ -173,7 +173,7 @@ func TestUserMigration(t *testing.T) {
v, corrupt, err := vault.New(settingsFolder, settingsFolder, token, async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
require.NoError(t, migrateOldAccounts(locations, kcl, v))
require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs())

View File

@ -42,21 +42,25 @@ func WithVault(locations *locations.Locations, keychains *keychain.List, panicHa
logrus.WithFields(logrus.Fields{
"insecure": insecure,
"corrupt": corrupt,
"corrupt": corrupt != nil,
}).Debug("Vault created")
if corrupt != nil {
logrus.WithError(corrupt).Warn("Failed to load existing vault, vault has been reset")
}
cert, _ := encVault.GetBridgeTLSCert()
certs.NewInstaller().LogCertInstallStatus(cert)
// GODT-1950: Add teardown actions (e.g. to close the vault).
return fn(encVault, insecure, corrupt)
return fn(encVault, insecure, corrupt != nil)
}
func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) {
func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, error, error) {
vaultDir, err := locations.ProvideSettingsPath()
if err != nil {
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
return nil, false, nil, fmt.Errorf("could not get vault dir: %w", err)
}
logrus.WithField("vaultDir", vaultDir).Debug("Loading vault from directory")
@ -78,12 +82,12 @@ func newVault(locations *locations.Locations, keychains *keychain.List, panicHan
gluonCacheDir, err := locations.ProvideGluonCachePath()
if err != nil {
return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err)
return nil, false, nil, fmt.Errorf("could not provide gluon path: %w", err)
}
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler)
if err != nil {
return nil, false, false, fmt.Errorf("could not create vault: %w", err)
return nil, false, corrupt, fmt.Errorf("could not create vault: %w", err)
}
return vault, insecure, corrupt, nil

View File

@ -133,7 +133,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil)
require.NoError(tb, err)
require.False(tb, corrupt)
require.NoError(tb, corrupt)
vaultUser, err := v.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
require.NoError(tb, err)

View File

@ -55,7 +55,7 @@ func TestMigrate(t *testing.T) {
// Migrate the vault.
s, corrupt, err := New(dir, "default-gluon-dir", []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
// Check the migrated vault.
require.Equal(t, "v2.3.x-gluon-dir", s.GetGluonCacheDir())

View File

@ -68,7 +68,7 @@ func TestVault_Settings_GluonDir(t *testing.T) {
// create a new test vault.
s, corrupt, err := vault.New(t.TempDir(), "/path/to/gluon", []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
// Check the default gluon dir.
require.Equal(t, "/path/to/gluon", s.GetGluonCacheDir())

View File

@ -19,6 +19,7 @@ package vault
import (
"crypto/cipher"
"fmt"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/vmihailenco/msgpack/v5"
@ -34,12 +35,12 @@ func unmarshalFile[T any](gcm cipher.AEAD, b []byte, data *T) error {
var f File
if err := msgpack.Unmarshal(b, &f); err != nil {
return err
return fmt.Errorf("%w: %v", ErrUnmarshal, err)
}
dec, err := gcm.Open(nil, f.Data[:gcm.NonceSize()], f.Data[gcm.NonceSize():], nil)
if err != nil {
return err
return fmt.Errorf("%w: %v", ErrDecryptFailed, err)
}
for v := f.Version; v < Current; v++ {
@ -48,7 +49,11 @@ func unmarshalFile[T any](gcm cipher.AEAD, b []byte, data *T) error {
}
}
return msgpack.Unmarshal(dec, data)
if err := msgpack.Unmarshal(dec, data); err != nil {
return fmt.Errorf("%w: %v", ErrUnmarshal, err)
}
return nil
}
func marshalFile[T any](gcm cipher.AEAD, t T) ([]byte, error) {

View File

@ -49,27 +49,31 @@ type Vault struct {
panicHandler async.PanicHandler
}
var ErrDecryptFailed = errors.New("failed to decrypt vault")
var ErrUnmarshal = errors.New("vault contents are corrupt")
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, bool, error) {
// The first error is a corruption error for an existing vault, the second errors refrain to all other errors.
func New(vaultDir, gluonCacheDir string, key []byte, panicHandler async.PanicHandler) (*Vault, error, error) {
if err := os.MkdirAll(vaultDir, 0o700); err != nil {
return nil, false, err
return nil, nil, err
}
hash256 := sha256.Sum256(key)
aes, err := aes.NewCipher(hash256[:])
if err != nil {
return nil, false, err
return nil, nil, err
}
gcm, err := cipher.NewGCM(aes)
if err != nil {
return nil, false, err
return nil, nil, err
}
vault, corrupt, err := newVault(filepath.Join(vaultDir, "vault.enc"), gluonCacheDir, gcm)
if err != nil {
return nil, false, err
return nil, corrupt, err
}
vault.panicHandler = panicHandler
@ -341,28 +345,28 @@ func (vault *Vault) detachUser(userID string) error {
return nil
}
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, error, error) {
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
if _, err := initVault(path, gluonDir, gcm); err != nil {
return nil, false, err
return nil, nil, err
}
}
enc, err := os.ReadFile(filepath.Clean(path))
if err != nil {
return nil, false, err
return nil, nil, err
}
var corrupt bool
var corrupt error
if err := unmarshalFile(gcm, enc, new(Data)); err != nil {
corrupt = true
corrupt = err
}
if corrupt {
if corrupt != nil {
newEnc, err := initVault(path, gluonDir, gcm)
if err != nil {
return nil, false, err
return nil, corrupt, err
}
enc = newEnc

View File

@ -34,7 +34,7 @@ func BenchmarkVault(b *testing.B) {
// Create a new vault.
s, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(b, err)
require.False(b, corrupt)
require.NoError(b, corrupt)
// Add 10kB of cookies to the vault.
require.NoError(b, s.SetCookies(bytes.Repeat([]byte("a"), 10_000)))

View File

@ -34,19 +34,19 @@ func TestVault_Corrupt(t *testing.T) {
{
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
}
{
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
}
{
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("bad key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.True(t, corrupt)
require.ErrorIs(t, corrupt, vault.ErrDecryptFailed)
}
}
@ -56,13 +56,13 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
{
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
}
{
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
}
{
@ -75,7 +75,7 @@ func TestVault_Corrupt_JunkData(t *testing.T) {
_, corrupt, err := vault.New(vaultDir, gluonDir, []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.True(t, corrupt)
require.ErrorIs(t, corrupt, vault.ErrUnmarshal)
}
}
@ -103,7 +103,7 @@ func newVault(t *testing.T) *vault.Vault {
s, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"), async.NoopPanicHandler{})
require.NoError(t, err)
require.False(t, corrupt)
require.NoError(t, corrupt)
return s
}

View File

@ -112,8 +112,8 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) {
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, t.storeKey, async.NoopPanicHandler{})
if err != nil {
return nil, fmt.Errorf("could not create vault: %w", err)
} else if corrupt {
return nil, fmt.Errorf("vault is corrupt")
} else if corrupt != nil {
return nil, fmt.Errorf("vault is corrupt: %w", corrupt)
}
t.vault = vault