From e8d9534b9ca1f68885064e56a11397184c8903ab Mon Sep 17 00:00:00 2001 From: Romain Le Jeune Date: Wed, 8 Nov 2023 13:05:57 +0000 Subject: [PATCH] feat(GODT-2277): Move Keychain helpers creation in main. --- internal/app/app.go | 99 +++++++++++++---------- internal/app/bridge.go | 3 + internal/app/migration.go | 5 +- internal/app/migration_test.go | 9 +-- internal/app/vault.go | 12 +-- internal/bridge/bridge.go | 9 +++ internal/bridge/bridge_test.go | 2 + internal/bridge/heartbeat.go | 5 +- internal/bridge/keychain.go | 24 ++++++ internal/frontend/grpc/service_methods.go | 4 +- pkg/keychain/helper_darwin.go | 12 ++- pkg/keychain/helper_linux.go | 62 ++++---------- pkg/keychain/helper_windows.go | 15 ++-- pkg/keychain/keychain.go | 96 +++++++++++++++++++--- pkg/keychain/test_helper.go | 14 +++- tests/ctx_bridge_test.go | 2 + utils/vault-editor/main.go | 4 +- 17 files changed, 243 insertions(+), 134 deletions(-) create mode 100644 internal/bridge/keychain.go diff --git a/internal/app/app.go b/internal/app/app.go index f130e332..e87f22b3 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -41,6 +41,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/sentry" "github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/restarter" "github.com/pkg/profile" "github.com/sirupsen/logrus" @@ -234,56 +235,59 @@ func run(c *cli.Context) error { } return withSingleInstance(settings, locations.GetLockFile(), version, func() error { - // Unlock the encrypted vault. - return WithVault(locations, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error { - if !v.Migrated() { - // Migrate old settings into the vault. - if err := migrateOldSettings(v); err != nil { - logrus.WithError(err).Error("Failed to migrate old settings") - } - - // Migrate old accounts into the vault. - if err := migrateOldAccounts(locations, v); err != nil { - logrus.WithError(err).Error("Failed to migrate old accounts") - } - - // The vault has been migrated. - if err := v.SetMigrated(); err != nil { - logrus.WithError(err).Error("Failed to mark vault as migrated") - } - } - - logrus.WithFields(logrus.Fields{ - "lastVersion": v.GetLastVersion().String(), - "showAllMail": v.GetShowAllMail(), - "updateCh": v.GetUpdateChannel(), - "autoUpdate": v.GetAutoUpdate(), - "rollout": v.GetUpdateRollout(), - "DoH": v.GetProxyAllowed(), - }).Info("Vault loaded") - - // Load the cookies from the vault. - return withCookieJar(v, func(cookieJar http.CookieJar) error { - // Create a new bridge instance. - return withBridge(c, exe, locations, version, identifier, crashHandler, reporter, v, cookieJar, func(b *bridge.Bridge, eventCh <-chan events.Event) error { - if insecure { - logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted") - b.PushError(bridge.ErrVaultInsecure) + // Look for available keychains + return withKeychainList(func(keychains *keychain.List) error { + // Unlock the encrypted vault. + return WithVault(locations, keychains, crashHandler, func(v *vault.Vault, insecure, corrupt bool) error { + if !v.Migrated() { + // Migrate old settings into the vault. + if err := migrateOldSettings(v); err != nil { + logrus.WithError(err).Error("Failed to migrate old settings") } - if corrupt { - logrus.Warn("The vault is corrupt and has been wiped") - b.PushError(bridge.ErrVaultCorrupt) + // Migrate old accounts into the vault. + if err := migrateOldAccounts(locations, keychains, v); err != nil { + logrus.WithError(err).Error("Failed to migrate old accounts") } - // Remove old updates files - b.RemoveOldUpdates() + // The vault has been migrated. + if err := v.SetMigrated(); err != nil { + logrus.WithError(err).Error("Failed to mark vault as migrated") + } + } - // Start telemetry heartbeat process - b.StartHeartbeat(b) + logrus.WithFields(logrus.Fields{ + "lastVersion": v.GetLastVersion().String(), + "showAllMail": v.GetShowAllMail(), + "updateCh": v.GetUpdateChannel(), + "autoUpdate": v.GetAutoUpdate(), + "rollout": v.GetUpdateRollout(), + "DoH": v.GetProxyAllowed(), + }).Info("Vault loaded") - // Run the frontend. - return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID)) + // Load the cookies from the vault. + return withCookieJar(v, func(cookieJar http.CookieJar) error { + // Create a new bridge instance. + return withBridge(c, exe, locations, version, identifier, crashHandler, reporter, v, cookieJar, keychains, func(b *bridge.Bridge, eventCh <-chan events.Event) error { + if insecure { + logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted") + b.PushError(bridge.ErrVaultInsecure) + } + + if corrupt { + logrus.Warn("The vault is corrupt and has been wiped") + b.PushError(bridge.ErrVaultCorrupt) + } + + // Remove old updates files + b.RemoveOldUpdates() + + // Start telemetry heartbeat process + b.StartHeartbeat(b) + + // Run the frontend. + return runFrontend(c, crashHandler, restarter, locations, b, eventCh, quitCh, c.Int(flagParentPID)) + }) }) }) }) @@ -480,6 +484,13 @@ func withCookieJar(vault *vault.Vault, fn func(http.CookieJar) error) error { return fn(persister) } +// List usable keychains. +func withKeychainList(fn func(*keychain.List) error) error { + logrus.Debug("Creating keychain list") + defer logrus.Debug("Keychain list stop") + return fn(keychain.NewList()) +} + func setDeviceCookies(jar *cookies.Jar) error { url, err := url.Parse(constants.APIHost) if err != nil { diff --git a/internal/app/bridge.go b/internal/app/bridge.go index d3cb79cf..97af2937 100644 --- a/internal/app/bridge.go +++ b/internal/app/bridge.go @@ -37,6 +37,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/internal/versioner" + "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" ) @@ -55,6 +56,7 @@ func withBridge( reporter *sentry.Reporter, vault *vault.Vault, cookieJar http.CookieJar, + keychains *keychain.List, fn func(*bridge.Bridge, <-chan events.Event) error, ) error { logrus.Debug("Creating bridge") @@ -97,6 +99,7 @@ func withBridge( autostarter, updater, version, + keychains, // The API stuff. constants.APIHost, diff --git a/internal/app/migration.go b/internal/app/migration.go index c5197ad9..b2ed8610 100644 --- a/internal/app/migration.go +++ b/internal/app/migration.go @@ -122,7 +122,7 @@ func migrateOldSettingsWithDir(configDir string, v *vault.Vault) error { return v.SetBridgeTLSCertKey(certPEM, keyPEM) } -func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error { +func migrateOldAccounts(locations *locations.Locations, keychains *keychain.List, v *vault.Vault) error { logrus.Info("Migrating accounts") settings, err := locations.ProvideSettingsPath() @@ -134,8 +134,7 @@ func migrateOldAccounts(locations *locations.Locations, v *vault.Vault) error { if err != nil { return fmt.Errorf("failed to get helper: %w", err) } - - keychain, err := keychain.NewKeychain(helper, "bridge") + keychain, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper()) if err != nil { return fmt.Errorf("failed to create keychain: %w", err) } diff --git a/internal/app/migration_test.go b/internal/app/migration_test.go index e25f2a59..17284c08 100644 --- a/internal/app/migration_test.go +++ b/internal/app/migration_test.go @@ -35,7 +35,6 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/ProtonMail/proton-bridge/v3/pkg/algo" "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" - dockerCredentials "github.com/docker/docker-credential-helpers/credentials" "github.com/stretchr/testify/require" ) @@ -133,11 +132,9 @@ func TestKeychainMigration(t *testing.T) { } func TestUserMigration(t *testing.T) { - keychainHelper := keychain.NewTestHelper() + kcl := keychain.NewTestKeychainsList() - keychain.Helpers["mock"] = func(string) (dockerCredentials.Helper, error) { return keychainHelper, nil } - - kc, err := keychain.NewKeychain("mock", "bridge") + kc, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper()) require.NoError(t, err) require.NoError(t, kc.Put("brokenID", "broken")) @@ -178,7 +175,7 @@ func TestUserMigration(t *testing.T) { require.NoError(t, err) require.False(t, corrupt) - require.NoError(t, migrateOldAccounts(locations, v)) + require.NoError(t, migrateOldAccounts(locations, kcl, v)) require.Equal(t, []string{wantCredentials.UserID}, v.GetUserIDs()) require.NoError(t, v.GetUser(wantCredentials.UserID, func(u *vault.User) { diff --git a/internal/app/vault.go b/internal/app/vault.go index dc349c22..fd317a74 100644 --- a/internal/app/vault.go +++ b/internal/app/vault.go @@ -29,12 +29,12 @@ import ( "github.com/sirupsen/logrus" ) -func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error { +func WithVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler, fn func(*vault.Vault, bool, bool) error) error { logrus.Debug("Creating vault") defer logrus.Debug("Vault stopped") // Create the encVault. - encVault, insecure, corrupt, err := newVault(locations, panicHandler) + encVault, insecure, corrupt, err := newVault(locations, keychains, panicHandler) if err != nil { return fmt.Errorf("could not create vault: %w", err) } @@ -49,7 +49,7 @@ func WithVault(locations *locations.Locations, panicHandler async.PanicHandler, return fn(encVault, insecure, corrupt) } -func newVault(locations *locations.Locations, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) { +func newVault(locations *locations.Locations, keychains *keychain.List, panicHandler async.PanicHandler) (*vault.Vault, bool, bool, error) { vaultDir, err := locations.ProvideSettingsPath() if err != nil { return nil, false, false, fmt.Errorf("could not get vault dir: %w", err) @@ -62,7 +62,7 @@ func newVault(locations *locations.Locations, panicHandler async.PanicHandler) ( insecure bool ) - if key, err := loadVaultKey(vaultDir); err != nil { + if key, err := loadVaultKey(vaultDir, keychains); err != nil { logrus.WithError(err).Error("Could not load/create vault key") insecure = true @@ -85,13 +85,13 @@ func newVault(locations *locations.Locations, panicHandler async.PanicHandler) ( return vault, insecure, corrupt, nil } -func loadVaultKey(vaultDir string) ([]byte, error) { +func loadVaultKey(vaultDir string, keychains *keychain.List) ([]byte, error) { helper, err := vault.GetHelper(vaultDir) if err != nil { return nil, fmt.Errorf("could not get keychain helper: %w", err) } - kc, err := keychain.NewKeychain(helper, constants.KeyChainName) + kc, err := keychain.NewKeychain(helper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper()) if err != nil { return nil, fmt.Errorf("could not create keychain: %w", err) } diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index 8f44ee9a..6a91697f 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -45,6 +45,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/telemetry" "github.com/ProtonMail/proton-bridge/v3/internal/user" "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/bradenaw/juniper/xslices" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" @@ -82,6 +83,9 @@ type Bridge struct { newVersion *semver.Version newVersionLock safe.RWMutex + // keychains is the utils that own usable keychains found in the OS. + keychains *keychain.List + // focusService is used to raise the bridge window when needed. focusService *focus.Service @@ -138,6 +142,7 @@ func New( autostarter Autostarter, // the autostarter to manage autostart settings updater Updater, // the updater to fetch and install updates curVersion *semver.Version, // the current version of the bridge + keychains *keychain.List, // usable keychains apiURL string, // the URL of the API to use cookieJar http.CookieJar, // the cookie jar to use @@ -171,6 +176,7 @@ func New( autostarter, updater, curVersion, + keychains, panicHandler, reporter, @@ -204,6 +210,7 @@ func newBridge( autostarter Autostarter, updater Updater, curVersion *semver.Version, + keychains *keychain.List, panicHandler async.PanicHandler, reporter reporter.Reporter, @@ -256,6 +263,8 @@ func newBridge( newVersion: curVersion, newVersionLock: safe.NewRWMutex(), + keychains: keychains, + panicHandler: panicHandler, reporter: reporter, diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 6ce8b1bd..9540e02c 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -49,6 +49,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/user" "github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/tests" "github.com/bradenaw/juniper/xslices" imapid "github.com/emersion/go-imap-id" @@ -950,6 +951,7 @@ func withBridgeNoMocks( mocks.Autostarter, mocks.Updater, v2_3_0, + keychain.NewTestKeychainsList(), // The API stuff. apiURL, diff --git a/internal/bridge/heartbeat.go b/internal/bridge/heartbeat.go index 9b382c4c..a851ecb5 100644 --- a/internal/bridge/heartbeat.go +++ b/internal/bridge/heartbeat.go @@ -26,7 +26,6 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/telemetry" "github.com/ProtonMail/proton-bridge/v3/internal/vault" - "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/sirupsen/logrus" ) @@ -81,7 +80,7 @@ func (bridge *Bridge) SetLastHeartbeatSent(timestamp time.Time) error { } func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { - bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), keychain.DefaultHelper) + bridge.heartbeat = telemetry.NewHeartbeat(manager, 1143, 1025, bridge.GetGluonCacheDir(), bridge.keychains.GetDefaultHelper()) // Check for heartbeat when triggered. bridge.goHeartbeat = bridge.tasks.PeriodicOrTrigger(HeartbeatCheckInterval, 0, func(ctx context.Context) { @@ -104,7 +103,7 @@ func (bridge *Bridge) StartHeartbeat(manager telemetry.HeartbeatManager) { if val, err := bridge.GetKeychainApp(); err != nil { bridge.heartbeat.SetKeyChainPref(val) } else { - bridge.heartbeat.SetKeyChainPref(keychain.DefaultHelper) + bridge.heartbeat.SetKeyChainPref(bridge.keychains.GetDefaultHelper()) } bridge.heartbeat.SetPrevVersion(bridge.GetLastVersion().String()) diff --git a/internal/bridge/keychain.go b/internal/bridge/keychain.go new file mode 100644 index 00000000..94783f53 --- /dev/null +++ b/internal/bridge/keychain.go @@ -0,0 +1,24 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package bridge + +import "golang.org/x/exp/maps" + +func (bridge *Bridge) GetHelpersNames() []string { + return maps.Keys(bridge.keychains.GetHelpers()) +} diff --git a/internal/frontend/grpc/service_methods.go b/internal/frontend/grpc/service_methods.go index 48024cef..46ccfda4 100644 --- a/internal/frontend/grpc/service_methods.go +++ b/internal/frontend/grpc/service_methods.go @@ -33,10 +33,8 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/updater" - "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/ProtonMail/proton-bridge/v3/pkg/ports" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/runtime/protoimpl" @@ -712,7 +710,7 @@ func (s *Service) IsPortFree(_ context.Context, port *wrapperspb.Int32Value) (*w func (s *Service) AvailableKeychains(_ context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) { s.log.Debug("AvailableKeychains") - return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil + return &AvailableKeychainsResponse{Keychains: s.bridge.GetHelpersNames()}, nil } func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) { diff --git a/pkg/keychain/helper_darwin.go b/pkg/keychain/helper_darwin.go index d43ac8ae..ee21675d 100644 --- a/pkg/keychain/helper_darwin.go +++ b/pkg/keychain/helper_darwin.go @@ -31,14 +31,18 @@ const ( MacOSKeychain = "macos-keychain" ) -func init() { //nolint:gochecknoinits - Helpers = make(map[string]helperConstructor) +func listHelpers() (Helpers, string) { + helpers := make(Helpers) // MacOS always provides a keychain. - Helpers[MacOSKeychain] = newMacOSHelper + if isUsable(newMacOSHelper("")) { + helpers[MacOSKeychain] = newMacOSHelper + } else { + logrus.WithField("keychain", "MacOSKeychain").Warn("Keychain is not available.") + } // Use MacOSKeychain by default. - DefaultHelper = MacOSKeychain + return helpers, MacOSKeychain } func parseError(original error) error { diff --git a/pkg/keychain/helper_linux.go b/pkg/keychain/helper_linux.go index ea146fc7..1ab288e9 100644 --- a/pkg/keychain/helper_linux.go +++ b/pkg/keychain/helper_linux.go @@ -18,8 +18,6 @@ package keychain import ( - "reflect" - "github.com/docker/docker-credential-helpers/credentials" "github.com/docker/docker-credential-helpers/pass" "github.com/docker/docker-credential-helpers/secretservice" @@ -33,30 +31,37 @@ const ( SecretServiceDBus = "secret-service-dbus" ) -func init() { //nolint:gochecknoinits - Helpers = make(map[string]helperConstructor) +func listHelpers() (Helpers, string) { + helpers := make(Helpers) if isUsable(newDBusHelper("")) { - Helpers[SecretServiceDBus] = newDBusHelper + helpers[SecretServiceDBus] = newDBusHelper + } else { + logrus.WithField("keychain", "SecretServiceDBus").Warn("Keychain is not available.") } if _, err := execabs.LookPath("gnome-keyring"); err == nil && isUsable(newSecretServiceHelper("")) { - Helpers[SecretService] = newSecretServiceHelper + helpers[SecretService] = newSecretServiceHelper + } else { + logrus.WithField("keychain", "SecretService").Warn("Keychain is not available.") } if _, err := execabs.LookPath("pass"); err == nil && isUsable(newPassHelper("")) { - Helpers[Pass] = newPassHelper + helpers[Pass] = newPassHelper + } else { + logrus.WithField("keychain", "Pass").Warn("Keychain is not available.") } - DefaultHelper = SecretServiceDBus + defaultHelper := SecretServiceDBus // If Pass is available, use it by default. // Otherwise, if SecretService is available, use it by default. - if _, ok := Helpers[Pass]; ok { - DefaultHelper = Pass - } else if _, ok := Helpers[SecretService]; ok { - DefaultHelper = SecretService + if _, ok := helpers[Pass]; ok { + defaultHelper = Pass + } else if _, ok := helpers[SecretService]; ok { + defaultHelper = SecretService } + return helpers, defaultHelper } func newDBusHelper(string) (credentials.Helper, error) { @@ -70,36 +75,3 @@ func newPassHelper(string) (credentials.Helper, error) { func newSecretServiceHelper(string) (credentials.Helper, error) { return &secretservice.Secretservice{}, nil } - -// isUsable returns whether the credentials helper is usable. -func isUsable(helper credentials.Helper, err error) bool { - l := logrus.WithField("helper", reflect.TypeOf(helper)) - - if err != nil { - l.WithError(err).Warn("Keychain helper couldn't be created") - return false - } - - creds := &credentials.Credentials{ - ServerURL: "bridge/check", - Username: "check", - Secret: "check", - } - - if err := helper.Add(creds); err != nil { - l.WithError(err).Warn("Failed to add test credentials to keychain") - return false - } - - if _, _, err := helper.Get(creds.ServerURL); err != nil { - l.WithError(err).Warn("Failed to get test credentials from keychain") - return false - } - - if err := helper.Delete(creds.ServerURL); err != nil { - l.WithError(err).Warn("Failed to delete test credentials from keychain") - return false - } - - return true -} diff --git a/pkg/keychain/helper_windows.go b/pkg/keychain/helper_windows.go index b4b8ccbd..718b42f1 100644 --- a/pkg/keychain/helper_windows.go +++ b/pkg/keychain/helper_windows.go @@ -20,18 +20,21 @@ package keychain import ( "github.com/docker/docker-credential-helpers/credentials" "github.com/docker/docker-credential-helpers/wincred" + "github.com/sirupsen/logrus" ) const WindowsCredentials = "windows-credentials" -func init() { //nolint:gochecknoinits - Helpers = make(map[string]helperConstructor) - +func listHelpers() (Helpers, string) { + helpers := make(Helpers) // Windows always provides a keychain. - Helpers[WindowsCredentials] = newWinCredHelper - + if isUsable(newWinCredHelper("")) { + helpers[WindowsCredentials] = newWinCredHelper + } else { + logrus.WithField("keychain", "WindowsCredentials").Warn("Keychain is not available.") + } // Use WindowsCredentials by default. - DefaultHelper = WindowsCredentials + return helpers, WindowsCredentials } func newWinCredHelper(string) (credentials.Helper, error) { diff --git a/pkg/keychain/keychain.go b/pkg/keychain/keychain.go index 57b895c9..7110fa39 100644 --- a/pkg/keychain/keychain.go +++ b/pkg/keychain/keychain.go @@ -21,9 +21,12 @@ package keychain import ( "errors" "fmt" + "reflect" "sync" + "time" "github.com/docker/docker-credential-helpers/credentials" + "github.com/sirupsen/logrus" ) // helperConstructor constructs a keychain helperConstructor. @@ -38,28 +41,53 @@ var ( // ErrMacKeychainRebuild is returned on macOS with blocked or corrupted keychain. ErrMacKeychainRebuild = errors.New("keychain error -25293") - - // Helpers holds all discovered keychain helpers. It is populated in init(). - Helpers map[string]helperConstructor //nolint:gochecknoglobals - - // DefaultHelper is the default helper to use if the user hasn't yet set a preference. - DefaultHelper string //nolint:gochecknoglobals ) +type Helpers map[string]helperConstructor + +type List struct { + helpers Helpers + defaultHelper string + locker sync.Locker +} + +// NewList checks availability of every keychains detected on the User Operating System +// This will ask the user to unlock keychain(s) to check their usability. +// This should only be called once. +func NewList() *List { + var list = List{locker: &sync.Mutex{}} + list.helpers, list.defaultHelper = listHelpers() + return &list +} + +func (kcl *List) GetHelpers() Helpers { + kcl.locker.Lock() + defer kcl.locker.Unlock() + + return kcl.helpers +} + +func (kcl *List) GetDefaultHelper() string { + kcl.locker.Lock() + defer kcl.locker.Unlock() + + return kcl.defaultHelper +} + // NewKeychain creates a new native keychain. -func NewKeychain(preferred, keychainName string) (*Keychain, error) { +func NewKeychain(preferred, keychainName string, helpers Helpers, defaultHelper string) (*Keychain, error) { // There must be at least one keychain helper available. - if len(Helpers) < 1 { + if len(helpers) < 1 { return nil, ErrNoKeychain } // If the preferred keychain is unsupported, fallback to the default one. - if _, ok := Helpers[preferred]; !ok { - preferred = DefaultHelper + if _, ok := helpers[preferred]; !ok { + preferred = defaultHelper } // Load the user's preferred keychain helper. - helperConstructor, ok := Helpers[preferred] + helperConstructor, ok := helpers[preferred] if !ok { return nil, ErrNoKeychain } @@ -163,3 +191,49 @@ func (kc *Keychain) Put(userID, secret string) error { func (kc *Keychain) secretURL(userID string) string { return fmt.Sprintf("%v/%v", kc.url, userID) } + +// isUsable returns whether the credentials helper is usable. +func isUsable(helper credentials.Helper, err error) bool { + l := logrus.WithField("helper", reflect.TypeOf(helper)) + + if err != nil { + l.WithError(err).Warn("Keychain helper couldn't be created") + return false + } + + creds := &credentials.Credentials{ + ServerURL: "bridge/check", + Username: "check", + Secret: "check", + } + + if err := retry(func() error { + return helper.Add(creds) + }); err != nil { + l.WithError(err).Warn("Failed to add test credentials to keychain") + return false + } + + if _, _, err := helper.Get(creds.ServerURL); err != nil { + l.WithError(err).Warn("Failed to get test credentials from keychain") + return false + } + + if err := helper.Delete(creds.ServerURL); err != nil { + l.WithError(err).Warn("Failed to delete test credentials from keychain") + return false + } + + return true +} + +func retry(condition func() error) error { + var maxRetry = 5 + for r := 0; ; r++ { + err := condition() + if err == nil || r >= maxRetry { + return err + } + time.Sleep(200 * time.Millisecond) + } +} diff --git a/pkg/keychain/test_helper.go b/pkg/keychain/test_helper.go index 505419d2..7a1fa6c7 100644 --- a/pkg/keychain/test_helper.go +++ b/pkg/keychain/test_helper.go @@ -17,10 +17,22 @@ package keychain -import "github.com/docker/docker-credential-helpers/credentials" +import ( + "sync" + + "github.com/docker/docker-credential-helpers/credentials" +) type TestHelper map[string]*credentials.Credentials +func NewTestKeychainsList() *List { + keychainHelper := NewTestHelper() + helpers := make(Helpers) + helpers["mock"] = func(string) (credentials.Helper, error) { return keychainHelper, nil } + var list = List{helpers: helpers, defaultHelper: "mock", locker: &sync.Mutex{}} + return &list +} + func NewTestHelper() TestHelper { return make(TestHelper) } diff --git a/tests/ctx_bridge_test.go b/tests/ctx_bridge_test.go index 0218bc99..e551510c 100644 --- a/tests/ctx_bridge_test.go +++ b/tests/ctx_bridge_test.go @@ -39,6 +39,7 @@ import ( "github.com/ProtonMail/proton-bridge/v3/internal/service" "github.com/ProtonMail/proton-bridge/v3/internal/useragent" "github.com/ProtonMail/proton-bridge/v3/internal/vault" + "github.com/ProtonMail/proton-bridge/v3/pkg/keychain" "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -153,6 +154,7 @@ func (t *testCtx) initBridge() (<-chan events.Event, error) { t.mocks.Autostarter, t.mocks.Updater, t.version, + keychain.NewTestKeychainsList(), // API stuff t.api.GetHostURL(), diff --git a/utils/vault-editor/main.go b/utils/vault-editor/main.go index 20a535d9..475ecad9 100644 --- a/utils/vault-editor/main.go +++ b/utils/vault-editor/main.go @@ -50,7 +50,7 @@ func main() { func readAction(c *cli.Context) error { return app.WithLocations(func(locations *locations.Locations) error { - return app.WithVault(locations, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { + return app.WithVault(locations, nil, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { if _, err := os.Stdout.Write(vault.ExportJSON()); err != nil { return fmt.Errorf("failed to write vault: %w", err) } @@ -62,7 +62,7 @@ func readAction(c *cli.Context) error { func writeAction(c *cli.Context) error { return app.WithLocations(func(locations *locations.Locations) error { - return app.WithVault(locations, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { + return app.WithVault(locations, nil, async.NoopPanicHandler{}, func(vault *vault.Vault, insecure, corrupt bool) error { b, err := io.ReadAll(os.Stdin) if err != nil { return fmt.Errorf("failed to read vault: %w", err)