feat: switch to proxy when need be

This commit is contained in:
James Houlahan 2020-04-01 17:20:03 +02:00
parent f239e8f3bf
commit ce29d4d74e
36 changed files with 311 additions and 320 deletions

View File

@ -274,8 +274,11 @@ func run(context *cli.Context) (contextError error) { // nolint[funlen]
log.Error("Could not get credentials store: ", credentialsError)
}
clientman := pmapi.NewClientManager(pmapifactory.GetClientConfig(cfg, eventListener))
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, clientman, credentialsStore)
clientConfig := pmapifactory.GetClientConfig(cfg.GetAPIConfig())
cm := pmapi.NewClientManager(clientConfig)
pmapifactory.SetClientRoundTripper(cm, clientConfig, eventListener)
bridgeInstance := bridge.New(cfg, pref, panicHandler, eventListener, Version, cm, credentialsStore)
imapBackend := imap.NewIMAPBackend(panicHandler, eventListener, cfg, bridgeInstance)
smtpBackend := smtp.NewSMTPBackend(panicHandler, eventListener, pref, bridgeInstance)

View File

@ -97,7 +97,7 @@ func New(
// Allow DoH before starting bridge if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if pref.GetBool(preferences.AllowProxyKey) {
AllowDoH()
b.AllowProxy()
}
go func() {
@ -178,15 +178,16 @@ func (b *Bridge) watchBridgeOutdated() {
}
}
// watchUserAuths receives auths from the client manager and sends them to the appropriate user.
func (b *Bridge) watchUserAuths() {
for auth := range b.clientManager.GetBridgeAuthChannel() {
user, ok := b.hasUser(auth.UserID)
logrus.WithField("token", auth.Auth.GenToken()).WithField("userID", auth.UserID).Info("Received auth from bridge auth channel")
if !ok {
continue
if user, ok := b.hasUser(auth.UserID); ok {
user.ReceiveAPIAuth(auth.Auth)
} else {
logrus.Info("User is not added to bridge yet")
}
user.ReceiveAPIAuth(auth.Auth)
}
}
@ -274,7 +275,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
apiClient := b.clientManager.GetClient(apiUser.ID)
auth, err = apiClient.AuthRefresh(auth.GenToken())
if err != nil {
log.WithError(err).Error("Could refresh token in new client")
log.WithError(err).Error("Could not refresh token in new client")
return
}
@ -298,6 +299,7 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
log.WithField("user", apiUser.ID).WithError(err).Error("Could not create user")
return
}
b.users = append(b.users, user)
}
// Set up the user auth and store (which we do for both new and existing users).
@ -307,7 +309,6 @@ func (b *Bridge) FinishLogin(loginClient PMAPIProvider, auth *pmapi.Auth, mbPass
}
if !hasUser {
b.users = append(b.users, user)
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel))
}
@ -475,16 +476,16 @@ func (b *Bridge) GetIMAPUpdatesChannel() chan interface{} {
return b.idleUpdates
}
// AllowDoH instructs bridge to use DoH to access an API proxy if necessary.
// AllowProxy instructs bridge to use DoH to access an API proxy if necessary.
// It also needs to work before bridge is initialised (because we may need to use the proxy at startup).
func AllowDoH() {
pmapi.GlobalAllowDoH()
func (b *Bridge) AllowProxy() {
b.clientManager.AllowProxy()
}
// DisallowDoH instructs bridge to not use DoH to access an API proxy if necessary.
// DisallowProxy instructs bridge to not use DoH to access an API proxy if necessary.
// It also needs to work before bridge is initialised (because we may need to use the proxy at startup).
func DisallowDoH() {
pmapi.GlobalDisallowDoH()
func (b *Bridge) DisallowProxy() {
b.clientManager.DisallowProxy()
}
func (b *Bridge) updateCurrentUserAgent() {
@ -493,7 +494,11 @@ func (b *Bridge) updateCurrentUserAgent() {
// hasUser returns whether the bridge currently has a user with ID `id`.
func (b *Bridge) hasUser(id string) (user *User, ok bool) {
logrus.WithField("id", id).Info("Checking whether bridge has given user")
for _, u := range b.users {
logrus.WithField("id", u.ID()).Info("Found potential user")
if u.ID() == id {
user, ok = u, true
return

View File

@ -107,6 +107,8 @@ func (u *User) init(idleUpdates chan interface{}) (err error) {
u.wasKeyringUnlocked = false
u.unlockingKeyringLock.Unlock()
u.log.Info("Initialising user")
// Reload the user's credentials (if they log out and back in we need the new
// version with the apitoken and mailbox password).
creds, err := u.credStorer.Get(u.userID)
@ -242,27 +244,19 @@ func (u *User) authorizeAndUnlock() (err error) {
}
func (u *User) ReceiveAPIAuth(auth *pmapi.Auth) {
u.lock.Lock()
defer u.lock.Unlock()
if auth == nil {
if err := u.logout(); err != nil {
u.log.WithError(err).Error("Cannot logout user after receiving empty auth from API")
u.log.WithError(err).Error("Failed to logout user after receiving empty auth from API")
}
u.isAuthorized = false
return
}
u.updateAPIToken(auth.GenToken())
}
// updateAPIToken is helper for updating the token in keychain. It's not supposed to be
// called directly from other parts of the code, only from `ReceiveAPIAuth`.
func (u *User) updateAPIToken(newRefreshToken string) {
u.lock.Lock()
defer u.lock.Unlock()
u.log.WithField("token", newRefreshToken).Info("Saving token to credentials store")
if err := u.credStorer.UpdateToken(u.userID, newRefreshToken); err != nil {
u.log.WithError(err).Error("Cannot update refresh token in credentials store")
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
return
}

View File

@ -22,7 +22,6 @@ import (
"strconv"
"strings"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/pkg/connection"
"github.com/ProtonMail/proton-bridge/pkg/ports"
@ -135,13 +134,13 @@ func (f *frontendCLI) toggleAllowProxy(c *ishell.Context) {
f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.")
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
f.preferences.SetBool(preferences.AllowProxyKey, false)
bridge.DisallowDoH()
f.bridge.DisallowProxy()
}
} else {
f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.")
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
f.preferences.SetBool(preferences.AllowProxyKey, true)
bridge.AllowDoH()
f.bridge.AllowProxy()
}
}
}

View File

@ -52,6 +52,8 @@ type Bridger interface {
DeleteUser(userID string, clearCache bool) error
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
ClearData() error
AllowProxy()
DisallowProxy()
}
// BridgeUser is an interface of user needed by frontend.

View File

@ -21,11 +21,14 @@
package pmapifactory
import (
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func GetClientConfig(config bridge.Configer, _ listener.Listener) *pmapi.ClientConfig {
return config.GetAPIConfig()
func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
return clientConfig
}
func SetClientRoundTripper(_ *pmapi.ClientManager, _ *pmapi.ClientConfig, _ listener.Listener) {
// Use the default roundtripper; do nothing.
}

View File

@ -23,26 +23,13 @@ package pmapifactory
import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/sirupsen/logrus"
)
func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.ClientConfig {
clientConfig := config.GetAPIConfig()
pin := pmapi.NewPMAPIPinning(clientConfig.AppVersion)
pin.ReportCertIssueLocal = func() {
listener.Emit(events.TLSCertIssue, "")
}
// This transport already has timeouts set governing the roundtrip:
// - IdleConnTimeout: 5 * time.Minute,
// - ExpectContinueTimeout: 500 * time.Millisecond,
// - ResponseHeaderTimeout: 30 * time.Second,
clientConfig.Transport = pin.TransportWithPinning()
func GetClientConfig(clientConfig *pmapi.ClientConfig) *pmapi.ClientConfig {
// We set additional timeouts/thresholds for the request as a whole:
clientConfig.Timeout = 10 * time.Minute // Overall request timeout (~25MB / 10 mins => ~40kB/s, should be reasonable).
clientConfig.FirstReadTimeout = 30 * time.Second // 30s to match 30s response header timeout.
@ -50,3 +37,15 @@ func GetClientConfig(config bridge.Configer, listener listener.Listener) *pmapi.
return clientConfig
}
func SetClientRoundTripper(cm *pmapi.ClientManager, cfg *pmapi.ClientConfig, listener listener.Listener) {
logrus.Info("Setting dialer with pinning")
pin := pmapi.NewDialerWithPinning(cm, cfg.AppVersion)
pin.ReportCertIssueLocal = func() {
listener.Emit(events.TLSCertIssue, "")
}
cm.SetClientRoundTripper(pin.TransportWithPinning())
}

View File

@ -39,7 +39,8 @@ var (
// Two errors can be returned, ErrNoInternetConnection or ErrCanNotReachAPI.
func CheckInternetConnection() error {
client := &http.Client{
Transport: pmapi.NewPMAPIPinning(pmapi.CurrentUserAgent).TransportWithPinning(),
// TODO: Set transport properly! (Need access to ClientManager somehow)
// Transport: pmapi.NewDialerWithPinning(pmapi.CurrentUserAgent).TransportWithPinning(),
}
// Do not cumulate timeouts, use goroutines.
@ -51,7 +52,8 @@ func CheckInternetConnection() error {
go checkConnection(client, "http://protonstatus.com/vpn_status", retStatus)
// Check of API reachability also uses a fast endpoint.
go checkConnection(client, pmapi.GlobalGetRootURL()+"/tests/ping", retAPI)
// TODO: This should check active proxy, not the RootURL
go checkConnection(client, pmapi.RootURL+"/tests/ping", retAPI)
errStatus := <-retStatus
errAPI := <-retAPI

View File

@ -36,7 +36,7 @@ type osxkeychain struct {
}
func newKeychain() (credentials.Helper, error) {
log.Debug("creating osckeychain")
log.Debug("Creating osckeychain")
return &osxkeychain{}, nil
}

View File

@ -24,14 +24,14 @@ import (
)
func newKeychain() (credentials.Helper, error) {
log.Debug("creating pass")
log.Debug("Creating pass")
passHelper := &pass.Pass{}
passErr := checkPassIsUsable(passHelper)
if passErr == nil {
return passHelper, nil
}
log.Debug("creating secretservice")
log.Debug("Creating secretservice")
sserviceHelper := &secretservice.Secretservice{}
_, sserviceErr := sserviceHelper.List()
if sserviceErr == nil {

View File

@ -23,7 +23,7 @@ import (
)
func newKeychain() (credentials.Helper, error) {
log.Debug("creating wincred")
log.Debug("Creating wincred")
return &wincred.Wincred{}, nil
}

View File

@ -161,7 +161,7 @@ func ConstructAddress(headerEmail string, addressEmail string) string {
// GetAddresses requests all of current user addresses (without pagination).
func (c *Client) GetAddresses() (addresses AddressList, err error) {
req, err := NewRequest("GET", "/addresses", nil)
req, err := c.NewRequest("GET", "/addresses", nil)
if err != nil {
return
}

View File

@ -179,7 +179,7 @@ func writeAttachment(w *multipart.Writer, att *Attachment, r io.Reader, sig io.R
//
// The returned created attachment contains the new attachment ID and its size.
func (c *Client) CreateAttachment(att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) {
req, w, err := NewMultipartRequest("POST", "/attachments")
req, w, err := c.NewMultipartRequest("POST", "/attachments")
if err != nil {
return
}
@ -213,7 +213,7 @@ type UpdateAttachmentSignatureReq struct {
func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err error) {
updateReq := &UpdateAttachmentSignatureReq{signature}
req, err := NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq)
req, err := c.NewJSONRequest("PUT", "/attachments/"+attachmentID+"/signature", updateReq)
if err != nil {
return
}
@ -228,7 +228,7 @@ func (c *Client) UpdateAttachmentSignature(attachmentID, signature string) (err
// DeleteAttachment removes an attachment. message is the message ID, att is the attachment ID.
func (c *Client) DeleteAttachment(attID string) (err error) {
req, err := NewRequest("DELETE", "/attachments/"+attID, nil)
req, err := c.NewRequest("DELETE", "/attachments/"+attID, nil)
if err != nil {
return
}
@ -249,7 +249,7 @@ func (c *Client) GetAttachment(id string) (att io.ReadCloser, err error) {
return
}
req, err := NewRequest("GET", "/attachments/"+id, nil)
req, err := c.NewRequest("GET", "/attachments/"+id, nil)
if err != nil {
return
}

View File

@ -214,7 +214,7 @@ func (c *Client) AuthInfo(username string) (info *AuthInfo, err error) {
Username: username,
}
req, err := NewJSONRequest("POST", "/auth/info", infoReq)
req, err := c.NewJSONRequest("POST", "/auth/info", infoReq)
if err != nil {
return
}
@ -257,7 +257,7 @@ func (c *Client) tryAuth(username, password string, info *AuthInfo, fallbackVers
SRPSession: info.srpSession,
}
req, err := NewJSONRequest("POST", "/auth", authReq)
req, err := c.NewJSONRequest("POST", "/auth", authReq)
if err != nil {
return
}
@ -335,7 +335,7 @@ func (c *Client) Auth2FA(twoFactorCode string, auth *Auth) (*Auth2FA, error) {
TwoFactorCode: twoFactorCode,
}
req, err := NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
req, err := c.NewJSONRequest("POST", "/auth/2fa", auth2FAReq)
if err != nil {
return nil, err
}
@ -430,7 +430,7 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
// UID must be set for `x-pm-uid` header field, see backend-communication#11
c.uid = split[0]
req, err := NewJSONRequest("POST", "/auth/refresh", refreshReq)
req, err := c.NewJSONRequest("POST", "/auth/refresh", refreshReq)
if err != nil {
return
}
@ -450,13 +450,14 @@ func (c *Client) AuthRefresh(uidAndRefreshToken string) (auth *Auth, err error)
return auth, err
}
// Logout instructs the client manager to log out this client.
func (c *Client) Logout() {
c.cm.LogoutClient(c.userID)
}
// logout logs the current user out.
func (c *Client) logout() (err error) {
req, err := NewRequest("DELETE", "/auth", nil)
req, err := c.NewRequest("DELETE", "/auth", nil)
if err != nil {
return
}

View File

@ -332,7 +332,9 @@ func TestClient_Logout(t *testing.T) {
c.uid = testUID
c.accessToken = testAccessToken
Ok(t, c.Logout())
c.Logout()
// TODO: Check that the client is logged out and sensitive data is cleared eventually.
}
func TestClient_DoUnauthorized(t *testing.T) {
@ -355,7 +357,7 @@ func TestClient_DoUnauthorized(t *testing.T) {
c.expiresAt = aLongTimeAgo
c.cm.tokens[c.userID] = testUID + ":" + testRefreshToken
req, err := NewRequest("GET", "/", nil)
req, err := c.NewRequest("GET", "/", nil)
Ok(t, err)
res, err := c.Do(req, true)

View File

@ -139,9 +139,9 @@ func (c *Client) Report(rep ReportReq) (err error) {
var req *http.Request
var w *MultipartWriter
if len(rep.Attachments) > 0 {
req, w, err = NewMultipartRequest("POST", "/reports/bug")
req, w, err = c.NewMultipartRequest("POST", "/reports/bug")
} else {
req, err = NewJSONRequest("POST", "/reports/bug", rep)
req, err = c.NewJSONRequest("POST", "/reports/bug", rep)
}
if err != nil {
return
@ -202,7 +202,7 @@ func (c *Client) ReportCrash(stacktrace string) (err error) {
OS: runtime.GOOS,
Debug: stacktrace,
}
req, err := NewJSONRequest("POST", "/reports/crash", crashReq)
req, err := c.NewJSONRequest("POST", "/reports/crash", crashReq)
if err != nil {
return
}

View File

@ -99,12 +99,12 @@ type ClientConfig struct {
// Client to communicate with API.
type Client struct {
cm *ClientManager
client *http.Client
cm *ClientManager
hc *http.Client
uid string
accessToken string
userID string // Twice here because Username is not unique.
userID string
requestLocker sync.Locker
keyLocker sync.Locker
@ -120,7 +120,7 @@ type Client struct {
func newClient(cm *ClientManager, userID string) *Client {
return &Client{
cm: cm,
client: getHTTPClient(cm.GetConfig()),
hc: getHTTPClient(cm.GetConfig()),
userID: userID,
requestLocker: &sync.Mutex{},
keyLocker: &sync.Mutex{},
@ -132,12 +132,10 @@ func newClient(cm *ClientManager, userID string) *Client {
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) {
hc = &http.Client{Timeout: cfg.Timeout}
if cfg.Transport == nil && defaultTransport == nil {
return
}
if defaultTransport != nil {
hc.Transport = defaultTransport
if cfg.Transport == nil {
if defaultTransport != nil {
hc.Transport = defaultTransport
}
return
}
@ -205,7 +203,7 @@ func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthori
}
hasBody := len(bodyBuffer) > 0
if res, err = c.client.Do(req); err != nil {
if res, err = c.hc.Do(req); err != nil {
if res == nil {
c.log.WithError(err).Error("Cannot get response")
err = ErrAPINotReachable

View File

@ -51,7 +51,7 @@ func TestClient_Do(t *testing.T) {
}))
defer s.Close()
req, err := NewRequest("GET", "/", nil)
req, err := c.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal("Expected no error while creating request, got:", err)
}
@ -163,8 +163,8 @@ func TestClient_FirstReadTimeout(t *testing.T) {
)
defer finish()
c.client.Transport = &slowTransport{
transport: c.client.Transport,
c.hc.Transport = &slowTransport{
transport: c.hc.Transport,
firstBodySleep: requestTimeout,
}

View File

@ -1,24 +1,32 @@
package pmapi
import (
"net/http"
"sync"
"github.com/getsentry/raven-go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// ClientManager is a manager of clients.
type ClientManager struct {
config *ClientConfig
clients map[string]*Client
clientsLocker sync.Locker
tokens map[string]string
tokensLocker sync.Locker
config *ClientConfig
url string
urlLocker sync.Locker
bridgeAuths chan ClientAuth
clientAuths chan ClientAuth
allowProxy bool
proxyProvider *proxyProvider
}
type ClientAuth struct {
@ -33,13 +41,21 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
}
cm = &ClientManager{
config: config,
clients: make(map[string]*Client),
clientsLocker: &sync.Mutex{},
tokens: make(map[string]string),
tokensLocker: &sync.Mutex{},
config: config,
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth),
tokens: make(map[string]string),
tokensLocker: &sync.Mutex{},
url: RootURL,
urlLocker: &sync.Mutex{},
bridgeAuths: make(chan ClientAuth),
clientAuths: make(chan ClientAuth),
proxyProvider: newProxyProvider(dohProviders, proxyQuery),
}
go cm.forwardClientAuths()
@ -47,6 +63,12 @@ func NewClientManager(config *ClientConfig) (cm *ClientManager) {
return
}
// SetClientRoundTripper sets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) SetClientRoundTripper(rt http.RoundTripper) {
logrus.Info("Setting client roundtripper")
cm.config.Transport = rt
}
// GetClient returns a client for the given userID.
// If the client does not exist already, it is created.
func (cm *ClientManager) GetClient(userID string) *Client {
@ -71,7 +93,7 @@ func (cm *ClientManager) LogoutClient(userID string) {
go func() {
if err := client.logout(); err != nil {
// TODO: Try again!
// TODO: Try again! This should loop until it succeeds (might fail the first time due to internet).
logrus.WithError(err).Error("Client logout failed, not trying again")
}
client.clearSensitiveData()
@ -81,6 +103,56 @@ func (cm *ClientManager) LogoutClient(userID string) {
return
}
// GetRootURL returns the root URL to make requests to.
// It does not include the protocol i.e. no "https://".
func (cm *ClientManager) GetRootURL() string {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
return cm.url
}
// SetRootURL sets the root URL to make requests to.
func (cm *ClientManager) SetRootURL(url string) {
cm.urlLocker.Lock()
defer cm.urlLocker.Unlock()
logrus.WithField("url", url).Info("Changing to a new root URL")
cm.url = url
}
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be.
func (cm *ClientManager) IsProxyAllowed() bool {
return cm.allowProxy
}
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
func (cm *ClientManager) AllowProxy() {
cm.allowProxy = true
}
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
func (cm *ClientManager) DisallowProxy() {
cm.allowProxy = false
}
// FindProxy returns a usable proxy server.
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
logrus.Info("Attempting gto switch to a proxy")
if proxy, err = cm.proxyProvider.findProxy(); err != nil {
err = errors.Wrap(err, "failed to find usable proxy")
return
}
cm.SetRootURL(proxy)
// TODO: Disable after 24 hours.
return
}
// GetConfig returns the config used to configure clients.
func (cm *ClientManager) GetConfig() *ClientConfig {
return cm.config
@ -113,7 +185,7 @@ func (cm *ClientManager) setToken(userID, token string) {
cm.tokensLocker.Lock()
defer cm.tokensLocker.Unlock()
logrus.WithField("userID", userID).WithField("token", token).Info("Updating refresh token")
logrus.WithField("userID", userID).Info("Updating refresh token")
cm.tokens[userID] = token
}
@ -127,16 +199,19 @@ func (cm *ClientManager) clearToken(userID string) {
delete(cm.tokens, userID)
}
// handleClientAuth
// handleClientAuth updates or clears client authorisation based on auths received.
func (cm *ClientManager) handleClientAuth(ca ClientAuth) {
// TODO: Maybe want to logout the client in case of nil auth.
// If we aren't managing this client, there's nothing to do.
if _, ok := cm.clients[ca.UserID]; !ok {
return
}
// If the auth is nil, we should clear the token.
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself.
if ca.Auth == nil {
cm.clearToken(ca.UserID)
} else {
cm.setToken(ca.UserID, ca.Auth.GenToken())
return
}
cm.setToken(ca.UserID, ca.Auth.GenToken())
}

View File

@ -24,9 +24,9 @@ import (
// RootURL is the API root URL.
//
// This can be changed using build flags: pmapi_local for "http://localhost/api",
// pmapi_dev or pmapi_prod. Default is pmapi_prod.
var RootURL = "https://api.protonmail.ch" //nolint[gochecknoglobals]
// This can be changed using build flags: pmapi_local for "localhost/api", pmapi_dev or pmapi_prod.
// Default is pmapi_prod.
var RootURL = "api.protonmail.ch" //nolint[gochecknoglobals]
// CurrentUserAgent is the default User-Agent for go-pmapi lib. This can be changed to program
// version and email client.

View File

@ -20,5 +20,5 @@
package pmapi
func init() {
RootURL = "https://dev.protonmail.com/api"
RootURL = "dev.protonmail.com/api"
}

View File

@ -27,7 +27,7 @@ import (
func init() {
// Use port above 1000 which doesn't need root access to start anything on it.
// Now the port is rounded pi. :-)
RootURL = "http://127.0.0.1:3142/api"
RootURL = "127.0.0.1:3142/api"
// TLS certificate is self-signed
defaultTransport = &http.Transport{

View File

@ -119,7 +119,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e
if pageSize > 0 {
v.Set("PageSize", strconv.Itoa(pageSize))
}
req, err := NewRequest("GET", "/contacts?"+v.Encode(), nil)
req, err := c.NewRequest("GET", "/contacts?"+v.Encode(), nil)
if err != nil {
return
@ -136,7 +136,7 @@ func (c *Client) GetContacts(page int, pageSize int) (contacts []*Contact, err e
// GetContactByID gets contact details specified by contact ID.
func (c *Client) GetContactByID(id string) (contactDetail Contact, err error) {
req, err := NewRequest("GET", "/contacts/"+id, nil)
req, err := c.NewRequest("GET", "/contacts/"+id, nil)
if err != nil {
return
@ -164,7 +164,7 @@ func (c *Client) GetContactsForExport(page int, pageSize int) (contacts []Contac
v.Set("PageSize", strconv.Itoa(pageSize))
}
req, err := NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
req, err := c.NewRequest("GET", "/contacts/export?"+v.Encode(), nil)
if err != nil {
return
@ -198,7 +198,7 @@ func (c *Client) GetAllContactsEmails(page int, pageSize int) (contactsEmails []
v.Set("PageSize", strconv.Itoa(pageSize))
}
req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
if err != nil {
return
}
@ -221,7 +221,7 @@ func (c *Client) GetContactEmailByEmail(email string, page int, pageSize int) (c
}
v.Set("Email", email)
req, err := NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
req, err := c.NewRequest("GET", "/contacts/emails?"+v.Encode(), nil)
if err != nil {
return
}
@ -276,7 +276,7 @@ func (c *Client) AddContacts(cards ContactsCards, overwrite int, groups int, lab
Labels: labels,
}
req, err := NewJSONRequest("POST", "/contacts", reqBody)
req, err := c.NewJSONRequest("POST", "/contacts", reqBody)
if err != nil {
return
}
@ -306,7 +306,7 @@ func (c *Client) UpdateContact(id string, cards []Card) (res *UpdateContactRespo
reqBody := UpdateContactReq{
Cards: cards,
}
req, err := NewJSONRequest("PUT", "/contacts/"+id, reqBody)
req, err := c.NewJSONRequest("PUT", "/contacts/"+id, reqBody)
if err != nil {
return
}
@ -354,7 +354,7 @@ func (c *Client) modifyContactGroups(groupID string, modifyContactGroupsAction i
Action: modifyContactGroupsAction,
ContactEmailIDs: contactEmailIDs,
}
req, err := NewJSONRequest("PUT", "/contacts/group", reqBody)
req, err := c.NewJSONRequest("PUT", "/contacts/group", reqBody)
if err != nil {
return
}
@ -377,7 +377,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) {
IDs: ids,
}
req, err := NewJSONRequest("PUT", "/contacts/delete", deleteReq)
req, err := c.NewJSONRequest("PUT", "/contacts/delete", deleteReq)
if err != nil {
return
}
@ -402,7 +402,7 @@ func (c *Client) DeleteContacts(ids []string) (err error) {
// DeleteAllContacts deletes all contacts.
func (c *Client) DeleteAllContacts() (err error) {
req, err := NewRequest("DELETE", "/contacts", nil)
req, err := c.NewRequest("DELETE", "/contacts", nil)
if err != nil {
return
}

View File

@ -36,7 +36,7 @@ func (c *Client) CountConversations(addressID string) (counts []*ConversationsCo
if addressID != "" {
reqURL += ("?AddressID=" + addressID)
}
req, err := NewRequest("GET", reqURL, nil)
req, err := c.NewRequest("GET", reqURL, nil)
if err != nil {
return
}

View File

@ -112,51 +112,47 @@ type DialerWithPinning struct {
// It is used only if set.
ReportCertIssueLocal func()
// proxyManager manages API proxies.
proxyManager *proxyManager
// cm is used to find and switch to a proxy if necessary.
cm *ClientManager
// A logger for logging messages.
log logrus.FieldLogger
}
func NewDialerWithPinning(reportURI string, report TLSReport) *DialerWithPinning {
// NewDialerWithPinning constructs a new dialer with pinned certs.
func NewDialerWithPinning(cm *ClientManager, appVersion string) *DialerWithPinning {
reportURI := "https://reports.protonmail.ch/reports/tls"
report := TLSReport{
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
IncludeSubdomains: false,
ValidatedCertificateChain: []string{},
ServedCertificateChain: []string{},
AppVersion: appVersion,
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
KnownPins: []string{
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
},
}
log := logrus.WithField("pkg", "pmapi/tls-pinning")
proxyManager := newProxyManager(dohProviders, proxyQuery)
return &DialerWithPinning{
isReported: false,
reportURI: reportURI,
report: report,
proxyManager: proxyManager,
log: log,
cm: cm,
isReported: false,
reportURI: reportURI,
report: report,
log: log,
}
}
func NewPMAPIPinning(appVersion string) *DialerWithPinning {
return NewDialerWithPinning(
"https://reports.protonmail.ch/reports/tls",
TLSReport{
EffectiveExpirationDate: time.Now().Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339),
IncludeSubdomains: false,
ValidatedCertificateChain: []string{},
ServedCertificateChain: []string{},
AppVersion: appVersion,
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;)
KnownPins: []string{
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // proxy main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // proxy backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // proxy backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // proxy backup 3
},
},
)
}
func (p *DialerWithPinning) reportCertIssue(connState tls.ConnectionState) {
p.isReported = true
@ -231,6 +227,7 @@ func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
return pemCerts
}
// TransportWithPinning creates an http.Transport that checks fingerprints when dialing.
func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -258,7 +255,7 @@ func (p *DialerWithPinning) TransportWithPinning() *http.Transport {
// p.ReportCertIssueLocal() and p.reportCertIssueRemote() if they are not nil.
func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (conn net.Conn, err error) {
// If DoH is enabled, we hardfail on fingerprint mismatches.
if globalIsDoHAllowed() && p.isReported {
if p.cm.IsProxyAllowed() && p.isReported {
return nil, ErrTLSMatch
}
@ -283,6 +280,8 @@ func (p *DialerWithPinning) dialAndCheckFingerprints(network, address string) (c
// dialWithProxyFallback tries to dial the given address but falls back to alternative proxies if need be.
func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn net.Conn, err error) {
p.log.Info("Dialing with proxy fallback")
var host, port string
if host, port, err = net.SplitHostPort(address); err != nil {
return
@ -296,21 +295,18 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
// If DoH is not allowed, give up. Or, if we are dialing something other than the API
// (e.g. we dial protonmail.com/... to check for updates), there's also no point in
// continuing since a proxy won't help us reach that.
if !globalIsDoHAllowed() || host != stripProtocol(GlobalGetRootURL()) {
if !p.cm.IsProxyAllowed() || host != p.cm.GetRootURL() {
p.log.WithField("useProxy", p.cm.IsProxyAllowed()).Info("Dial failed but not switching to proxy")
return
}
// Find a new proxy.
// Switch to a proxy and retry the dial.
var proxy string
if proxy, err = p.proxyManager.findProxy(); err != nil {
if proxy, err = p.cm.SwitchToProxy(); err != nil {
return
}
// Switch to the proxy.
p.log.WithField("proxy", proxy).Debug("Switching to proxy")
p.proxyManager.useProxy(proxy)
// Retry dial with proxy.
return p.dial(network, net.JoinHostPort(proxy, port))
}
@ -329,7 +325,7 @@ func (p *DialerWithPinning) dial(network, address string) (conn net.Conn, err er
// If we are not dialing the standard API then we should skip cert verification checks.
var tlsConfig *tls.Config = nil
if address != stripProtocol(globalOriginalURL) {
if address != RootURL {
tlsConfig = &tls.Config{InsecureSkipVerify: true} // nolint[gosec]
}

View File

@ -179,7 +179,7 @@ func (c *Client) GetEvent(last string) (event *Event, err error) {
func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event, err error) {
var req *http.Request
if last == "" {
req, err = NewRequest("GET", "/events/latest", nil)
req, err = c.NewRequest("GET", "/events/latest", nil)
if err != nil {
return
}
@ -191,7 +191,7 @@ func (c *Client) getEvent(last string, numberOfMergedEvents int) (event *Event,
event, err = res.Event, res.Err()
} else {
req, err = NewRequest("GET", "/events/"+last, nil)
req, err = c.NewRequest("GET", "/events/"+last, nil)
if err != nil {
return
}

View File

@ -120,7 +120,7 @@ type ImportMsgRes struct {
func (c *Client) Import(reqs []*ImportMsgReq) (resps []*ImportMsgRes, err error) {
importReq := &ImportReq{Messages: reqs}
req, w, err := NewMultipartRequest("POST", "/import")
req, w, err := c.NewMultipartRequest("POST", "/import")
if err != nil {
return
}

View File

@ -57,7 +57,7 @@ func (c *Client) PublicKeys(emails []string) (keys map[string]*pmcrypto.KeyRing,
email = url.QueryEscape(email)
var req *http.Request
if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil {
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
return
}
@ -90,7 +90,7 @@ func (c *Client) GetPublicKeysForEmail(email string) (keys []PublicKey, internal
email = url.QueryEscape(email)
var req *http.Request
if req, err = NewRequest("GET", "/keys?Email="+email, nil); err != nil {
if req, err = c.NewRequest("GET", "/keys?Email="+email, nil); err != nil {
return
}
@ -123,7 +123,7 @@ type KeySaltRes struct {
// GetKeySalts sends request to get list of key salts (n.b. locked route).
func (c *Client) GetKeySalts() (keySalts []KeySalt, err error) {
var req *http.Request
if req, err = NewRequest("GET", "/keys/salts", nil); err != nil {
if req, err = c.NewRequest("GET", "/keys/salts", nil); err != nil {
return
}

View File

@ -103,7 +103,7 @@ func (c *Client) ListContactGroups() (labels []*Label, err error) {
// ListLabelType lists all labels created by the user.
func (c *Client) ListLabelType(labelType int) (labels []*Label, err error) {
req, err := NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil)
req, err := c.NewRequest("GET", fmt.Sprintf("/labels?%d", labelType), nil)
if err != nil {
return
}
@ -129,7 +129,7 @@ type LabelRes struct {
// CreateLabel creates a new label.
func (c *Client) CreateLabel(label *Label) (created *Label, err error) {
labelReq := &LabelReq{label}
req, err := NewJSONRequest("POST", "/labels", labelReq)
req, err := c.NewJSONRequest("POST", "/labels", labelReq)
if err != nil {
return
}
@ -146,7 +146,7 @@ func (c *Client) CreateLabel(label *Label) (created *Label, err error) {
// UpdateLabel updates a label.
func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) {
labelReq := &LabelReq{label}
req, err := NewJSONRequest("PUT", "/labels/"+label.ID, labelReq)
req, err := c.NewJSONRequest("PUT", "/labels/"+label.ID, labelReq)
if err != nil {
return
}
@ -162,7 +162,7 @@ func (c *Client) UpdateLabel(label *Label) (updated *Label, err error) {
// DeleteLabel deletes a label.
func (c *Client) DeleteLabel(id string) (err error) {
req, err := NewRequest("DELETE", "/labels/"+id, nil)
req, err := c.NewRequest("DELETE", "/labels/"+id, nil)
if err != nil {
return
}

View File

@ -468,7 +468,7 @@ type MessagesListRes struct {
// ListMessages gets message metadata.
func (c *Client) ListMessages(filter *MessagesFilter) (msgs []*Message, total int, err error) {
req, err := NewRequest("GET", "/messages", nil)
req, err := c.NewRequest("GET", "/messages", nil)
if err != nil {
return
}
@ -500,7 +500,7 @@ func (c *Client) CountMessages(addressID string) (counts []*MessagesCount, err e
if addressID != "" {
reqURL += ("?AddressID=" + addressID)
}
req, err := NewRequest("GET", reqURL, nil)
req, err := c.NewRequest("GET", reqURL, nil)
if err != nil {
return
}
@ -522,7 +522,7 @@ type MessageRes struct {
// GetMessage retrieves a message.
func (c *Client) GetMessage(id string) (msg *Message, err error) {
req, err := NewRequest("GET", "/messages/"+id, nil)
req, err := c.NewRequest("GET", "/messages/"+id, nil)
if err != nil {
return
}
@ -599,7 +599,7 @@ func (c *Client) SendMessage(id string, sendReq *SendMessageReq) (sent, parent *
sendReq.Packages = []*MessagePackage{}
}
req, err := NewJSONRequest("POST", "/messages/"+id, sendReq)
req, err := c.NewJSONRequest("POST", "/messages/"+id, sendReq)
if err != nil {
return
}
@ -629,7 +629,7 @@ type DraftReq struct {
func (c *Client) CreateDraft(m *Message, parent string, action int) (created *Message, err error) {
createReq := &DraftReq{Message: m, ParentID: parent, Action: action, AttachmentKeyPackets: []string{}}
req, err := NewJSONRequest("POST", "/messages", createReq)
req, err := c.NewJSONRequest("POST", "/messages", createReq)
if err != nil {
return
}
@ -688,7 +688,7 @@ func (c *Client) doMessagesAction(action string, ids []string) (err error) {
// You should not call this directly unless you know what you are doing (it can overload the server).
func (c *Client) doMessagesActionInner(action string, ids []string) (err error) {
actionReq := &MessagesActionReq{IDs: ids}
req, err := NewJSONRequest("PUT", "/messages/"+action, actionReq)
req, err := c.NewJSONRequest("PUT", "/messages/"+action, actionReq)
if err != nil {
return
}
@ -740,7 +740,7 @@ func (c *Client) LabelMessages(ids []string, label string) (err error) {
func (c *Client) labelMessages(ids []string, label string) (err error) {
labelReq := &LabelMessagesReq{LabelID: label, IDs: ids}
req, err := NewJSONRequest("PUT", "/messages/label", labelReq)
req, err := c.NewJSONRequest("PUT", "/messages/label", labelReq)
if err != nil {
return
}
@ -770,7 +770,7 @@ func (c *Client) UnlabelMessages(ids []string, label string) (err error) {
func (c *Client) unlabelMessages(ids []string, label string) (err error) {
labelReq := &LabelMessagesReq{LabelID: label, IDs: ids}
req, err := NewJSONRequest("PUT", "/messages/unlabel", labelReq)
req, err := c.NewJSONRequest("PUT", "/messages/unlabel", labelReq)
if err != nil {
return
}
@ -793,7 +793,7 @@ func (c *Client) EmptyFolder(labelID, addressID string) (err error) {
reqURL += ("&AddressID=" + addressID)
}
req, err := NewRequest("DELETE", reqURL, nil)
req, err := c.NewRequest("DELETE", reqURL, nil)
if err != nil {
return

View File

@ -28,7 +28,7 @@ func (c *Client) SendSimpleMetric(category, action, label string) (err error) {
v.Set("Action", action)
v.Set("Label", label)
req, err := NewRequest("GET", "/metrics?"+v.Encode(), nil)
req, err := c.NewRequest("GET", "/metrics?"+v.Encode(), nil)
if err != nil {
return
}

View File

@ -21,7 +21,6 @@ import (
"crypto/tls"
"encoding/base64"
"strings"
"sync"
"time"
"github.com/go-resty/resty/v2"
@ -43,63 +42,8 @@ var dohProviders = []string{ //nolint[gochecknoglobals]
"https://dns.google/dns-query",
}
// globalAllowDoH controls whether or not to enable use of DoH/Proxy in pmapi.
var globalAllowDoH = false // nolint[golint]
// globalProxyMutex allows threadsafe modification of proxy state.
var globalProxyMutex = sync.RWMutex{} // nolint[golint]
// globalOriginalURL backs up the original API url so it can be restored later.
var globalOriginalURL = RootURL // nolint[golint]
// globalIsDoHAllowed returns whether or not to use DoH.
func globalIsDoHAllowed() bool { // nolint[golint]
globalProxyMutex.RLock()
defer globalProxyMutex.RUnlock()
return globalAllowDoH
}
// GlobalAllowDoH enables DoH.
func GlobalAllowDoH() { // nolint[golint]
globalProxyMutex.Lock()
defer globalProxyMutex.Unlock()
globalAllowDoH = true
}
// GlobalDisallowDoH disables DoH and sets the RootURL back to what it was.
func GlobalDisallowDoH() { // nolint[golint]
globalProxyMutex.Lock()
defer globalProxyMutex.Unlock()
globalAllowDoH = false
RootURL = globalOriginalURL
}
// globalSetRootURL sets the global RootURL.
func globalSetRootURL(url string) { // nolint[golint]
globalProxyMutex.Lock()
defer globalProxyMutex.Unlock()
RootURL = url
}
// GlobalGetRootURL returns the global RootURL.
func GlobalGetRootURL() (url string) { // nolint[golint]
globalProxyMutex.RLock()
defer globalProxyMutex.RUnlock()
return RootURL
}
// isProxyEnabled returns whether or not we are currently using a proxy.
func isProxyEnabled() bool { // nolint[golint]
return globalOriginalURL != GlobalGetRootURL()
}
// proxyManager manages known proxies.
type proxyManager struct {
// proxyProvider manages known proxies.
type proxyProvider struct {
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
dohLookup func(query, provider string) (urls []string, err error)
@ -113,10 +57,10 @@ type proxyManager struct {
lastLookup time.Time // The time at which we last attempted to find a proxy.
}
// newProxyManager creates a new proxyManager that queries the given DoH providers
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
// to retrieve DNS records for the given query string.
func newProxyManager(providers []string, query string) (p *proxyManager) { // nolint[unparam]
p = &proxyManager{
func newProxyProvider(providers []string, query string) (p *proxyProvider) { // nolint[unparam]
p = &proxyProvider{
providers: providers,
query: query,
useDuration: proxyRevertTime,
@ -132,7 +76,7 @@ func newProxyManager(providers []string, query string) (p *proxyManager) { // no
// findProxy returns a new proxy domain which is not equal to the current RootURL.
// It returns an error if the process takes longer than ProxySearchTime.
func (p *proxyManager) findProxy() (proxy string, err error) {
func (p *proxyProvider) findProxy() (proxy string, err error) {
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon")
}
@ -147,7 +91,7 @@ func (p *proxyManager) findProxy() (proxy string, err error) {
}
for _, proxy := range p.proxyCache {
if proxy != stripProtocol(GlobalGetRootURL()) && p.canReach(proxy) {
if p.canReach(proxy) {
proxyResult <- proxy
return
}
@ -171,25 +115,8 @@ func (p *proxyManager) findProxy() (proxy string, err error) {
}
}
// useProxy sets the proxy server to use. It returns to the original RootURL after 24 hours.
func (p *proxyManager) useProxy(proxy string) {
if !isProxyEnabled() {
p.disableProxyAfter(p.useDuration)
}
globalSetRootURL(https(proxy))
}
// disableProxyAfter disables the proxy after the given amount of time.
func (p *proxyManager) disableProxyAfter(d time.Duration) {
go func() {
<-time.After(d)
globalSetRootURL(globalOriginalURL)
}()
}
// refreshProxyCache loads the latest proxies from the known providers.
func (p *proxyManager) refreshProxyCache() error {
func (p *proxyProvider) refreshProxyCache() error {
logrus.Info("Refreshing proxy cache")
for _, provider := range p.providers {
@ -197,7 +124,7 @@ func (p *proxyManager) refreshProxyCache() error {
p.proxyCache = proxies
// We also want to allow bridge to switch back to the standard API at any time.
p.proxyCache = append(p.proxyCache, globalOriginalURL)
p.proxyCache = append(p.proxyCache, RootURL)
logrus.WithField("proxies", proxies).Info("Available proxies")
@ -210,9 +137,13 @@ func (p *proxyManager) refreshProxyCache() error {
// canReach returns whether we can reach the given url.
// NOTE: we skip cert verification to stop it complaining that cert name doesn't match hostname.
func (p *proxyManager) canReach(url string) bool {
func (p *proxyProvider) canReach(url string) bool {
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
pinger := resty.New().
SetHostURL(https(url)).
SetHostURL(url).
SetTimeout(p.lookupTimeout).
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) // nolint[gosec]
@ -227,7 +158,7 @@ func (p *proxyManager) canReach(url string) bool {
// It looks up DNS TXT records for the given query URL using the given DoH provider.
// It returns a list of all found TXT records.
// If the whole process takes more than ProxyQueryTime then an error is returned.
func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []string, err error) {
func (p *proxyProvider) defaultDoHLookup(query, dohProvider string) (data []string, err error) {
dataResult := make(chan []string)
errResult := make(chan error)
go func() {
@ -282,23 +213,3 @@ func (p *proxyManager) defaultDoHLookup(query, dohProvider string) (data []strin
return
}
}
func stripProtocol(url string) string {
if strings.HasPrefix(url, "https://") {
return strings.TrimPrefix(url, "https://")
}
if strings.HasPrefix(url, "http://") {
return strings.TrimPrefix(url, "http://")
}
return url
}
func https(url string) string {
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
return url
}

View File

@ -32,14 +32,14 @@ const (
TestGoogleProvider = "https://dns.google/dns-query"
)
func TestProxyManager_FindProxy(t *testing.T) {
func TestProxyProvider_FindProxy(t *testing.T) {
blockAPI()
defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy()
@ -47,7 +47,7 @@ func TestProxyManager_FindProxy(t *testing.T) {
require.Equal(t, proxy.URL, url)
}
func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
blockAPI()
defer unblockAPI()
@ -58,7 +58,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
badProxy.Close()
defer goodProxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil }
url, err := p.findProxy()
@ -66,7 +66,7 @@ func TestProxyManager_FindProxy_ChooseReachableProxy(t *testing.T) {
require.Equal(t, goodProxy.URL, url)
}
func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) {
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
blockAPI()
defer unblockAPI()
@ -77,21 +77,21 @@ func TestProxyManager_FindProxy_FailIfNoneReachable(t *testing.T) {
badProxy.Close()
anotherBadProxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil }
_, err := p.findProxy()
require.Error(t, err)
}
func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) {
func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) {
blockAPI()
defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.lookupTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
@ -100,7 +100,7 @@ func TestProxyManager_FindProxy_LookupTimeout(t *testing.T) {
require.Error(t, err)
}
func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) {
blockAPI()
defer unblockAPI()
@ -109,7 +109,7 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
}))
defer slowProxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.findTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
@ -118,14 +118,14 @@ func TestProxyManager_FindProxy_FindTimeout(t *testing.T) {
require.Error(t, err)
}
func TestProxyManager_UseProxy(t *testing.T) {
func TestProxyProvider_UseProxy(t *testing.T) {
blockAPI()
defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy()
@ -135,7 +135,7 @@ func TestProxyManager_UseProxy(t *testing.T) {
require.Equal(t, proxy.URL, GlobalGetRootURL())
}
func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
blockAPI()
defer unblockAPI()
@ -146,7 +146,7 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
proxy3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy3.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
url, err := p.findProxy()
@ -173,14 +173,14 @@ func TestProxyManager_UseProxy_MultipleTimes(t *testing.T) {
require.Equal(t, proxy3.URL, GlobalGetRootURL())
}
func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) {
func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
blockAPI()
defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.useDuration = time.Second
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
@ -195,14 +195,14 @@ func TestProxyManager_UseProxy_RevertAfterTime(t *testing.T) {
require.Equal(t, globalOriginalURL, GlobalGetRootURL())
}
func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
// Don't block the API here because we want it to be working so the test can find it.
defer unblockAPI()
proxy := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy()
@ -225,7 +225,7 @@ func TestProxyManager_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachabl
require.Equal(t, globalOriginalURL, GlobalGetRootURL())
}
func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
blockAPI()
defer unblockAPI()
@ -234,7 +234,7 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo
proxy2 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer proxy2.Close()
p := newProxyManager([]string{"not used"}, "not used")
p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
// Find a proxy.
@ -256,32 +256,32 @@ func TestProxyManager_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlo
require.Equal(t, proxy2.URL, GlobalGetRootURL())
}
func TestProxyManager_DoHLookup_Quad9(t *testing.T) {
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
records, err := p.dohLookup(TestDoHQuery, TestQuad9Provider)
require.NoError(t, err)
require.NotEmpty(t, records)
}
func TestProxyManager_DoHLookup_Google(t *testing.T) {
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
records, err := p.dohLookup(TestDoHQuery, TestGoogleProvider)
require.NoError(t, err)
require.NotEmpty(t, records)
}
func TestProxyManager_DoHLookup_FindProxy(t *testing.T) {
p := newProxyManager([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
url, err := p.findProxy()
require.NoError(t, err)
require.NotEmpty(t, url)
}
func TestProxyManager_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
p := newProxyManager([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery)
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery)
url, err := p.findProxy()
require.NoError(t, err)

View File

@ -26,8 +26,9 @@ import (
)
// NewRequest creates a new request.
func NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
req, err = http.NewRequest(method, GlobalGetRootURL()+path, body)
func (c *Client) NewRequest(method, path string, body io.Reader) (req *http.Request, err error) {
// TODO: Support other protocols (localhost needs http not https).
req, err = http.NewRequest(method, "https://"+c.cm.GetRootURL()+path, body)
if req != nil {
req.Header.Set("User-Agent", CurrentUserAgent)
}
@ -35,13 +36,13 @@ func NewRequest(method, path string, body io.Reader) (req *http.Request, err err
}
// NewJSONRequest create a new JSON request.
func NewJSONRequest(method, path string, body interface{}) (*http.Request, error) {
func (c *Client) NewJSONRequest(method, path string, body interface{}) (*http.Request, error) {
b, err := json.Marshal(body)
if err != nil {
panic(err)
}
req, err := NewRequest(method, path, bytes.NewReader(b))
req, err := c.NewRequest(method, path, bytes.NewReader(b))
if err != nil {
return nil, err
}
@ -70,7 +71,7 @@ func (w *MultipartWriter) Close() error {
// that writing the request and sending it MUST be done in parallel. If the
// request fails, subsequent writes to the multipart writer will fail with an
// io.ErrClosedPipe error.
func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) {
func (c *Client) NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWriter, err error) {
// The pipe will connect the multipart writer and the HTTP request body.
pr, pw := io.Pipe()
@ -80,7 +81,7 @@ func NewMultipartRequest(method, path string) (req *http.Request, w *MultipartWr
pw,
}
req, err = NewRequest(method, path, pr)
req, err = c.NewRequest(method, path, pr)
if err != nil {
return
}

View File

@ -45,7 +45,7 @@ type UserSettings struct {
// GetUserSettings gets general settings.
func (c *Client) GetUserSettings() (settings UserSettings, err error) {
req, err := NewRequest("GET", "/settings", nil)
req, err := c.NewRequest("GET", "/settings", nil)
if err != nil {
return
@ -99,7 +99,7 @@ type MailSettings struct {
// GetMailSettings gets contact details specified by contact ID.
func (c *Client) GetMailSettings() (settings MailSettings, err error) {
req, err := NewRequest("GET", "/settings/mail", nil)
req, err := c.NewRequest("GET", "/settings/mail", nil)
if err != nil {
return

View File

@ -93,7 +93,7 @@ func (u *User) KeyRing() *pmcrypto.KeyRing {
// UpdateUser retrieves details about user and loads its addresses.
func (c *Client) UpdateUser() (user *User, err error) {
req, err := NewRequest("GET", "/users", nil)
req, err := c.NewRequest("GET", "/users", nil)
if err != nil {
return
}