Introducing API block avoidance based on DoH

This commit is contained in:
Mateusz Markowicz 2020-07-03 11:24:18 +00:00
parent a8edd0b944
commit ca85c626fb
21 changed files with 775 additions and 29 deletions

View File

@ -6,7 +6,7 @@ plugins {
`android-library`
}
libVersion = Version(0, 1, 0)
libVersion = Version(0, 1, 1)
android()

View File

@ -7,7 +7,7 @@ plugins {
`kotlin-serialization`
}
libVersion = Version(0, 1, 0)
libVersion = Version(0, 1, 1)
android()
@ -18,15 +18,19 @@ dependencies {
implementation(
project(Module.kotlinUtil),
project(Module.sharedPreferencesUtil),
project(Module.networkDomain),
`kotlin-jdk7`,
`coroutines-core`,
`serialization`,
`retrofit-kotlin-serialization`,
squareup("retrofit2", "retrofit") version retrofitVersion,
squareup("okhttp3", "logging-interceptor") version okHttpVersion,
dependency("com.jakewharton.retrofit", module = "retrofit2-kotlinx-serialization-converter") version "0.5.0"
dependency("org.minidns", module = "minidns-hla") version "0.3.4",
dependency("commons-codec", module = "commons-codec") version "1.14",
dependency("com.datatheorem.android.trustkit", module = "trustkit") version "1.1.2"
)
testImplementation(

View File

@ -0,0 +1,40 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.network.data
import android.content.Context
import android.content.SharedPreferences
import me.proton.core.network.domain.NetworkPrefs
import me.proton.core.util.android.sharedpreferences.PreferencesProvider
import me.proton.core.util.android.sharedpreferences.list
import me.proton.core.util.android.sharedpreferences.long
import me.proton.core.util.android.sharedpreferences.string
class NetworkPrefsImpl(context: Context) : NetworkPrefs, PreferencesProvider {
override val preferences: SharedPreferences =
context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE)
override var activeAltBaseUrl: String? by string()
override var lastPrimaryApiFail: Long by long(default = Long.MIN_VALUE)
override var alternativeBaseUrls: List<String>? by list()
companion object {
private const val PREFS_NAME = "me.proton.core.network"
}
}

View File

@ -17,8 +17,14 @@
*/
package me.proton.core.network.data
import com.datatheorem.android.trustkit.config.PublicKeyPin
import okhttp3.CertificatePinner
import okhttp3.OkHttpClient
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLContext
import javax.net.ssl.X509TrustManager
/**
* Inits given okhttp builder with pinning.
@ -33,3 +39,39 @@ internal fun initPinning(okBuilder: OkHttpClient.Builder, host: String, pins: Ar
.build()
okBuilder.certificatePinner(pinner)
}
/**
* Inits given okhttp builder with leaf SPKI pinning. Accepts certificate chain iff leaf certificate
* SPKI matches one of the [spkiPins].
*
* @param okBuilder builder to introduce pinning to.
* @param spkiPins list of sha-256 SPKI hashes.
*/
internal fun initSPKIleafPinning(builder: OkHttpClient.Builder, spkiPins: List<String>) {
val trustManager = LeafSPKIPinningTrustManager(spkiPins)
val sslContext = SSLContext.getInstance("TLS")
sslContext.init(null, arrayOf(trustManager), null)
builder.sslSocketFactory(sslContext.socketFactory, trustManager)
builder.hostnameVerifier(HostnameVerifier { _, _ ->
// Verification is based solely on SPKI pinning of leaf certificate
true
})
}
internal class LeafSPKIPinningTrustManager(pinnedSPKIHashes: List<String>) : X509TrustManager {
private val pins: List<PublicKeyPin> = pinnedSPKIHashes.map { PublicKeyPin(it) }
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {
if (PublicKeyPin(chain.first()) !in pins)
throw CertificateException("Pin verification failed")
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<X509Certificate?>?, authType: String?) {
throw CertificateException("Client certificates not supported!")
}
override fun getAcceptedIssuers(): Array<X509Certificate?>? = arrayOfNulls(0)
}

View File

@ -150,6 +150,9 @@ internal class ProtonApiBackend<Api : BaseRetrofitApi>(
}
}
override suspend fun isPotentiallyBlocked(): Boolean =
invokeInternal { ping() }.isPotentialBlocking
companion object {
private const val MAX_ERROR_BYTES = 1_000_000L
}

View File

