GODT-1158: Store full messages bodies on disk

- GODT-1158: simple on-disk cache in store
- GODT-1158: better member naming in event loop
- GODT-1158: create on-disk cache during bridge setup
- GODT-1158: better job options
- GODT-1158: rename GetLiteral to GetRFC822
- GODT-1158: rename events -> currentEvents
- GODT-1158: unlock cache per-user
- GODT-1158: clean up cache after logout
- GODT-1158: randomized encrypted cache passphrase
- GODT-1158: Opt out of on-disk cache in settings
- GODT-1158: free space in cache
- GODT-1158: make tests compile
- GODT-1158: optional compression
- GODT-1158: cache custom location
- GODT-1158: basic capacity checker
- GODT-1158: cache free space config
- GODT-1158: only unlock cache if pmapi client is unlocked as well
- GODT-1158: simple background sync worker
- GODT-1158: set size/bodystructure when caching message
- GODT-1158: limit store db update blocking with semaphore
- GODT-1158: dumb 10-semaphore
- GODT-1158: properly handle delete; remove bad bodystructure handling
- GODT-1158: hacky fix for caching after logout... baaaaad
- GODT-1158: cache worker
- GODT-1158: compute body structure lazily
- GODT-1158: cache size in store
- GODT-1158: notify cacher when adding to store
- GODT-1158: 15 second store cache watcher
- GODT-1158: enable cacher
- GODT-1158: better cache worker starting/stopping
- GODT-1158: limit cacher to less concurrency than disk cache
- GODT-1158: message builder prio + pchan pkg
- GODT-1158: fix pchan, use in message builder
- GODT-1158: no sem in cacher (rely on message builder prio)
- GODT-1158: raise priority of existing jobs when requested
- GODT-1158: pending messages in on-disk cache
- GODT-1158: WIP just a note about deleting messages from disk cache
- GODT-1158: pending wait when trying to write
- GODT-1158: pending.add to return bool
- GODT-1225: Headers in bodystructure are stored as bytes.
- GODT-1158: fixing header caching
- GODT-1158: don't cache in background
- GODT-1158: all concurrency set in settings
- GODT-1158: worker pools inside message builder
- GODT-1158: fix linter issues
- GODT-1158: remove completed builds from builder
- GODT-1158: remove builder pool
- GODT-1158: cacher defer job done properly
- GODT-1158: fix linter
- GODT-1299: Continue with bodystructure build if deserialization failed
- GODT-1324: Delete messages from the cache when they are deleted on the server
- GODT-1158: refactor cache tests
- GODT-1158: move builder to app/bridge
- GODT-1306: Migrate cache on disk when location is changed (and delete when disabled)
This commit is contained in:
James Houlahan 2021-07-30 12:20:38 +02:00 committed by Jakub
parent 5cb893fc1b
commit 6bd0739013
79 changed files with 2911 additions and 1387 deletions

View File

@ -230,7 +230,7 @@ integration-test-bridge:
mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier,Storer > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go
@ -288,7 +288,7 @@ run-nogui-cli: clean-vendor gofiles
PROTONMAIL_ENV=dev go run ${BUILD_FLAGS} cmd/${TARGET_CMD}/main.go ${RUN_FLAGS} -c
run-debug:
PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS}
PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS} --noninteractive
run-qml-preview:
find internal/frontend/qml/ -iname '*qmlc' | xargs rm -f

1
TODO.md Normal file
View File

@ -0,0 +1 @@
- when cache is full, we need to stop the watcher? don't want to keep downloading messages and throwing them away when we try to cache them.

8
go.mod
View File

@ -32,7 +32,6 @@ require (
github.com/emersion/go-imap-move v0.0.0-20190710073258-6e5a51a5b342
github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26
github.com/emersion/go-mbox v1.0.2
github.com/emersion/go-message v0.12.1-0.20201221184100-40c3f864532b
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
github.com/emersion/go-smtp v0.14.0
@ -45,7 +44,6 @@ require (
github.com/golang/mock v1.4.4
github.com/google/go-cmp v0.5.1
github.com/google/uuid v1.1.1
github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c // indirect
github.com/hashicorp/go-multierror v1.1.0
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
@ -55,13 +53,11 @@ require (
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1
github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285
github.com/sirupsen/logrus v1.7.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.7.0
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/urfave/cli/v2 v2.2.0
github.com/vmihailenco/msgpack/v5 v5.1.3
go.etcd.io/bbolt v1.3.6

23
go.sum
View File

@ -124,8 +124,6 @@ github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c h1:khcEdu1y
github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c/go.mod h1:iApyhIQBiU4XFyr+3kdJyyGqle82TbQyuP2o+OZHrV0=
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 h1:FiSb8+XBQQSkcX3ubr+1tAtlRJBYaFmRZqOAweZ9Wy8=
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26/go.mod h1:+gnnZx3Mg3MnCzZrv0eZdp5puxXQUgGT/6N6L7ShKfM=
github.com/emersion/go-mbox v1.0.2 h1:tE/rT+lEugK9y0myEymCCHnwlZN04hlXPrbKkxRBA5I=
github.com/emersion/go-mbox v1.0.2/go.mod h1:Yp9IVuuOYLEuMv4yjgDHvhb5mHOcYH6x92Oas3QqEZI=
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
@ -197,9 +195,6 @@ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v0.0.0-20190411002643-bd77b112433e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4=
github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
@ -269,7 +264,6 @@ github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
@ -351,6 +345,8 @@ github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285 h1:d54EL9l+XteliUfUCGsEwwuk65dmmxX85VXF+9T6+50=
github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285/go.mod h1:fxIDly1xtudczrZeOOlfaUvd2OPb2qZAPuWdU2BsBTk=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo=
@ -365,13 +361,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@ -401,12 +393,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e h1:G0DQ/TRQyrEZjtLlLwevFjaRiG8eeCMlq9WXQ2OO2bk=
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e/go.mod h1:SUUR2j3aE1z6/g76SdD6NwACEpvCxb3fvG82eKbD6us=
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d h1:hAZyEG2swPRWjF0kqqdGERXUazYnRJdAk4a58f14z7Y=
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:7m8PDYDEtEVqfjoUQc2UrFqhG0CDmoVJjRlQxexndFc=
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d h1:AJRoBel/g9cDS+yE8BcN3E+TDD/xNAguG21aoR8DAIE=
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:mH55Ek7AZcdns5KPp99O0bg+78el64YCYWHiQKrOdt4=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@ -444,7 +430,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190418165655-df01cb2cc480/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -488,7 +473,6 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@ -521,9 +505,7 @@ golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -557,7 +539,6 @@ golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=

View File

@ -32,7 +32,9 @@ import (
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/internal/imap"
"github.com/ProtonMail/proton-bridge/internal/smtp"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
@ -69,10 +71,21 @@ func New(base *base.Base) *cli.App {
func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
tlsConfig, err := loadTLSConfig(b)
if err != nil {
logrus.WithError(err).Fatal("Failed to load TLS config")
return err
}
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, b.CM, b.Creds, b.Updater, b.Versioner)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, bridge)
cache, err := loadCache(b)
if err != nil {
return err
}
builder := message.NewBuilder(
b.Settings.GetInt(settings.FetchWorkers),
b.Settings.GetInt(settings.AttachmentWorkers),
)
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, cache, builder, b.CM, b.Creds, b.Updater, b.Versioner)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
go func() {
@ -233,3 +246,35 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool)
f.NotifySilentUpdateInstalled()
}
// NOTE(GODT-1158): How big should in-memory cache be?
// NOTE(GODT-1158): How to handle cache location migration if user changes custom path?
func loadCache(b *base.Base) (cache.Cache, error) {
if !b.Settings.GetBool(settings.CacheEnabledKey) {
return cache.NewInMemoryCache(100 * (1 << 20)), nil
}
var compressor cache.Compressor
// NOTE(GODT-1158): If user changes compression setting we have to nuke the cache.
if b.Settings.GetBool(settings.CacheCompressionKey) {
compressor = &cache.GZipCompressor{}
} else {
compressor = &cache.NoopCompressor{}
}
var path string
if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" {
path = customPath
} else {
path = b.Cache.GetDefaultMessageCacheDir()
}
return cache.NewOnDiskCache(path, compressor, cache.Options{
MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)),
MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey),
ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead),
ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
})
}

View File

