feat(GODT-2801): Identity Service

Identity Service contains all the information related to user state,
addresses and keys.

This patch also introduces the `State` type which can be used by other
services to maintain their own copy of this state to avoid lock
contention.

Finally, there are currently no external facing methods via a CPC
interface. Those will added as needed once the refactoring of the
architecture is complete.
This commit is contained in:
Leander Beernaert 2023-07-21 16:33:35 +02:00
parent 11f6f84dd6
commit 040ddadb7a
6 changed files with 1007 additions and 1 deletions

View File

@ -282,12 +282,14 @@ mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/telemetry HeartbeatManager > internal/telemetry/mocks/mocks.go
cp internal/telemetry/mocks/mocks.go internal/bridge/mocks/telemetry_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \
EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go
EventSource,EventIDStore > internal/services/userevents/mocks/mocks.go
mockgen --package userevents github.com/ProtonMail/proton-bridge/v3/internal/services/userevents \
MessageSubscriber,LabelSubscriber,AddressSubscriber,RefreshSubscriber,UserSubscriber,UserUsedSpaceSubscriber > tmp
mv tmp internal/services/userevents/mocks_test.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/events EventPublisher \
> internal/events/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider \
> internal/services/useridentity/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog

View File

@ -0,0 +1,34 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package userevents
// Subscribable represents a type that allows the registration of event subscribers.
type Subscribable interface {
Subscribe(subscription Subscription)
Unsubscribe(subscription Subscription)
}
type NoOpSubscribable struct{}
func (n NoOpSubscribable) Subscribe(_ Subscription) {
// Does nothing
}
func (n NoOpSubscribable) Unsubscribe(_ Subscription) {
// Does nothing
}

View File

@ -0,0 +1,66 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity (interfaces: IdentityProvider)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
reflect "reflect"
proton "github.com/ProtonMail/go-proton-api"
gomock "github.com/golang/mock/gomock"
)
// MockIdentityProvider is a mock of IdentityProvider interface.
type MockIdentityProvider struct {
ctrl *gomock.Controller
recorder *MockIdentityProviderMockRecorder
}
// MockIdentityProviderMockRecorder is the mock recorder for MockIdentityProvider.
type MockIdentityProviderMockRecorder struct {
mock *MockIdentityProvider
}
// NewMockIdentityProvider creates a new mock instance.
func NewMockIdentityProvider(ctrl *gomock.Controller) *MockIdentityProvider {
mock := &MockIdentityProvider{ctrl: ctrl}
mock.recorder = &MockIdentityProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockIdentityProvider) EXPECT() *MockIdentityProviderMockRecorder {
return m.recorder
}
// GetAddresses mocks base method.
func (m *MockIdentityProvider) GetAddresses(arg0 context.Context) ([]proton.Address, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAddresses", arg0)
ret0, _ := ret[0].([]proton.Address)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAddresses indicates an expected call of GetAddresses.
func (mr *MockIdentityProviderMockRecorder) GetAddresses(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockIdentityProvider)(nil).GetAddresses), arg0)
}
// GetUser mocks base method.
func (m *MockIdentityProvider) GetUser(arg0 context.Context) (proton.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUser", arg0)
ret0, _ := ret[0].(proton.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUser indicates an expected call of GetUser.
func (mr *MockIdentityProviderMockRecorder) GetUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockIdentityProvider)(nil).GetUser), arg0)
}

View File

