proton-bridge/internal/users/users_test.go

320 lines
11 KiB
Go

// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
import (
"fmt"
"io/ioutil"
"os"
"runtime/debug"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
r "github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel)
}
if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel)
}
os.Exit(m.Run())
}
var (
testAuthRefresh = &pmapi.Auth{ //nolint[gochecknoglobals]
UserID: "user",
AuthRefresh: pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
},
}
testCredentials = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "user",
Name: "username",
Emails: "user@pm.me",
APIToken: "uid:acc",
MailboxPassword: []byte("pass"),
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: true,
}
testCredentialsSplit = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "users",
Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "uid:acc",
MailboxPassword: []byte("pass"),
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: false,
}
testCredentialsDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "userDisconnected",
Name: "username",
Emails: "user@pm.me",
APIToken: "",
MailboxPassword: []byte{},
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: true,
}
testCredentialsSplitDisconnected = &credentials.Credentials{ //nolint[gochecknoglobals]
UserID: "usersDisconnected",
Name: "usersname",
Emails: "users@pm.me;anotheruser@pm.me;alsouser@pm.me",
APIToken: "",
MailboxPassword: []byte{},
BridgePassword: "0123456789abcdef",
Version: "v1",
Timestamp: 123456789,
IsHidden: false,
IsCombinedAddressMode: false,
}
testPMAPIUser = &pmapi.User{ //nolint[gochecknoglobals]
ID: "user",
Name: "username",
}
testPMAPIUserDisconnected = &pmapi.User{ //nolint[gochecknoglobals]
ID: "userDisconnected",
Name: "username",
}
testPMAPIAddress = &pmapi.Address{ //nolint[gochecknoglobals]
ID: "testAddressID",
Type: pmapi.OriginalAddress,
Email: "user@pm.me",
Receive: true,
}
testPMAPIAddresses = []*pmapi.Address{ //nolint[gochecknoglobals]
{ID: "usersAddress1ID", Email: "users@pm.me", Receive: true, Type: pmapi.OriginalAddress},
{ID: "usersAddress2ID", Email: "anotheruser@pm.me", Receive: true, Type: pmapi.AliasAddress},
{ID: "usersAddress3ID", Email: "alsouser@pm.me", Receive: true, Type: pmapi.AliasAddress},
}
testPMAPIEvent = &pmapi.Event{ // nolint[gochecknoglobals]
EventID: "ACXDmTaBub14w==",
}
)
type mocks struct {
t *testing.T
ctrl *gomock.Controller
locator *usersmocks.MockLocator
PanicHandler *usersmocks.MockPanicHandler
credentialsStore *usersmocks.MockCredentialsStorer
storeMaker *usersmocks.MockStoreMaker
eventListener *usersmocks.MockListener
clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient
storeCache *store.Cache
}
func initMocks(t *testing.T) mocks {
var mockCtrl *gomock.Controller
if os.Getenv("VERBOSITY") == "trace" {
mockCtrl = gomock.NewController(&fullStackReporter{t})
} else {
mockCtrl = gomock.NewController(t)
}
cacheFile, err := ioutil.TempFile("", "bridge-store-cache-*.db")
r.NoError(t, err, "could not get temporary file for store cache")
m := mocks{
t: t,
ctrl: mockCtrl,
locator: usersmocks.NewMockLocator(mockCtrl),
PanicHandler: usersmocks.NewMockPanicHandler(mockCtrl),
credentialsStore: usersmocks.NewMockCredentialsStorer(mockCtrl),
storeMaker: usersmocks.NewMockStoreMaker(mockCtrl),
eventListener: usersmocks.NewMockListener(mockCtrl),
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
}
// Called during clean-up.
m.PanicHandler.EXPECT().HandlePanic().AnyTimes()
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
return m
}
type fullStackReporter struct {
T testing.TB
}
func (fr *fullStackReporter) Errorf(format string, args ...interface{}) {
fmt.Printf("err: "+format+"\n", args...)
fr.T.Fail()
}
func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
debug.PrintStack()
fmt.Printf("fail: "+format+"\n", args...)
fr.T.FailNow()
}
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(m, testCredentialsSplit)
mockEventLoopNoAction(m)
return testNewUsers(t, m)
}
func testNewUsers(t *testing.T, m mocks) *Users { //nolint[unparam]
m.eventListener.EXPECT().ProvideChannel(events.UpgradeApplicationEvent)
m.eventListener.EXPECT().ProvideChannel(events.InternetOnEvent)
users := New(m.locator, m.PanicHandler, m.eventListener, m.clientManager, m.credentialsStore, m.storeMaker)
waitForEvents()
return users
}
func waitForEvents() {
// Wait for goroutine to add listener.
// E.g. calling login to invoke firstsync event. Functions can end sooner than
// goroutines call the listener mock. We need to wait a little bit before the end of
// the test to capture all event calls. This allows us to detect whether there were
// missing calls, or perhaps whether something was called too many times.
time.Sleep(100 * time.Millisecond)
}
func cleanUpUsersData(b *Users) {
for _, user := range b.users {
_ = user.clearStore()
}
}
func mockAddingConnectedUser(m mocks) {
gomock.InOrder(
// Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), testCredentials.MailboxPassword).Return(nil),
m.pmapiClient.EXPECT().CurrentUser(gomock.Any()).Return(testPMAPIUser, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().Add("user", "username", testAuthRefresh.UID, testAuthRefresh.RefreshToken, testCredentials.MailboxPassword, []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
mockInitConnectedUser(m)
}
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
authRefresh := &pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
RefreshToken: "ref",
}
gomock.InOrder(
// Mock of users.loadUsersFromCredentialsStore.
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
m.clientManager.EXPECT().NewClientWithRefresh(gomock.Any(), "uid", "acc").Return(m.pmapiClient, authRefresh, nil),
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
)
mockInitConnectedUser(m)
}
func mockInitConnectedUser(m mocks) {
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
// Mock of store initialisation.
gomock.InOrder(
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
)
}
func mockLoadingDisconnectedUser(m mocks, creds *credentials.Credentials) {
gomock.InOrder(
// Mock of users.loadUsersFromCredentialsStore.
m.credentialsStore.EXPECT().Get(creds.UserID).Return(creds, nil),
m.clientManager.EXPECT().NewClient("", "", "", time.Time{}).Return(m.pmapiClient),
)
mockInitDisconnectedUser(m)
}
func mockInitDisconnectedUser(m mocks) {
gomock.InOrder(
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
// Mock of store initialisation for the unauthorized user.
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
)
}
func mockEventLoopNoAction(m mocks) {
// Set up mocks for starting the store's event loop (in store.New).
// The event loop runs in another goroutine so this might happen at any time.
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), "").Return(testPMAPIEvent, nil).AnyTimes()
m.pmapiClient.EXPECT().GetEvent(gomock.Any(), testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).AnyTimes()
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
}