@ -28,17 +28,17 @@ import (
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/listener"
logrus "github.com/sirupsen/logrus"
)
var (
log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
)
var log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
type Bridge struct {
*users.Users
@ -52,11 +52,13 @@ type Bridge struct {
func New(
locations Locator,
cache Cacher,
s SettingsProvider,
cacheProvider CacheProvider,
setting SettingsProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
updater Updater,
@ -64,7 +66,7 @@ func New(
) *Bridge {
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if s.GetBool(settings.AllowProxyKey) {
if setting.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy()
}
@ -74,25 +76,25 @@ func New(
eventListener,
clientManager,
credStorer,
newStoreFactory(cache, sentryReporter, panicHandler, eventListener),
newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
)
b := &Bridge{
Users: u,
locations: locations,
settings: s,
settings: setting,
clientManager: clientManager,
updater: updater,
versioner: versioner,
}
if s.GetBool(settings.FirstStartKey) {
if setting.GetBool(settings.FirstStartKey) {
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
logrus.WithError(err).Error("Failed to send metric")
}
s.SetBool(settings.FirstStartKey, false)
setting.SetBool(settings.FirstStartKey, false)
}
go b.heartbeat()

View File

@ -23,47 +23,65 @@ import (
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
)
type storeFactory struct {
cache Cacher
cacheProvider CacheProvider
sentryReporter *sentry.Reporter
panicHandler users.PanicHandler
eventListener listener.Listener
storeCache *store.Cache
events *store.Events
cache cache.Cache
builder *message.Builder
}
func newStoreFactory(
cache Cacher,
cacheProvider CacheProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
) *storeFactory {
return &storeFactory{
cache: cache,
cacheProvider: cacheProvider,
sentryReporter: sentryReporter,
panicHandler: panicHandler,
eventListener: eventListener,
storeCache: store.NewCache(cache.GetIMAPCachePath()),
events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
cache: cache,
builder: builder,
}
}
// New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
return store.New(
f.sentryReporter,
f.panicHandler,
user,
f.eventListener,
f.cache,
f.builder,
getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
f.events,
)
}
// Remove removes all store files for given user.
func (f *storeFactory) Remove(userID string) error {
storePath := getUserStorePath(f.cache.GetDBDir(), userID)
return store.RemoveStore(f.storeCache, storePath, userID)
return store.RemoveStore(
f.events,
getUserStorePath(f.cacheProvider.GetDBDir(), userID),
userID,
)
}
// getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) {
fileName := fmt.Sprintf("mailbox-%v.db", userID)
return filepath.Join(storeDir, fileName)
return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
}

View File

@ -28,7 +28,7 @@ type Locator interface {
ClearUpdates() error
}
type Cacher interface {
type CacheProvider interface {
GetIMAPCachePath() string
GetDBDir() string
}
@ -38,6 +38,7 @@ type SettingsProvider interface {
Set(key string, value string)
GetBool(key string) bool
SetBool(key string, val bool)
GetInt(key string) int
}
type Updater interface {

View File

@ -45,6 +45,11 @@ func (c *Cache) GetDBDir() string {
return c.getCurrentCacheDir()
}
// GetDefaultMessageCacheDir returns folder for cached messages files.
func (c *Cache) GetDefaultMessageCacheDir() string {
return filepath.Join(c.getCurrentCacheDir(), "messages")
}
// GetIMAPCachePath returns path to file with IMAP status.
func (c *Cache) GetIMAPCachePath() string {
return filepath.Join(c.getCurrentCacheDir(), "user_info.json")

View File

@ -100,18 +100,28 @@ func (p *keyValueStore) GetBool(key string) bool {
}
func (p *keyValueStore) GetInt(key string) int {
if p.Get(key) == "" {
return 0
}
value, err := strconv.Atoi(p.Get(key))
if err != nil {
logrus.WithError(err).Error("Cannot parse int")
}
return value
}
func (p *keyValueStore) GetFloat64(key string) float64 {
if p.Get(key) == "" {
return 0
}
value, err := strconv.ParseFloat(p.Get(key), 64)
if err != nil {
logrus.WithError(err).Error("Cannot parse float64")
}
return value
}

View File

@ -43,6 +43,16 @@ const (
UpdateChannelKey = "update_channel"
RolloutKey = "rollout"
PreferredKeychainKey = "preferred_keychain"
CacheEnabledKey = "cache_enabled"
CacheCompressionKey = "cache_compression"
CacheLocationKey = "cache_location"
CacheMinFreeAbsKey = "cache_min_free_abs"
CacheMinFreeRatKey = "cache_min_free_rat"
CacheConcurrencyRead = "cache_concurrent_read"
CacheConcurrencyWrite = "cache_concurrent_write"
IMAPWorkers = "imap_workers"
FetchWorkers = "fetch_workers"
AttachmentWorkers = "attachment_workers"
)
type Settings struct {
@ -80,6 +90,16 @@ func (s *Settings) setDefaultValues() {
s.setDefault(UpdateChannelKey, "")
s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint[gosec] G404 It is OK to use weak random number generator here
s.setDefault(PreferredKeychainKey, "")
s.setDefault(CacheEnabledKey, "true")
s.setDefault(CacheCompressionKey, "true")
s.setDefault(CacheLocationKey, "")
s.setDefault(CacheMinFreeAbsKey, "250000000")
s.setDefault(CacheMinFreeRatKey, "")
s.setDefault(CacheConcurrencyRead, "16")
s.setDefault(CacheConcurrencyWrite, "16")
s.setDefault(IMAPWorkers, "16")
s.setDefault(FetchWorkers, "16")
s.setDefault(AttachmentWorkers, "16")
s.setDefault(APIPortKey, DefaultAPIPort)
s.setDefault(IMAPPortKey, DefaultIMAPPort)

View File

@ -128,6 +128,24 @@ func New( //nolint[funlen]
})
fe.AddCmd(dohCmd)
// Cache-On-Disk commands.
codCmd := &ishell.Cmd{Name: "local-cache",
Help: "manage the local encrypted message cache",
}
codCmd.AddCmd(&ishell.Cmd{Name: "enable",
Help: "enable the local cache",
Func: fe.enableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{Name: "disable",
Help: "disable the local cache",
Func: fe.disableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{Name: "change-location",
Help: "change the location of the local cache",
Func: fe.setCacheOnDiskLocation,
})
fe.AddCmd(codCmd)
// Updates commands.
updatesCmd := &ishell.Cmd{Name: "updates",
Help: "manage bridge updates",

View File

@ -19,6 +19,7 @@ package cli
import (
"fmt"
"os"
"strconv"
"strings"
@ -155,6 +156,67 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
}
}
func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) {
if f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already enabled.")
return
}
if f.yesNoQuestion("Are you sure you want to enable the local cache") {
// Set this back to the default location before enabling.
f.settings.Set(settings.CacheLocationKey, "")
if err := f.bridge.EnableCache(); err != nil {
f.Println("The local cache could not be enabled.")
return
}
f.settings.SetBool(settings.CacheEnabledKey, true)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) {
if !f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already disabled.")
return
}
if f.yesNoQuestion("Are you sure you want to disable the local cache") {
if err := f.bridge.DisableCache(); err != nil {
f.Println("The local cache could not be disabled.")
return
}
f.settings.SetBool(settings.CacheEnabledKey, false)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) {
if !f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache must be enabled.")
return
}
if location := f.settings.Get(settings.CacheLocationKey); location != "" {
f.Println("The current local cache location is:", location)
}
if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
if err := f.bridge.MigrateCache(f.settings.Get(settings.CacheLocationKey), location); err != nil {
f.Println("The local cache location could not be changed.")
return
}
f.settings.Set(settings.CacheLocationKey, location)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) isPortFree(port string) bool {
port = strings.ReplaceAll(port, ":", "")
if port == "" || port == currentPort {
@ -171,3 +233,13 @@ func (f *frontendCLI) isPortFree(port string) bool {
}
return true
}
// NOTE(GODT-1158): Check free space in location.
func (f *frontendCLI) isCacheLocationUsable(location string) bool {
stat, err := os.Stat(location)
if err != nil {
return false
}
return stat.IsDir()
}

View File

@ -77,6 +77,9 @@ type Bridger interface {
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
AllowProxy()
DisallowProxy()
EnableCache() error
DisableCache() error
MigrateCache(from, to string) error
GetUpdateChannel() updater.UpdateChannel
SetUpdateChannel(updater.UpdateChannel) (needRestart bool, err error)
GetKeychainApp() string

View File

@ -37,21 +37,13 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/emersion/go-imap"
goIMAPBackend "github.com/emersion/go-imap/backend"
)
const (
// NOTE: Each fetch worker has its own set of attach workers so there can be up to 20*5=100 API requests at once.
// This is a reasonable limit to not overwhelm API while still maintaining as much parallelism as possible.
fetchWorkers = 20 // In how many workers to fetch message (group list on IMAP).
attachWorkers = 5 // In how many workers to fetch attachments (for one message).
buildWorkers = 20 // In how many workers to build messages.
)
type panicHandler interface {
HandlePanic()
}
@ -61,26 +53,32 @@ type imapBackend struct {
bridge bridger
updates *imapUpdates
eventListener listener.Listener
listWorkers int
users map[string]*imapUser
usersLocker sync.Locker
builder *message.Builder
imapCache map[string]map[string]string
imapCachePath string
imapCacheLock *sync.RWMutex
}
type settingsProvider interface {
GetInt(string) int
}
// NewIMAPBackend returns struct implementing go-imap/backend interface.
func NewIMAPBackend(
panicHandler panicHandler,
eventListener listener.Listener,
cache cacheProvider,
setting settingsProvider,
bridge *bridge.Bridge,
) *imapBackend { //nolint[golint]
bridgeWrap := newBridgeWrap(bridge)
backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener)
imapWorkers := setting.GetInt(settings.IMAPWorkers)
backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener, imapWorkers)
go backend.monitorDisconnectedUsers()
@ -92,6 +90,7 @@ func newIMAPBackend(
cache cacheProvider,
bridge bridger,
eventListener listener.Listener,
listWorkers int,
) *imapBackend {
return &imapBackend{
panicHandler: panicHandler,
@ -102,10 +101,9 @@ func newIMAPBackend(
users: map[string]*imapUser{},
usersLocker: &sync.Mutex{},
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
imapCachePath: cache.GetIMAPCachePath(),
imapCacheLock: &sync.RWMutex{},
listWorkers: listWorkers,
}
}

View File

@ -1,151 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"bytes"
"sort"
"sync"
"time"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
)
type key struct {
ID string
Timestamp int64
Size int
}
type oldestFirst []key
func (s oldestFirst) Len() int { return len(s) }
func (s oldestFirst) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s oldestFirst) Less(i, j int) bool { return s[i].Timestamp < s[j].Timestamp }
type cachedMessage struct {
key
data []byte
structure pkgMsg.BodyStructure
}
//nolint[gochecknoglobals]
var (
cacheTimeLimit = int64(1 * 60 * 60 * 1000) // milliseconds
cacheSizeLimit = 100 * 1000 * 1000 // B - MUST be larger than email max size limit (~ 25 MB)
mailCache = make(map[string]cachedMessage)
// cacheMutex takes care of one single operation, whereas buildMutex takes
// care of the whole action doing multiple operations. buildMutex will protect
// you from asking server or decrypting or building the same message more
// than once. When first request to build the message comes, it will block
// all other build requests. When the first one is done, all others are
// handled by cache, not doing anything twice. With cacheMutex we are safe
// only to not mess up with the cache, but we could end up downloading and
// building message twice.
cacheMutex = &sync.Mutex{}
buildMutex = &sync.Mutex{}
buildLocks = map[string]interface{}{}
)
func (m *cachedMessage) isValidOrDel() bool {
if m.key.Timestamp+cacheTimeLimit < timestamp() {
delete(mailCache, m.key.ID)
return false
}
return true
}
func timestamp() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}
func Clear() {
mailCache = make(map[string]cachedMessage)
}
// BuildLock locks per message level, not on global level.
// Multiple different messages can be building at once.
func BuildLock(messageID string) {
for {
buildMutex.Lock()
if _, ok := buildLocks[messageID]; ok { // if locked, wait
buildMutex.Unlock()
time.Sleep(10 * time.Millisecond)
} else { // if unlocked, lock it
buildLocks[messageID] = struct{}{}
buildMutex.Unlock()
return
}
}
}
func BuildUnlock(messageID string) {
buildMutex.Lock()
defer buildMutex.Unlock()
delete(buildLocks, messageID)
}
func LoadMail(mID string) (reader *bytes.Reader, structure *pkgMsg.BodyStructure) {
reader = &bytes.Reader{}
cacheMutex.Lock()
defer cacheMutex.Unlock()
if message, ok := mailCache[mID]; ok && message.isValidOrDel() {
reader = bytes.NewReader(message.data)
structure = &message.structure
// Update timestamp to keep emails which are used often.
message.Timestamp = timestamp()
}
return
}
func SaveMail(mID string, msg []byte, structure *pkgMsg.BodyStructure) {
cacheMutex.Lock()
defer cacheMutex.Unlock()
newMessage := cachedMessage{
key: key{
ID: mID,
Timestamp: timestamp(),
Size: len(msg),
},
data: msg,
structure: *structure,
}
// Remove old and reduce size.
totalSize := 0
messageList := []key{}
for _, message := range mailCache {
if message.isValidOrDel() {
messageList = append(messageList, message.key)
totalSize += message.key.Size
}
}
sort.Sort(oldestFirst(messageList))
var oldest key
for totalSize+newMessage.key.Size >= cacheSizeLimit {
oldest, messageList = messageList[0], messageList[1:]
delete(mailCache, oldest.ID)
totalSize -= oldest.Size
}
// Write new.
mailCache[mID] = newMessage
}

View File

@ -1,98 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"fmt"
"testing"
"time"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/stretchr/testify/require"
)
var bs = &pkgMsg.BodyStructure{} //nolint[gochecknoglobals]
const testUID = "testmsg"
func TestSaveAndLoad(t *testing.T) {
msg := []byte("Test message")
SaveMail(testUID, msg, bs)
require.Equal(t, mailCache[testUID].data, msg)
reader, _ := LoadMail(testUID)
require.Equal(t, reader.Len(), len(msg))
stored := make([]byte, len(msg))
_, _ = reader.Read(stored)
require.Equal(t, stored, msg)
}
func TestMissing(t *testing.T) {
reader, _ := LoadMail("non-existing")
require.Equal(t, reader.Len(), 0)
}
func TestClearOld(t *testing.T) {
cacheTimeLimit = 10
msg := []byte("Test message")
SaveMail(testUID, msg, bs)
time.Sleep(100 * time.Millisecond)
reader, _ := LoadMail(testUID)
require.Equal(t, reader.Len(), 0)
}
func TestClearBig(t *testing.T) {
r := require.New(t)
wantMessage := []byte("Test message")
wantCacheSize := 3
nTestMessages := wantCacheSize * wantCacheSize
cacheSizeLimit = wantCacheSize*len(wantMessage) + 1
cacheTimeLimit = int64(1 << 20) // be sure the message will survive
// It should never have more than nSize items.
for i := 0; i < nTestMessages; i++ {
time.Sleep(1 * time.Millisecond)
SaveMail(fmt.Sprintf("%s%d", testUID, i), wantMessage, bs)
r.LessOrEqual(len(mailCache), wantCacheSize, "cache too big when %d", i)
}
// Check that the oldest are deleted first.
for i := 0; i < nTestMessages; i++ {
iUID := fmt.Sprintf("%s%d", testUID, i)
reader, _ := LoadMail(iUID)
mail := mailCache[iUID]
if i < (nTestMessages - wantCacheSize) {
r.Zero(reader.Len(), "LoadMail should return empty, but have %s for %s time %d ", string(mail.data), iUID, mail.key.Timestamp)
} else {
stored := make([]byte, len(wantMessage))
_, err := reader.Read(stored)
r.NoError(err)
r.Equal(wantMessage, stored, "LoadMail returned wrong message: %s for %s time %d", stored, iUID, mail.key.Timestamp)
}
}
}
func TestConcurency(t *testing.T) {
msg := []byte("Test message")
for i := 0; i < 10; i++ {
go SaveMail(fmt.Sprintf("%s%d", testUID, i), msg, bs)
}
}

View File

@ -37,12 +37,10 @@ type imapMailbox struct {
storeUser storeUserProvider
storeAddress storeAddressProvider
storeMailbox storeMailboxProvider
builder *message.Builder
}
// newIMAPMailbox returns struct implementing go-imap/mailbox interface.
func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider, builder *message.Builder) *imapMailbox {
func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider) *imapMailbox {
return &imapMailbox{
panicHandler: panicHandler,
user: user,
@ -56,8 +54,6 @@ func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox stor
storeUser: user.storeUser,
storeAddress: user.storeAddress,
storeMailbox: storeMailbox,
builder: builder,
}
}

View File

@ -19,21 +19,13 @@ package imap
import (
"bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
func (im *imapMailbox) getMessage(
storeMessage storeMessageProvider,
items []imap.FetchItem,
msgBuildCountHistogram *msgBuildCountHistogram,
) (msg *imap.Message, err error) {
func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []imap.FetchItem) (msg *imap.Message, err error) {
msglog := im.log.WithField("msgID", storeMessage.ID())
msglog.Trace("Getting message")
@ -69,9 +61,12 @@ func (im *imapMailbox) getMessage(
// There is no point having message older than RFC itself, it's not possible.
msg.InternalDate = message.SanitizeMessageDate(m.Time)
case imap.FetchRFC822Size:
if msg.Size, err = im.getSize(storeMessage); err != nil {
size, err := storeMessage.GetRFC822Size()
if err != nil {
return nil, err
}
msg.Size = size
case imap.FetchUid:
if msg.Uid, err = storeMessage.UID(); err != nil {
return nil, err
@ -79,7 +74,7 @@ func (im *imapMailbox) getMessage(
case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text:
fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests
default:
if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil {
if err = im.getLiteralForSection(item, msg, storeMessage); err != nil {
return
}
}
@ -88,35 +83,7 @@ func (im *imapMailbox) getMessage(
return msg, err
}
// getSize returns cached size or it will build the message, save the size in
// DB and then returns the size after build.
//
// We are storing size in DB as part of pmapi messages metada. The size
// attribute on the server represents size of encrypted body. The value is
// cleared in Bridge and the final decrypted size (including header, attachment
// and MIME structure) is computed after building the message.
func (im *imapMailbox) getSize(storeMessage storeMessageProvider) (uint32, error) {
m := storeMessage.Message()
if m.Size <= 0 {
im.log.WithField("msgID", m.ID).Debug("Size unknown - downloading body")
// We are sure the size is not a problem right now. Clients
// might not first check sizes of all messages so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude getting size from the
// counting and see build count as real message build.
if _, _, err := im.getBodyAndStructure(storeMessage, nil); err != nil {
return 0, err
}
}
return uint32(m.Size), nil
}
func (im *imapMailbox) getLiteralForSection(
itemSection imap.FetchItem,
msg *imap.Message,
storeMessage storeMessageProvider,
msgBuildCountHistogram *msgBuildCountHistogram,
) error {
func (im *imapMailbox) getLiteralForSection(itemSection imap.FetchItem, msg *imap.Message, storeMessage storeMessageProvider) error {
section, err := imap.ParseBodySectionName(itemSection)
if err != nil {
log.WithError(err).Warn("Failed to parse body section name; part will be skipped")
@ -124,7 +91,7 @@ func (im *imapMailbox) getLiteralForSection(
}
var literal imap.Literal
if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil {
if literal, err = im.getMessageBodySection(storeMessage, section); err != nil {
return err
}
@ -149,88 +116,25 @@ func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs *
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude first body structure fetch
// from the counting and see build count as real message build.
if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil {
if bs, _, err = im.getBodyAndStructure(storeMessage); err != nil {
return
}
}
return
}
func (im *imapMailbox) getBodyAndStructure(
storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram,
) (
structure *message.BodyStructure, bodyReader *bytes.Reader, err error,
) {
m := storeMessage.Message()
id := im.storeUser.UserID() + m.ID
cache.BuildLock(id)
defer cache.BuildUnlock(id)
bodyReader, structure = cache.LoadMail(id)
// return the message which was found in cache
if bodyReader.Len() != 0 && structure != nil {
return structure, bodyReader, nil
func (im *imapMailbox) getBodyAndStructure(storeMessage storeMessageProvider) (*message.BodyStructure, *bytes.Reader, error) {
rfc822, err := storeMessage.GetRFC822()
if err != nil {
return nil, nil, err
}
structure, body, err := im.buildMessage(m)
bodyReader = bytes.NewReader(body)
size := int64(len(body))
l := im.log.WithField("newSize", size).WithField("msgID", m.ID)
if err != nil || structure == nil || size == 0 {
l.WithField("hasStructure", structure != nil).Warn("Failed to build message")
return structure, bodyReader, err
structure, err := storeMessage.GetBodyStructure()
if err != nil {
return nil, nil, err
}
// Save the size, body structure and header even for messages which
// were unable to decrypt. Hence they doesn't have to be computed every
// time.
m.Size = size
cacheMessageInStore(storeMessage, structure, body, l)
if msgBuildCountHistogram != nil {
times, errCount := storeMessage.IncreaseBuildCount()
if errCount != nil {
l.WithError(errCount).Warn("Cannot increase build count")
}
msgBuildCountHistogram.add(times)
}
// Drafts can change therefore we don't want to cache them.
if !isMessageInDraftFolder(m) {
cache.SaveMail(id, body, structure)
}
return structure, bodyReader, err
}
func cacheMessageInStore(storeMessage storeMessageProvider, structure *message.BodyStructure, body []byte, l *logrus.Entry) {
m := storeMessage.Message()
if errSize := storeMessage.SetSize(m.Size); errSize != nil {
l.WithError(errSize).Warn("Cannot update size while building")
}
if structure != nil && !isMessageInDraftFolder(m) {
if errStruct := storeMessage.SetBodyStructure(structure); errStruct != nil {
l.WithError(errStruct).Warn("Cannot update bodystructure while building")
}
}
header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body))
if errHead == nil && len(header) != 0 {
if errStore := storeMessage.SetHeader(header); errStore != nil {
l.WithError(errStore).Warn("Cannot update header in store")
}
} else {
l.WithError(errHead).Warn("Cannot get header bytes from structure")
}
}
func isMessageInDraftFolder(m *pmapi.Message) bool {
for _, labelID := range m.LabelIDs {
if labelID == pmapi.DraftLabel {
return true
}
}
return false
return structure, bytes.NewReader(rfc822), nil
}
// This will download message (or read from cache) and pick up the section,
@ -246,11 +150,7 @@ func isMessageInDraftFolder(m *pmapi.Message) bool {
// For all other cases it is necessary to download and decrypt the message
// and drop the header which was obtained from cache. The header will
// will be stored in DB once successfully built. Check `getBodyAndStructure`.
func (im *imapMailbox) getMessageBodySection(
storeMessage storeMessageProvider,
section *imap.BodySectionName,
msgBuildCountHistogram *msgBuildCountHistogram,
) (imap.Literal, error) {
func (im *imapMailbox) getMessageBodySection(storeMessage storeMessageProvider, section *imap.BodySectionName) (imap.Literal, error) {
var header []byte
var response []byte
@ -260,7 +160,7 @@ func (im *imapMailbox) getMessageBodySection(
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
header = storeMessage.GetHeader()
} else {
structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram)
structure, bodyReader, err := im.getBodyAndStructure(storeMessage)
if err != nil {
return nil, err
}
@ -276,7 +176,7 @@ func (im *imapMailbox) getMessageBodySection(
case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part.
fallthrough
case section.Specifier == imap.HeaderSpecifier:
header, err = structure.GetSectionHeaderBytes(bodyReader, section.Path)
header, err = structure.GetSectionHeaderBytes(section.Path)
default:
err = errors.New("Unknown specifier " + string(section.Specifier))
}
@ -293,30 +193,3 @@ func (im *imapMailbox) getMessageBodySection(
// Trim any output if requested.
return bytes.NewBuffer(section.ExtractPartial(response)), nil
}
// buildMessage from PM to IMAP.
func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) {
body, err := im.builder.NewJobWithOptions(
context.Background(),
im.user.client(),
m.ID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
).GetResult()
if err != nil {
return nil, nil, err
}
structure, err := message.NewBodyStructure(bytes.NewReader(body))
if err != nil {
return nil, nil, err
}
return structure, body, nil
}

View File

@ -479,11 +479,16 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
}
// Filter by size (only if size was already calculated).
if m.Size > 0 {
if criteria.Larger != 0 && m.Size <= int64(criteria.Larger) {
size, err := storeMessage.GetRFC822Size()
if err != nil {
return nil, err
}
if size > 0 {
if criteria.Larger != 0 && int64(size) <= int64(criteria.Larger) {
continue
}
if criteria.Smaller != 0 && m.Size >= int64(criteria.Smaller) {
if criteria.Smaller != 0 && int64(size) >= int64(criteria.Smaller) {
continue
}
}
@ -513,13 +518,12 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
//
// Messages must be sent to msgResponse. When the function returns, msgResponse must be closed.
func (im *imapMailbox) ListMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) error {
msgBuildCountHistogram := newMsgBuildCountHistogram()
return im.logCommand(func() error {
return im.listMessages(isUID, seqSet, items, msgResponse, msgBuildCountHistogram)
}, "FETCH", isUID, seqSet, items, msgBuildCountHistogram)
return im.listMessages(isUID, seqSet, items, msgResponse)
}, "FETCH", isUID, seqSet, items)
}
func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message, msgBuildCountHistogram *msgBuildCountHistogram) (err error) { //nolint[funlen]
func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) (err error) { //nolint[funlen]
defer func() {
close(msgResponse)
if err != nil {
@ -564,7 +568,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err
}
msg, err := im.getMessage(storeMessage, items, msgBuildCountHistogram)
msg, err := im.getMessage(storeMessage, items)
if err != nil {
err = fmt.Errorf("list message build: %v", err)
l.WithField("metaID", storeMessage.ID()).Error(err)
@ -594,7 +598,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil
}
err = parallel.RunParallel(fetchWorkers, input, processCallback, collectCallback)
err = parallel.RunParallel(im.user.backend.listWorkers, input, processCallback, collectCallback)
if err != nil {
return err
}

View File

@ -1,65 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"fmt"
"sync"
)
// msgBuildCountHistogram is used to analyse and log the number of repetitive
// downloads of requested messages per one fetch. The number of builds per each
// messageID is stored in persistent database. The msgBuildCountHistogram will
// take this number for each message in ongoing fetch and create histogram of
// repeats.
//
// Example: During `fetch 1:300` there were
// - 100 messages were downloaded first time
// - 100 messages were downloaded second time
// - 99 messages were downloaded 10th times
// - 1 messages were downloaded 100th times.
type msgBuildCountHistogram struct {
// Key represents how many times message was build.
// Value stores how many messages are build X times based on the key.
counts map[uint32]uint32
lock sync.Locker
}
func newMsgBuildCountHistogram() *msgBuildCountHistogram {
return &msgBuildCountHistogram{
counts: map[uint32]uint32{},
lock: &sync.Mutex{},
}
}
func (c *msgBuildCountHistogram) String() string {
res := ""
for nRebuild, counts := range c.counts {
if res != "" {
res += ", "
}
res += fmt.Sprintf("[%d]:%d", nRebuild, counts)
}
return res
}
func (c *msgBuildCountHistogram) add(nRebuild uint32) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[nRebuild]++
}

View File

@ -80,7 +80,6 @@ type storeMailboxProvider interface {
GetDelimiter() string
GetMessage(apiID string) (storeMessageProvider, error)
FetchMessage(apiID string) (storeMessageProvider, error)
LabelMessages(apiID []string) error
UnlabelMessages(apiID []string) error
MarkMessagesRead(apiID []string) error
@ -100,14 +99,12 @@ type storeMessageProvider interface {
Message() *pmapi.Message
IsMarkedDeleted() bool
SetSize(int64) error
SetHeader([]byte) error
GetHeader() []byte
GetRFC822() ([]byte, error)
GetRFC822Size() (uint32, error)
GetMIMEHeader() textproto.MIMEHeader
IsFullHeaderCached() bool
SetBodyStructure(*pkgMsg.BodyStructure) error
GetBodyStructure() (*pkgMsg.BodyStructure, error)
IncreaseBuildCount() (uint32, error)
}
type storeUserWrap struct {
@ -165,7 +162,3 @@ func newStoreMailboxWrap(mailbox *store.Mailbox) *storeMailboxWrap {
func (s *storeMailboxWrap) GetMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.GetMessage(apiID)
}
func (s *storeMailboxWrap) FetchMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.FetchMessage(apiID)
}

View File

@ -135,7 +135,7 @@ func (iu *imapUser) ListMailboxes(showOnlySubcribed bool) ([]goIMAPBackend.Mailb
if showOnlySubcribed && !iu.isSubscribed(storeMailbox.LabelID()) {
continue
}
mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder)
mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox)
mailboxes = append(mailboxes, mailbox)
}
@ -167,7 +167,7 @@ func (iu *imapUser) GetMailbox(name string) (mb goIMAPBackend.Mailbox, err error
return
}
return newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder), nil
return newIMAPMailbox(iu.panicHandler, iu, storeMailbox), nil
}
// CreateMailbox creates a new mailbox.

View File

@ -18,99 +18,113 @@
package store
import (
"encoding/json"
"os"
"sync"
"github.com/pkg/errors"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
// Cache caches the last event IDs for all accounts (there should be only one instance).
type Cache struct {
// cache is map from userID => key (such as last event) => value (such as event ID).
cache map[string]map[string]string
path string
lock *sync.RWMutex
}
const passphraseKey = "passphrase"
// NewCache constructs a new cache at the given path.
func NewCache(path string) *Cache {
return &Cache{
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Cache) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.loadCache(); err != nil {
log.WithError(err).Warn("Problem to load store cache")
}
if c.cache == nil {
c.cache = map[string]map[string]string{}
}
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
return c.cache[userID]["events"]
}
func (c *Cache) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
c.cache[userID]["events"] = eventID
return c.saveCache()
}
func (c *Cache) loadCache() error {
if c.cache != nil {
return nil
}
f, err := os.Open(c.path)
// UnlockCache unlocks the cache for the user with the given keyring.
func (store *Store) UnlockCache(kr *crypto.KeyRing) error {
passphrase, err := store.getCachePassphrase()
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.cache)
}
if passphrase == nil {
if passphrase, err = crypto.RandomToken(32); err != nil {
return err
}
func (c *Cache) saveCache() error {
if c.cache == nil {
return errors.New("events: cannot save cache: cache is nil")
enc, err := kr.Encrypt(crypto.NewPlainMessage(passphrase), nil)
if err != nil {
return err
}
if err := store.setCachePassphrase(enc.GetBinary()); err != nil {
return err
}
} else {
dec, err := kr.Decrypt(crypto.NewPGPMessage(passphrase), nil, crypto.GetUnixTime())
if err != nil {
return err
}
passphrase = dec.GetBinary()
}
f, err := os.Create(c.path)
if err := store.cache.Unlock(store.user.ID(), passphrase); err != nil {
return err
}
store.cacher.start()
return nil
}
func (store *Store) getCachePassphrase() ([]byte, error) {
var passphrase []byte
if err := store.db.View(func(tx *bolt.Tx) error {
passphrase = tx.Bucket(cachePassphraseBucket).Get([]byte(passphraseKey))
return nil
}); err != nil {
return nil, err
}
return passphrase, nil
}
func (store *Store) setCachePassphrase(passphrase []byte) error {
return store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(cachePassphraseBucket).Put([]byte(passphraseKey), passphrase)
})
}
func (store *Store) clearCachePassphrase() error {
return store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(cachePassphraseBucket).Delete([]byte(passphraseKey))
})
}
func (store *Store) getCachedMessage(messageID string) ([]byte, error) {
if store.cache.Has(store.user.ID(), messageID) {
return store.cache.Get(store.user.ID(), messageID)
}
job, done := store.newBuildJob(messageID, message.ForegroundPriority)
defer done()
literal, err := job.GetResult()
if err != nil {
return nil, err
}
// NOTE(GODT-1158): No need to block until cache has been set; do this async?
if err := store.cache.Set(store.user.ID(), messageID, literal); err != nil {
logrus.WithError(err).Error("Failed to cache message")
}
return literal, nil
}
// IsCached returns whether the given message already exists in the cache.
func (store *Store) IsCached(messageID string) bool {
return store.cache.Has(store.user.ID(), messageID)
}
// BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache.
// It builds with background priority.
func (store *Store) BuildAndCacheMessage(messageID string) error {
job, done := store.newBuildJob(messageID, message.BackgroundPriority)
defer done()
literal, err := job.GetResult()
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.cache)
}
func (c *Cache) clearCacheUser(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache == nil {
log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user from event loop cache")
delete(c.cache, userID)
return c.saveCache()
return store.cache.Set(store.user.ID(), messageID, literal)
}

73
internal/store/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,73 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOnDiskCacheNoCompression(t *testing.T) {
cache, err := NewOnDiskCache(t.TempDir(), &NoopCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
require.NoError(t, err)
testCache(t, cache)
}
func TestOnDiskCacheGZipCompression(t *testing.T) {
cache, err := NewOnDiskCache(t.TempDir(), &GZipCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
require.NoError(t, err)
testCache(t, cache)
}
func TestInMemoryCache(t *testing.T) {
testCache(t, NewInMemoryCache(1<<20))
}
func testCache(t *testing.T, cache Cache) {
assert.NoError(t, cache.Unlock("userID1", []byte("my secret passphrase")))
assert.NoError(t, cache.Unlock("userID2", []byte("my other passphrase")))
getSetCachedMessage(t, cache, "userID1", "messageID1", "some secret")
assert.True(t, cache.Has("userID1", "messageID1"))
getSetCachedMessage(t, cache, "userID2", "messageID2", "some other secret")
assert.True(t, cache.Has("userID2", "messageID2"))
assert.NoError(t, cache.Rem("userID1", "messageID1"))
assert.False(t, cache.Has("userID1", "messageID1"))
assert.NoError(t, cache.Rem("userID2", "messageID2"))
assert.False(t, cache.Has("userID2", "messageID2"))
assert.NoError(t, cache.Delete("userID1"))
assert.NoError(t, cache.Delete("userID2"))
}
func getSetCachedMessage(t *testing.T, cache Cache, userID, messageID, secret string) {
assert.NoError(t, cache.Set(userID, messageID, []byte(secret)))
data, err := cache.Get(userID, messageID)
assert.NoError(t, err)
assert.Equal(t, []byte(secret), data)
}

33
internal/store/cache/compressor.go vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Compressor interface {
Compress([]byte) ([]byte, error)
Decompress([]byte) ([]byte, error)
}
type NoopCompressor struct{}
func (NoopCompressor) Compress(dec []byte) ([]byte, error) {
return dec, nil
}
func (NoopCompressor) Decompress(cmp []byte) ([]byte, error) {
return cmp, nil
}

60
internal/store/cache/compressor_gzip.go vendored Normal file
View File

@ -0,0 +1,60 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"bytes"
"compress/gzip"
)
type GZipCompressor struct{}
func (GZipCompressor) Compress(dec []byte) ([]byte, error) {
buf := new(bytes.Buffer)
zw := gzip.NewWriter(buf)
if _, err := zw.Write(dec); err != nil {
return nil, err
}
if err := zw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (GZipCompressor) Decompress(cmp []byte) ([]byte, error) {
zr, err := gzip.NewReader(bytes.NewReader(cmp))
if err != nil {
return nil, err
}
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(zr); err != nil {
return nil, err
}
if err := zr.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

244
internal/store/cache/disk.go vendored Normal file
View File

@ -0,0 +1,244 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"errors"
"io/ioutil"
"os"
"path/filepath"
"sync"
"github.com/ProtonMail/proton-bridge/pkg/semaphore"
"github.com/ricochet2200/go-disk-usage/du"
)
var ErrLowSpace = errors.New("not enough free space left on device")
type onDiskCache struct {
path string
opts Options
gcm map[string]cipher.AEAD
cmp Compressor
rsem, wsem semaphore.Semaphore
pending *pending
diskSize uint64
diskFree uint64
once *sync.Once
lock sync.Mutex
}
func NewOnDiskCache(path string, cmp Compressor, opts Options) (Cache, error) {
if err := os.MkdirAll(path, 0700); err != nil {
return nil, err
}
usage := du.NewDiskUsage(path)
// NOTE(GODT-1158): use Available() or Free()?
return &onDiskCache{
path: path,
opts: opts,
gcm: make(map[string]cipher.AEAD),
cmp: cmp,
rsem: semaphore.New(opts.ConcurrentRead),
wsem: semaphore.New(opts.ConcurrentWrite),
pending: newPending(),
diskSize: usage.Size(),
diskFree: usage.Available(),
once: &sync.Once{},
}, nil
}
func (c *onDiskCache) Unlock(userID string, passphrase []byte) error {
hash := sha256.New()
if _, err := hash.Write(passphrase); err != nil {
return err
}
aes, err := aes.NewCipher(hash.Sum(nil))
if err != nil {
return err
}
gcm, err := cipher.NewGCM(aes)
if err != nil {
return err
}
if err := os.MkdirAll(c.getUserPath(userID), 0700); err != nil {
return err
}
c.gcm[userID] = gcm
return nil
}
func (c *onDiskCache) Delete(userID string) error {
defer c.update()
return os.RemoveAll(c.getUserPath(userID))
}
// Has returns whether the given message exists in the cache.
func (c *onDiskCache) Has(userID, messageID string) bool {
c.pending.wait(c.getMessagePath(userID, messageID))
c.rsem.Lock()
defer c.rsem.Unlock()
_, err := os.Stat(c.getMessagePath(userID, messageID))
switch {
case err == nil:
return true
case os.IsNotExist(err):
return false
default:
panic(err)
}
}
func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) {
enc, err := c.readFile(c.getMessagePath(userID, messageID))
if err != nil {
return nil, err
}
cmp, err := c.gcm[userID].Open(nil, enc[:c.gcm[userID].NonceSize()], enc[c.gcm[userID].NonceSize():], nil)
if err != nil {
return nil, err
}
return c.cmp.Decompress(cmp)
}
func (c *onDiskCache) Set(userID, messageID string, literal []byte) error {
nonce := make([]byte, c.gcm[userID].NonceSize())
if _, err := rand.Read(nonce); err != nil {
return err
}
cmp, err := c.cmp.Compress(literal)
if err != nil {
return err
}
// NOTE(GODT-1158): How to properly handle low space? Don't return error, that's bad. Instead send event?
if !c.hasSpace(len(cmp)) {
return nil
}
return c.writeFile(c.getMessagePath(userID, messageID), c.gcm[userID].Seal(nonce, nonce, cmp, nil))
}
func (c *onDiskCache) Rem(userID, messageID string) error {
defer c.update()
return os.Remove(c.getMessagePath(userID, messageID))
}
func (c *onDiskCache) readFile(path string) ([]byte, error) {
c.rsem.Lock()
defer c.rsem.Unlock()
// Wait before reading in case the file is currently being written.
c.pending.wait(path)
return ioutil.ReadFile(filepath.Clean(path))
}
func (c *onDiskCache) writeFile(path string, b []byte) error {
c.wsem.Lock()
defer c.wsem.Unlock()
// Mark the file as currently being written.
// If it's already being written, wait for it to be done and return nil.
// NOTE(GODT-1158): Let's hope it succeeded...
if ok := c.pending.add(path); !ok {
c.pending.wait(path)
return nil
}
defer c.pending.done(path)
// Reduce the approximate free space (update it exactly later).
c.lock.Lock()
c.diskFree -= uint64(len(b))
c.lock.Unlock()
// Update the diskFree eventually.
defer c.update()
// NOTE(GODT-1158): What happens when this fails? Should be fixed eventually.
return ioutil.WriteFile(filepath.Clean(path), b, 0600)
}
func (c *onDiskCache) hasSpace(size int) bool {
c.lock.Lock()
defer c.lock.Unlock()
if c.opts.MinFreeAbs > 0 {
if c.diskFree-uint64(size) < c.opts.MinFreeAbs {
return false
}
}
if c.opts.MinFreeRat > 0 {
if float64(c.diskFree-uint64(size))/float64(c.diskSize) < c.opts.MinFreeRat {
return false
}
}
return true
}
func (c *onDiskCache) update() {
go func() {
c.once.Do(func() {
c.lock.Lock()
defer c.lock.Unlock()
// Update the free space.
c.diskFree = du.NewDiskUsage(c.path).Available()
// Reset the Once object (so we can update again).
c.once = &sync.Once{}
})
}()
}
func (c *onDiskCache) getUserPath(userID string) string {
return filepath.Join(c.path, getHash(userID))
}
func (c *onDiskCache) getMessagePath(userID, messageID string) string {
return filepath.Join(c.getUserPath(userID), getHash(messageID))
}

33
internal/store/cache/hash.go vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"crypto/sha256"
"encoding/hex"
)
func getHash(name string) string {
hash := sha256.New()
if _, err := hash.Write([]byte(name)); err != nil {
panic(err)
}
return hex.EncodeToString(hash.Sum(nil))
}

104
internal/store/cache/memory.go vendored Normal file
View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"errors"
"sync"
)
type inMemoryCache struct {
lock sync.RWMutex
data map[string]map[string][]byte
size, limit int
}
// NewInMemoryCache creates a new in memory cache which stores up to the given number of bytes of cached data.
// NOTE(GODT-1158): Make this threadsafe.
func NewInMemoryCache(limit int) Cache {
return &inMemoryCache{
data: make(map[string]map[string][]byte),
limit: limit,
}
}
func (c *inMemoryCache) Unlock(userID string, passphrase []byte) error {
c.data[userID] = make(map[string][]byte)
return nil
}
func (c *inMemoryCache) Delete(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
for _, message := range c.data[userID] {
c.size -= len(message)
}
delete(c.data, userID)
return nil
}
// Has returns whether the given message exists in the cache.
func (c *inMemoryCache) Has(userID, messageID string) bool {
if _, err := c.Get(userID, messageID); err != nil {
return false
}
return true
}
func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) {
c.lock.RLock()
defer c.lock.RUnlock()
literal, ok := c.data[userID][messageID]
if !ok {
return nil, errors.New("no such message in cache")
}
return literal, nil
}
// NOTE(GODT-1158): What to actually do when memory limit is reached? Replace something existing? Return error? Drop silently?
// NOTE(GODT-1158): Pull in cache-rotating feature from old IMAP cache.
func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.size+len(literal) > c.limit {
return nil
}
c.size += len(literal)
c.data[userID][messageID] = literal
return nil
}
func (c *inMemoryCache) Rem(userID, messageID string) error {
c.lock.Lock()
defer c.lock.Unlock()
c.size -= len(c.data[userID][messageID])
delete(c.data[userID], messageID)
return nil
}

25
internal/store/cache/options.go vendored Normal file
View File

@ -0,0 +1,25 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Options struct {
MinFreeAbs uint64
MinFreeRat float64
ConcurrentRead int
ConcurrentWrite int
}

61
internal/store/cache/pending.go vendored Normal file
View File

@ -0,0 +1,61 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import "sync"
type pending struct {
lock sync.Mutex
path map[string]chan struct{}
}
func newPending() *pending {
return &pending{path: make(map[string]chan struct{})}
}
func (p *pending) add(path string) bool {
p.lock.Lock()
defer p.lock.Unlock()
if _, ok := p.path[path]; ok {
return false
}
p.path[path] = make(chan struct{})
return true
}
func (p *pending) wait(path string) {
p.lock.Lock()
ch, ok := p.path[path]
p.lock.Unlock()
if ok {
<-ch
}
}
func (p *pending) done(path string) {
p.lock.Lock()
defer p.lock.Unlock()
defer close(p.path[path])
delete(p.path, path)
}

51
internal/store/cache/pending_test.go vendored Normal file
View File

@ -0,0 +1,51 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPending(t *testing.T) {
pending := newPending()
pending.add("1")
pending.add("2")
pending.add("3")
resCh := make(chan string)
go func() { pending.wait("1"); resCh <- "1" }()
go func() { pending.wait("2"); resCh <- "2" }()
go func() { pending.wait("3"); resCh <- "3" }()
pending.done("1")
assert.Equal(t, "1", <-resCh)
pending.done("2")
assert.Equal(t, "2", <-resCh)
pending.done("3")
assert.Equal(t, "3", <-resCh)
}
func TestPendingUnknown(t *testing.T) {
newPending().wait("this is not currently being waited")
}

28
internal/store/cache/types.go vendored Normal file
View File

@ -0,0 +1,28 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Cache interface {
Unlock(userID string, passphrase []byte) error
Delete(userID string) error
Has(userID, messageID string) bool
Get(userID, messageID string) ([]byte, error)
Set(userID, messageID string, literal []byte) error
Rem(userID, messageID string) error
}

View File

@ -0,0 +1,63 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import "time"
func (store *Store) StartWatcher() {
store.done = make(chan struct{})
go func() {
ticker := time.NewTicker(3 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// NOTE(GODT-1158): Race condition here? What if DB was already closed?
messageIDs, err := store.getAllMessageIDs()
if err != nil {
return
}
for _, messageID := range messageIDs {
if !store.IsCached(messageID) {
store.cacher.newJob(messageID)
}
}
case <-store.done:
return
}
}
}()
}
func (store *Store) stopWatcher() {
if store.done == nil {
return
}
select {
default:
close(store.done)
case <-store.done:
return
}
}

View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sync"
"github.com/sirupsen/logrus"
)
type Cacher struct {
storer Storer
jobs chan string
done chan struct{}
started bool
wg *sync.WaitGroup
}
type Storer interface {
IsCached(messageID string) bool
BuildAndCacheMessage(messageID string) error
}
func newCacher(storer Storer) *Cacher {
return &Cacher{
storer: storer,
jobs: make(chan string),
done: make(chan struct{}),
wg: &sync.WaitGroup{},
}
}
// newJob sends a new job to the cacher if it's running.
func (cacher *Cacher) newJob(messageID string) {
if !cacher.started {
return
}
select {
case <-cacher.done:
return
default:
if !cacher.storer.IsCached(messageID) {
cacher.wg.Add(1)
go func() { cacher.jobs <- messageID }()
}
}
}
func (cacher *Cacher) start() {
cacher.started = true
go func() {
for {
select {
case messageID := <-cacher.jobs:
go cacher.handleJob(messageID)
case <-cacher.done:
return
}
}
}()
}
func (cacher *Cacher) handleJob(messageID string) {
defer cacher.wg.Done()
if err := cacher.storer.BuildAndCacheMessage(messageID); err != nil {
logrus.WithError(err).Error("Failed to build and cache message")
} else {
logrus.WithField("messageID", messageID).Trace("Message cached")
}
}
func (cacher *Cacher) stop() {
cacher.started = false
cacher.wg.Wait()
select {
case <-cacher.done:
return
default:
close(cacher.done)
}
}

View File

@ -0,0 +1,103 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"testing"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
)
func withTestCacher(t *testing.T, doTest func(storer *storemocks.MockStorer, cacher *Cacher)) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Mock storer used to build/cache messages.
storer := storemocks.NewMockStorer(ctrl)
// Create a new cacher pointing to the fake store.
cacher := newCacher(storer)
// Start the cacher and wait for it to stop.
cacher.start()
defer cacher.stop()
doTest(storer, cacher)
}
func TestCacher(t *testing.T) {
// If the message is not yet cached, we should expect to try to build and cache it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
cacher.newJob("messageID")
})
}
func TestCacherAlreadyCached(t *testing.T) {
// If the message is already cached, we should not try to build it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(true)
cacher.newJob("messageID")
})
}
func TestCacherFail(t *testing.T) {
// If building the message fails, we should not try to cache it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(errors.New("failed to build message"))
cacher.newJob("messageID")
})
}
func TestCacherStop(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Mock storer used to build/cache messages.
storer := storemocks.NewMockStorer(ctrl)
// Create a new cacher pointing to the fake store.
cacher := newCacher(storer)
// Start the cacher.
cacher.start()
// Send a job -- this should succeed.
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
cacher.newJob("messageID")
// Stop the cacher.
cacher.stop()
// Send more jobs -- these should all be dropped.
cacher.newJob("messageID2")
cacher.newJob("messageID3")
cacher.newJob("messageID4")
cacher.newJob("messageID5")
// Stopping the cacher multiple times is safe.
cacher.stop()
cacher.stop()
cacher.stop()
cacher.stop()
}

View File

@ -34,7 +34,7 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@ -49,7 +49,7 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@ -61,7 +61,7 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})

View File

@ -38,7 +38,7 @@ const (
)
type eventLoop struct {
cache *Cache
currentEvents *Events
currentEventID string
currentEvent *pmapi.Event
pollCh chan chan struct{}
@ -51,26 +51,26 @@ type eventLoop struct {
log *logrus.Entry
store *Store
user BridgeUser
events listener.Listener
store *Store
user BridgeUser
listener listener.Listener
}
func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop {
func newEventLoop(currentEvents *Events, store *Store, user BridgeUser, listener listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop")
return &eventLoop{
cache: cache,
currentEventID: cache.getEventID(user.ID()),
currentEvents: currentEvents,
currentEventID: currentEvents.getEventID(user.ID()),
pollCh: make(chan chan struct{}),
isRunning: false,
log: eventLog,
store: store,
user: user,
events: events,
store: store,
user: user,
listener: listener,
}
}
@ -89,7 +89,7 @@ func (loop *eventLoop) setFirstEventID() (err error) {
loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
if err = loop.currentEvents.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
loop.log.WithError(err).Error("Could not set latest event ID in user cache")
return
}
@ -229,7 +229,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
if err != nil && isFdCloseToULimit() {
l.Warn("Ulimit reached")
loop.events.Emit(bridgeEvents.RestartBridgeEvent, "")
loop.listener.Emit(bridgeEvents.RestartBridgeEvent, "")
err = nil
}
@ -291,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// This allows the event loop to continue to function (unless the cache was broken
// and bridge stopped, in which case it will start from the old event ID anyway).
loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil {
if err = loop.currentEvents.setEventID(loop.user.ID(), event.EventID); err != nil {
return false, errors.Wrap(err, "failed to save event ID to cache")
}
}
@ -371,7 +371,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
switch addressEvent.Action {
case pmapi.EventCreate:
log.WithField("email", addressEvent.Address.Email).Debug("Address was created")
loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
loop.listener.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
case pmapi.EventUpdate:
oldAddress := oldList.ByID(addressEvent.ID)
@ -383,7 +383,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email
log.WithField("email", email).Debug("Address was updated")
if addressEvent.Address.Receive != oldAddress.Receive {
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
}
case pmapi.EventDelete:
@ -396,7 +396,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email
log.WithField("email", email).Debug("Address was deleted")
loop.user.CloseConnection(email)
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
case pmapi.EventUpdateFlags:
log.Error("EventUpdateFlags for address event is uknown operation")
}

View File

@ -53,7 +53,7 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
More: false,
}, nil),
)
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
// Event loop runs in goroutine started during store creation (newStoreNoEvents).
// Force to run the next event.
@ -78,7 +78,7 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
subject := "old subject"
newSubject := "new subject"
m.newStoreNoEvents(true, &pmapi.Message{
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: subject,
})
@ -106,7 +106,7 @@ func TestEventLoopDeletionNotPaused(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true, &pmapi.Message{
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
LabelIDs: []string{"label"},
@ -133,7 +133,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true, &pmapi.Message{
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
LabelIDs: []string{"label"},

116
internal/store/events.go Normal file
View File

@ -0,0 +1,116 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"os"
"sync"
"github.com/pkg/errors"
)
// Events caches the last event IDs for all accounts (there should be only one instance).
type Events struct {
// eventMap is map from userID => key (such as last event) => value (such as event ID).
eventMap map[string]map[string]string
path string
lock *sync.RWMutex
}
// NewEvents constructs a new event cache at the given path.
func NewEvents(path string) *Events {
return &Events{
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Events) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.loadEvents(); err != nil {
log.WithError(err).Warn("Problem to load store events")
}
if c.eventMap == nil {
c.eventMap = map[string]map[string]string{}
}
if c.eventMap[userID] == nil {
c.eventMap[userID] = map[string]string{}
}
return c.eventMap[userID]["events"]
}
func (c *Events) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.eventMap[userID] == nil {
c.eventMap[userID] = map[string]string{}
}
c.eventMap[userID]["events"] = eventID
return c.saveEvents()
}
func (c *Events) loadEvents() error {
if c.eventMap != nil {
return nil
}
f, err := os.Open(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.eventMap)
}
func (c *Events) saveEvents() error {
if c.eventMap == nil {
return errors.New("events: cannot save events: events map is nil")
}
f, err := os.Create(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.eventMap)
}
func (c *Events) clearUserEvents(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.eventMap == nil {
log.WithField("user", userID).Warning("Cannot clear user events: event map is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user events from event loop")
delete(c.eventMap, userID)
return c.saveEvents()
}

View File

@ -107,7 +107,7 @@ func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Sto
func TestMailboxCountRemove(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
testCounts := []*pmapi.MessagesCount{
{LabelID: "label1", Total: 100, Unread: 0},

View File

@ -35,7 +35,7 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@ -80,7 +80,7 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))

View File

@ -67,40 +67,19 @@ func (message *Message) Message() *pmapi.Message {
return message.msg
}
// IsMarkedDeleted returns true if message is marked as deleted for specific
// mailbox.
// IsMarkedDeleted returns true if message is marked as deleted for specific mailbox.
func (message *Message) IsMarkedDeleted() bool {
isMarkedAsDeleted := false
err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
var isMarkedAsDeleted bool
if err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
isMarkedAsDeleted = message.storeMailbox.txGetDeletedIDsBucket(tx).Get([]byte(message.msg.ID)) != nil
return nil
})
if err != nil {
}); err != nil {
message.storeMailbox.log.WithError(err).Error("Not able to retrieve deleted mark, assuming false.")
return false
}
return isMarkedAsDeleted
}
// SetSize updates the information about size of decrypted message which can be
// used for IMAP. This should not trigger any IMAP update.
// NOTE: The size from the server corresponds to pure body bytes. Hence it
// should not be used. The correct size has to be calculated from decrypted and
// built message.
func (message *Message) SetSize(size int64) error {
message.msg.Size = size
txUpdate := func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
}
stored.Size = size
return message.store.txPutMessage(
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
return isMarkedAsDeleted
}
// SetContentTypeAndHeader updates the information about content type and
@ -112,7 +91,7 @@ func (message *Message) SetSize(size int64) error {
func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error {
message.msg.MIMEType = mimeType
message.msg.Header = header
txUpdate := func(tx *bolt.Tx) error {
return message.store.db.Update(func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
@ -123,34 +102,26 @@ func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Hea
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
}
// SetHeader checks header can be parsed and if yes it stores header bytes in
// database.
func (message *Message) SetHeader(header []byte) error {
_, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(header))).ReadMIMEHeader()
if err != nil {
return err
}
return message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(headersBucket).Put([]byte(message.ID()), header)
})
}
// IsFullHeaderCached will check that valid full header is stored in DB.
func (message *Message) IsFullHeaderCached() bool {
header, err := message.getRawHeader()
return err == nil && header != nil
}
func (message *Message) getRawHeader() (raw []byte, err error) {
err = message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(headersBucket).Get([]byte(message.ID()))
var raw []byte
err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
return nil
})
return
return err == nil && raw != nil
}
func (message *Message) getRawHeader() ([]byte, error) {
bs, err := message.GetBodyStructure()
if err != nil {
return nil, err
}
return bs.GetMailHeaderBytes()
}
// GetHeader will return cached header from DB.
@ -178,44 +149,79 @@ func (message *Message) GetMIMEHeader() textproto.MIMEHeader {
return header
}
// SetBodyStructure stores serialized body structure in database.
func (message *Message) SetBodyStructure(bs *pkgMsg.BodyStructure) error {
txUpdate := func(tx *bolt.Tx) error {
return message.store.txPutBodyStructure(
tx.Bucket(bodystructureBucket),
message.ID(), bs,
)
}
return message.store.db.Update(txUpdate)
}
// GetBodyStructure returns the message's body structure.
// It checks first if it's in the store. If it is, it returns it from store,
// otherwise it computes it from the message cache (and saves the result to the store).
func (message *Message) GetBodyStructure() (*pkgMsg.BodyStructure, error) {
var raw []byte
// GetBodyStructure deserializes body structure from database. If body structure
// is not in database it returns nil error and nil body structure. If error
// occurs it returns nil body structure.
func (message *Message) GetBodyStructure() (bs *pkgMsg.BodyStructure, err error) {
txRead := func(tx *bolt.Tx) error {
bs, err = message.store.txGetBodyStructure(
tx.Bucket(bodystructureBucket),
message.ID(),
)
return err
}
if err = message.store.db.View(txRead); err != nil {
if err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
return nil
}); err != nil {
return nil, err
}
if len(raw) > 0 {
// If not possible to deserialize just continue with build.
if bs, err := pkgMsg.DeserializeBodyStructure(raw); err == nil {
return bs, nil
}
}
literal, err := message.store.getCachedMessage(message.ID())
if err != nil {
return nil, err
}
bs, err := pkgMsg.NewBodyStructure(bytes.NewReader(literal))
if err != nil {
return nil, err
}
if raw, err = bs.Serialize(); err != nil {
return nil, err
}
if err := message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(bodystructureBucket).Put([]byte(message.ID()), raw)
}); err != nil {
return nil, err
}
return bs, nil
}
func (message *Message) IncreaseBuildCount() (times uint32, err error) {
txUpdate := func(tx *bolt.Tx) error {
times, err = message.store.txIncreaseMsgBuildCount(
tx.Bucket(msgBuildCountBucket),
message.ID(),
)
return err
}
if err = message.store.db.Update(txUpdate); err != nil {
// GetRFC822 returns the raw message literal.
func (message *Message) GetRFC822() ([]byte, error) {
return message.store.getCachedMessage(message.ID())
}
// GetRFC822Size returns the size of the raw message literal.
func (message *Message) GetRFC822Size() (uint32, error) {
var raw []byte
if err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(sizeBucket).Get([]byte(message.ID()))
return nil
}); err != nil {
return 0, err
}
return times, nil
if len(raw) > 0 {
return btoi(raw), nil
}
literal, err := message.store.getCachedMessage(message.ID())
if err != nil {
return 0, err
}
if err := message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(sizeBucket).Put([]byte(message.ID()), itob(uint32(len(literal))))
}); err != nil {
return 0, err
}
return uint32(len(literal)), nil
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier,Storer)
// Package mocks is a generated GoMock package.
package mocks
@ -318,3 +318,54 @@ func (mr *MockChangeNotifierMockRecorder) UpdateMessage(arg0, arg1, arg2, arg3,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMessage", reflect.TypeOf((*MockChangeNotifier)(nil).UpdateMessage), arg0, arg1, arg2, arg3, arg4, arg5)
}
// MockStorer is a mock of Storer interface.
type MockStorer struct {
ctrl *gomock.Controller
recorder *MockStorerMockRecorder
}
// MockStorerMockRecorder is the mock recorder for MockStorer.
type MockStorerMockRecorder struct {
mock *MockStorer
}
// NewMockStorer creates a new mock instance.
func NewMockStorer(ctrl *gomock.Controller) *MockStorer {
mock := &MockStorer{ctrl: ctrl}
mock.recorder = &MockStorerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorer) EXPECT() *MockStorerMockRecorder {
return m.recorder
}
// BuildAndCacheMessage mocks base method.
func (m *MockStorer) BuildAndCacheMessage(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildAndCacheMessage", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// BuildAndCacheMessage indicates an expected call of BuildAndCacheMessage.
func (mr *MockStorerMockRecorder) BuildAndCacheMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildAndCacheMessage", reflect.TypeOf((*MockStorer)(nil).BuildAndCacheMessage), arg0)
}
// IsCached mocks base method.
func (m *MockStorer) IsCached(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsCached", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// IsCached indicates an expected call of IsCached.
func (mr *MockStorerMockRecorder) IsCached(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStorer)(nil).IsCached), arg0)
}