@ -0,0 +1,270 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package useridentity
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
type IdentityProvider interface {
GetUser(ctx context.Context) (proton.User, error)
GetAddresses(ctx context.Context) ([]proton.Address, error)
}
// Service contains all the data required to establish the user identity. This
// includes all the user's information as well as mail addresses and keys.
type Service struct {
eventService userevents.Subscribable
eventPublisher events.EventPublisher
log *logrus.Entry
identity State
userSubscriber *userevents.UserChanneledSubscriber
addressSubscriber *userevents.AddressChanneledSubscriber
usedSpaceSubscriber *userevents.UserUsedSpaceChanneledSubscriber
refreshSubscriber *userevents.RefreshChanneledSubscriber
}
func NewService(
ctx context.Context,
service userevents.Subscribable,
user proton.User,
eventPublisher events.EventPublisher,
provider IdentityProvider,
) (*Service, error) {
addresses, err := provider.GetAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get addresses: %w", err)
}
subscriberName := fmt.Sprintf("identity-%v", user.ID)
return &Service{
eventService: service,
identity: NewState(user, addresses, provider),
eventPublisher: eventPublisher,
log: logrus.WithFields(logrus.Fields{
"service": "user-identity",
"user": user.ID,
}),
userSubscriber: userevents.NewUserSubscriber(subscriberName),
refreshSubscriber: userevents.NewRefreshSubscriber(subscriberName),
addressSubscriber: userevents.NewAddressSubscriber(subscriberName),
usedSpaceSubscriber: userevents.NewUserUsedSpaceSubscriber(subscriberName),
}, nil
}
func (s *Service) Start(group *async.Group) {
group.Once(func(ctx context.Context) {
s.run(ctx)
})
}
func (s *Service) run(ctx context.Context) {
s.log.WithFields(logrus.Fields{
"numAddr": len(s.identity.Addresses),
}).Info("Starting user identity service")
s.registerSubscription()
defer s.unregisterSubscription()
for {
select {
case <-ctx.Done():
return
case evt, ok := <-s.userSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(user proton.User) error {
s.onUserEvent(ctx, user)
return nil
})
case evt, ok := <-s.refreshSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(_ proton.RefreshFlag) error {
return s.onRefreshEvent(ctx)
})
case evt, ok := <-s.usedSpaceSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(usedSpace int) error {
s.onUserSpaceChanged(ctx, usedSpace)
return nil
})
case evt, ok := <-s.addressSubscriber.OnEventCh():
if !ok {
continue
}
evt.Consume(func(events []proton.AddressEvent) error {
return s.onAddressEvent(ctx, events)
})
}
}
}
func (s *Service) registerSubscription() {
s.eventService.Subscribe(userevents.Subscription{
Refresh: s.refreshSubscriber,
User: s.userSubscriber,
Address: s.addressSubscriber,
UserUsedSpace: s.usedSpaceSubscriber,
})
}
func (s *Service) unregisterSubscription() {
s.eventService.Unsubscribe(userevents.Subscription{
Refresh: s.refreshSubscriber,
User: s.userSubscriber,
Address: s.addressSubscriber,
UserUsedSpace: s.usedSpaceSubscriber,
})
}
func (s *Service) onUserEvent(ctx context.Context, user proton.User) {
s.log.WithField("username", logging.Sensitive(user.Name)).Info("Handling user event")
s.identity.OnUserEvent(user)
s.eventPublisher.PublishEvent(ctx, events.UserChanged{
UserID: user.ID,
})
}
func (s *Service) onRefreshEvent(ctx context.Context) error {
s.log.Info("Handling refresh event")
if err := s.identity.OnRefreshEvent(ctx); err != nil {
s.log.WithError(err).Error("Failed to handle refresh event")
return err
}
s.eventPublisher.PublishEvent(ctx, events.UserRefreshed{
UserID: s.identity.User.ID,
CancelEventPool: false,
})
return nil
}
func (s *Service) onUserSpaceChanged(ctx context.Context, value int) {
s.log.Info("Handling User Space Changed event")
if !s.identity.OnUserSpaceChanged(value) {
return
}
s.eventPublisher.PublishEvent(ctx, events.UsedSpaceChanged{
UserID: s.identity.User.ID,
UsedSpace: value,
})
}
func (s *Service) onAddressEvent(ctx context.Context, addressEvents []proton.AddressEvent) error {
s.log.Infof("Handling Address Events (%v)", len(addressEvents))
for idx, event := range addressEvents {
switch event.Action {
case proton.EventCreate:
s.log.WithFields(logrus.Fields{
"index": idx,
"addressID": event.ID,
"email": logging.Sensitive(event.Address.Email),
}).Info("Handling address created event")
if s.identity.OnAddressCreated(event) == AddressUpdateCreated {
s.eventPublisher.PublishEvent(ctx, events.UserAddressCreated{
UserID: s.identity.User.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
}
case proton.EventUpdate, proton.EventUpdateFlags:
addr, status := s.identity.OnAddressUpdated(event)
switch status {
case AddressUpdateCreated:
s.eventPublisher.PublishEvent(ctx, events.UserAddressCreated{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateUpdated:
s.eventPublisher.PublishEvent(ctx, events.UserAddressUpdated{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateDisabled:
s.eventPublisher.PublishEvent(ctx, events.UserAddressDisabled{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateEnabled:
s.eventPublisher.PublishEvent(ctx, events.UserAddressEnabled{
UserID: s.identity.User.ID,
AddressID: addr.ID,
Email: addr.Email,
})
case AddressUpdateNoop:
continue
case AddressUpdateDeleted:
s.log.Warnf("Unexpected address update status after update event %v", status)
continue
}
case proton.EventDelete:
if addr, status := s.identity.OnAddressDeleted(event); status == AddressUpdateDeleted {
s.eventPublisher.PublishEvent(ctx, events.UserAddressDeleted{
UserID: s.identity.User.ID,
AddressID: event.ID,
Email: addr.Email,
})
}
}
}
return nil
}
func sortAddresses(addr []proton.Address) []proton.Address {
slices.SortFunc(addr, func(a, b proton.Address) bool {
return a.Order < b.Order
})
return addr
}
func buildAddressMapFromSlice(addr []proton.Address) map[string]proton.Address {
return usertypes.GroupBy(addr, func(addr proton.Address) string { return addr.ID })
}

View File

@ -0,0 +1,447 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package useridentity
import (
"context"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
mocks2 "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
"github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
const TestUserID = "MyUserID"
func TestService_OnUserEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserChanged{UserID: TestUserID})).Times(1)
service.onUserEvent(context.Background(), newTestUser())
}
func TestService_OnUserSpaceChanged(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UsedSpaceChanged{UserID: TestUserID, UsedSpace: 1024})).Times(1)
// Original value, no changes.
service.onUserSpaceChanged(context.Background(), 0)
// New value, event should be published.
service.onUserSpaceChanged(context.Background(), 1024)
require.Equal(t, 1024, service.identity.User.UsedSpace)
}
func TestService_OnRefreshEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, provider := newTestService(t, mockCtrl)
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserRefreshed{UserID: TestUserID, CancelEventPool: false})).Times(1)
newUser := newTestUserRefreshed()
newAddresses := newTestAddressesRefreshed()
{
getUserCall := provider.EXPECT().GetUser(gomock.Any()).Times(1).Return(newUser, nil)
provider.EXPECT().GetAddresses(gomock.Any()).After(getUserCall).Times(1).Return(newAddresses, nil)
}
// Original value, no changes.
require.NoError(t, service.onRefreshEvent(context.Background()))
require.Equal(t, newUser, service.identity.User)
require.Equal(t, newAddresses, service.identity.AddressesSorted)
}
func TestService_OnAddressCreated(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
newAddress := proton.Address{
ID: "NewAddrID",
Email: "new@bar.com",
Status: proton.AddressStatusEnabled,
}
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressCreated{
UserID: TestUserID,
AddressID: newAddress.ID,
Email: newAddress.Email,
})).Times(1)
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventCreate,
},
Address: newAddress,
},
})
require.NoError(t, err)
require.Contains(t, service.identity.Addresses, newAddress.ID)
}
func TestService_OnAddressCreatedDisabledDoesNotProduceEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, _, _ := newTestService(t, mockCtrl)
newAddress := proton.Address{
ID: "Address1",
Email: "new@bar.com",
Status: proton.AddressStatusEnabled,
}
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventCreate,
},
Address: newAddress,
},
})
require.NoError(t, err)
require.Contains(t, service.identity.Addresses, newAddress.ID)
}
func TestService_OnAddressCreatedDuplicateDoesNotProduceEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, _, _ := newTestService(t, mockCtrl)
newAddress := proton.Address{
ID: "NewAddrID",
Email: "new@bar.com",
Status: proton.AddressStatusDisabled,
}
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventCreate,
},
Address: newAddress,
},
})
require.NoError(t, err)
require.Contains(t, service.identity.Addresses, newAddress.ID)
}
func TestService_OnAddressUpdated(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
newAddress := proton.Address{
ID: "Address1",
Email: "new@bar.com",
Status: proton.AddressStatusEnabled,
}
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressUpdated{
UserID: TestUserID,
AddressID: newAddress.ID,
Email: newAddress.Email,
})).Times(1)
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventUpdate,
},
Address: newAddress,
},
})
require.NoError(t, err)
require.Equal(t, newAddress, service.identity.Addresses[newAddress.ID])
}
func TestService_OnAddressUpdatedDisableFollowedByEnable(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
newAddressDisabled := proton.Address{
ID: "Address1",
Email: "new@bar.com",
Status: proton.AddressStatusDisabled,
}
newAddressEnabled := proton.Address{
ID: "Address1",
Email: "new@bar.com",
Status: proton.AddressStatusEnabled,
}
{
disabledCall := eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressDisabled{
UserID: TestUserID,
AddressID: newAddressDisabled.ID,
Email: newAddressDisabled.Email,
})).Times(1)
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressEnabled{
UserID: TestUserID,
AddressID: newAddressEnabled.ID,
Email: newAddressEnabled.Email,
})).Times(1).After(disabledCall)
}
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventUpdate,
},
Address: newAddressDisabled,
},
})
require.NoError(t, err)
require.Equal(t, newAddressDisabled, service.identity.Addresses[newAddressEnabled.ID])
err = service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventUpdate,
},
Address: newAddressEnabled,
},
})
require.NoError(t, err)
require.Equal(t, newAddressEnabled, service.identity.Addresses[newAddressEnabled.ID])
}
func TestService_OnAddressUpdateCreatedIfNotExists(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
newAddress := proton.Address{
ID: "NewAddrID",
Email: "new@bar.com",
Status: proton.AddressStatusEnabled,
}
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressCreated{
UserID: TestUserID,
AddressID: newAddress.ID,
Email: newAddress.Email,
})).Times(1)
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: "",
Action: proton.EventUpdate,
},
Address: newAddress,
},
})
require.NoError(t, err)
require.Contains(t, service.identity.Addresses, newAddress.ID)
}
func TestService_OnAddressDeleted(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, eventPublisher, _ := newTestService(t, mockCtrl)
address := proton.Address{
ID: "Address1",
Email: "foo@bar.com",
Status: proton.AddressStatusEnabled,
}
eventPublisher.EXPECT().PublishEvent(gomock.Any(), gomock.Eq(events.UserAddressDeleted{
UserID: TestUserID,
AddressID: address.ID,
Email: address.Email,
})).Times(1)
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: address.ID,
Action: proton.EventDelete,
},
},
})
require.NoError(t, err)
require.NotContains(t, service.identity.Addresses, address.ID)
}
func TestService_OnAddressDeleteDisabledDoesNotProduceEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, _, _ := newTestService(t, mockCtrl)
address := proton.Address{
ID: "Address2",
Email: "foo2@bar.com",
Status: proton.AddressStatusDisabled,
}
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: address.ID,
Action: proton.EventDelete,
},
},
})
require.NoError(t, err)
require.NotContains(t, service.identity.Addresses, address.ID)
}
func TestService_OnAddressDeletedUnknownDoesNotProduceEvent(t *testing.T) {
mockCtrl := gomock.NewController(t)
service, _, _ := newTestService(t, mockCtrl)
address := proton.Address{
ID: "UnknownID",
Email: "new@bar.com",
Status: proton.AddressStatusEnabled,
}
err := service.onAddressEvent(context.Background(), []proton.AddressEvent{
{
EventItem: proton.EventItem{
ID: address.ID,
Action: proton.EventDelete,
},
Address: address,
},
})
require.NoError(t, err)
}
func newTestService(t *testing.T, mockCtrl *gomock.Controller) (*Service, *mocks2.MockEventPublisher, *mocks.MockIdentityProvider) {
subscribable := &userevents.NoOpSubscribable{}
eventPublisher := mocks2.NewMockEventPublisher(mockCtrl)
provider := mocks.NewMockIdentityProvider(mockCtrl)
user := newTestUser()
provider.EXPECT().GetAddresses(gomock.Any()).Times(1).Return(newTestAddresses(), nil)
service, err := NewService(context.Background(), subscribable, user, eventPublisher, provider)
require.NoError(t, err)
return service, eventPublisher, provider
}
func newTestUser() proton.User {
return proton.User{
ID: TestUserID,
Name: "Foo",
DisplayName: "Foo",
Email: "foo@bar",
Keys: nil,
UsedSpace: 0,
MaxSpace: 0,
MaxUpload: 0,
Credit: 0,
Currency: "",
}
}
func newTestUserRefreshed() proton.User {
return proton.User{
ID: TestUserID,
Name: "Alternate",
DisplayName: "Universe",
Email: "foo2@bar",
Keys: nil,
UsedSpace: 0,
MaxSpace: 0,
MaxUpload: 0,
Credit: 0,
Currency: "USD",
}
}
func newTestAddresses() []proton.Address {
return []proton.Address{
{
ID: "Address1",
Email: "foo@bar.com",
Status: proton.AddressStatusEnabled,
Type: 0,
Order: 0,
DisplayName: "",
Keys: nil,
},
{
ID: "Address2",
Email: "foo2@bar.com",
Status: proton.AddressStatusDisabled,
Type: 0,
Order: 1,
DisplayName: "",
Keys: nil,
},
}
}
func newTestAddressesRefreshed() []proton.Address {
return []proton.Address{
{
ID: "Address1",
Email: "foo@bar.com",
Status: proton.AddressStatusEnabled,
Type: 0,
Order: 2,
DisplayName: "FOo barish",
Keys: nil,
},
{
ID: "Address2",
Email: "foo2@bar.com",
Status: proton.AddressStatusDisabled,
Type: 0,
Order: 4,
DisplayName: "New display name",
Keys: nil,
},
}
}