@ -27,5 +27,20 @@ object Constants {
"sha256/drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE=",
"sha256/YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54=",
"sha256/AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw=")
/**
* SPKI pins for alternative Proton API leaf certificates (SHA-256).
*/
val ALTERNATIVE_API_SPKI_PINS = listOf(
"EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM=",
"iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA=",
"MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA=",
"C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw=")
/**
* DNS over HTTPS services urls.
*/
val DOH_PROVIDERS_URLS =
arrayOf("https://dns11.quad9.net/dns-query/", "https://dns.google/dns-query/")
}

View File

@ -18,22 +18,26 @@
package me.proton.core.network.data.di
import android.content.Context
import android.net.Uri
import com.jakewharton.retrofit2.converter.kotlinx.serialization.asConverterFactory
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ObsoleteCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.plus
import me.proton.core.network.data.NetworkManagerImpl
import me.proton.core.network.data.NetworkPrefsImpl
import me.proton.core.network.data.ProtonApiBackend
import me.proton.core.network.data.doh.DnsOverHttpsProviderRFC8484
import me.proton.core.network.data.initPinning
import me.proton.core.network.data.initSPKIleafPinning
import me.proton.core.network.data.protonApi.BaseRetrofitApi
import me.proton.core.network.domain.ApiClient
import me.proton.core.network.domain.ApiErrorHandler
import me.proton.core.network.domain.ApiManager
import me.proton.core.network.domain.ApiManagerImpl
import me.proton.core.network.domain.DohApiHandler
import me.proton.core.network.domain.DohProvider
import me.proton.core.network.domain.NetworkManager
import me.proton.core.network.domain.NetworkPrefs
import me.proton.core.network.domain.ProtonForceUpdateHandler
import me.proton.core.network.domain.RefreshTokenHandler
import me.proton.core.network.domain.UserData
@ -41,6 +45,7 @@ import me.proton.core.util.kotlin.ProtonCoreConfig
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
import okhttp3.logging.HttpLoggingInterceptor
import java.net.URI
import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
@ -53,12 +58,17 @@ class ApiFactory(
private val baseUrl: String,
private val apiClient: ApiClient,
private val networkManager: NetworkManager,
private val prefs: NetworkPrefs,
scope: CoroutineScope
) {
@OptIn(ObsoleteCoroutinesApi::class)
private val mainScope = scope + newSingleThreadContext("core.network.main")
init {
requireNotNull(URI(baseUrl).host)
}
/**
* Instantiates ApiManager for given [Api] interface and user.
*
@ -76,10 +86,10 @@ class ApiFactory(
certificatePins: Array<String> = Constants.DEFAULT_PINS
): ApiManager<Api> {
val pinningStrategy = { builder: OkHttpClient.Builder ->
initPinning(builder, Uri.parse(baseUrl).host!!, certificatePins)
initPinning(builder, URI(baseUrl).host, certificatePins)
}
val primaryBackend = ProtonApiBackend(
baseUrl.toString(),
baseUrl,
apiClient,
userData,
baseOkHttpClient,
@ -88,10 +98,28 @@ class ApiFactory(
networkManager,
pinningStrategy
)
val dohProvider = DohProvider()
val errorHandlers =
createBaseErrorHandlers<Api>(userData, ::javaMonoClockMs, mainScope) + clientErrorHandlers
return ApiManagerImpl(apiClient, primaryBackend, dohProvider, networkManager, errorHandlers, ::javaMonoClockMs)
val alternativePinningStrategy = { builder: OkHttpClient.Builder ->
initSPKIleafPinning(builder, Constants.ALTERNATIVE_API_SPKI_PINS)
}
val dohApiHandler = DohApiHandler(apiClient, primaryBackend, dohProvider, prefs, ::javaWallClockMs) { baseUrl ->
ProtonApiBackend(
baseUrl,
apiClient,
userData,
baseOkHttpClient,
listOf(jsonConverter),
interfaceClass,
networkManager,
alternativePinningStrategy
)
}
return ApiManagerImpl(
apiClient, primaryBackend, dohApiHandler, networkManager, errorHandlers, ::javaMonoClockMs)
}
internal val jsonConverter =
@ -119,8 +147,18 @@ class ApiFactory(
ProtonForceUpdateHandler(apiClient)
)
private val dohProvider by lazy {
val dohServices = Constants.DOH_PROVIDERS_URLS.map {
DnsOverHttpsProviderRFC8484(baseOkHttpClient, baseUrl, networkManager)
}
DohProvider(baseUrl, apiClient, dohServices, mainScope, prefs, ::javaMonoClockMs)
}
private fun javaMonoClockMs(): Long =
TimeUnit.NANOSECONDS.toMillis(System.nanoTime())
private fun javaWallClockMs(): Long =
System.currentTimeMillis()
}
/**
@ -128,3 +166,9 @@ class ApiFactory(
*/
fun NetworkManager(context: Context): NetworkManager =
NetworkManagerImpl(context.applicationContext)
/**
* Factory method to create persistent storage of preferences for network module.
*/
fun NetworkPrefs(context: Context): NetworkPrefs =
NetworkPrefsImpl(context.applicationContext)

View File

@ -0,0 +1,107 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.network.data.doh
import me.proton.core.network.data.safeApiCall
import me.proton.core.network.domain.ApiResult
import me.proton.core.network.domain.DohService
import me.proton.core.network.domain.NetworkManager
import okhttp3.OkHttpClient
import okhttp3.ResponseBody
import org.apache.commons.codec.binary.Base32
import org.apache.commons.codec.binary.Base64
import org.minidns.dnsmessage.DnsMessage
import org.minidns.dnsmessage.Question
import org.minidns.record.Record
import org.minidns.record.TXT
import retrofit2.Converter
import retrofit2.Retrofit
import java.lang.reflect.Type
import java.net.URI
import java.net.URISyntaxException
import java.util.concurrent.TimeUnit
class DnsOverHttpsProviderRFC8484(
baseOkHttpClient: OkHttpClient,
private val baseUrl: String,
private val networkManager: NetworkManager
) : DohService {
private val api: DnsOverHttpsRetrofitApi
init {
require(baseUrl.endsWith('/'))
val converterFactory = object : Converter.Factory() {
override fun responseBodyConverter(
type: Type,
annotations: Array<Annotation>,
retrofit: Retrofit
): Converter<ResponseBody, *>? = Converter<ResponseBody, DnsMessage> { body ->
body.use {
DnsMessage(it.bytes())
}
}
}
val httpClientBuilder = baseOkHttpClient.newBuilder()
.connectTimeout(TIMEOUT_S, TimeUnit.SECONDS)
.writeTimeout(TIMEOUT_S, TimeUnit.SECONDS)
.readTimeout(TIMEOUT_S, TimeUnit.SECONDS)
val okClient = httpClientBuilder.build()
api = Retrofit.Builder()
.baseUrl(baseUrl)
.client(okClient)
.addConverterFactory(converterFactory)
.build()
.create(DnsOverHttpsRetrofitApi::class.java)
}
override suspend fun getAlternativeBaseUrls(primaryBaseUrl: String): List<String>? {
val primaryURI = URI(primaryBaseUrl)
val base32domain = Base32().encodeAsString(primaryURI.host.toByteArray()).trim('=')
val question = Question("d$base32domain.protonpro.xyz", Record.TYPE.TXT)
val queryMessage = DnsMessage.builder()
.setRecursionDesired(true)
.setQuestion(question)
.build()
val queryMessageBase64 = Base64(true).encodeToString(
queryMessage.toArray())
val response = safeApiCall(networkManager, api) {
api.getServers(baseUrl.removeSuffix("/"), queryMessageBase64)
}
if (response is ApiResult.Success) {
val answers = response.value.answerSection
return try {
answers
.mapNotNull { (it.payload as? TXT)?.text }
.map { URI("https", it, primaryURI.path, null).toString() }
.takeIf { it.isNotEmpty() }
} catch (e: URISyntaxException) {
null
}
}
return null
}
companion object {
private const val TIMEOUT_S = 10L
}
}

View File

@ -0,0 +1,31 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.network.data.doh
import org.minidns.dnsmessage.DnsMessage
import retrofit2.http.GET
import retrofit2.http.Headers
import retrofit2.http.Query
import retrofit2.http.Url
interface DnsOverHttpsRetrofitApi {
@Headers("Accept: application/dns-message")
@GET
suspend fun getServers(@Url url: String, @Query("dns") base64DnsMessage: String): DnsMessage
}

View File

@ -17,10 +17,10 @@
*/
package me.proton.core.network.data
import android.content.Context
import io.mockk.MockKAnnotations
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.impl.annotations.MockK
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
@ -31,6 +31,7 @@ import kotlinx.coroutines.test.runBlockingTest
import me.proton.core.network.data.di.ApiFactory
import me.proton.core.network.data.util.MockApiClient
import me.proton.core.network.data.util.MockNetworkManager
import me.proton.core.network.data.util.MockNetworkPrefs
import me.proton.core.network.data.util.MockUserData
import me.proton.core.network.data.util.TestResult
import me.proton.core.network.data.util.TestRetrofitApi
@ -38,47 +39,68 @@ import me.proton.core.network.domain.ApiBackend
import me.proton.core.network.domain.ApiManager
import me.proton.core.network.domain.ApiManagerImpl
import me.proton.core.network.domain.ApiResult
import me.proton.core.network.domain.DohApiHandler
import me.proton.core.network.domain.DohProvider
import me.proton.core.network.domain.DohService
import me.proton.core.network.domain.NetworkPrefs
import me.proton.core.network.domain.NetworkStatus
import me.proton.core.network.domain.ProtonForceUpdateHandler
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertTrue
@ExperimentalCoroutinesApi
internal class ApiManagerTests {
private val baseUrl = "https://primary.com/"
private val proxy1url = "https://proxy1.com/"
private val success5foo = ApiResult.Success(TestResult(5, "foo"))
private lateinit var apiFactory: ApiFactory
private lateinit var apiClient: MockApiClient
private lateinit var user: MockUserData
private lateinit var networkManager: MockNetworkManager
private lateinit var apiManager: ApiManager<TestRetrofitApi>
private lateinit var dohApiHandler: DohApiHandler<TestRetrofitApi>
@MockK
private lateinit var backend: ProtonApiBackend<TestRetrofitApi>
@MockK private lateinit var backend: ProtonApiBackend<TestRetrofitApi>
@MockK private lateinit var altBackend1: ProtonApiBackend<TestRetrofitApi>
@MockK private lateinit var dohService: DohService
private var time = 0L
private var wallTime = 0L
private lateinit var prefs: NetworkPrefs
@BeforeTest
fun before() {
MockKAnnotations.init(this)
time = 0L
prefs = MockNetworkPrefs()
apiClient = MockApiClient()
networkManager = MockNetworkManager()
networkManager.networkStatus = NetworkStatus.Unmetered
val scope = CoroutineScope(TestCoroutineDispatcher())
apiFactory = ApiFactory("https://example.com/", apiClient, networkManager, scope)
apiFactory = ApiFactory(baseUrl, apiClient, networkManager, prefs, scope)
user = MockUserData()
val dohProvider = DohProvider()
coEvery { dohService.getAlternativeBaseUrls(any()) } returns listOf(proxy1url)
val dohProvider = DohProvider(baseUrl, apiClient, listOf(dohService), scope, prefs, ::time)
dohApiHandler = DohApiHandler(apiClient, backend, dohProvider, prefs, ::wallTime) {
altBackend1
}
ApiManagerImpl.failRequestBeforeTimeMs = Long.MIN_VALUE
apiManager = ApiManagerImpl(apiClient, backend, dohProvider, networkManager,
apiManager = ApiManagerImpl(apiClient, backend, dohApiHandler, networkManager,
apiFactory.createBaseErrorHandlers(user, ::time, scope), ::time)
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Success(TestResult(5, "foo"))
every { altBackend1.baseUrl } returns proxy1url
}
@Test
@ -249,4 +271,105 @@ internal class ApiManagerTests {
val result3 = apiManager.invoke { test() }
assertTrue(result3 is ApiResult.Success)
}
@Test
fun `basic doh scenario`() = runBlockingTest {
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Error.Timeout(true)
coEvery { backend.isPotentiallyBlocked() } returns true
coEvery { altBackend1.invoke<TestResult>(any()) } returns success5foo
val result1 = apiManager.invoke { test() }
assertTrue(result1 is ApiResult.Success)
assertEquals(altBackend1, dohApiHandler.activeAltBackend)
val result2 = apiManager.invoke { test() }
assertTrue(result2 is ApiResult.Success)
// There was no call to primary backend as altBackend1 is active
coVerify(exactly = 1) {
backend.invoke<TestResult>(any())
}
// After proxy is no longer valid, attempt primary backend again
wallTime += apiClient.proxyValidityPeriodMs
assertNull(dohApiHandler.activeAltBackend)
apiManager.invoke { test() }
coVerify(exactly = 2) {
backend.invoke<TestResult>(any())
}
}
@Test
fun `test doh ping ok`() = runBlockingTest {
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Error.Connection(true)
// when isPotentiallyBlocked == false DoH logic won't be applied
coEvery { backend.isPotentiallyBlocked() } returns false
coEvery { altBackend1.invoke<TestResult>(any()) } returns success5foo
val result = apiManager.invoke { test() }
// Accept the error when pinging primary api succeeds
assertTrue(result is ApiResult.Error.Connection)
}
@Test
fun `test doh off`() = runBlockingTest {
apiClient.shouldUseDoh = false
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Error.Connection(true)
coEvery { backend.isPotentiallyBlocked() } returns true
coEvery { altBackend1.invoke<TestResult>(any()) } returns success5foo
val result = apiManager.invoke { test() }
// Doh is off, no proxy should be called
assertTrue(result is ApiResult.Error.Connection)
coVerify(exactly = 0) { altBackend1.invoke<TestResult>(any()) }
}
@Test
fun `test no DoH on client error`() = runBlockingTest {
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Error.Http(400, "")
coEvery { backend.isPotentiallyBlocked() } returns true
coEvery { altBackend1.invoke<TestResult>(any()) } returns success5foo
val result = apiManager.invoke { test() }
// HTTP 400 shouldn't trigger DoH
assertTrue(result is ApiResult.Error.Http)
coVerify(exactly = 0) { altBackend1.invoke<TestResult>(any()) }
}
@Test
fun `test DoH timeout`() = runBlockingTest {
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Error.Connection(true)
coEvery { backend.isPotentiallyBlocked() } returns true
coEvery { altBackend1.invoke<TestResult>(any()) } coAnswers {
delay(apiClient.dohTimeoutMs)
success5foo
}
val result = apiManager.invoke { test() }
assertTrue(result is ApiResult.Error.Timeout)
}
@Test
fun `test doh proxy refresh throttling`() = runBlockingTest {
coEvery { backend.invoke<TestResult>(any()) } returns ApiResult.Error.Connection(true)
coEvery { backend.isPotentiallyBlocked() } returns true
coEvery { altBackend1.invoke<TestResult>(any()) } returns ApiResult.Error.Connection(true)
val result = apiManager.invoke { test() }
assertTrue(result is ApiResult.Error.Connection)
val result2 = apiManager.invoke { test() }
assertTrue(result2 is ApiResult.Error.Connection)
time += DohProvider.MIN_REFRESH_INTERVAL_MS
val result3 = apiManager.invoke { test() }
assertTrue(result3 is ApiResult.Error.Connection)
coVerify(exactly = 2) {
dohService.getAlternativeBaseUrls(any())
}
}
}

View File

@ -0,0 +1,91 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.network.data
import io.mockk.MockKAnnotations
import io.mockk.every
import io.mockk.impl.annotations.MockK
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.runBlocking
import me.proton.core.network.data.doh.DnsOverHttpsProviderRFC8484
import me.proton.core.network.domain.NetworkManager
import okhttp3.OkHttpClient
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import okio.Buffer
import org.minidns.dnsmessage.DnsMessage
import org.minidns.record.Record
import org.minidns.record.TXT
import kotlin.test.AfterTest
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
// Can't use runBlockingTest with MockWebServer. See:
// https://github.com/square/retrofit/issues/3330
// https://github.com/Kotlin/kotlinx.coroutines/issues/1204
@ExperimentalCoroutinesApi
internal class DohProviderTests {
private val domain = "example.com"
lateinit var webServer: MockWebServer
lateinit var dohProvider: DnsOverHttpsProviderRFC8484
private var isNetworkAvailable = true
@MockK
lateinit var networkManager: NetworkManager
@BeforeTest
fun before() {
MockKAnnotations.init(this)
every { networkManager.isConnectedToNetwork() } returns isNetworkAvailable
isNetworkAvailable = true
webServer = MockWebServer()
val okHttpClient = OkHttpClient.Builder().build()
dohProvider = DnsOverHttpsProviderRFC8484(
okHttpClient,
webServer.url("/").toString(),
networkManager
)
}
@AfterTest
fun after() {
webServer.shutdown()
}
@Test
fun `test ok call`() = runBlocking {
val txtBytes = "proxy.com".toByteArray()
val txtBlob = byteArrayOf(txtBytes.size.toByte(), *txtBytes)
val dnsMessage = DnsMessage.builder()
.addAnswer(Record("", Record.TYPE.TXT, Record.CLASS.IN, 0L, TXT(txtBlob), false))
.build()
val response = MockResponse()
.setResponseCode(200)
.addHeader("Content-Type", "application/dns-message")
.setBody(Buffer().write(dnsMessage.toArray()))
webServer.enqueue(response)
val result = dohProvider.getAlternativeBaseUrls("https://$domain/")!!
assertEquals(listOf("https://proxy.com/"), result)
assertEquals("application/dns-message", webServer.takeRequest().headers["Accept"])
}
}

View File

@ -17,7 +17,6 @@
*/
package me.proton.core.network.data
import android.content.Context
import io.mockk.MockKAnnotations
import io.mockk.every
import io.mockk.impl.annotations.MockK
@ -27,12 +26,14 @@ import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestCoroutineDispatcher
import me.proton.core.network.data.di.ApiFactory
import me.proton.core.network.data.util.MockApiClient
import me.proton.core.network.data.util.MockNetworkPrefs
import me.proton.core.network.data.util.MockUserData
import me.proton.core.network.data.util.TestRetrofitApi
import me.proton.core.network.data.util.prepareResponse
import me.proton.core.network.domain.ApiManager
import me.proton.core.network.domain.ApiResult
import me.proton.core.network.domain.NetworkManager
import me.proton.core.network.domain.NetworkPrefs
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import retrofit2.converter.scalars.ScalarsConverterFactory
@ -54,12 +55,15 @@ internal class ProtonApiBackendTests {
@MockK
lateinit var networkManager: NetworkManager
private lateinit var prefs: NetworkPrefs
@BeforeTest
fun before() {
MockKAnnotations.init(this)
val client = MockApiClient()
val scope = CoroutineScope(TestCoroutineDispatcher())
apiFactory = ApiFactory("https://example.com/", client, networkManager, scope)
prefs = MockNetworkPrefs()
apiFactory = ApiFactory("https://example.com/", client, networkManager, prefs, scope)
val user = MockUserData()
every { networkManager.isConnectedToNetwork() } returns isNetworkAvailable

View File

@ -18,6 +18,7 @@
package me.proton.core.network.data.util
import me.proton.core.network.domain.ApiClient
import me.proton.core.network.domain.NetworkPrefs
import me.proton.core.network.domain.UserData
class MockUserData : UserData {
@ -48,3 +49,9 @@ class MockApiClient : ApiClient {
forceUpdated = true
}
}
class MockNetworkPrefs : NetworkPrefs {
override var activeAltBaseUrl: String? = null
override var lastPrimaryApiFail: Long = Long.MIN_VALUE
override var alternativeBaseUrls: List<String>? = null
}

View File

@ -7,7 +7,7 @@ plugins {
`kotlin-serialization`
}
libVersion = Version(0, 1, 0)
libVersion = Version(0, 1, 1)
dependencies {

View File

@ -36,6 +36,12 @@ interface ApiBackend<Api> {
suspend fun refreshTokens(): ApiResult<Tokens>
data class Tokens(val refresh: String, val access: String)
/**
* Lightweight call checking if API might be blocked.
* @return [true] if API is not reachable and error might indicate blocking.
*/
suspend fun isPotentiallyBlocked(): Boolean
/**
* Makes API call defined with [block] lambda.
*

View File

@ -17,6 +17,8 @@
*/
package me.proton.core.network.domain
import java.util.concurrent.TimeUnit
/**
* Represents the client of the library. Enables 2-way communication between the lib and the client.
*/
@ -49,6 +51,21 @@ interface ApiClient {
*/
val dohTimeoutMs: Long get() = 60_000L
/**
* How long alternative API proxy will be used before primary API is attempted again.
*/
val proxyValidityPeriodMs: Long get() = TimeUnit.DAYS.toMillis(1)
/**
* Timeout for DoH queries.
*/
val dohServiceTimeoutMs: Long get() = TimeUnit.SECONDS.toMillis(10)
/**
* Timeout for refreshing proxy list (can span multiple DoH queries).
*/
val dohProxyRefreshTimeoutMs: Long get() = TimeUnit.SECONDS.toMillis(30)
/**
* Retry count for exponential backoff.
*/

View File

@ -27,9 +27,8 @@ import kotlin.random.Random
*
* @param Api API interface
* @property client [ApiClient] for client-library integration.
* @property primaryBackend [ApiBackend] for regular API calls, when it fails [dohProvider] will
* be used to deliver [ApiBackend] for proxies.
* @property dohProvider [DohProvider] instance to deliver alternative [ApiBackend]s.
* @property primaryBackend [ApiBackend] for regular API calls.
* @property dohApiHandler [DohApiHandler] instance to handle DoH logic for API calls.
* @property networkManager [NetworkManager] for connectivity checks.
* @property errorHandlers list of [ApiErrorHandler] for call error recovery.
* @property monoClockMs Monotonic clock with millisecond resolution.
@ -37,7 +36,7 @@ import kotlin.random.Random
class ApiManagerImpl<Api>(
private val client: ApiClient,
private val primaryBackend: ApiBackend<Api>,
private val dohProvider: DohProvider,
private val dohApiHandler: DohApiHandler<Api>,
private val networkManager: NetworkManager,
private val errorHandlers: List<ApiErrorHandler<Api>>,
private val monoClockMs: () -> Long
@ -61,18 +60,14 @@ class ApiManagerImpl<Api>(
forceNoRetryOnConnectionErrors ->
handledCall(primaryBackend, call)
client.shouldUseDoh ->
ApiResult.withTimeout(client.dohTimeoutMs) {
callWithDoH(call)
ApiResult.withTimeout(client.dohProxyRefreshTimeoutMs) {
dohApiHandler(::handledCall, call)
}
else ->
callWithBackoff(call)
}
}
private suspend fun <T> callWithDoH(call: ApiManager.Call<Api, T>): ApiResult<T> {
return handledCall(primaryBackend, call) // TODO: DoH logic
}
private suspend fun <T> callWithBackoff(call: ApiManager.Call<Api, T>): ApiResult<T> {
var retryCount = 0
while (true) {

View File

@ -0,0 +1,114 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.network.domain
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.withTimeoutOrNull
/**
* Responsible for making an API call according to DoH feature logic: when our API seems blocked
* will refresh the list of alternative proxies with DoH queries and try to repeat a call on those
* proxies.
*/
class DohApiHandler<Api>(
private val apiClient: ApiClient,
private val primaryBackend: ApiBackend<Api>,
private val dohProvider: DohProvider,
private val prefs: NetworkPrefs,
private val wallClockMs: () -> Long,
private val createAltBackend: (baseUrl: String) -> ApiBackend<Api>
) {
// Active proxy backend or null if we should use our primary backend.
var activeAltBackend: ApiBackend<Api>? = null
get() {
// If alt backend is outdated reset it so that primary backend is attempted.
if (wallClockMs() - prefs.lastPrimaryApiFail >= apiClient.proxyValidityPeriodMs) {
field = null
} else if (field == null) {
val baseUrl = prefs.activeAltBaseUrl
if (baseUrl != null)
activeAltBackend = createAltBackend(baseUrl)
}
return field
}
set(value) {
field = value
prefs.activeAltBaseUrl = value?.baseUrl
}
/**
* Makes an API [call] according to DoH feature logic.
* @param callHandler Function that should be used to make a call with a reachable
* backend.
*/
suspend operator fun <T> invoke(
callHandler: suspend (ApiBackend<Api>, ApiManager.Call<Api, T>) -> ApiResult<T>,
call: ApiManager.Call<Api, T>
): ApiResult<T> {
val activeBackend = activeAltBackend ?: primaryBackend
val result = callHandler(activeBackend, call)
return if (!result.isPotentialBlocking)
result
else coroutineScope {
// Ping primary backend (to make sure failure wasn't a random network error rather than
// an actual block) parallel with refreshing proxy list
val isPotentiallyBlockedAsync = async {
primaryBackend.isPotentiallyBlocked()
}
val dohRefresh = async {
withTimeoutOrNull(apiClient.dohProxyRefreshTimeoutMs) {
dohProvider.refreshAlternatives()
}
}
// If ping on primary api succeeded don't fallback to proxy
val isPotentiallyBlocked = isPotentiallyBlockedAsync.await()
if (isPotentiallyBlocked) {
dohRefresh.await()
if (activeBackend == primaryBackend)
prefs.lastPrimaryApiFail = wallClockMs()
else
activeAltBackend = null
callWithAlternatives(callHandler, call) ?: result
} else {
dohRefresh.cancel()
activeAltBackend = null
result
}
}
}
private suspend fun <T> callWithAlternatives(
callHandler: suspend (ApiBackend<Api>, ApiManager.Call<Api, T>) -> ApiResult<T>,
call: ApiManager.Call<Api, T>
): ApiResult<T>? {
val alternatives = prefs.alternativeBaseUrls?.shuffled()
alternatives?.forEach { baseUrl ->
val backend = createAltBackend(baseUrl)
val result = callHandler(backend, call)
if (!result.isPotentialBlocking) {
activeAltBackend = backend
return result
}
}
return null
}
}

View File

@ -17,5 +17,58 @@
*/
package me.proton.core.network.domain
// TODO: refreshing/storing proxy list for DoH
class DohProvider
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeoutOrNull
import java.util.concurrent.TimeUnit
/**
* Gets the list of alternative baseUrls for Proton API.
*/
interface DohService {
suspend fun getAlternativeBaseUrls(primaryBaseUrl: String): List<String>?
}
/**
* Refreshes alternative urls for [baseUrl] using given list of DoH services ([dohServices]). Makes
* sure that only one refresh operation takes place at one time for given baseUrl. Single instance
* should exist per baseUrl.
*/
class DohProvider(
private val baseUrl: String,
private val apiClient: ApiClient,
private val dohServices: List<DohService>,
private val networkMainScope: CoroutineScope,
private val prefs: NetworkPrefs,
private val monoClockMs: () -> Long
) {
private var ongoingRefresh: Deferred<Unit>? = null
private var lastRefresh = Long.MIN_VALUE
suspend fun refreshAlternatives() = withContext(networkMainScope.coroutineContext) {
if (monoClockMs() >= lastRefresh + MIN_REFRESH_INTERVAL_MS) {
ongoingRefresh = ongoingRefresh ?: async(start = CoroutineStart.LAZY) {
for (service in dohServices) {
val success = withTimeoutOrNull(apiClient.dohServiceTimeoutMs) {
val result = service.getAlternativeBaseUrls(baseUrl)
if (result != null)
prefs.alternativeBaseUrls = result
result != null
}
if (success == true)
break
}
lastRefresh = monoClockMs()
ongoingRefresh = null
}
ongoingRefresh!!.join()
}
}
companion object {
val MIN_REFRESH_INTERVAL_MS = TimeUnit.MINUTES.toMillis(10)
}
}

View File

@ -0,0 +1,34 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.network.domain
/**
* Persistently stored preferences for network module.
*/
interface NetworkPrefs {
/** Base url of currently active Proton API proxy. [null] if primary API is in use */
var activeAltBaseUrl: String?
/** Timestamp for last primary API fail. After a defined period (see
* [ApiClient.proxyValidityPeriodMs]) primary API will be attempted again. */
var lastPrimaryApiFail: Long
/** List of base urls for Proton API proxies returned in the last DoH query. */
var alternativeBaseUrls: List<String>?
}

View File

@ -54,6 +54,20 @@ fun SharedPreferencesDelegationExtensions.int(key: String? = null) = optional(
setter = { k, v -> edit { putInt(k, v) } }
)
/** @return ( by Delegation ) Mutable Property of type [Long] */
fun SharedPreferencesDelegationExtensions.long(default: Long, key: String? = null) = required(
explicitKey = key,
getter = { k -> getLong(k, default) },
setter = { k, v -> edit { putLong(k, v) } }
)
/** @return ( by Delegation ) Mutable Property of type Nullable [Long] */
fun SharedPreferencesDelegationExtensions.long(key: String? = null) = optional(
explicitKey = key,
getter = { k -> getLong(k) },
setter = { k, v -> edit { putLong(k, v) } }
)
/** @return ( by Delegation ) Mutable Property of type [String] */
fun SharedPreferencesDelegationExtensions.string(default: String, key: String? = null) = required(
explicitKey = key,
@ -149,6 +163,8 @@ fun SharedPreferences.boolean(default: Boolean, key: String? = null) = ext.boole
fun SharedPreferences.boolean(key: String? = null) = ext.boolean(key)
fun SharedPreferences.int(default: Int, key: String? = null) = ext.int(default, key)
fun SharedPreferences.int(key: String? = null) = ext.int(key)
fun SharedPreferences.long(default: Long, key: String? = null) = ext.long(default, key)
fun SharedPreferences.long(key: String? = null) = ext.long(key)
fun SharedPreferences.string(default: String, key: String? = null) = ext.string(default, key)
fun SharedPreferences.string(key: String? = null) = ext.string(key)
@NeedSerializable