feat: refresh expired access tokens in one goroutine
This commit is contained in:
parent
40e96b9d1e
commit
3f32fd95e0
|
@ -5,10 +5,11 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
credentials "github.com/ProtonMail/proton-bridge/internal/bridge/credentials"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockConfiger is a mock of Configer interface
|
||||
|
|
|
@ -528,7 +528,6 @@ func (u *User) Logout() (err error) {
|
|||
u.wasKeyringUnlocked = false
|
||||
u.unlockingKeyringLock.Unlock()
|
||||
|
||||
// TODO: Is this necessary or could it be done by ClientManager when a nil auth is received?
|
||||
u.client().Logout()
|
||||
|
||||
if err = u.credStorer.Logout(u.userID); err != nil {
|
||||
|
|
|
@ -74,11 +74,8 @@ func (l *listener) Add(eventName string, channel chan<- string) {
|
|||
if l.channels == nil {
|
||||
l.channels = make(map[string][]chan<- string)
|
||||
}
|
||||
if _, ok := l.channels[eventName]; ok {
|
||||
l.channels[eventName] = append(l.channels[eventName], channel)
|
||||
} else {
|
||||
l.channels[eventName] = []chan<- string{channel}
|
||||
}
|
||||
|
||||
l.channels[eventName] = append(l.channels[eventName], channel)
|
||||
}
|
||||
|
||||
// Remove removes an event listener.
|
||||
|
|
|
@ -470,7 +470,7 @@ func (c *client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||
return auth, err
|
||||
}
|
||||
|
||||
// TODO: Should this even be a client method? Or just a method on the client manager?
|
||||
// Logout instructs the client manager to log this client out.
|
||||
func (c *client) Logout() {
|
||||
c.cm.LogoutClient(c.userID)
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ type ClientManager struct {
|
|||
tokensLocker sync.Locker
|
||||
|
||||
expirations map[string]*tokenExpiration
|
||||
expiredTokens chan string
|
||||
expirationsLocker sync.Locker
|
||||
|
||||
bridgeAuths chan ClientAuth
|
||||
|
@ -76,6 +77,7 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||
tokensLocker: &sync.Mutex{},
|
||||
|
||||
expirations: make(map[string]*tokenExpiration),
|
||||
expiredTokens: make(chan string),
|
||||
expirationsLocker: &sync.Mutex{},
|
||||
|
||||
host: RootURL,
|
||||
|
@ -97,6 +99,8 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||
|
||||
go cm.forwardClientAuths()
|
||||
|
||||
go cm.watchTokenExpirations()
|
||||
|
||||
return cm
|
||||
}
|
||||
|
||||
|
@ -131,8 +135,10 @@ func (cm *ClientManager) GetAnonymousClient() Client {
|
|||
|
||||
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared.
|
||||
func (cm *ClientManager) LogoutClient(userID string) {
|
||||
client, ok := cm.clients[userID]
|
||||
cm.clientsLocker.Lock()
|
||||
defer cm.clientsLocker.Unlock()
|
||||
|
||||
client, ok := cm.clients[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -140,13 +146,16 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
|||
delete(cm.clients, userID)
|
||||
|
||||
go func() {
|
||||
if !strings.HasPrefix(userID, "anonymous-") {
|
||||
for client.DeleteAuth() == ErrAPINotReachable {
|
||||
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
|
||||
}
|
||||
defer client.ClearData()
|
||||
defer cm.clearToken(userID)
|
||||
|
||||
if strings.HasPrefix(userID, "anonymous-") {
|
||||
return
|
||||
}
|
||||
|
||||
for client.DeleteAuth() == ErrAPINotReachable {
|
||||
cm.log.Warn("Logging out client failed because API was not reachable, retrying...")
|
||||
}
|
||||
client.ClearData()
|
||||
cm.clearToken(userID)
|
||||
}()
|
||||
}
|
||||
|
||||
|
@ -281,9 +290,6 @@ func (cm *ClientManager) setToken(userID, token string, expiration time.Duration
|
|||
cm.tokens[userID] = token
|
||||
|
||||
cm.setTokenExpiration(userID, expiration)
|
||||
|
||||
// TODO: This should be one go routine per all tokens.
|
||||
go cm.watchTokenExpiration(userID)
|
||||
}
|
||||
|
||||
// setTokenExpiration will ensure the token is refreshed if it expires.
|
||||
|
@ -292,6 +298,9 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
|
|||
cm.expirationsLocker.Lock()
|
||||
defer cm.expirationsLocker.Unlock()
|
||||
|
||||
// Reduce the expiration by one minute so we can do the refresh with enough time to spare.
|
||||
expiration -= time.Minute
|
||||
|
||||
if exp, ok := cm.expirations[userID]; ok {
|
||||
exp.timer.Stop()
|
||||
close(exp.cancel)
|
||||
|
@ -301,6 +310,16 @@ func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Durat
|
|||
timer: time.NewTimer(expiration),
|
||||
cancel: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func(expiration *tokenExpiration) {
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
cm.expiredTokens <- userID
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
||||
}
|
||||
}(cm.expirations[userID])
|
||||
}
|
||||
|
||||
func (cm *ClientManager) clearToken(userID string) {
|
||||
|
@ -324,30 +343,35 @@ func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
|||
}
|
||||
|
||||
// If the auth is nil, we should clear the token.
|
||||
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself.
|
||||
if ca.Auth == nil {
|
||||
cm.clearToken(ca.UserID)
|
||||
go cm.LogoutClient(ca.UserID)
|
||||
return
|
||||
}
|
||||
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second)
|
||||
}
|
||||
|
||||
func (cm *ClientManager) watchTokenExpiration(userID string) {
|
||||
cm.expirationsLocker.Lock()
|
||||
expiration := cm.expirations[userID]
|
||||
cm.expirationsLocker.Unlock()
|
||||
func (cm *ClientManager) watchTokenExpirations() {
|
||||
for userID := range cm.expiredTokens {
|
||||
log := cm.log.WithField("userID", userID)
|
||||
|
||||
select {
|
||||
case <-expiration.timer.C:
|
||||
cm.log.WithField("userID", userID).Info("Auth token expired! Refreshing")
|
||||
if _, err := cm.clients[userID].AuthRefresh(cm.tokens[userID]); err != nil {
|
||||
cm.log.WithField("userID", userID).
|
||||
WithError(err).
|
||||
Error("Token refresh failed before expiration")
|
||||
log.Info("Auth token expired! Refreshing")
|
||||
|
||||
client, ok := cm.clients[userID]
|
||||
if !ok {
|
||||
log.Warn("Can't refresh expired token because there is no such client")
|
||||
continue
|
||||
}
|
||||
|
||||
case <-expiration.cancel:
|
||||
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
|
||||
token, ok := cm.tokens[userID]
|
||||
if !ok {
|
||||
log.Warn("Can't refresh expired token because there is no such token")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := client.AuthRefresh(token); err != nil {
|
||||
log.WithError(err).Error("Failed to refresh expired token")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,12 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
crypto "github.com/ProtonMail/gopenpgp/crypto"
|
||||
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface
|
||||
|
|
Loading…
Reference in New Issue