View File

@ -0,0 +1,187 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package useridentity
import (
"context"
"fmt"
"strings"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"golang.org/x/exp/maps"
)
// State holds all the required user identity state. The idea of this type is that
// it can be replicated across all services to avoid lock contention. The only
// requirement is that the service with the respective events.
type State struct {
AddressesSorted []proton.Address
Addresses map[string]proton.Address
User proton.User
provider IdentityProvider
}
func NewState(
user proton.User,
addresses []proton.Address,
provider IdentityProvider,
) State {
addressMap := buildAddressMapFromSlice(addresses)
return State{
AddressesSorted: sortAddresses(maps.Values(addressMap)),
Addresses: addressMap,
User: user,
provider: provider,
}
}
func NewStateFromProvider(ctx context.Context, provider IdentityProvider) (State, error) {
user, err := provider.GetUser(ctx)
if err != nil {
return State{}, fmt.Errorf("failed to get user: %w", err)
}
addresses, err := provider.GetAddresses(ctx)
if err != nil {
return State{}, fmt.Errorf("failed to get user addresses: %w", err)
}
return NewState(user, addresses, provider), nil
}
// GetAddr returns the address for the given email address.
func (s *State) GetAddr(email string) (proton.Address, error) {
for _, addr := range s.AddressesSorted {
if strings.EqualFold(addr.Email, usertypes.SanitizeEmail(email)) {
return addr, nil
}
}
return proton.Address{}, fmt.Errorf("address %s not found", email)
}
// GetPrimaryAddr returns the primary address for this user.
func (s *State) GetPrimaryAddr() (proton.Address, error) {
if len(s.AddressesSorted) == 0 {
return proton.Address{}, fmt.Errorf("no addresses available")
}
return s.AddressesSorted[0], nil
}
func (s *State) OnUserEvent(user proton.User) {
s.User = user
}
func (s *State) OnRefreshEvent(ctx context.Context) error {
user, err := s.provider.GetUser(ctx)
if err != nil {
return fmt.Errorf("failed to get user:%w", err)
}
addresses, err := s.provider.GetAddresses(ctx)
if err != nil {
return fmt.Errorf("failed to get addresses:%w", err)
}
s.User = user
s.Addresses = buildAddressMapFromSlice(addresses)
s.AddressesSorted = sortAddresses(maps.Values(s.Addresses))
return nil
}
func (s *State) OnUserSpaceChanged(value int) bool {
if s.User.UsedSpace == value {
return false
}
s.User.UsedSpace = value
return true
}
type AddressUpdate int
const (
AddressUpdateNoop AddressUpdate = iota
AddressUpdateCreated
AddressUpdateEnabled
AddressUpdateDisabled
AddressUpdateUpdated
AddressUpdateDeleted
)
func (s *State) OnAddressCreated(event proton.AddressEvent) AddressUpdate {
if _, ok := s.Addresses[event.Address.ID]; ok {
return AddressUpdateNoop
}
s.Addresses[event.Address.ID] = event.Address
s.AddressesSorted = sortAddresses(maps.Values(s.Addresses))
if event.Address.Status != proton.AddressStatusEnabled {
return AddressUpdateNoop
}
return AddressUpdateCreated
}
func (s *State) OnAddressUpdated(event proton.AddressEvent) (proton.Address, AddressUpdate) {
// Address does not exist create it.
oldAddr, ok := s.Addresses[event.Address.ID]
if !ok {
return event.Address, s.OnAddressCreated(event)
}
s.Addresses[event.Address.ID] = event.Address
s.AddressesSorted = sortAddresses(maps.Values(s.Addresses))
switch {
// If the address was newly enabled:
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
return event.Address, AddressUpdateEnabled
// If the address was newly disabled:
case oldAddr.Status == proton.AddressStatusEnabled && event.Address.Status != proton.AddressStatusEnabled:
return event.Address, AddressUpdateDisabled
// Otherwise it's just an update:
default:
return event.Address, AddressUpdateUpdated
}
}
func (s *State) OnAddressDeleted(event proton.AddressEvent) (proton.Address, AddressUpdate) {
addr, ok := s.Addresses[event.ID]
if !ok {
return proton.Address{}, AddressUpdateNoop
}
delete(s.Addresses, event.ID)
s.AddressesSorted = sortAddresses(maps.Values(s.Addresses))
if addr.Status != proton.AddressStatusEnabled {
return proton.Address{}, AddressUpdateNoop
}
return addr, AddressUpdateDeleted
}