feat: switch to proxy when need be
This commit is contained in:
parent
f239e8f3bf
commit
ce29d4d74e
|
@ -274,8 +274,11 @@ func run(context *cli.Context) (contextError error) { // nolint[funlen]
|
|||
log.Error("Could not get credentials store: ", credentialsError)
|
||||
}
|
||||
|
||||
clientman := pmapi.NewClientManager(pmapifactory.GetClientConfig(cfg, eventListener))
|
||||
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, clientman, credentialsStore)
|
||||
clientConfig := pmapifactory.GetClientConfig(cfg.GetAPIConfig())
|
||||
cm := pmapi.NewClientManager(clientConfig)
|
||||
pmapifactory.SetClientRoundTripper(cm, clientConfig, eventListener)
|
||||
|
||||
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, cm, credentialsStore)
|
||||
imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance)
|
||||
smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance)
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ func New(
|
|||
// Allow DoH before starting bridge if the user has previously set this setting.
|
||||
// This allows us to start even if protonmail is blocked.
|
||||
if pref.GetBool(preferences.AllowProxyKey) {
|
||||
AllowDoH()
|
||||
b.AllowProxy()
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
@ -178,15 +178,16 @@ func (b *Bridge) watchBridgeOutdated() {
|
|||
}
|
||||
}
|
||||
|
||||
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
|
||||
func (b *Bridge) watchUserAuths() {
|
||||
for auth := range b.clientManager.GetBridgeAuthChannel() {
|
||||
user, ok := b.hasUser(auth.UserID)
|
||||
logrus.WithField("token", auth.Auth.GenToken()).WithField("userID", auth.UserID).Info("Received auth from bridge auth channel")
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
if user, ok := b.hasUser(auth.UserID); ok {
|
||||
user.ReceiveAPIAuth(auth.Auth)
|
||||
} else {
|
||||
logrus.Info("User is not added to bridge yet")
|
||||
}
|
||||
|
||||
user.ReceiveAPIAuth(auth.Auth)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -274,7 +275,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||
apiClient := b.clientManager.GetClient(apiUser.ID)
|
||||
auth, err = apiClient.AuthRefresh(auth.GenToken())
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could refresh token in new client")
|
||||
log.WithError(err).Error("Could not refresh token in new client")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -298,6 +299,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
|
||||
return
|
||||
}
|
||||
b.users = append(b.users, user)
|
||||
}
|
||||
|
||||
// Set up the user auth and store (which we do for both new and existing users).
|
||||
|
@ -307,7 +309,6 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
|
|||
}
|
||||
|
||||
if !hasUser {
|
||||
b.users = append(b.users, user)
|
||||
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
|
||||
}
|
||||
|
||||
|
@ -475,16 +476,16 @@ func (b *Bridge) GetIMAPUpdatesChannel() chan interface{} {
|
|||
return b.idleUpdates
|
||||
}
|
||||
|
||||
// AllowDoH instructs bridge to use DoH to access an API proxy if necessary.
|
||||
// AllowProxy instructs bridge to use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before bridge is initialised (because we may need to use the proxy at startup).
|
||||
func AllowDoH() {
|
||||
pmapi.GlobalAllowDoH()
|
||||
func (b *Bridge) AllowProxy() {
|
||||
b.clientManager.AllowProxy()
|
||||
}
|
||||
|
||||
// DisallowDoH instructs bridge to not use DoH to access an API proxy if necessary.
|
||||
// DisallowProxy instructs bridge to not use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before bridge is initialised (because we may need to use the proxy at startup).
|
||||
func DisallowDoH() {
|
||||
pmapi.GlobalDisallowDoH()
|
||||
func (b *Bridge) DisallowProxy() {
|
||||
b.clientManager.DisallowProxy()
|
||||
}
|
||||
|
||||
func (b *Bridge) updateCurrentUserAgent() {
|
||||
|
@ -493,7 +494,11 @@ func (b *Bridge) updateCurrentUserAgent() {
|
|||
|
||||
// hasUser returns whether the bridge currently has a user with ID `id`.
|
||||
func (b *Bridge) hasUser(id string) (user *User, ok bool) {
|
||||
logrus.WithField("id", id).Info("Checking whether bridge has given user")
|
||||
|
||||
for _, u := range b.users {
|
||||
logrus.WithField("id", u.ID()).Info("Found potential user")
|
||||
|
||||
if u.ID() == id {
|
||||
user, ok = u, true
|
||||
return
|
||||
|
|
|
@ -107,6 +107,8 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
|
|||
u.wasKeyringUnlocked = false
|
||||
u.unlockingKeyringLock.Unlock()
|
||||
|
||||
u.log.Info("Initialising user")
|
||||
|
||||
// Reload the user's credentials (if they log out and back in we need the new
|
||||
// version with the apitoken and mailbox password).
|
||||
creds, err := u.credStorer.Get(u.userID)
|
||||
|
@ -242,27 +244,19 @@ func (u *User) authorizeAndUnlock() (err error) {
|
|||
}
|
||||
|
||||
func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
if auth == nil {
|
||||
if err := u.logout(); err != nil {
|
||||
u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API")
|
||||
u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
|
||||
}
|
||||
u.isAuthorized = false
|
||||
return
|
||||
}
|
||||
|
||||
u.updateAPIToken(auth.GenToken())
|
||||
}
|
||||
|
||||
// updateAPIToken is helper for updating the token in keychain. It's not supposed to be
|
||||
// called directly from other parts of the code, only from `ReceiveAPIAuth`.
|
||||
func (u *User) updateAPIToken(newRefreshToken string) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
u.log.WithField("token", newRefreshToken).Info("Saving token to credentials store")
|
||||
|
||||
if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil {
|
||||
u.log.WithError(err).Error("Cannot update refresh token in credentials store")
|
||||
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
|
||||
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/internal/preferences"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/connection"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/ports"
|
||||
|
@ -135,13 +134,13 @@ func (f *frontendCLI) toggleAllowProxy(c *ishell.Context) {
|
|||
f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.")
|
||||
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
|
||||
f.preferences.SetBool(preferences.AllowProxyKey, false)
|
||||
bridge.DisallowDoH()
|
||||
f.bridge.DisallowProxy()
|
||||
}
|
||||
} else {
|
||||
f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.")
|
||||
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
|
||||
f.preferences.SetBool(preferences.AllowProxyKey, true)
|
||||
bridge.AllowDoH()
|
||||
f.bridge.AllowProxy()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,6 +52,8 @@ type Bridger interface {
|
|||
DeleteUser(userID string, clearCache bool) error
|
||||
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
|
||||
ClearData() error
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
}
|
||||
|
||||
// BridgeUser is an interface of user needed by frontend.
|
||||
|
|
|
@ -21,11 +21,14 @@
|
|||
package pmapifactory
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
func GetClientConfig(config bridge.Configer, _ listener.Listener) *pmapi.ClientConfig {
|
||||
return config.GetAPIConfig()
|
||||
func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
|
||||
return clientConfig
|
||||
}
|
||||
|
||||
func SetClientRoundTripper(_ *pmapi.ClientManager, _ *pmapi.ClientConfig, _ listener.Listener) {
|
||||
// Use the default roundtripper; do nothing.
|
||||
}
|
||||
|
|
|
@ -23,26 +23,13 @@ package pmapifactory
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.ClientConfig {
|
||||
clientConfig := config.GetAPIConfig()
|
||||
|
||||
pin := pmapi.NewPMAPIPinning(clientConfig.AppVersion)
|
||||
pin.ReportCertIssueLocal = func() {
|
||||
listener.Emit(events.TLSCertIssue, "")
|
||||
}
|
||||
|
||||
// This transport already has timeouts set governing the roundtrip:
|
||||
// - IdleConnTimeout: 5 * time.Minute,
|
||||
// - ExpectContinueTimeout: 500 * time.Millisecond,
|
||||
// - ResponseHeaderTimeout: 30 * time.Second,
|
||||
clientConfig.Transport = pin.TransportWithPinning()
|
||||
|
||||
func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
|
||||
// We set additional timeouts/thresholds for the request as a whole:
|
||||
clientConfig.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable).
|
||||
clientConfig.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout.
|
||||
|
@ -50,3 +37,15 @@ func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.
|
|||
|
||||
return clientConfig
|
||||
}
|
||||
|
||||
func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) {
|
||||
logrus.Info("Setting dialer with pinning")
|
||||
|
||||
pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion)
|
||||
|
||||
pin.ReportCertIssueLocal = func() {
|
||||
listener.Emit(events.TLSCertIssue, "")
|
||||
}
|
||||
|
||||
cm.SetClientRoundTripper(pin.TransportWithPinning())
|
||||
}
|
||||
|
|
|
@ -39,7 +39,8 @@ var (
|
|||
// Two errors can be returned, ErrNoInternetConnection or ErrCanNotReachAPI.
|
||||
func CheckInternetConnection() error {
|
||||
client := &http.Client{
|
||||
Transport: pmapi.NewPMAPIPinning(pmapi.CurrentUserAgent).TransportWithPinning(),
|
||||
// TODO: Set transport properly! (Need access to ClientManager somehow)
|
||||
// Transport: pmapi.NewDialerWithPinning(pmapi.CurrentUserAgent).TransportWithPinning(),
|
||||
}
|
||||
|
||||
// Do not cumulate timeouts, use goroutines.
|
||||
|
@ -51,7 +52,8 @@ func CheckInternetConnection() error {
|
|||
go checkConnection(client, "http://protonstatus.com/vpn_status", retStatus)
|
||||
|
||||
// Check of API reachability also uses a fast endpoint.
|
||||
go checkConnection(client, pmapi.GlobalGetRootURL()+"/tests/ping", retAPI)
|
||||
// TODO: This should check active proxy, not the RootURL
|
||||
go checkConnection(client, pmapi.RootURL+"/tests/ping", retAPI)
|
||||
|
||||
errStatus := <-retStatus
|
||||
errAPI := <-retAPI
|
||||
|
|
|
@ -36,7 +36,7 @@ type osxkeychain struct {
|
|||
}
|
||||
|
||||
func newKeychain() (credentials.Helper, error) {
|
||||
log.Debug("creating osckeychain")
|
||||
log.Debug("Creating osckeychain")
|
||||
return &osxkeychain{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -24,14 +24,14 @@ import (
|
|||
)
|
||||
|
||||
func newKeychain() (credentials.Helper, error) {
|
||||
log.Debug("creating pass")
|
||||
log.Debug("Creating pass")
|
||||
passHelper := &pass.Pass{}
|
||||
passErr := checkPassIsUsable(passHelper)
|
||||
if passErr == nil {
|
||||
return passHelper, nil
|
||||
}
|
||||
|
||||
log.Debug("creating secretservice")
|
||||
log.Debug("Creating secretservice")
|
||||
sserviceHelper := &secretservice.Secretservice{}
|
||||
_, sserviceErr := sserviceHelper.List()
|
||||
if sserviceErr == nil {
|
||||
|
|
|
@ -23,7 +23,7 @@ import (
|
|||
)
|
||||
|
||||
func newKeychain() (credentials.Helper, error) {
|
||||
log.Debug("creating wincred")
|
||||
log.Debug("Creating wincred")
|
||||
return &wincred.Wincred{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ func ConstructAddress(headerEmail string, addressEmail string) string {
|
|||
|
||||
// GetAddresses requests all of current user addresses (without pagination).
|
||||
func (c *Client) GetAddresses() (addresses AddressList, err error) {
|
||||
req, err := NewRequest("GET", "/addresses", nil)
|
||||
req, err := c.NewRequest("GET", "/addresses", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.R
|
|||
//
|
||||
// The returned created attachment contains the new attachment ID and its size.
|
||||
func (c *Client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) {
|
||||
req, w, err := NewMultipartRequest("POST", "/attachments")
|
||||
req, w, err := c.NewMultipartRequest("POST", "/attachments")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ type UpdateAttachmentSignatureReq struct {
|
|||
|
||||
func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err error) {
|
||||
updateReq := &UpdateAttachmentSignatureReq{signature}
|
||||
req, err := NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq)
|
||||
req, err := c.NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err
|
|||
|
||||
// DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID.
|
||||
func (c *Client) DeleteAttachment(attID string) (err error) {
|
||||
req, err := NewRequest("DELETE", "/attachments/"+attID, nil)
|
||||
req, err := c.NewRequest("DELETE", "/attachments/"+attID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ func (c *Client) GetAttachment(id string) (att io.ReadCloser, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
req, err := NewRequest("GET", "/attachments/"+id, nil)
|
||||
req, err := c.NewRequest("GET", "/attachments/"+id, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -214,7 +214,7 @@ func (c *Client) AuthInfo(username string) (info *AuthInfo, err error) {
|
|||
Username: username,
|
||||
}
|
||||
|
||||
req, err := NewJSONRequest("POST", "/auth/info", infoReq)
|
||||
req, err := c.NewJSONRequest("POST", "/auth/info", infoReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -257,7 +257,7 @@ func (c *Client) tryAuth(username, password string, info *AuthInfo, fallbackVers
|
|||
SRPSession: info.srpSession,
|
||||
}
|
||||
|
||||
req, err := NewJSONRequest("POST", "/auth", authReq)
|
||||
req, err := c.NewJSONRequest("POST", "/auth", authReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -335,7 +335,7 @@ func (c *Client) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) {
|
|||
TwoFactorCode: twoFactorCode,
|
||||
}
|
||||
|
||||
req, err := NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
|
||||
req, err := c.NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -430,7 +430,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||
// UID must be set for `x-pm-uid` header field, see backend-communication#11
|
||||
c.uid = split[0]
|
||||
|
||||
req, err := NewJSONRequest("POST", "/auth/refresh", refreshReq)
|
||||
req, err := c.NewJSONRequest("POST", "/auth/refresh", refreshReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -450,13 +450,14 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
|
|||
return auth, err
|
||||
}
|
||||
|
||||
// Logout instructs the client manager to log out this client.
|
||||
func (c *Client) Logout() {
|
||||
c.cm.LogoutClient(c.userID)
|
||||
}
|
||||
|
||||
// logout logs the current user out.
|
||||
func (c *Client) logout() (err error) {
|
||||
req, err := NewRequest("DELETE", "/auth", nil)
|
||||
req, err := c.NewRequest("DELETE", "/auth", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -332,7 +332,9 @@ func TestClient_Logout(t *testing.T) {
|
|||
c.uid = testUID
|
||||
c.accessToken = testAccessToken
|
||||
|
||||
Ok(t, c.Logout())
|
||||
c.Logout()
|
||||
|
||||
// TODO: Check that the client is logged out and sensitive data is cleared eventually.
|
||||
}
|
||||
|
||||
func TestClient_DoUnauthorized(t *testing.T) {
|
||||
|
@ -355,7 +357,7 @@ func TestClient_DoUnauthorized(t *testing.T) {
|
|||
c.expiresAt = aLongTimeAgo
|
||||
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
|
||||
|
||||
req, err := NewRequest("GET", "/", nil)
|
||||
req, err := c.NewRequest("GET", "/", nil)
|
||||
Ok(t, err)
|
||||
|
||||
res, err := c.Do(req, true)
|
||||
|
|
|
@ -139,9 +139,9 @@ func (c *Client) Report(rep ReportReq) (err error) {
|
|||
var req *http.Request
|
||||
var w *MultipartWriter
|
||||
if len(rep.Attachments) > 0 {
|
||||
req, w, err = NewMultipartRequest("POST", "/reports/bug")
|
||||
req, w, err = c.NewMultipartRequest("POST", "/reports/bug")
|
||||
} else {
|
||||
req, err = NewJSONRequest("POST", "/reports/bug", rep)
|
||||
req, err = c.NewJSONRequest("POST", "/reports/bug", rep)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -202,7 +202,7 @@ func (c *Client) ReportCrash(stacktrace string) (err error) {
|
|||
OS: runtime.GOOS,
|
||||
Debug: stacktrace,
|
||||
}
|
||||
req, err := NewJSONRequest("POST", "/reports/crash", crashReq)
|
||||
req, err := c.NewJSONRequest("POST", "/reports/crash", crashReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -99,12 +99,12 @@ type ClientConfig struct {
|
|||
|
||||
// Client to communicate with API.
|
||||
type Client struct {
|
||||
cm *ClientManager
|
||||
client *http.Client
|
||||
cm *ClientManager
|
||||
hc *http.Client
|
||||
|
||||
uid string
|
||||
accessToken string
|
||||
userID string // Twice here because Username is not unique.
|
||||
userID string
|
||||
requestLocker sync.Locker
|
||||
keyLocker sync.Locker
|
||||
|
||||
|
@ -120,7 +120,7 @@ type Client struct {
|
|||
func newClient(cm *ClientManager, userID string) *Client {
|
||||
return &Client{
|
||||
cm: cm,
|
||||
client: getHTTPClient(cm.GetConfig()),
|
||||
hc: getHTTPClient(cm.GetConfig()),
|
||||
userID: userID,
|
||||
requestLocker: &sync.Mutex{},
|
||||
keyLocker: &sync.Mutex{},
|
||||
|
@ -132,12 +132,10 @@ func newClient(cm *ClientManager, userID string) *Client {
|
|||
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
|
||||
hc = &http.Client{Timeout: cfg.Timeout}
|
||||
|
||||
if cfg.Transport == nil && defaultTransport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if defaultTransport != nil {
|
||||
hc.Transport = defaultTransport
|
||||
if cfg.Transport == nil {
|
||||
if defaultTransport != nil {
|
||||
hc.Transport = defaultTransport
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -205,7 +203,7 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori
|
|||
}
|
||||
|
||||
hasBody := len(bodyBuffer) > 0
|
||||
if res, err = c.client.Do(req); err != nil {
|
||||
if res, err = c.hc.Do(req); err != nil {
|
||||
if res == nil {
|
||||
c.log.WithError(err).Error("Cannot get response")
|
||||
err = ErrAPINotReachable
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestClient_Do(t *testing.T) {
|
|||
}))
|
||||
defer s.Close()
|
||||
|
||||
req, err := NewRequest("GET", "/", nil)
|
||||
req, err := c.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal("Expected no error while creating request, got:", err)
|
||||
}
|
||||
|
@ -163,8 +163,8 @@ func TestClient_FirstReadTimeout(t *testing.T) {
|
|||
)
|
||||
defer finish()
|
||||
|
||||
c.client.Transport = &slowTransport{
|
||||
transport: c.client.Transport,
|
||||
c.hc.Transport = &slowTransport{
|
||||
transport: c.hc.Transport,
|
||||
firstBodySleep: requestTimeout,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,24 +1,32 @@
|
|||
package pmapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/getsentry/raven-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ClientManager is a manager of clients.
|
||||
type ClientManager struct {
|
||||
config *ClientConfig
|
||||
|
||||
clients map[string]*Client
|
||||
clientsLocker sync.Locker
|
||||
|
||||
tokens map[string]string
|
||||
tokensLocker sync.Locker
|
||||
|
||||
config *ClientConfig
|
||||
url string
|
||||
urlLocker sync.Locker
|
||||
|
||||
bridgeAuths chan ClientAuth
|
||||
clientAuths chan ClientAuth
|
||||
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
}
|
||||
|
||||
type ClientAuth struct {
|
||||
|
@ -33,13 +41,21 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||
}
|
||||
|
||||
cm = &ClientManager{
|
||||
config: config,
|
||||
|
||||
clients: make(map[string]*Client),
|
||||
clientsLocker: &sync.Mutex{},
|
||||
tokens: make(map[string]string),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
config: config,
|
||||
bridgeAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
|
||||
tokens: make(map[string]string),
|
||||
tokensLocker: &sync.Mutex{},
|
||||
|
||||
url: RootURL,
|
||||
urlLocker: &sync.Mutex{},
|
||||
|
||||
bridgeAuths: make(chan ClientAuth),
|
||||
clientAuths: make(chan ClientAuth),
|
||||
|
||||
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
|
||||
}
|
||||
|
||||
go cm.forwardClientAuths()
|
||||
|
@ -47,6 +63,12 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
|
|||
return
|
||||
}
|
||||
|
||||
// SetClientRoundTripper sets the roundtripper used by clients created by this client manager.
|
||||
func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) {
|
||||
logrus.Info("Setting client roundtripper")
|
||||
cm.config.Transport = rt
|
||||
}
|
||||
|
||||
// GetClient returns a client for the given userID.
|
||||
// If the client does not exist already, it is created.
|
||||
func (cm *ClientManager) GetClient(userID string) *Client {
|
||||
|
@ -71,7 +93,7 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
|||
|
||||
go func() {
|
||||
if err := client.logout(); err != nil {
|
||||
// TODO: Try again!
|
||||
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
|
||||
logrus.WithError(err).Error("Client logout failed, not trying again")
|
||||
}
|
||||
client.clearSensitiveData()
|
||||
|
@ -81,6 +103,56 @@ func (cm *ClientManager) LogoutClient(userID string) {
|
|||
return
|
||||
}
|
||||
|
||||
// GetRootURL returns the root URL to make requests to.
|
||||
// It does not include the protocol i.e. no "https://".
|
||||
func (cm *ClientManager) GetRootURL() string {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
return cm.url
|
||||
}
|
||||
|
||||
// SetRootURL sets the root URL to make requests to.
|
||||
func (cm *ClientManager) SetRootURL(url string) {
|
||||
cm.urlLocker.Lock()
|
||||
defer cm.urlLocker.Unlock()
|
||||
|
||||
logrus.WithField("url", url).Info("Changing to a new root URL")
|
||||
|
||||
cm.url = url
|
||||
}
|
||||
|
||||
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
|
||||
func (cm *ClientManager) IsProxyAllowed() bool {
|
||||
return cm.allowProxy
|
||||
}
|
||||
|
||||
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
|
||||
func (cm *ClientManager) AllowProxy() {
|
||||
cm.allowProxy = true
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
|
||||
func (cm *ClientManager) DisallowProxy() {
|
||||
cm.allowProxy = false
|
||||
}
|
||||
|
||||
// FindProxy returns a usable proxy server.
|
||||
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
|
||||
logrus.Info("Attempting gto switch to a proxy")
|
||||
|
||||
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
|
||||
err = errors.Wrap(err, "failed to find usable proxy")
|
||||
return
|
||||
}
|
||||
|
||||
cm.SetRootURL(proxy)
|
||||
|
||||
// TODO: Disable after 24 hours.
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetConfig returns the config used to configure clients.
|
||||
func (cm *ClientManager) GetConfig() *ClientConfig {
|
||||
return cm.config
|
||||
|
@ -113,7 +185,7 @@ func (cm *ClientManager) setToken(userID, token string) {
|
|||
cm.tokensLocker.Lock()
|
||||
defer cm.tokensLocker.Unlock()
|
||||
|
||||
logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token")
|
||||
logrus.WithField("userID", userID).Info("Updating refresh token")
|
||||
|
||||
cm.tokens[userID] = token
|
||||
}
|
||||
|
@ -127,16 +199,19 @@ func (cm *ClientManager) clearToken(userID string) {
|
|||
delete(cm.tokens, userID)
|
||||
}
|
||||
|
||||
// handleClientAuth
|
||||
// handleClientAuth updates or clears client authorisation based on auths received.
|
||||
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
|
||||
// TODO: Maybe want to logout the client in case of nil auth.
|
||||
// If we aren't managing this client, there's nothing to do.
|
||||
if _, ok := cm.clients[ca.UserID]; !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// If the auth is nil, we should clear the token.
|
||||
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself.
|
||||
if ca.Auth == nil {
|
||||
cm.clearToken(ca.UserID)
|
||||
} else {
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken())
|
||||
return
|
||||
}
|
||||
|
||||
cm.setToken(ca.UserID, ca.Auth.GenToken())
|
||||
}
|
||||
|
|
|
@ -24,9 +24,9 @@ import (
|
|||
|
||||
// RootURL is the API root URL.
|
||||
//
|
||||
// This can be changed using build flags: pmapi_local for "http://localhost/api",
|
||||
// pmapi_dev or pmapi_prod. Default is pmapi_prod.
|
||||
var RootURL = "https://api.protonmail.ch" //nolint[gochecknoglobals]
|
||||
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
|
||||
// Default is pmapi_prod.
|
||||
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
|
||||
|
||||
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
|
||||
// version and email client.
|
||||
|
|
|
@ -20,5 +20,5 @@
|
|||
package pmapi
|
||||
|
||||
func init() {
|
||||
RootURL = "https://dev.protonmail.com/api"
|
||||
RootURL = "dev.protonmail.com/api"
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
func init() {
|
||||
// Use port above 1000 which doesn't need root access to start anything on it.
|
||||
// Now the port is rounded pi. :-)
|
||||
RootURL = "http://127.0.0.1:3142/api"
|
||||
RootURL = "127.0.0.1:3142/api"
|
||||
|
||||
// TLS certificate is self-signed
|
||||
defaultTransport = &http.Transport{
|
||||
|
|
|
@ -119,7 +119,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e
|
|||
if pageSize > 0 {
|
||||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
req, err := NewRequest("GET", "/contacts?"+v.Encode(), nil)
|
||||
req, err := c.NewRequest("GET", "/contacts?"+v.Encode(), nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -136,7 +136,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e
|
|||
|
||||
// GetContactByID gets contact details specified by contact ID.
|
||||
func (c *Client) GetContactByID(id string) (contactDetail Contact, err error) {
|
||||
req, err := NewRequest("GET", "/contacts/"+id, nil)
|
||||
req, err := c.NewRequest("GET", "/contacts/"+id, nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -164,7 +164,7 @@ func (c *Client) GetContactsForExport(page int, pageSize int) (contacts []Contac
|
|||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
|
||||
req, err := NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
|
||||
req, err := c.NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -198,7 +198,7 @@ func (c *Client) GetAllContactsEmails(page int, pageSize int) (contactsEmails []
|
|||
v.Set("PageSize", strconv.Itoa(pageSize))
|
||||
}
|
||||
|
||||
req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
|
||||
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ func (c *Client) GetContactEmailByEmail(email string, page int, pageSize int) (c
|
|||
}
|
||||
v.Set("Email", email)
|
||||
|
||||
req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
|
||||
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -276,7 +276,7 @@ func (c *Client) AddContacts(cards ContactsCards, overwrite int, groups int, lab
|
|||
Labels: labels,
|
||||
}
|
||||
|
||||
req, err := NewJSONRequest("POST", "/contacts", reqBody)
|
||||
req, err := c.NewJSONRequest("POST", "/contacts", reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ func (c *Client) UpdateContact(id string, cards []Card) (res *UpdateContactRespo
|
|||
reqBody := UpdateContactReq{
|
||||
Cards: cards,
|
||||
}
|
||||
req, err := NewJSONRequest("PUT", "/contacts/"+id, reqBody)
|
||||
req, err := c.NewJSONRequest("PUT", "/contacts/"+id, reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ func (c *Client) modifyContactGroups(groupID string, modifyContactGroupsAction i
|
|||
Action: modifyContactGroupsAction,
|
||||
ContactEmailIDs: contactEmailIDs,
|
||||
}
|
||||
req, err := NewJSONRequest("PUT", "/contacts/group", reqBody)
|
||||
req, err := c.NewJSONRequest("PUT", "/contacts/group", reqBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -377,7 +377,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) {
|
|||
IDs: ids,
|
||||
}
|
||||
|
||||
req, err := NewJSONRequest("PUT", "/contacts/delete", deleteReq)
|
||||
req, err := c.NewJSONRequest("PUT", "/contacts/delete", deleteReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -402,7 +402,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) {
|
|||
|
||||
// DeleteAllContacts deletes all contacts.
|
||||
func (c *Client) DeleteAllContacts() (err error) {
|
||||
req, err := NewRequest("DELETE", "/contacts", nil)
|
||||
req, err := c.NewRequest("DELETE", "/contacts", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func (c *Client) CountConversations(addressID string) (counts []*ConversationsCo
|
|||
if addressID != "" {
|
||||
reqURL += ("?AddressID=" + addressID)
|
||||
}
|
||||
req, err := NewRequest("GET", reqURL, nil)
|
||||
req, err := c.NewRequest("GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -112,51 +112,47 @@ type DialerWithPinning struct {
|
|||
// It is used only if set.
|
||||
ReportCertIssueLocal func()
|
||||
|
||||
// proxyManager manages API proxies.
|
||||
proxyManager *proxyManager
|
||||
// cm is used to find and switch to a proxy if necessary.
|
||||
cm *ClientManager
|
||||
|
||||
// A logger for logging messages.
|
||||
log logrus.FieldLogger
|
||||
}
|
||||
|
||||
func NewDialerWithPinning(reportURI string, report TLSReport) *DialerWithPinning {
|
||||
// NewDialerWithPinning constructs a new dialer with pinned certs.
|
||||
func NewDialerWithPinning(cm *ClientManager, appVersion string) *DialerWithPinning {
|
||||
reportURI := "https://reports.protonmail.ch/reports/tls"
|
||||
|
||||
report := TLSReport{
|
||||
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
|
||||
IncludeSubdomains: false,
|
||||
ValidatedCertificateChain: []string{},
|
||||
ServedCertificateChain: []string{},
|
||||
AppVersion: appVersion,
|
||||
|
||||
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
|
||||
KnownPins: []string{
|
||||
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
|
||||
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
|
||||
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
|
||||
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
|
||||
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
|
||||
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
|
||||
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
|
||||
},
|
||||
}
|
||||
|
||||
log := logrus.WithField("pkg", "pmapi/tls-pinning")
|
||||
|
||||
proxyManager := newProxyManager(dohProviders, proxyQuery)
|
||||
|
||||
return &DialerWithPinning{
|
||||
isReported: false,
|
||||
reportURI: reportURI,
|
||||
report: report,
|
||||
proxyManager: proxyManager,
|
||||
log: log,
|
||||
cm: cm,
|
||||
isReported: false,
|
||||
reportURI: reportURI,
|
||||
report: report,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func NewPMAPIPinning(appVersion string) *DialerWithPinning {
|
||||
return NewDialerWithPinning(
|
||||
"https://reports.protonmail.ch/reports/tls",
|
||||
TLSReport{
|
||||
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
|
||||
IncludeSubdomains: false,
|
||||
ValidatedCertificateChain: []string{},
|
||||
ServedCertificateChain: []string{},
|
||||
AppVersion: appVersion,
|
||||
|
||||
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
|
||||
KnownPins: []string{
|
||||
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
|
||||
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
|
||||
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
|
||||
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
|
||||
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
|
||||
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
|
||||
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) {
|
||||
p.isReported = true
|
||||
|
||||
|
@ -231,6 +227,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
|||
return pemCerts
|
||||
}
|
||||
|
||||
// TransportWithPinning creates an http.Transport that checks fingerprints when dialing.
|
||||
func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
|
@ -258,7 +255,7 @@ func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
|
|||
// p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil.
|
||||
func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) {
|
||||
// If DoH is enabled, we hardfail on fingerprint mismatches.
|
||||
if globalIsDoHAllowed() && p.isReported {
|
||||
if p.cm.IsProxyAllowed() && p.isReported {
|
||||
return nil, ErrTLSMatch
|
||||
}
|
||||
|
||||
|
@ -283,6 +280,8 @@ func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (c
|
|||
|
||||
// dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be.
|
||||
func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
|
||||
p.log.Info("Dialing with proxy fallback")
|
||||
|
||||
var host, port string
|
||||
if host, port, err = net.SplitHostPort(address); err != nil {
|
||||
return
|
||||
|
@ -296,21 +295,18 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
|
|||
// If DoH is not allowed, give up. Or, if we are dialing something other than the API
|
||||
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
|
||||
// continuing since a proxy won't help us reach that.
|
||||
if !globalIsDoHAllowed() || host != stripProtocol(GlobalGetRootURL()) {
|
||||
if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
|
||||
p.log.WithField("useProxy", p.cm.IsProxyAllowed()).Info("Dial failed but not switching to proxy")
|
||||
return
|
||||
}
|
||||
|
||||
// Find a new proxy.
|
||||
// Switch to a proxy and retry the dial.
|
||||
var proxy string
|
||||
if proxy, err = p.proxyManager.findProxy(); err != nil {
|
||||
|
||||
if proxy, err = p.cm.SwitchToProxy(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Switch to the proxy.
|
||||
p.log.WithField("proxy", proxy).Debug("Switching to proxy")
|
||||
p.proxyManager.useProxy(proxy)
|
||||
|
||||
// Retry dial with proxy.
|
||||
return p.dial(network, net.JoinHostPort(proxy, port))
|
||||
}
|
||||
|
||||
|
@ -329,7 +325,7 @@ func (p *DialerWithPinning) dial(network, address string) (conn net.Conn, err er
|
|||
|
||||
// If we are not dialing the standard API then we should skip cert verification checks.
|
||||
var tlsConfig *tls.Config = nil
|
||||
if address != stripProtocol(globalOriginalURL) {
|
||||
if address != RootURL {
|
||||
tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
|
||||
}
|
||||
|
||||
|
|
|
@ -179,7 +179,7 @@ func (c *Client) GetEvent(last string) (event *Event, err error) {
|
|||
func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) {
|
||||
var req *http.Request
|
||||
if last == "" {
|
||||
req, err = NewRequest("GET", "/events/latest", nil)
|
||||
req, err = c.NewRequest("GET", "/events/latest", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event,
|
|||
|
||||
event, err = res.Event, res.Err()
|
||||
} else {
|
||||
req, err = NewRequest("GET", "/events/"+last, nil)
|
||||
req, err = c.NewRequest("GET", "/events/"+last, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ type ImportMsgRes struct {
|
|||
func (c *Client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) {
|
||||
importReq := &ImportReq{Messages: reqs}
|
||||
|
||||
req, w, err := NewMultipartRequest("POST", "/import")
|
||||
req, w, err := c.NewMultipartRequest("POST", "/import")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ func (c *Client) PublicKeys(emails []string) (keys map[string]*pmcrypto.KeyRing,
|
|||
email = url.QueryEscape(email)
|
||||
|
||||
var req *http.Request
|
||||
if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil {
|
||||
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -90,7 +90,7 @@ func (c *Client) GetPublicKeysForEmail(email string) (keys []PublicKey, internal
|
|||
email = url.QueryEscape(email)
|
||||
|
||||
var req *http.Request
|
||||
if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil {
|
||||
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -123,7 +123,7 @@ type KeySaltRes struct {
|
|||
// GetKeySalts sends request to get list of key salts (n.b. locked route).
|
||||
func (c *Client) GetKeySalts() (keySalts []KeySalt, err error) {
|
||||
var req *http.Request
|
||||
if req, err = NewRequest("GET", "/keys/salts", nil); err != nil {
|
||||
if req, err = c.NewRequest("GET", "/keys/salts", nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ func (c *Client) ListContactGroups() (labels []*Label, err error) {
|
|||
|
||||
// ListLabelType lists all labels created by the user.
|
||||
func (c *Client) ListLabelType(labelType int) (labels []*Label, err error) {
|
||||
req, err := NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil)
|
||||
req, err := c.NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ type LabelRes struct {
|
|||
// CreateLabel creates a new label.
|
||||
func (c *Client) CreateLabel(label *Label) (created *Label, err error) {
|
||||
labelReq := &LabelReq{label}
|
||||
req, err := NewJSONRequest("POST", "/labels", labelReq)
|
||||
req, err := c.NewJSONRequest("POST", "/labels", labelReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func (c *Client) CreateLabel(label *Label) (created *Label, err error) {
|
|||
// UpdateLabel updates a label.
|
||||
func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) {
|
||||
labelReq := &LabelReq{label}
|
||||
req, err := NewJSONRequest("PUT", "/labels/"+label.ID, labelReq)
|
||||
req, err := c.NewJSONRequest("PUT", "/labels/"+label.ID, labelReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) {
|
|||
|
||||
// DeleteLabel deletes a label.
|
||||
func (c *Client) DeleteLabel(id string) (err error) {
|
||||
req, err := NewRequest("DELETE", "/labels/"+id, nil)
|
||||
req, err := c.NewRequest("DELETE", "/labels/"+id, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -468,7 +468,7 @@ type MessagesListRes struct {
|
|||
|
||||
// ListMessages gets message metadata.
|
||||
func (c *Client) ListMessages(filter *MessagesFilter) (msgs []*Message, total int, err error) {
|
||||
req, err := NewRequest("GET", "/messages", nil)
|
||||
req, err := c.NewRequest("GET", "/messages", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -500,7 +500,7 @@ func (c *Client) CountMessages(addressID string) (counts []*MessagesCount, err e
|
|||
if addressID != "" {
|
||||
reqURL += ("?AddressID=" + addressID)
|
||||
}
|
||||
req, err := NewRequest("GET", reqURL, nil)
|
||||
req, err := c.NewRequest("GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -522,7 +522,7 @@ type MessageRes struct {
|
|||
|
||||
// GetMessage retrieves a message.
|
||||
func (c *Client) GetMessage(id string) (msg *Message, err error) {
|
||||
req, err := NewRequest("GET", "/messages/"+id, nil)
|
||||
req, err := c.NewRequest("GET", "/messages/"+id, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -599,7 +599,7 @@ func (c *Client) SendMessage(id string, sendReq *SendMessageReq) (sent, parent *
|
|||
sendReq.Packages = []*MessagePackage{}
|
||||
}
|
||||
|
||||
req, err := NewJSONRequest("POST", "/messages/"+id, sendReq)
|
||||
req, err := c.NewJSONRequest("POST", "/messages/"+id, sendReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -629,7 +629,7 @@ type DraftReq struct {
|
|||
func (c *Client) CreateDraft(m *Message, parent string, action int) (created *Message, err error) {
|
||||
createReq := &DraftReq{Message: m, ParentID: parent, Action: action, AttachmentKeyPackets: []string{}}
|
||||
|
||||
req, err := NewJSONRequest("POST", "/messages", createReq)
|
||||
req, err := c.NewJSONRequest("POST", "/messages", createReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -688,7 +688,7 @@ func (c *Client) doMessagesAction(action string, ids []string) (err error) {
|
|||
// You should not call this directly unless you know what you are doing (it can overload the server).
|
||||
func (c *Client) doMessagesActionInner(action string, ids []string) (err error) {
|
||||
actionReq := &MessagesActionReq{IDs: ids}
|
||||
req, err := NewJSONRequest("PUT", "/messages/"+action, actionReq)
|
||||
req, err := c.NewJSONRequest("PUT", "/messages/"+action, actionReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -740,7 +740,7 @@ func (c *Client) LabelMessages(ids []string, label string) (err error) {
|
|||
|
||||
func (c *Client) labelMessages(ids []string, label string) (err error) {
|
||||
labelReq := &LabelMessagesReq{LabelID: label, IDs: ids}
|
||||
req, err := NewJSONRequest("PUT", "/messages/label", labelReq)
|
||||
req, err := c.NewJSONRequest("PUT", "/messages/label", labelReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -770,7 +770,7 @@ func (c *Client) UnlabelMessages(ids []string, label string) (err error) {
|
|||
|
||||
func (c *Client) unlabelMessages(ids []string, label string) (err error) {
|
||||
labelReq := &LabelMessagesReq{LabelID: label, IDs: ids}
|
||||
req, err := NewJSONRequest("PUT", "/messages/unlabel", labelReq)
|
||||
req, err := c.NewJSONRequest("PUT", "/messages/unlabel", labelReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -793,7 +793,7 @@ func (c *Client) EmptyFolder(labelID, addressID string) (err error) {
|
|||
reqURL += ("&AddressID=" + addressID)
|
||||
}
|
||||
|
||||
req, err := NewRequest("DELETE", reqURL, nil)
|
||||
req, err := c.NewRequest("DELETE", reqURL, nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -28,7 +28,7 @@ func (c *Client) SendSimpleMetric(category, action, label string) (err error) {
|
|||
v.Set("Action", action)
|
||||
v.Set("Label", label)
|
||||
|
||||
req, err := NewRequest("GET", "/metrics?"+v.Encode(), nil)
|
||||
req, err := c.NewRequest("GET", "/metrics?"+v.Encode(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
|
@ -43,63 +42,8 @@ var dohProviders = []string{ //nolint[gochecknoglobals]
|
|||
"https://dns.google/dns-query",
|
||||
}
|
||||
|
||||
// globalAllowDoH controls whether or not to enable use of DoH/Proxy in pmapi.
|
||||
var globalAllowDoH = false // nolint[golint]
|
||||
|
||||
// globalProxyMutex allows threadsafe modification of proxy state.
|
||||
var globalProxyMutex = sync.RWMutex{} // nolint[golint]
|
||||
|
||||
// globalOriginalURL backs up the original API url so it can be restored later.
|
||||
var globalOriginalURL = RootURL // nolint[golint]
|
||||
|
||||
// globalIsDoHAllowed returns whether or not to use DoH.
|
||||
func globalIsDoHAllowed() bool { // nolint[golint]
|
||||
globalProxyMutex.RLock()
|
||||
defer globalProxyMutex.RUnlock()
|
||||
|
||||
return globalAllowDoH
|
||||
}
|
||||
|
||||
// GlobalAllowDoH enables DoH.
|
||||
func GlobalAllowDoH() { // nolint[golint]
|
||||
globalProxyMutex.Lock()
|
||||
defer globalProxyMutex.Unlock()
|
||||
|
||||
globalAllowDoH = true
|
||||
}
|
||||
|
||||
// GlobalDisallowDoH disables DoH and sets the RootURL back to what it was.
|
||||
func GlobalDisallowDoH() { // nolint[golint]
|
||||
globalProxyMutex.Lock()
|
||||
defer globalProxyMutex.Unlock()
|
||||
|
||||
globalAllowDoH = false
|
||||
RootURL = globalOriginalURL
|
||||
}
|
||||
|
||||
// globalSetRootURL sets the global RootURL.
|
||||
func globalSetRootURL(url string) { // nolint[golint]
|
||||
globalProxyMutex.Lock()
|
||||
defer globalProxyMutex.Unlock()
|
||||
|
||||
RootURL = url
|
||||
}
|
||||
|
||||
// GlobalGetRootURL returns the global RootURL.
|
||||
func GlobalGetRootURL() (url string) { // nolint[golint]
|
||||
globalProxyMutex.RLock()
|
||||
defer globalProxyMutex.RUnlock()
|
||||
|
||||
return RootURL
|
||||
}
|
||||
|
||||
// isProxyEnabled returns whether or not we are currently using a proxy.
|
||||
func isProxyEnabled() bool { // nolint[golint]
|
||||
return globalOriginalURL != GlobalGetRootURL()
|
||||
}
|
||||
|
||||
// proxyManager manages known proxies.
|
||||
type proxyManager struct {
|
||||
// proxyProvider manages known proxies.
|
||||
type proxyProvider struct {
|
||||
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
|
||||
dohLookup func(query, provider string) (urls []string, err error)
|
||||
|
||||
|
@ -113,10 +57,10 @@ type proxyManager struct {
|
|||
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
||||
}
|
||||
|
||||
// newProxyManager creates a new proxyManager that queries the given DoH providers
|
||||
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
|
||||
// to retrieve DNS records for the given query string.
|
||||
func newProxyManager(providers []string, query string) (p *proxyManager) { // nolint[unparam]
|
||||
p = &proxyManager{
|
||||
func newProxyProvider(providers []string, query string) (p *proxyProvider) { // nolint[unparam]
|
||||
p = &proxyProvider{
|
||||
providers: providers,
|
||||
query: query,
|
||||
useDuration: proxyRevertTime,
|
||||
|
@ -132,7 +76,7 @@ func newProxyManager(providers []string, query string) (p *proxyManager) { // no
|
|||
|
||||
// findProxy returns a new proxy domain which is not equal to the current RootURL.
|
||||
// It returns an error if the process takes longer than ProxySearchTime.
|
||||
func (p *proxyManager) findProxy() (proxy string, err error) {
|
||||
func (p *proxyProvider) findProxy() (proxy string, err error) {
|
||||
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
|
||||
return "", errors.New("not looking for a proxy, too soon")
|
||||
}
|
||||
|
@ -147,7 +91,7 @@ func (p *proxyManager) findProxy() (proxy string, err error) {
|
|||
}
|
||||
|
||||
for _, proxy := range p.proxyCache {
|
||||
if proxy != stripProtocol(GlobalGetRootURL()) && p.canReach(proxy) {
|
||||
if p.canReach(proxy) {
|
||||
proxyResult <- proxy
|
||||
return
|
||||
}
|
||||
|
@ -171,25 +115,8 @@ func (p *proxyManager) findProxy() (proxy string, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// useProxy sets the proxy server to use. It returns to the original RootURL after 24 hours.
|
||||
func (p *proxyManager) useProxy(proxy string) {
|
||||
if !isProxyEnabled() {
|
||||
p.disableProxyAfter(p.useDuration)
|
||||
}
|
||||
|
||||
globalSetRootURL(https(proxy))
|
||||
}
|
||||
|
||||
// disableProxyAfter disables the proxy after the given amount of time.
|
||||
func (p *proxyManager) disableProxyAfter(d time.Duration) {
|
||||
go func() {
|
||||
<-time.After(d)
|
||||
globalSetRootURL(globalOriginalURL)
|
||||
}()
|
||||
}
|
||||
|
||||
// refreshProxyCache loads the latest proxies from the known providers.
|
||||
func (p *proxyManager) refreshProxyCache() error {
|
||||
func (p *proxyProvider) refreshProxyCache() error {
|
||||
logrus.Info("Refreshing proxy cache")
|
||||
|
||||
for _, provider := range p.providers {
|
||||
|
@ -197,7 +124,7 @@ func (p *proxyManager) refreshProxyCache() error {
|
|||
p.proxyCache = proxies
|
||||
|
||||
// We also want to allow bridge to switch back to the standard API at any time.
|
||||
p.proxyCache = append(p.proxyCache, globalOriginalURL)
|
||||
p.proxyCache = append(p.proxyCache, RootURL)
|
||||
|
||||
logrus.WithField("proxies", proxies).Info("Available proxies")
|
||||
|
||||
|
@ -210,9 +137,13 @@ func (p *proxyManager) refreshProxyCache() error {
|
|||
|
||||
// canReach returns whether we can reach the given url.
|
||||
// NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname.
|
||||
func (p *proxyManager) canReach(url string) bool {
|
||||
func (p *proxyProvider) canReach(url string) bool {
|
||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
|
||||
url = "https://" + url
|
||||
}
|
||||
|
||||
pinger := resty.New().
|
||||
SetHostURL(https(url)).
|
||||
SetHostURL(url).
|
||||
SetTimeout(p.lookupTimeout).
|
||||
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
|
||||
|
||||
|
@ -227,7 +158,7 @@ func (p *proxyManager) canReach(url string) bool {
|
|||
// It looks up DNS TXT records for the given query URL using the given DoH provider.
|
||||
// It returns a list of all found TXT records.
|
||||
// If the whole process takes more than ProxyQueryTime then an error is returned.
|
||||
func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []string, err error) {
|
||||
func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []string, err error) {
|
||||
dataResult := make(chan []string)
|
||||
errResult := make(chan error)
|
||||
go func() {
|
||||
|
@ -282,23 +213,3 @@ func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []strin
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
func stripProtocol(url string) string {
|
||||
if strings.HasPrefix(url, "https://") {
|
||||
return strings.TrimPrefix(url, "https://")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(url, "http://") {
|
||||
return strings.TrimPrefix(url, "http://")
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
func https(url string) string {
|
||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
|
||||
url = "https://" + url
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
|
|
@ -32,14 +32,14 @@ const (
|
|||
TestGoogleProvider = "https://dns.google/dns-query"
|
||||
)
|
||||
|
||||
func TestProxyManager_FindProxy(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findProxy()
|
||||
|
@ -47,7 +47,7 @@ func TestProxyManager_FindProxy(t *testing.T) {
|
|||
require.Equal(t, proxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
|
@ -58,7 +58,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
|
|||
badProxy.Close()
|
||||
defer goodProxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil }
|
||||
|
||||
url, err := p.findProxy()
|
||||
|
@ -66,7 +66,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
|
|||
require.Equal(t, goodProxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
|
@ -77,21 +77,21 @@ func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) {
|
|||
badProxy.Close()
|
||||
anotherBadProxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil }
|
||||
|
||||
_, err := p.findProxy()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.lookupTimeout = time.Second
|
||||
p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
|
||||
|
||||
|
@ -100,7 +100,7 @@ func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
|
||||
func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
|
@ -109,7 +109,7 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
|
|||
}))
|
||||
defer slowProxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.findTimeout = time.Second
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
|
||||
|
||||
|
@ -118,14 +118,14 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyManager_UseProxy(t *testing.T) {
|
||||
func TestProxyProvider_UseProxy(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findProxy()
|
||||
|
@ -135,7 +135,7 @@ func TestProxyManager_UseProxy(t *testing.T) {
|
|||
require.Equal(t, proxy.URL, GlobalGetRootURL())
|
||||
}
|
||||
|
||||
func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
|
||||
func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
|
@ -146,7 +146,7 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
|
|||
proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy3.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||
url, err := p.findProxy()
|
||||
|
@ -173,14 +173,14 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
|
|||
require.Equal(t, proxy3.URL, GlobalGetRootURL())
|
||||
}
|
||||
|
||||
func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.useDuration = time.Second
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
|
@ -195,14 +195,14 @@ func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) {
|
|||
require.Equal(t, globalOriginalURL, GlobalGetRootURL())
|
||||
}
|
||||
|
||||
func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||
func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||
// Don't block the API here because we want it to be working so the test can find it.
|
||||
defer unblockAPI()
|
||||
|
||||
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findProxy()
|
||||
|
@ -225,7 +225,7 @@ func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachabl
|
|||
require.Equal(t, globalOriginalURL, GlobalGetRootURL())
|
||||
}
|
||||
|
||||
func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
|
||||
func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
|
||||
blockAPI()
|
||||
defer unblockAPI()
|
||||
|
||||
|
@ -234,7 +234,7 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo
|
|||
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer proxy2.Close()
|
||||
|
||||
p := newProxyManager([]string{"not used"}, "not used")
|
||||
p := newProxyProvider([]string{"not used"}, "not used")
|
||||
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||
|
||||
// Find a proxy.
|
||||
|
@ -256,32 +256,32 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo
|
|||
require.Equal(t, proxy2.URL, GlobalGetRootURL())
|
||||
}
|
||||
|
||||
func TestProxyManager_DoHLookup_Quad9(t *testing.T) {
|
||||
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
func TestProxyManager_DoHLookup_Google(t *testing.T) {
|
||||
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
|
||||
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
func TestProxyManager_DoHLookup_FindProxy(t *testing.T) {
|
||||
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
|
||||
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
url, err := p.findProxy()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, url)
|
||||
}
|
||||
|
||||
func TestProxyManager_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
|
||||
p := newProxyManager([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery)
|
||||
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
|
||||
p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery)
|
||||
|
||||
url, err := p.findProxy()
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -26,8 +26,9 @@ import (
|
|||
)
|
||||
|
||||
// NewRequest creates a new request.
|
||||
func NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
|
||||
req, err = http.NewRequest(method, GlobalGetRootURL()+path, body)
|
||||
func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
|
||||
// TODO: Support other protocols (localhost needs http not https).
|
||||
req, err = http.NewRequest(method, "https://"+c.cm.GetRootURL()+path, body)
|
||||
if req != nil {
|
||||
req.Header.Set("User-Agent", CurrentUserAgent)
|
||||
}
|
||||
|
@ -35,13 +36,13 @@ func NewRequest(method, path string, body io.Reader) (req *http.Request, err err
|
|||
}
|
||||
|
||||
// NewJSONRequest create a new JSON request.
|
||||
func NewJSONRequest(method, path string, body interface{}) (*http.Request, error) {
|
||||
func (c *Client) NewJSONRequest(method, path string, body interface{}) (*http.Request, error) {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
req, err := NewRequest(method, path, bytes.NewReader(b))
|
||||
req, err := c.NewRequest(method, path, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -70,7 +71,7 @@ func (w *MultipartWriter) Close() error {
|
|||
// that writing the request and sending it MUST be done in parallel. If the
|
||||
// request fails, subsequent writes to the multipart writer will fail with an
|
||||
// io.ErrClosedPipe error.
|
||||
func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) {
|
||||
func (c *Client) NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) {
|
||||
// The pipe will connect the multipart writer and the HTTP request body.
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
|
@ -80,7 +81,7 @@ func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWr
|
|||
pw,
|
||||
}
|
||||
|
||||
req, err = NewRequest(method, path, pr)
|
||||
req, err = c.NewRequest(method, path, pr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ type UserSettings struct {
|
|||
|
||||
// GetUserSettings gets general settings.
|
||||
func (c *Client) GetUserSettings() (settings UserSettings, err error) {
|
||||
req, err := NewRequest("GET", "/settings", nil)
|
||||
req, err := c.NewRequest("GET", "/settings", nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -99,7 +99,7 @@ type MailSettings struct {
|
|||
|
||||
// GetMailSettings gets contact details specified by contact ID.
|
||||
func (c *Client) GetMailSettings() (settings MailSettings, err error) {
|
||||
req, err := NewRequest("GET", "/settings/mail", nil)
|
||||
req, err := c.NewRequest("GET", "/settings/mail", nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -93,7 +93,7 @@ func (u *User) KeyRing() *pmcrypto.KeyRing {
|
|||
|
||||
// UpdateUser retrieves details about user and loads its addresses.
|
||||
func (c *Client) UpdateUser() (user *User, err error) {
|
||||
req, err := NewRequest("GET", "/users", nil)
|
||||
req, err := c.NewRequest("GET", "/users", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue