167 lines
5.4 KiB
Go
167 lines
5.4 KiB
Go
// Copyright (c) 2022 Proton AG
|
|
//
|
|
// This file is part of Proton Mail Bridge.
|
|
//
|
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package dialer
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ProtonMail/go-proton-api"
|
|
"github.com/ProtonMail/go-proton-api/server"
|
|
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
|
a "github.com/stretchr/testify/assert"
|
|
r "github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func getRootURL() string {
|
|
return "https://mail-api.proton.me"
|
|
}
|
|
|
|
func TestTLSPinValid(t *testing.T) {
|
|
called, _, _, _, cm := createClientWithPinningDialer(getRootURL())
|
|
|
|
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", []byte("password")) //nolint:dogsled
|
|
|
|
checkTLSIssueHandler(t, 0, called)
|
|
}
|
|
|
|
func TestTLSPinBackup(t *testing.T) {
|
|
called, _, _, checker, cm := createClientWithPinningDialer(getRootURL())
|
|
copyTrustedPins(checker)
|
|
checker.trustedPins[1] = checker.trustedPins[0]
|
|
checker.trustedPins[0] = ""
|
|
|
|
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", []byte("password")) //nolint:dogsled
|
|
|
|
checkTLSIssueHandler(t, 0, called)
|
|
}
|
|
|
|
func TestTLSPinInvalid(t *testing.T) {
|
|
s := server.New()
|
|
defer s.Close()
|
|
|
|
called, _, _, _, cm := createClientWithPinningDialer(s.GetHostURL())
|
|
|
|
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", []byte("password")) //nolint:dogsled
|
|
|
|
checkTLSIssueHandler(t, 1, called)
|
|
}
|
|
|
|
func TestTLSPinNoMatch(t *testing.T) {
|
|
skipIfProxyIsSet(t)
|
|
|
|
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
|
|
|
|
copyTrustedPins(checker)
|
|
for i := 0; i < len(checker.trustedPins); i++ {
|
|
checker.trustedPins[i] = "testing"
|
|
}
|
|
|
|
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", []byte("password")) //nolint:dogsled
|
|
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", []byte("password")) //nolint:dogsled
|
|
|
|
// Check that it will be reported only once per session, but notified every time.
|
|
r.Equal(t, 1, len(reporter.sentReports))
|
|
checkTLSIssueHandler(t, 2, called)
|
|
}
|
|
|
|
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
|
|
skipIfProxyIsSet(t)
|
|
|
|
_, dialer, _, _, _ := createClientWithPinningDialer("") //nolint:dogsled
|
|
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
|
|
r.Error(t, err, "expected dial to fail because of wrong public key")
|
|
}
|
|
|
|
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
|
|
skipIfProxyIsSet(t)
|
|
|
|
_, dialer, _, checker, _ := createClientWithPinningDialer("")
|
|
copyTrustedPins(checker)
|
|
checker.trustedPins = append(checker.trustedPins, `pin-sha256="LwnIKjNLV3z243ap8y0yXNPghsqE76J08Eq3COvUt2E="`)
|
|
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
|
|
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
|
|
}
|
|
|
|
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
|
|
skipIfProxyIsSet(t)
|
|
|
|
_, dialer, _, checker, _ := createClientWithPinningDialer("")
|
|
copyTrustedPins(checker)
|
|
checker.trustedPins = append(checker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
|
_, err := dialer.DialTLSContext(context.Background(), "tcp", "self-signed.badssl.com:443")
|
|
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
|
|
}
|
|
|
|
func createClientWithPinningDialer(hostURL string) (*atomicUint64, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *proton.Manager) {
|
|
called := &atomicUint64{}
|
|
|
|
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
|
|
checker := NewTLSPinChecker(TrustedAPIPins)
|
|
dialer := NewPinningTLSDialer(NewBasicTLSDialer(hostURL), reporter, checker)
|
|
|
|
go func() {
|
|
for range dialer.GetTLSIssueCh() {
|
|
called.add(1)
|
|
}
|
|
}()
|
|
|
|
return called, dialer, reporter, checker, proton.New(
|
|
proton.WithHostURL(hostURL),
|
|
proton.WithTransport(CreateTransportWithDialer(dialer)),
|
|
)
|
|
}
|
|
|
|
func copyTrustedPins(pinChecker *TLSPinChecker) {
|
|
copiedPins := make([]string, len(pinChecker.trustedPins))
|
|
copy(copiedPins, pinChecker.trustedPins)
|
|
pinChecker.trustedPins = copiedPins
|
|
}
|
|
|
|
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast uint64, called *atomicUint64) {
|
|
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
|
|
a.Eventually(
|
|
t,
|
|
func() bool {
|
|
if wantCalledAtLeast == 0 {
|
|
return called.load() == 0
|
|
}
|
|
// Dialer can do more attempts resulting in more calls.
|
|
return called.load() >= wantCalledAtLeast
|
|
},
|
|
time.Second,
|
|
10*time.Millisecond,
|
|
)
|
|
// Repeated again so it generates nice message.
|
|
if wantCalledAtLeast == 0 {
|
|
r.Equal(t, uint64(0), called.load())
|
|
} else {
|
|
r.GreaterOrEqual(t, called.load(), wantCalledAtLeast)
|
|
}
|
|
}
|
|
|
|
type atomicUint64 struct {
|
|
v uint64
|
|
}
|
|
|
|
func (x *atomicUint64) load() uint64 { return atomic.LoadUint64(&x.v) }
|
|
|
|
func (x *atomicUint64) add(delta uint64) uint64 { return atomic.AddUint64(&x.v, delta) }
|