View File

@ -26,8 +26,11 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -52,19 +55,21 @@ var (
// Database structure:
// * metadata
// * {messageID} -> message data (subject, from, to, time, body size, ...)
// * {messageID} -> message data (subject, from, to, time, ...)
// * headers
// * {messageID} -> header bytes
// * bodystructure
// * {messageID} -> message body structure
// * msgbuildcount
// * {messageID} -> uint32 number of message builds to track re-sync issues
// * size
// * {messageID} -> uint32 value
// * counts
// * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive
// * address_info
// * {index} -> {address, addressID}
// * address_mode
// * mode -> string split or combined
// * cache_passphrase
// * passphrase -> cache passphrase (pgp encrypted message)
// * mailboxes_version
// * version -> uint32 value
// * sync_state
@ -79,19 +84,20 @@ var (
// * {messageID} -> uint32 imapUID
// * deleted_ids (can be missing or have no keys)
// * {messageID} -> true
metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
headersBucket = []byte("headers") //nolint[gochecknoglobals]
bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
msgBuildCountBucket = []byte("msgbuildcount") //nolint[gochecknoglobals]
countsBucket = []byte("counts") //nolint[gochecknoglobals]
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
headersBucket = []byte("headers") //nolint[gochecknoglobals]
bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
sizeBucket = []byte("size") //nolint[gochecknoglobals]
countsBucket = []byte("counts") //nolint[gochecknoglobals]
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
cachePassphraseBucket = []byte("cache_passphrase") //nolint[gochecknoglobals]
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
// ErrNoSuchAPIID when mailbox does not have API ID.
ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals]
@ -117,18 +123,23 @@ func exposeContextForSMTP() context.Context {
type Store struct {
sentryReporter *sentry.Reporter
panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser
eventLoop *eventLoop
currentEvents *Events
log *logrus.Entry
cache *Cache
filePath string
db *bolt.DB
lock *sync.RWMutex
addresses map[string]*Address
notifier ChangeNotifier
builder *message.Builder
cache cache.Cache
cacher *Cacher
done chan struct{}
isSyncRunning bool
syncCooldown cooldown
addressMode addressMode
@ -139,12 +150,14 @@ func New( // nolint[funlen]
sentryReporter *sentry.Reporter,
panicHandler PanicHandler,
user BridgeUser,
events listener.Listener,
listener listener.Listener,
cache cache.Cache,
builder *message.Builder,
path string,
cache *Cache,
currentEvents *Events,
) (store *Store, err error) {
if user == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
if user == nil || listener == nil || currentEvents == nil {
return nil, fmt.Errorf("missing parameters - user: %v, listener: %v, currentEvents: %v", user, listener, currentEvents)
}
l := log.WithField("user", user.ID())
@ -160,21 +173,29 @@ func New( // nolint[funlen]
bdb, err := openBoltDatabase(path)
if err != nil {
err = errors.Wrap(err, "failed to open store database")
return
return nil, errors.Wrap(err, "failed to open store database")
}
store = &Store{
sentryReporter: sentryReporter,
panicHandler: panicHandler,
user: user,
cache: cache,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
log: l,
currentEvents: currentEvents,
log: l,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
builder: builder,
cache: cache,
}
// Create a new cacher. It's not started yet.
// NOTE(GODT-1158): I hate this circular dependency store->cacher->store :(
store.cacher = newCacher(store)
// Minimal increase is event pollInterval, doubles every failed retry up to 5 minutes.
store.syncCooldown.setExponentialWait(pollInterval, 2, 5*time.Minute)
@ -188,7 +209,7 @@ func New( // nolint[funlen]
}
if user.IsConnected() {
store.eventLoop = newEventLoop(cache, store, user, events)
store.eventLoop = newEventLoop(currentEvents, store, user, listener)
go func() {
defer store.panicHandler.HandlePanic()
store.eventLoop.start()
@ -216,10 +237,11 @@ func openBoltDatabase(filePath string) (db *bolt.DB, err error) {
metadataBucket,
headersBucket,
bodystructureBucket,
msgBuildCountBucket,
sizeBucket,
countsBucket,
addressInfoBucket,
addressModeBucket,
cachePassphraseBucket,
syncStateBucket,
mailboxesBucket,
mboxVersionBucket,
@ -365,6 +387,24 @@ func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label)
return
}
// newBuildJob returns a new build job for the given message using the store's message builder.
func (store *Store) newBuildJob(messageID string, priority int) (*message.Job, pool.DoneFunc) {
return store.builder.NewJobWithOptions(
context.Background(),
store.client(),
messageID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
priority,
)
}
// Close stops the event loop and closes the database to free the file.
func (store *Store) Close() error {
store.lock.Lock()
@ -381,12 +421,21 @@ func (store *Store) CloseEventLoop() {
}
func (store *Store) close() error {
// Stop the watcher first before closing the database.
store.stopWatcher()
// Stop the cacher.
store.cacher.stop()
// Stop the event loop.
store.CloseEventLoop()
// Close the database.
return store.db.Close()
}
// Remove closes and removes the database file and clears the cache file.
func (store *Store) Remove() (err error) {
func (store *Store) Remove() error {
store.lock.Lock()
defer store.lock.Unlock()
@ -394,22 +443,34 @@ func (store *Store) Remove() (err error) {
var result *multierror.Error
if err = store.close(); err != nil {
if err := store.close(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to close store"))
}
if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil {
if err := RemoveStore(store.currentEvents, store.filePath, store.user.ID()); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove store"))
}
if err := store.RemoveCache(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove cache"))
}
return result.ErrorOrNil()
}
func (store *Store) RemoveCache() error {
if err := store.clearCachePassphrase(); err != nil {
logrus.WithError(err).Error("Failed to clear cache passphrase")
}
return store.cache.Delete(store.user.ID())
}
// RemoveStore removes the database file and clears the cache file.
func RemoveStore(cache *Cache, path, userID string) error {
func RemoveStore(currentEvents *Events, path, userID string) error {
var result *multierror.Error
if err := cache.clearCacheUser(userID); err != nil {
if err := currentEvents.clearUserEvents(userID); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache"))
}

View File

@ -23,13 +23,17 @@ import (
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
tests "github.com/ProtonMail/proton-bridge/test"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
@ -139,7 +143,7 @@ type mocksForStore struct {
store *Store
tmpDir string
cache *Cache
cache *Events
}
func initMocks(tb testing.TB) (*mocksForStore, func()) {
@ -162,7 +166,7 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
require.NoError(tb, err)
cacheFile := filepath.Join(mocks.tmpDir, "cache.json")
mocks.cache = NewCache(cacheFile)
mocks.cache = NewEvents(cacheFile)
return mocks, func() {
if err := recover(); err != nil {
@ -176,13 +180,14 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
}
}
func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
mocks.user.EXPECT().ID().Return("userID").AnyTimes()
mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes()
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
@ -213,6 +218,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.panicHandler,
mocks.user,
mocks.events,
cache.NewInMemoryCache(1<<20),
message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache,
)

View File

@ -27,7 +27,6 @@ import (
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -154,11 +153,6 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
return false, err
}
msgSize := message.Size
if msgSize == 0 {
msgSize = int64(len(message.Body))
}
var attSize int64
for _, att := range attachments {
b, err := ioutil.ReadAll(att.encReader)
@ -169,7 +163,7 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
att.encReader = bytes.NewBuffer(b)
}
return msgSize+attSize <= maxUpload, nil
return int64(len(message.Body))+attSize <= maxUpload, nil
}
func (store *Store) getDraftAction(message *pmapi.Message) int {
@ -237,39 +231,6 @@ func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Messag
return nil
}
func (store *Store) txPutBodyStructure(bsBucket *bolt.Bucket, msgID string, bs *pkgMsg.BodyStructure) error {
raw, err := bs.Serialize()
if err != nil {
return err
}
err = bsBucket.Put([]byte(msgID), raw)
if err != nil {
return errors.Wrap(err, "cannot put bodystructure bucket")
}
return nil
}
func (store *Store) txGetBodyStructure(bsBucket *bolt.Bucket, msgID string) (*pkgMsg.BodyStructure, error) {
raw := bsBucket.Get([]byte(msgID))
if len(raw) == 0 {
return nil, nil
}
return pkgMsg.DeserializeBodyStructure(raw)
}
func (store *Store) txIncreaseMsgBuildCount(b *bolt.Bucket, msgID string) (uint32, error) {
key := []byte(msgID)
count := uint32(0)
raw := b.Get(key)
if raw != nil {
count = btoi(raw)
}
count++
return count, b.Put(key, itob(count))
}
// createOrUpdateMessageEvent is helper to create only one message with
// createOrUpdateMessagesEvent.
func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error {
@ -287,7 +248,7 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
b := tx.Bucket(metadataBucket)
for _, msg := range msgs {
clearNonMetadata(msg)
txUpdateMetadaFromDB(b, msg, store.log)
txUpdateMetadataFromDB(b, msg, store.log)
}
return nil
})
@ -341,6 +302,11 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
return err
}
// Notify the cacher that it should start caching messages.
for _, msg := range msgs {
store.cacher.newJob(msg.ID)
}
return nil
}
@ -351,16 +317,12 @@ func clearNonMetadata(onlyMeta *pmapi.Message) {
onlyMeta.Attachments = nil
}
// txUpdateMetadaFromDB changes the the onlyMeta data.
// txUpdateMetadataFromDB changes the the onlyMeta data.
// If there is stored message in metaBucket the size, header and MIMEType are
// not changed if already set. To change these:
// * size must be updated by Message.SetSize
// * contentType and header must be updated by Message.SetContentTypeAndHeader.
func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
// Size attribute on the server is counting encrypted data. We need to compute
// "real" size of decrypted data. Negative values will be processed during fetch.
onlyMeta.Size = -1
func txUpdateMetadataFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
msgb := metaBucket.Get([]byte(onlyMeta.ID))
if msgb == nil {
return
@ -378,8 +340,7 @@ func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log
return
}
// Keep already calculated size and content type.
onlyMeta.Size = stored.Size
// Keep content type.
onlyMeta.MIMEType = stored.MIMEType
if stored.Header != "" && stored.Header != "(No Header)" {
tmpMsg, err := mail.ReadMessage(
@ -401,6 +362,12 @@ func (store *Store) deleteMessageEvent(apiID string) error {
// deleteMessagesEvent deletes the message from metadata and all mailbox buckets.
func (store *Store) deleteMessagesEvent(apiIDs []string) error {
for _, messageID := range apiIDs {
if err := store.cache.Rem(store.UserID(), messageID); err != nil {
logrus.WithError(err).Error("Failed to remove message from cache")
}
}
return store.db.Update(func(tx *bolt.Tx) error {
for _, apiID := range apiIDs {
if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil {

View File

@ -33,7 +33,7 @@ func TestGetAllMessageIDs(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@ -47,7 +47,7 @@ func TestGetMessageFromDB(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{
@ -72,7 +72,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1")
@ -81,12 +81,10 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
// Check non-meta and calculated data are cleared/empty.
a.Equal(t, "", msg.Body)
a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments)
a.Equal(t, int64(-1), msg.Size)
a.Equal(t, "", msg.MIMEType)
a.Equal(t, make(mail.Header), msg.Header)
// Change the calculated data.
wantSize := int64(42)
wantMIMEType := "plain-text"
wantHeader := mail.Header{
"Key": []string{"value"},
@ -94,13 +92,11 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1")
require.Nil(t, err)
require.Nil(t, storeMsg.SetSize(wantSize))
require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader))
// Check calculated data.
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
@ -109,7 +105,6 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
}
@ -118,7 +113,7 @@ func TestDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
@ -129,8 +124,7 @@ func TestDeleteMessage(t *testing.T) {
}
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
require.Nil(t, m.store.createOrUpdateMessageEvent(getTestMessage(id, subject, sender, unread, labelIDs)))
}
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
@ -142,7 +136,6 @@ func getTestMessage(id, subject, sender string, unread bool, labelIDs []string)
Sender: address,
ToList: []*mail.Address{address},
LabelIDs: labelIDs,
Size: 12345,
Body: "body of message",
Attachments: []*pmapi.Attachment{{
ID: "attachment1",
@ -162,7 +155,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(false)
m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil)
@ -181,7 +174,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(false)
m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil)

View File

@ -30,7 +30,7 @@ func TestLoadSaveSyncState(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})

View File

@ -107,6 +107,21 @@ func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) erro
return err
}
// If the client is already unlocked, we can unlock the store cache as well.
if client.IsUnlocked() {
kr, err := client.GetUserKeyRing()
if err != nil {
return err
}
if err := u.store.UnlockCache(kr); err != nil {
return err
}
// NOTE(GODT-1158): If using in-memory cache we probably shouldn't start the watcher?
u.store.StartWatcher()
}
return nil
}

View File

@ -32,7 +32,7 @@ func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@ -50,7 +50,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
// Ignore any sync on background.
@ -76,7 +76,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode())
// MOck change to combined mode.
// Mock change to combined mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
@ -98,7 +98,7 @@ func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@ -115,7 +115,7 @@ func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@ -133,7 +133,7 @@ func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
@ -144,7 +144,7 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
@ -187,7 +187,7 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin("wrong!")

View File

@ -64,7 +64,7 @@ func TestNewUser(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(m, "", testCredentials)

View File

@ -31,7 +31,7 @@ func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
r.Nil(t, user.store.Close())
@ -43,7 +43,7 @@ func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
r.NotNil(t, user.store)

View File

@ -18,13 +18,15 @@
package users
import (
"testing"
r "github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
func testNewUser(t *testing.T, m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker)

View File

@ -20,12 +20,13 @@ package users
import (
"context"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener"
@ -225,6 +226,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password []by
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
}
// will go and unlock cache if not already done
if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user")
}
@ -341,9 +343,6 @@ func (u *Users) ClearData() error {
result = multierror.Append(result, err)
}
// Need to clear imap cache otherwise fetch response will be remembered from previous test.
imapcache.Clear()
return result
}
@ -366,6 +365,7 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
if err := user.closeStore(); err != nil {
log.WithError(err).Error("Failed to close user store")
}
if clearStore {
// Clear cache after closing connections (done in logout).
if err := user.clearStore(); err != nil {
@ -427,6 +427,41 @@ func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy()
}
func (u *Users) EnableCache() error {
// NOTE(GODT-1158): Check for available size before enabling.
return nil
}
func (u *Users) MigrateCache(from, to string) error {
// NOTE(GODT-1158): Is it enough to just close the store? Do we need to force-close the cacher too?
for _, user := range u.users {
if err := user.closeStore(); err != nil {
logrus.WithError(err).Error("Failed to close user's store")
}
}
// Ensure the parent directory exists.
if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
return err
}
return os.Rename(from, to)
}
func (u *Users) DisableCache() error {
// NOTE(GODT-1158): Is it an error if we can't remove a user's cache?
for _, user := range u.users {
if err := user.store.RemoveCache(); err != nil {
logrus.WithError(err).Error("Failed to remove user's message cache")
}
}
return nil
}
// hasUser returns whether the struct currently has a user with ID `id`.
func (u *Users) hasUser(id string) (user *User, ok bool) {
for _, u := range u.users {

View File

@ -49,7 +49,7 @@ func TestUsersFinishLoginNewUser(t *testing.T) {
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
mockAddingConnectedUser(m)
mockAddingConnectedUser(t, m)
mockEventLoopNoAction(m)
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
@ -74,7 +74,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
@ -95,7 +95,7 @@ func TestUsersFinishLoginConnectedUser(t *testing.T) {
// Mock loading connected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
// Mock process of FinishLogin of already connected user.

View File

@ -49,7 +49,7 @@ func TestNewUsersWithConnectedUser(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
}
@ -71,7 +71,7 @@ func TestNewUsersWithUsers(t *testing.T) {
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
}

View File

@ -21,6 +21,7 @@ import (
"fmt"
"io/ioutil"
"os"
"runtime"
"runtime/debug"
"testing"
"time"
@ -28,10 +29,13 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
tests "github.com/ProtonMail/proton-bridge/test"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -42,9 +46,11 @@ func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel)
}
if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel)
}
os.Exit(m.Run())
}
@ -151,7 +157,7 @@ type mocks struct {
clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient
storeCache *store.Cache
storeCache *store.Events
}
func initMocks(t *testing.T) mocks {
@ -178,7 +184,7 @@ func initMocks(t *testing.T) mocks {
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
storeCache: store.NewEvents(cacheFile.Name()),
}
// Called during clean-up.
@ -187,9 +193,20 @@ func initMocks(t *testing.T) mocks {
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
dbFile, err := ioutil.TempFile(t.TempDir(), "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
return store.New(
sentryReporter,
m.PanicHandler,
user,
m.eventListener,
cache.NewInMemoryCache(1<<20),
message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
dbFile.Name(),
m.storeCache,
)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
@ -212,8 +229,8 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(m, testCredentialsSplit)
mockLoadingConnectedUser(t, m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentialsSplit)
mockEventLoopNoAction(m)
return testNewUsers(t, m)
@ -245,7 +262,7 @@ func cleanUpUsersData(b *Users) {
}
}
func mockAddingConnectedUser(m mocks) {
func mockAddingConnectedUser(t *testing.T, m mocks) {
gomock.InOrder(
// Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
@ -256,10 +273,10 @@ func mockAddingConnectedUser(m mocks) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
}
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
func mockLoadingConnectedUser(t *testing.T, m mocks, creds *credentials.Credentials) {
authRefresh := &pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
@ -273,10 +290,10 @@ func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
}
func mockInitConnectedUser(m mocks) {
func mockInitConnectedUser(t *testing.T, m mocks) {
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
@ -286,6 +303,7 @@ func mockInitConnectedUser(m mocks) {
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.pmapiClient.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes(),
)
}

View File

@ -20,10 +20,12 @@ package message
import (
"context"
"io"
"io/ioutil"
"sync"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/pkg/errors"
)
@ -32,11 +34,15 @@ var (
ErrNoSuchKeyRing = errors.New("the keyring to decrypt this message could not be found")
)
const (
BackgroundPriority = 1 << iota
ForegroundPriority
)
type Builder struct {
reqs chan fetchReq
done chan struct{}
jobs map[string]*BuildJob
locker sync.Mutex
pool *pool.Pool
jobs map[string]*Job
lock sync.Mutex
}
type Fetcher interface {
@ -48,111 +54,159 @@ type Fetcher interface {
// NewBuilder creates a new builder which manages the given number of fetch/attach/build workers.
// - fetchWorkers: the number of workers which fetch messages from API
// - attachWorkers: the number of workers which fetch attachments from API.
// - buildWorkers: the number of workers which decrypt/build RFC822 message literals.
//
// NOTE: Each fetch worker spawns a unique set of attachment workers!
// There can therefore be up to fetchWorkers*attachWorkers simultaneous API connections.
//
// The returned builder is ready to handle jobs -- see (*Builder).NewJob for more information.
//
// Call (*Builder).Done to shut down the builder and stop all workers.
func NewBuilder(fetchWorkers, attachWorkers, buildWorkers int) *Builder {
b := newBuilder()
func NewBuilder(fetchWorkers, attachWorkers int) *Builder {
attacherPool := pool.New(attachWorkers, newAttacherWorkFunc())
fetchReqCh, fetchResCh := startFetchWorkers(fetchWorkers, attachWorkers)
buildReqCh, buildResCh := startBuildWorkers(buildWorkers)
fetcherPool := pool.New(fetchWorkers, newFetcherWorkFunc(attacherPool))
go func() {
defer close(fetchReqCh)
for {
select {
case req := <-b.reqs:
fetchReqCh <- req
case <-b.done:
return
}
}
}()
go func() {
defer close(buildReqCh)
for res := range fetchResCh {
if res.err != nil {
b.jobFailure(res.messageID, res.err)
} else {
buildReqCh <- res
}
}
}()
go func() {
for res := range buildResCh {
if res.err != nil {
b.jobFailure(res.messageID, res.err)
} else {
b.jobSuccess(res.messageID, res.literal)
}
}
}()
return b
}
func newBuilder() *Builder {
return &Builder{
reqs: make(chan fetchReq),
done: make(chan struct{}),
jobs: make(map[string]*BuildJob),
pool: fetcherPool,
jobs: make(map[string]*Job),
}
}
// NewJob tells the builder to begin building the message with the given ID.
// The result (or any error which occurred during building) can be retrieved from the returned job when available.
func (b *Builder) NewJob(ctx context.Context, api Fetcher, messageID string) *BuildJob {
return b.NewJobWithOptions(ctx, api, messageID, JobOptions{})
func (builder *Builder) NewJob(ctx context.Context, fetcher Fetcher, messageID string, prio int) (*Job, pool.DoneFunc) {
return builder.NewJobWithOptions(ctx, fetcher, messageID, JobOptions{}, prio)
}
// NewJobWithOptions creates a new job with custom options. See NewJob for more information.
func (b *Builder) NewJobWithOptions(ctx context.Context, api Fetcher, messageID string, opts JobOptions) *BuildJob {
b.locker.Lock()
defer b.locker.Unlock()
func (builder *Builder) NewJobWithOptions(ctx context.Context, fetcher Fetcher, messageID string, opts JobOptions, prio int) (*Job, pool.DoneFunc) {
builder.lock.Lock()
defer builder.lock.Unlock()
if job, ok := b.jobs[messageID]; ok {
return job
if job, ok := builder.jobs[messageID]; ok {
if job.GetPriority() < prio {
job.SetPriority(prio)
}
return job, job.done
}
b.jobs[messageID] = newBuildJob(messageID)
job, done := builder.pool.NewJob(
&fetchReq{
fetcher: fetcher,
messageID: messageID,
options: opts,
},
prio,
)
go func() { b.reqs <- fetchReq{ctx: ctx, api: api, messageID: messageID, opts: opts} }()
buildJob := &Job{
Job: job,
done: done,
}
return b.jobs[messageID]
builder.jobs[messageID] = buildJob
return buildJob, func() {
builder.lock.Lock()
defer builder.lock.Unlock()
// Remove the job from the builder.
delete(builder.jobs, messageID)
// And mark it as done.
done()
}
}
// Done shuts down the builder and stops all workers.
func (b *Builder) Done() {
b.locker.Lock()
defer b.locker.Unlock()
close(b.done)
func (builder *Builder) Done() {
// NOTE(GODT-1158): Stop worker pool.
}
func (b *Builder) jobSuccess(messageID string, literal []byte) {
b.locker.Lock()
defer b.locker.Unlock()
b.jobs[messageID].postSuccess(literal)
delete(b.jobs, messageID)
type fetchReq struct {
fetcher Fetcher
messageID string
options JobOptions
}
func (b *Builder) jobFailure(messageID string, err error) {
b.locker.Lock()
defer b.locker.Unlock()
b.jobs[messageID].postFailure(err)
delete(b.jobs, messageID)
type attachReq struct {
fetcher Fetcher
message *pmapi.Message
}
type Job struct {
*pool.Job
done pool.DoneFunc
}
func (job *Job) GetResult() ([]byte, error) {
res, err := job.Job.GetResult()
if err != nil {
return nil, err
}
return res.([]byte), nil
}
func newAttacherWorkFunc() pool.WorkFunc {
return func(payload interface{}, prio int) (interface{}, error) {
req, ok := payload.(*attachReq)
if !ok {
panic("bad payload type")
}
res := make(map[string][]byte)
for _, att := range req.message.Attachments {
rc, err := req.fetcher.GetAttachment(context.Background(), att.ID)
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
if err := rc.Close(); err != nil {
return nil, err
}
res[att.ID] = b
}
return res, nil
}
}
func newFetcherWorkFunc(attacherPool *pool.Pool) pool.WorkFunc {
return func(payload interface{}, prio int) (interface{}, error) {
req, ok := payload.(*fetchReq)
if !ok {
panic("bad payload type")
}
msg, err := req.fetcher.GetMessage(context.Background(), req.messageID)
if err != nil {
return nil, err
}
attJob, attDone := attacherPool.NewJob(&attachReq{
fetcher: req.fetcher,
message: msg,
}, prio)
defer attDone()
val, err := attJob.GetResult()
if err != nil {
return nil, err
}
attData, ok := val.(map[string][]byte)
if !ok {
panic("bad response type")
}
kr, err := req.fetcher.KeyRingForAddressID(msg.AddressID)
if err != nil {
return nil, ErrNoSuchKeyRing
}
return buildRFC822(kr, msg, attData, req.options)
}
}

View File

@ -1,89 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"sync"
"github.com/pkg/errors"
)
type buildRes struct {
messageID string
literal []byte
err error
}
func newBuildResSuccess(messageID string, literal []byte) buildRes {
return buildRes{
messageID: messageID,
literal: literal,
}
}
func newBuildResFailure(messageID string, err error) buildRes {
return buildRes{
messageID: messageID,
err: err,
}
}
// startBuildWorkers starts the given number of build workers.
// These workers decrypt and build messages into RFC822 literals.
// Two channels are returned:
// - buildReqCh: used to send work items to the worker pool
// - buildResCh: used to receive work results from the worker pool
func startBuildWorkers(buildWorkers int) (chan fetchRes, chan buildRes) {
buildReqCh := make(chan fetchRes)
buildResCh := make(chan buildRes)
go func() {
defer close(buildResCh)
var wg sync.WaitGroup
wg.Add(buildWorkers)
for workerID := 0; workerID < buildWorkers; workerID++ {
go buildWorker(buildReqCh, buildResCh, &wg)
}
wg.Wait()
}()
return buildReqCh, buildResCh
}
func buildWorker(buildReqCh <-chan fetchRes, buildResCh chan<- buildRes, wg *sync.WaitGroup) {
defer wg.Done()
for req := range buildReqCh {
l := log.
WithField("addrID", req.msg.AddressID).
WithField("msgID", req.msg.ID)
if kr, err := req.api.KeyRingForAddressID(req.msg.AddressID); err != nil {
l.WithError(err).Warn("Cannot find keyring for address")
buildResCh <- newBuildResFailure(req.msg.ID, errors.Wrap(ErrNoSuchKeyRing, err.Error()))
} else if literal, err := buildRFC822(kr, req.msg, req.atts, req.opts); err != nil {
l.WithError(err).Warn("Build failed")
buildResCh <- newBuildResFailure(req.msg.ID, err)
} else {
buildResCh <- newBuildResSuccess(req.msg.ID, literal)
}
}
}

View File

@ -1,141 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"context"
"io/ioutil"
"sync"
"github.com/ProtonMail/proton-bridge/pkg/parallel"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type fetchReq struct {
ctx context.Context
api Fetcher
messageID string
opts JobOptions
}
type fetchRes struct {
fetchReq
msg *pmapi.Message
atts [][]byte
err error
}
func newFetchResSuccess(req fetchReq, msg *pmapi.Message, atts [][]byte) fetchRes {
return fetchRes{
fetchReq: req,
msg: msg,
atts: atts,
}
}
func newFetchResFailure(req fetchReq, err error) fetchRes {
return fetchRes{
fetchReq: req,
err: err,
}
}
// startFetchWorkers starts the given number of fetch workers.
// These workers download message and attachment data from API.
// Each fetch worker will use up to the given number of attachment workers to download attachments.
// Two channels are returned:
// - fetchReqCh: used to send work items to the worker pool
// - fetchResCh: used to receive work results from the worker pool
func startFetchWorkers(fetchWorkers, attachWorkers int) (chan fetchReq, chan fetchRes) {
fetchReqCh := make(chan fetchReq)
fetchResCh := make(chan fetchRes)
go func() {
defer close(fetchResCh)
var wg sync.WaitGroup
wg.Add(fetchWorkers)
for workerID := 0; workerID < fetchWorkers; workerID++ {
go fetchWorker(fetchReqCh, fetchResCh, attachWorkers, &wg)
}
wg.Wait()
}()
return fetchReqCh, fetchResCh
}
func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachWorkers int, wg *sync.WaitGroup) {
defer wg.Done()
for req := range fetchReqCh {
msg, atts, err := fetchMessage(req, attachWorkers)
if err != nil {
fetchResCh <- newFetchResFailure(req, err)
} else {
fetchResCh <- newFetchResSuccess(req, msg, atts)
}
}
}
func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
msg, err := req.api.GetMessage(req.ctx, req.messageID)
if err != nil {
return nil, nil, err
}
attList := make([]interface{}, len(msg.Attachments))
for i, att := range msg.Attachments {
attList[i] = att.ID
}
process := func(value interface{}) (interface{}, error) {
rc, err := req.api.GetAttachment(req.ctx, value.(string))
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
if err := rc.Close(); err != nil {
return nil, err
}
return b, nil
}
attData := make([][]byte, len(msg.Attachments))
collect := func(idx int, value interface{}) error {
attData[idx] = value.([]byte) //nolint[forcetypeassert] we wan't to panic here
return nil
}
if err := parallel.RunParallel(attachWorkers, attList, process, collect); err != nil {
return nil, nil, err
}
return msg, attData, nil
}

View File

@ -25,35 +25,3 @@ type JobOptions struct {
AddMessageDate bool // Whether to include message time as X-Pm-Date.
AddMessageIDReference bool // Whether to include the MessageID in References.
}
type BuildJob struct {
messageID string
literal []byte
err error
done chan struct{}
}
func newBuildJob(messageID string) *BuildJob {
return &BuildJob{
messageID: messageID,
done: make(chan struct{}),
}
}
// GetResult returns the build result or any error which occurred during building.
// If the result is not ready yet, it blocks.
func (job *BuildJob) GetResult() ([]byte, error) {
<-job.done
return job.literal, job.err
}
func (job *BuildJob) postSuccess(literal []byte) {
job.literal = literal
close(job.done)
}
func (job *BuildJob) postFailure(err error) {
job.err = err
close(job.done)
}

View File

@ -34,7 +34,7 @@ import (
"github.com/pkg/errors"
)
func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData [][]byte, opts JobOptions) ([]byte, error) {
func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData map[string][]byte, opts JobOptions) ([]byte, error) {
switch {
case len(msg.Attachments) > 0:
return buildMultipartRFC822(kr, msg, attData, opts)
@ -80,7 +80,7 @@ func buildSimpleRFC822(kr *crypto.KeyRing, msg *pmapi.Message, opts JobOptions)
func buildMultipartRFC822(
kr *crypto.KeyRing,
msg *pmapi.Message,
attData [][]byte,
attData map[string][]byte,
opts JobOptions,
) ([]byte, error) {
boundary := newBoundary(msg.ID)
@ -103,13 +103,13 @@ func buildMultipartRFC822(
attachData [][]byte
)
for i, att := range msg.Attachments {
for _, att := range msg.Attachments {
if att.Disposition == pmapi.DispositionInline {
inlineAtts = append(inlineAtts, att)
inlineData = append(inlineData, attData[i])
inlineData = append(inlineData, attData[att.ID])
} else {
attachAtts = append(attachAtts, att)
attachData = append(attachData, attData[i])
attachData = append(attachData, attData[att.ID])
}
}

File diff suppressed because it is too large Load Diff

View File

@ -38,9 +38,10 @@ type BodyStructure map[string]*SectionInfo
// SectionInfo is used to hold data about parts of each section.
type SectionInfo struct {
Header textproto.MIMEHeader
Header []byte
Start, BSize, Size, Lines int
reader io.Reader
isHeaderReadFinished bool
}
// Read will also count the final size of section.
@ -48,9 +49,38 @@ func (si *SectionInfo) Read(p []byte) (n int, err error) {
n, err = si.reader.Read(p)
si.Size += n
si.Lines += bytes.Count(p, []byte("\n"))
si.readHeader(p)
return
}
// readHeader appends read data to Header until empty line is found.
func (si *SectionInfo) readHeader(p []byte) {
if si.isHeaderReadFinished {
return
}
si.Header = append(si.Header, p...)
if i := bytes.Index(si.Header, []byte("\n\r\n")); i > 0 {
si.Header = si.Header[:i+3]
si.isHeaderReadFinished = true
return
}
// textproto works also with simple line ending so we should be liberal
// as well.
if i := bytes.Index(si.Header, []byte("\n\n")); i > 0 {
si.Header = si.Header[:i+2]
si.isHeaderReadFinished = true
}
}
// GetMIMEHeader parses bytes and return MIME header.
func (si *SectionInfo) GetMIMEHeader() (textproto.MIMEHeader, error) {
return textproto.NewReader(bufio.NewReader(bytes.NewReader(si.Header))).ReadMIMEHeader()
}
func NewBodyStructure(reader io.Reader) (structure *BodyStructure, err error) {
structure = &BodyStructure{}
err = structure.Parse(reader)
@ -93,14 +123,15 @@ func (bs *BodyStructure) parseAllChildSections(r io.Reader, currentPath []int, s
bufInfo := bufio.NewReader(info)
tp := textproto.NewReader(bufInfo)
if info.Header, err = tp.ReadMIMEHeader(); err != nil {
tpHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
bodyInfo := &SectionInfo{reader: tp.R}
bodyReader := bufio.NewReader(bodyInfo)
mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type"))
mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
// If multipart, call getAllParts, else read to count lines.
if (strings.HasPrefix(mediaType, "multipart/") || mediaType == rfc822Message) && params["boundary"] != "" {
@ -260,9 +291,9 @@ func (bs *BodyStructure) GetMailHeader() (header textproto.MIMEHeader, err error
}
// GetMailHeaderBytes returns the bytes with main mail header.
// Warning: It can contain extra lines or multipart comment.
func (bs *BodyStructure) GetMailHeaderBytes(wholeMail io.ReadSeeker) (header []byte, err error) {
return bs.GetSectionHeaderBytes(wholeMail, []int{})
// Warning: It can contain extra lines.
func (bs *BodyStructure) GetMailHeaderBytes() (header []byte, err error) {
return bs.GetSectionHeaderBytes([]int{})
}
func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byte, error) {
@ -283,22 +314,21 @@ func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byt
}
// GetSectionHeader returns the mime header of specified section.
func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (header textproto.MIMEHeader, err error) {
func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (textproto.MIMEHeader, error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
return
return nil, err
}
header = info.Header
return
return info.GetMIMEHeader()
}
func (bs *BodyStructure) GetSectionHeaderBytes(wholeMail io.ReadSeeker, sectionPath []int) (header []byte, err error) {
// GetSectionHeaderBytes returns raw header bytes of specified section.
func (bs *BodyStructure) GetSectionHeaderBytes(sectionPath []int) ([]byte, error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
return
return nil, err
}
headerLength := info.Size - info.BSize
return goToOffsetAndReadNBytes(wholeMail, info.Start, headerLength)
return info.Header, nil
}
// IMAPBodyStructure will prepare imap bodystructure recurently for given part.
@ -309,7 +339,12 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body
return
}
mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type"))
tpHeader, err := info.GetMIMEHeader()
if err != nil {
return
}
mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
mediaTypeSep := strings.Split(mediaType, "/")
@ -324,19 +359,19 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body
Lines: uint32(info.Lines),
}
if val := info.Header.Get("Content-ID"); val != "" {
if val := tpHeader.Get("Content-ID"); val != "" {
imapBS.Id = val
}
if val := info.Header.Get("Content-Transfer-Encoding"); val != "" {
if val := tpHeader.Get("Content-Transfer-Encoding"); val != "" {
imapBS.Encoding = val
}
if val := info.Header.Get("Content-Description"); val != "" {
if val := tpHeader.Get("Content-Description"); val != "" {
imapBS.Description = val
}
if val := info.Header.Get("Content-Disposition"); val != "" {
if val := tpHeader.Get("Content-Disposition"); val != "" {
imapBS.Disposition = val
}

View File

@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"io/ioutil"
"net/textproto"
"path/filepath"
"runtime"
"sort"
@ -71,7 +70,9 @@ func TestParseBodyStructure(t *testing.T) {
debug("%10s: %-50s %5s %5s %5s %5s", "section", "type", "start", "size", "bsize", "lines")
for _, path := range paths {
sec := (*bs)[path]
contentType := (*bs)[path].Header.Get("Content-Type")
header, err := sec.GetMIMEHeader()
require.NoError(t, err)
contentType := header.Get("Content-Type")
debug("%10s: %-50s %5d %5d %5d %5d", path, contentType, sec.Start, sec.Size, sec.BSize, sec.Lines)
require.Equal(t, expectedStructure[path], contentType)
}
@ -100,7 +101,9 @@ func TestParseBodyStructurePGP(t *testing.T) {
haveStructure := map[string]string{}
for path := range *bs {
haveStructure[path] = (*bs)[path].Header.Get("Content-Type")
header, err := (*bs)[path].GetMIMEHeader()
require.NoError(t, err)
haveStructure[path] = header.Get("Content-Type")
}
require.Equal(t, expectedStructure, haveStructure)
@ -192,7 +195,7 @@ Content-Type: plain/text
r.NoError(err, debug(wantPath, info, haveBody))
r.Equal(wantBody, string(haveBody), debug(wantPath, info, haveBody))
haveHeader, err := bs.GetSectionHeaderBytes(strings.NewReader(wantMail), wantPath)
haveHeader, err := bs.GetSectionHeaderBytes(wantPath)
r.NoError(err, debug(wantPath, info, haveHeader))
r.Equal(wantHeader, string(haveHeader), debug(wantPath, info, haveHeader))
}
@ -211,7 +214,7 @@ Content-Type: multipart/mixed; boundary="0000MAIN"
bs, err := NewBodyStructure(structReader)
require.NoError(t, err)
haveHeader, err := bs.GetMailHeaderBytes(strings.NewReader(sampleMail))
haveHeader, err := bs.GetMailHeaderBytes()
require.NoError(t, err)
require.Equal(t, wantHeader, haveHeader)
}
@ -533,18 +536,14 @@ func TestBodyStructureSerialize(t *testing.T) {
r := require.New(t)
want := &BodyStructure{
"1": {
Header: textproto.MIMEHeader{
"Content": []string{"type"},
},
Start: 1,
Size: 2,
BSize: 3,
Lines: 4,
Header: []byte("Content: type"),
Start: 1,
Size: 2,
BSize: 3,
Lines: 4,
},
"1.1.1": {
Header: textproto.MIMEHeader{
"X-Pm-Key": []string{"id"},
},
Header: []byte("X-Pm-Key: id"),
Start: 11,
Size: 12,
BSize: 13,
@ -562,3 +561,32 @@ func TestBodyStructureSerialize(t *testing.T) {
(*want)["1.1.1"].reader = nil
r.Equal(want, have)
}
func TestSectionInfoReadHeader(t *testing.T) {
r := require.New(t)
testData := []struct {
wantHeader, mail string
}{
{
"key1: val1\nkey2: val2\n\n",
"key1: val1\nkey2: val2\n\nbody is here\n\nand it is not confused",
},
{
"key1:\n val1\n\n",
"key1:\n val1\n\nbody is here",
},
{
"key1: val1\r\nkey2: val2\r\n\r\n",
"key1: val1\r\nkey2: val2\r\n\r\nbody is here\r\n\r\nand it is not confused",
},
}
for _, td := range testData {
bs, err := NewBodyStructure(strings.NewReader(td.mail))
r.NoError(err, "case %q", td.mail)
haveHeader, err := bs.GetMailHeaderBytes()
r.NoError(err, "case %q", td.mail)
r.Equal(td.wantHeader, string(haveHeader), "case %q", td.mail)
}
}

131
pkg/pchan/pchan.go Normal file
View File

@ -0,0 +1,131 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pchan
import (
"sort"
"sync"
)
type PChan struct {
lock sync.Mutex
items []*Item
ready, done chan struct{}
}
type Item struct {
ch *PChan
val interface{}
prio int
done chan struct{}
}
func (item *Item) Wait() {
<-item.done
}
func (item *Item) GetPriority() int {
item.ch.lock.Lock()
defer item.ch.lock.Unlock()
return item.prio
}
func (item *Item) SetPriority(priority int) {
item.ch.lock.Lock()
defer item.ch.lock.Unlock()
item.prio = priority
sort.Slice(item.ch.items, func(i, j int) bool {
return item.ch.items[i].prio < item.ch.items[j].prio
})
}
func New() *PChan {
return &PChan{
ready: make(chan struct{}),
done: make(chan struct{}),
}
}
func (ch *PChan) Push(val interface{}, prio int) *Item {
defer ch.notify()
return ch.push(val, prio)
}
func (ch *PChan) Pop() (interface{}, int, bool) {
select {
case <-ch.ready:
val, prio := ch.pop()
return val, prio, true
case <-ch.done:
return nil, 0, false
}
}
func (ch *PChan) Close() {
select {
case <-ch.done:
return
default:
close(ch.done)
}
}
func (ch *PChan) push(val interface{}, prio int) *Item {
ch.lock.Lock()
defer ch.lock.Unlock()
done := make(chan struct{})
item := &Item{
ch: ch,
val: val,
prio: prio,
done: done,
}
ch.items = append(ch.items, item)
return item
}
func (ch *PChan) pop() (interface{}, int) {
ch.lock.Lock()
defer ch.lock.Unlock()
sort.Slice(ch.items, func(i, j int) bool {
return ch.items[i].prio < ch.items[j].prio
})
var item *Item
item, ch.items = ch.items[len(ch.items)-1], ch.items[:len(ch.items)-1]
defer close(item.done)
return item.val, item.prio
}
func (ch *PChan) notify() {
go func() { ch.ready <- struct{}{} }()
}

123
pkg/pchan/pchan_test.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pchan
import (
"sort"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPChanConcurrentPush(t *testing.T) {
ch := New()
var wg sync.WaitGroup
// We are going to test with 5 additional goroutines.
wg.Add(5)
// Start 5 concurrent pushes.
go func() { defer wg.Done(); ch.Push(1, 1) }()
go func() { defer wg.Done(); ch.Push(2, 2) }()
go func() { defer wg.Done(); ch.Push(3, 3) }()
go func() { defer wg.Done(); ch.Push(4, 4) }()
go func() { defer wg.Done(); ch.Push(5, 5) }()
// Wait for the items to be pushed.
wg.Wait()
// All 5 should now be ready for popping.
require.Len(t, ch.items, 5)
// They should be popped in priority order.
assert.Equal(t, 5, getValue(t, ch))
assert.Equal(t, 4, getValue(t, ch))
assert.Equal(t, 3, getValue(t, ch))
assert.Equal(t, 2, getValue(t, ch))
assert.Equal(t, 1, getValue(t, ch))
}
func TestPChanConcurrentPop(t *testing.T) {
ch := New()
var wg sync.WaitGroup
// We are going to test with 5 additional goroutines.
wg.Add(5)
// Make a list to store the results in.
var res list
// Start 5 concurrent pops; these consume any items pushed.
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
// Push and block; items should be popped immediately by the waiting goroutines.
ch.Push(1, 1).Wait()
ch.Push(2, 2).Wait()
ch.Push(3, 3).Wait()
ch.Push(4, 4).Wait()
ch.Push(5, 5).Wait()
// Wait for all items to be popped then close the result channel.
wg.Wait()
assert.True(t, sort.IntsAreSorted(res.items))
}
func TestPChanClose(t *testing.T) {
ch := New()
go ch.Push(1, 1)
valOpen, _, okOpen := ch.Pop()
assert.True(t, okOpen)
assert.Equal(t, 1, valOpen)
ch.Close()
valClose, _, okClose := ch.Pop()
assert.False(t, okClose)
assert.Nil(t, valClose)
}
type list struct {
items []int
mut sync.Mutex
}
func (l *list) append(val int) {
l.mut.Lock()
defer l.mut.Unlock()
l.items = append(l.items, val)
}
func getValue(t *testing.T, ch *PChan) int {
val, _, ok := ch.Pop()
assert.True(t, ok)
return val.(int)
}

View File

@ -71,6 +71,7 @@ type Client interface {
GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error)
CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
GetUserKeyRing() (*crypto.KeyRing, error)
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
}

View File

@ -175,7 +175,6 @@ type Message struct {
CCList []*mail.Address
BCCList []*mail.Address
Time int64 // Unix time
Size int64
NumAttachments int
ExpirationTime int64 // Unix time
SpamScore int

View File

@ -362,6 +362,21 @@ func (mr *MockClientMockRecorder) GetPublicKeysForEmail(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeysForEmail", reflect.TypeOf((*MockClient)(nil).GetPublicKeysForEmail), arg0, arg1)
}
// GetUserKeyRing mocks base method.
func (m *MockClient) GetUserKeyRing() (*crypto.KeyRing, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserKeyRing")
ret0, _ := ret[0].(*crypto.KeyRing)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserKeyRing indicates an expected call of GetUserKeyRing.
func (mr *MockClientMockRecorder) GetUserKeyRing() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserKeyRing", reflect.TypeOf((*MockClient)(nil).GetUserKeyRing))
}
// Import mocks base method.
func (m *MockClient) Import(arg0 context.Context, arg1 pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
m.ctrl.T.Helper()

View File

@ -20,6 +20,7 @@ package pmapi
import (
"context"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/getsentry/sentry-go"
"github.com/go-resty/resty/v2"
"github.com/pkg/errors"
@ -138,3 +139,12 @@ func (c *client) CurrentUser(ctx context.Context) (*User, error) {
return c.UpdateUser(ctx)
}
// CurrentUser returns currently active user or user will be updated.
func (c *client) GetUserKeyRing() (*crypto.KeyRing, error) {
if c.userKeyRing == nil {
return nil, errors.New("user keyring is not available")
}
return c.userKeyRing, nil
}

129
pkg/pool/pool.go Normal file
View File

@ -0,0 +1,129 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pool
import "github.com/ProtonMail/proton-bridge/pkg/pchan"
type WorkFunc func(interface{}, int) (interface{}, error)
type DoneFunc func()
type Pool struct {
jobCh *pchan.PChan
}
func New(size int, work WorkFunc) *Pool {
jobCh := pchan.New()
for i := 0; i < size; i++ {
go func() {
for {
val, prio, ok := jobCh.Pop()
if !ok {
return
}
job, ok := val.(*Job)
if !ok {
panic("bad result type")
}
res, err := work(job.req, prio)
if err != nil {
job.postFailure(err)
} else {
job.postSuccess(res)
}
job.waitDone()
}
}()
}
return &Pool{jobCh: jobCh}
}
func (pool *Pool) NewJob(req interface{}, prio int) (*Job, DoneFunc) {
job := newJob(req)
job.setItem(pool.jobCh.Push(job, prio))
return job, job.markDone
}
type Job struct {
req interface{}
res interface{}
err error
item *pchan.Item
ready, done chan struct{}
}
func newJob(req interface{}) *Job {
return &Job{
req: req,
ready: make(chan struct{}),
done: make(chan struct{}),
}
}
func (job *Job) GetResult() (interface{}, error) {
<-job.ready
return job.res, job.err
}
func (job *Job) GetPriority() int {
return job.item.GetPriority()
}
func (job *Job) SetPriority(prio int) {
job.item.SetPriority(prio)
}
func (job *Job) postSuccess(res interface{}) {
defer close(job.ready)
job.res = res
}
func (job *Job) postFailure(err error) {
defer close(job.ready)
job.err = err
}
func (job *Job) setItem(item *pchan.Item) {
job.item = item
}
func (job *Job) markDone() {
select {
case <-job.done:
return
default:
close(job.done)
}
}
func (job *Job) waitDone() {
<-job.done
}

45
pkg/pool/pool_test.go Normal file
View File

@ -0,0 +1,45 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pool_test
import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPool(t *testing.T) {
pool := pool.New(2, func(req interface{}, prio int) (interface{}, error) { return req, nil })
job1, done1 := pool.NewJob("echo", 1)
defer done1()
job2, done2 := pool.NewJob("this", 1)
defer done2()
res2, err := job2.GetResult()
require.NoError(t, err)
res1, err := job1.GetResult()
require.NoError(t, err)
assert.Equal(t, "echo", res1)
assert.Equal(t, "this", res2)
}

View File

@ -0,0 +1,53 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package semaphore
import "sync"
type Semaphore struct {
ch chan struct{}
wg sync.WaitGroup
}
func New(max int) Semaphore {
return Semaphore{ch: make(chan struct{}, max)}
}
func (sem *Semaphore) Lock() {
sem.ch <- struct{}{}
}
func (sem *Semaphore) Unlock() {
<-sem.ch
}
func (sem *Semaphore) Go(fn func()) {
sem.Lock()
sem.wg.Add(1)
go func() {
defer sem.Unlock()
defer sem.wg.Done()
fn()
}()
}
func (sem *Semaphore) Wait() {
sem.wg.Wait()
}

View File

@ -21,11 +21,14 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
@ -61,18 +64,28 @@ func (ctx *TestContext) RestartBridge() error {
}
// newBridgeInstance creates a new bridge instance configured to use the given config/credstore.
// NOTE(GODT-1158): Need some tests with on-disk cache as well! Configurable in feature file or envvar?
func newBridgeInstance(
t *bddT,
locations bridge.Locator,
cache bridge.Cacher,
settings *fakeSettings,
cacheProvider bridge.CacheProvider,
fakeSettings *fakeSettings,
credStore users.CredentialsStorer,
eventListener listener.Listener,
clientManager pmapi.Manager,
) *bridge.Bridge {
sentryReporter := sentry.NewReporter("bridge", constants.Version, useragent.New())
panicHandler := &panicHandler{t: t}
updater := newFakeUpdater()
versioner := newFakeVersioner()
return bridge.New(locations, cache, settings, sentryReporter, panicHandler, eventListener, clientManager, credStore, updater, versioner)
return bridge.New(
locations,
cacheProvider,
fakeSettings,
sentry.NewReporter("bridge", constants.Version, useragent.New()),
&panicHandler{t: t},
eventListener,
cache.NewInMemoryCache(100*(1<<20)),
message.NewBuilder(fakeSettings.GetInt(settings.FetchWorkers), fakeSettings.GetInt(settings.AttachmentWorkers)),
clientManager,
credStore,
newFakeUpdater(),
newFakeVersioner(),
)
}

View File

@ -58,7 +58,7 @@ func (ctx *TestContext) withIMAPServer() {
port := ctx.settings.GetInt(settings.IMAPPortKey)
tls, _ := tls.New(settingsPath).GetConfig()
backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.bridge)
backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.settings, ctx.bridge)
server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener)
go server.ListenAndServe()

View File

@ -125,6 +125,10 @@ func (api *FakePMAPI) Addresses() pmapi.AddressList {
return *api.addresses
}
func (api *FakePMAPI) GetUserKeyRing() (*crypto.KeyRing, error) {
return api.userKeyRing, nil
}
func (api *FakePMAPI) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
return api.addrKeyRing[addrID], nil
}