Integrated AccountManager with modules.

This commit is contained in:
Neil Marietta 2020-10-28 19:32:21 +01:00
parent 755dc1b8a0
commit bf33818a5a
59 changed files with 1090 additions and 773 deletions

View File

@ -28,9 +28,11 @@ import dagger.hilt.android.qualifiers.ApplicationContext
import me.proton.core.account.data.repository.AccountRepositoryImpl
import me.proton.core.account.domain.repository.AccountRepository
import me.proton.core.accountmanager.data.AccountManagerImpl
import me.proton.core.accountmanager.data.SessionManagerImpl
import me.proton.core.accountmanager.data.db.AccountManagerDatabase
import me.proton.core.accountmanager.domain.AccountManager
import me.proton.core.auth.domain.AccountWorkflowHandler
import me.proton.core.auth.domain.repository.AuthRepository
import me.proton.core.data.crypto.KeyStoreStringCrypto
import me.proton.core.data.crypto.StringCrypto
import me.proton.core.domain.entity.Product
@ -63,8 +65,19 @@ object AccountManagerModule {
@Provides
@Singleton
fun provideAccountManagerImpl(product: Product, accountRepository: AccountRepository): AccountManagerImpl =
AccountManagerImpl(product, accountRepository)
fun provideAccountManagerImpl(
product: Product,
accountRepository: AccountRepository,
authRepository: AuthRepository
): AccountManagerImpl =
AccountManagerImpl(product, accountRepository, authRepository)
@Provides
@Singleton
fun provideSessionManagerImpl(
accountRepository: AccountRepository
): SessionManagerImpl =
SessionManagerImpl(accountRepository)
}
@Module
@ -78,8 +91,8 @@ interface AccountManagerBindModule {
fun bindAccountWorkflowHandler(accountManagerImpl: AccountManagerImpl): AccountWorkflowHandler
@Binds
fun bindSessionProvider(accountManagerImpl: AccountManagerImpl): SessionProvider
fun bindSessionProvider(sessionManagerImpl: SessionManagerImpl): SessionProvider
@Binds
fun bindSessionListener(accountManagerImpl: AccountManagerImpl): SessionListener
fun bindSessionListener(sessionManagerImpl: SessionManagerImpl): SessionListener
}

View File

@ -21,34 +21,32 @@ package me.proton.core.accountmanager.data
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.account.domain.entity.isReady
import me.proton.core.account.domain.entity.isSecondFactorNeeded
import me.proton.core.account.domain.repository.AccountRepository
import me.proton.core.accountmanager.domain.AccountManager
import me.proton.core.accountmanager.domain.onSessionStateChanged
import me.proton.core.auth.domain.AccountWorkflowHandler
import me.proton.core.auth.domain.repository.AuthRepository
import me.proton.core.domain.arch.extension.onEntityChanged
import me.proton.core.domain.entity.Product
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionListener
import me.proton.core.network.domain.session.SessionProvider
import java.util.concurrent.ConcurrentHashMap
class AccountManagerImpl constructor(
product: Product,
private val accountRepository: AccountRepository
) : AccountManager(product), AccountWorkflowHandler, SessionProvider, SessionListener {
private val humanVerificationDetails: ConcurrentHashMap<SessionId, HumanVerificationDetails?> = ConcurrentHashMap()
private val accountRepository: AccountRepository,
private val authRepository: AuthRepository
) : AccountManager(product), AccountWorkflowHandler {
private suspend fun removeSession(sessionId: SessionId) {
// TODO: Revoke session (fire and forget, need Auth module).
authRepository.revokeSession(sessionId)
accountRepository.deleteSession(sessionId)
}
@ -100,7 +98,7 @@ class AccountManagerImpl constructor(
override fun onHumanVerificationNeeded(): Flow<Pair<Account, HumanVerificationDetails?>> =
onSessionStateChanged(SessionState.HumanVerificationNeeded)
.map { it to it.sessionId?.let { id -> humanVerificationDetails[id] } }
.map { it to it.sessionId?.let { id -> accountRepository.getHumanVerificationDetails(id) } }
override fun getPrimaryUserId(): Flow<UserId?> =
accountRepository.getPrimaryUserId()
@ -111,30 +109,31 @@ class AccountManagerImpl constructor(
// region AccountWorkflowHandler
override suspend fun handleSession(account: Account, session: Session) {
val initializing = when {
account.state == AccountState.TwoPassModeNeeded -> true
account.sessionState == SessionState.SecondFactorNeeded -> true
else -> false
}
val updatedAccount = if (initializing) account.copy(state = AccountState.Initializing) else account
accountRepository.createOrUpdateAccountSession(updatedAccount, session)
// Account state must be != Ready if SecondFactorNeeded.
val state = if (account.isReady() && account.isSecondFactorNeeded()) AccountState.NotReady else account.state
accountRepository.createOrUpdateAccountSession(account.copy(state = state), session)
}
override suspend fun handleTwoPassModeSuccess(sessionId: SessionId) {
accountRepository.updateAccountState(sessionId, AccountState.TwoPassModeSuccess)
accountRepository.updateAccountState(sessionId, AccountState.Ready)
accountRepository.getAccountOrNull(sessionId)?.let { account ->
if (account.sessionState == SessionState.Authenticated)
accountRepository.updateAccountState(sessionId, AccountState.Ready)
}
}
override suspend fun handleTwoPassModeFailed(sessionId: SessionId) {
accountRepository.updateAccountState(sessionId, AccountState.TwoPassModeFailed)
disableAccount(sessionId)
}
override suspend fun handleSecondFactorSuccess(sessionId: SessionId, updatedScopes: List<String>) {
accountRepository.updateSessionScopes(sessionId, updatedScopes)
accountRepository.updateSessionState(sessionId, SessionState.SecondFactorSuccess)
accountRepository.updateSessionState(sessionId, SessionState.Authenticated)
accountRepository.updateAccountState(sessionId, AccountState.Ready)
accountRepository.getAccountOrNull(sessionId)?.let { account ->
if (account.state != AccountState.TwoPassModeNeeded)
accountRepository.updateAccountState(sessionId, AccountState.Ready)
}
}
override suspend fun handleSecondFactorFailed(sessionId: SessionId) {
@ -152,51 +151,7 @@ class AccountManagerImpl constructor(
override suspend fun handleHumanVerificationFailed(sessionId: SessionId) {
accountRepository.updateSessionHeaders(sessionId, null, null)
accountRepository.updateSessionState(sessionId, SessionState.HumanVerificationFailed)
disableAccount(sessionId)
}
// endregion
// region SessionListener
override suspend fun onSessionTokenRefreshed(session: Session) {
accountRepository.updateSessionToken(session.sessionId, session.accessToken, session.refreshToken)
accountRepository.updateSessionState(session.sessionId, SessionState.Authenticated)
}
override suspend fun onSessionForceLogout(session: Session) {
accountRepository.updateSessionState(session.sessionId, SessionState.ForceLogout)
accountRepository.getAccountOrNull(session.sessionId)?.let { account ->
accountRepository.updateAccountState(account.userId, AccountState.Disabled)
}
}
override suspend fun onHumanVerificationNeeded(
session: Session,
details: HumanVerificationDetails?
): SessionListener.HumanVerificationResult {
humanVerificationDetails[session.sessionId] = details
accountRepository.updateSessionState(session.sessionId, SessionState.HumanVerificationNeeded)
accountRepository.updateAccountState(session.sessionId, AccountState.Initializing)
// Wait for HumanVerification Success or Failure.
val state = accountRepository.getAccount(session.sessionId)
.map { it?.sessionState }
.filter { it == SessionState.HumanVerificationSuccess || it == SessionState.HumanVerificationFailed }
.first()
return when (state) {
null -> SessionListener.HumanVerificationResult.Failure
SessionState.HumanVerificationSuccess -> SessionListener.HumanVerificationResult.Success
else -> SessionListener.HumanVerificationResult.Failure
}
}
// endregion
// region SessionProvider
override fun getSession(sessionId: SessionId): Session? = accountRepository.getSessionOrNull(sessionId)
// endregion
}

View File

@ -0,0 +1,84 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.accountmanager.data
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.account.domain.repository.AccountRepository
import me.proton.core.accountmanager.domain.SessionManager
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionListener
class SessionManagerImpl(
private val accountRepository: AccountRepository
) : SessionManager {
// region SessionListener
override suspend fun onSessionTokenRefreshed(session: Session) {
accountRepository.updateSessionToken(session.sessionId, session.accessToken, session.refreshToken)
accountRepository.updateSessionState(session.sessionId, SessionState.Authenticated)
}
override suspend fun onSessionForceLogout(session: Session) {
accountRepository.updateSessionState(session.sessionId, SessionState.ForceLogout)
accountRepository.getAccountOrNull(session.sessionId)?.let { account ->
accountRepository.updateAccountState(account.userId, AccountState.Disabled)
}
}
override suspend fun onHumanVerificationNeeded(
session: Session,
details: HumanVerificationDetails?
): SessionListener.HumanVerificationResult {
accountRepository.setHumanVerificationDetails(session.sessionId, details)
accountRepository.updateAccountState(session.sessionId, AccountState.NotReady)
accountRepository.updateSessionState(session.sessionId, SessionState.HumanVerificationNeeded)
// Wait for HumanVerification Success or Failure.
val state = accountRepository.getAccount(session.sessionId)
.map { it?.sessionState }
.filter { it == SessionState.HumanVerificationSuccess || it == SessionState.HumanVerificationFailed }
.first()
return when (state) {
null -> SessionListener.HumanVerificationResult.Failure
SessionState.HumanVerificationSuccess -> SessionListener.HumanVerificationResult.Success
else -> SessionListener.HumanVerificationResult.Failure
}
}
// endregion
// region SessionProvider
override suspend fun getSession(sessionId: SessionId): Session? =
accountRepository.getSessionOrNull(sessionId)
override suspend fun getSessionId(userId: UserId): SessionId? =
accountRepository.getSessionIdOrNull(userId)
// endregion
}

View File

@ -17,27 +17,19 @@
*/
package me.proton.core.accountmanager.data
import io.mockk.MockKAnnotations
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.impl.annotations.RelaxedMockK
import io.mockk.slot
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runBlockingTest
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.account.domain.repository.AccountRepository
import me.proton.core.domain.entity.Product
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.humanverification.HumanVerificationHeaders
import me.proton.core.network.domain.humanverification.VerificationMethod
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionListener
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
@ -46,9 +38,6 @@ class AccountManagerImplTest {
private lateinit var accountManager: AccountManagerImpl
@RelaxedMockK
private lateinit var accountRepository: AccountRepository
private val session1 = Session(
sessionId = SessionId("session1"),
accessToken = "accessToken",
@ -66,145 +55,25 @@ class AccountManagerImplTest {
sessionState = SessionState.Authenticated
)
private val flowOfAccountLists = mutableListOf<List<Account>>()
private val flowOfSessionLists = mutableListOf<List<Session>>()
@Suppress("LongMethod")
private fun setupUpdateListsFlow() {
val userIdSlot = slot<UserId>()
val sessionIdSlot = slot<SessionId>()
val accountStateSlot = slot<AccountState>()
val sessionStateSlot = slot<SessionState>()
val updatedScopesSlot = slot<List<String>>()
val tokenTypeSlot = slot<String>()
val tokenCodeSlot = slot<String>()
val accessTokenSlot = slot<String>()
val refreshTokenSlot = slot<String>()
// Initial state.
flowOfAccountLists.clear()
flowOfSessionLists.clear()
flowOfAccountLists.add(listOf(account1))
flowOfSessionLists.add(listOf(session1))
// For each updateAccountState -> emit a new updated List<Account> from getAccounts().
coEvery { accountRepository.updateAccountState(capture(userIdSlot), capture(accountStateSlot)) } answers {
flowOfAccountLists.add(
listOf(
flowOfAccountLists.last().first { it.userId == userIdSlot.captured }.copy(
userId = userIdSlot.captured,
state = accountStateSlot.captured
)
)
)
}
// For each updateSessionState -> emit a new updated List<Account> from getAccounts().
coEvery { accountRepository.updateSessionState(capture(sessionIdSlot), capture(sessionStateSlot)) } answers {
flowOfAccountLists.add(
listOf(
flowOfAccountLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
sessionState = sessionStateSlot.captured
)
)
)
}
// For each updateSessionScopes -> emit a new updated List<Session> from getSessions().
coEvery { accountRepository.updateSessionScopes(capture(sessionIdSlot), capture(updatedScopesSlot)) } answers {
flowOfSessionLists.add(
listOf(
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
scopes = updatedScopesSlot.captured
)
)
)
}
// For each updateSessionHeaders -> emit a new updated List<Session> from getSessions().
coEvery {
accountRepository.updateSessionHeaders(
capture(sessionIdSlot),
capture(tokenTypeSlot),
capture(tokenCodeSlot)
)
} answers {
flowOfSessionLists.add(
listOf(
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
headers = HumanVerificationHeaders(
tokenTypeSlot.captured,
tokenCodeSlot.captured
)
)
)
)
}
// For each updateSessionToken -> emit a new updated List<Session> from getSessions().
coEvery {
accountRepository.updateSessionToken(
capture(sessionIdSlot),
capture(accessTokenSlot),
capture(refreshTokenSlot)
)
} answers {
flowOfSessionLists.add(
listOf(
session1.copy(
sessionId = sessionIdSlot.captured,
accessToken = accessTokenSlot.captured,
refreshToken = refreshTokenSlot.captured
)
)
)
}
// Emit last state with same id if exist.
coEvery { accountRepository.getAccountOrNull(capture(userIdSlot)) } answers {
flowOfAccountLists.last().firstOrNull { it.userId == userIdSlot.captured }
}
coEvery { accountRepository.getAccountOrNull(capture(sessionIdSlot)) } answers {
flowOfAccountLists.last().firstOrNull { it.sessionId == sessionIdSlot.captured }
}
// Emit all state with same id if exist.
coEvery { accountRepository.getAccount(capture(sessionIdSlot)) } answers {
val filteredLists = flowOfAccountLists.map { list ->
list.firstOrNull { it.sessionId == sessionIdSlot.captured }
}
flowOf(*filteredLists.toTypedArray())
}
// Finally, emit all flow of Lists.
every { accountRepository.getAccounts() } answers {
flowOf(*flowOfAccountLists.toTypedArray())
}
every { accountRepository.getSessions() } answers {
flowOf(*flowOfSessionLists.toTypedArray())
}
}
private val mocks = RepositoryMocks(session1, account1)
@Before
fun beforeEveryTest() {
MockKAnnotations.init(this)
mocks.init()
accountManager = AccountManagerImpl(Product.Calendar, accountRepository)
accountManager = AccountManagerImpl(Product.Calendar, mocks.accountRepository, mocks.authRepository)
}
@Test
fun `add user with session`() = runBlockingTest {
accountManager.addAccount(account1, session1)
coVerify(exactly = 1) { accountRepository.createOrUpdateAccountSession(any(), any()) }
coVerify(exactly = 1) { mocks.accountRepository.createOrUpdateAccountSession(any(), any()) }
}
@Test
fun `on account state changed`() = runBlockingTest {
every { accountRepository.getAccounts() } returns flowOf(
every { mocks.accountRepository.getAccounts() } returns flowOf(
listOf(account1),
listOf(account1),
listOf(account1.copy(state = AccountState.Disabled))
@ -221,7 +90,7 @@ class AccountManagerImplTest {
@Test
fun `on session state changed`() = runBlockingTest {
every { accountRepository.getAccounts() } returns flowOf(
every { mocks.accountRepository.getAccounts() } returns flowOf(
listOf(account1),
listOf(account1),
listOf(account1.copy(sessionState = SessionState.ForceLogout))
@ -238,9 +107,9 @@ class AccountManagerImplTest {
@Test
fun `on handleTwoPassModeSuccess`() = runBlockingTest {
setupUpdateListsFlow()
mocks.setupAccountRepository()
accountManager.handleTwoPassModeSuccess(account1.userId)
accountManager.handleTwoPassModeSuccess(account1.sessionId!!)
val stateLists = accountManager.onAccountStateChanged().toList()
assertEquals(3, stateLists.size)
@ -251,20 +120,19 @@ class AccountManagerImplTest {
@Test
fun `on handleTwoPassModeFailed`() = runBlockingTest {
setupUpdateListsFlow()
mocks.setupAccountRepository()
accountManager.handleTwoPassModeFailed(account1.userId)
accountManager.handleTwoPassModeFailed(account1.sessionId!!)
val stateLists = accountManager.onAccountStateChanged().toList()
assertEquals(3, stateLists.size)
assertEquals(2, stateLists.size)
assertEquals(account1.state, stateLists[0].state)
assertEquals(AccountState.TwoPassModeFailed, stateLists[1].state)
assertEquals(AccountState.Disabled, stateLists[2].state)
}
@Test
fun `on handleSecondFactorSuccess`() = runBlockingTest {
setupUpdateListsFlow()
mocks.setupAccountRepository()
val newScopes = listOf("scope1", "scope2")
@ -288,7 +156,8 @@ class AccountManagerImplTest {
@Test
fun `on handleSecondFactorFailed`() = runBlockingTest {
setupUpdateListsFlow()
mocks.setupAccountRepository()
mocks.setupAuthRepository()
accountManager.handleSecondFactorFailed(session1.sessionId)
@ -305,7 +174,7 @@ class AccountManagerImplTest {
@Test
fun `on handleHumanVerificationSuccess`() = runBlockingTest {
setupUpdateListsFlow()
mocks.setupAccountRepository()
val tokenType = "newTokenType"
val tokenCode = "newTokenCode"
@ -331,107 +200,17 @@ class AccountManagerImplTest {
@Test
fun `on handleHumanVerificationFailed`() = runBlockingTest {
setupUpdateListsFlow()
mocks.setupAccountRepository()
accountManager.handleHumanVerificationFailed(session1.sessionId)
val stateLists = accountManager.onAccountStateChanged().toList()
assertEquals(2, stateLists.size)
assertEquals(1, stateLists.size)
assertEquals(account1.state, stateLists[0].state)
assertEquals(AccountState.Disabled, stateLists[1].state)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.HumanVerificationFailed, sessionStateLists[1].sessionState)
}
@Test
fun `on onSessionTokenRefreshed`() = runBlockingTest {
setupUpdateListsFlow()
val newAccessToken = "newAccessToken"
val newRefreshToken = "newRefreshToken"
accountManager.onSessionTokenRefreshed(
session1.refreshWith(
accessToken = newAccessToken,
refreshToken = newRefreshToken
)
)
val sessionLists = accountManager.getSessions().toList()
assertEquals(2, sessionLists.size)
assertEquals(session1.accessToken, sessionLists[0][0].accessToken)
assertEquals(session1.refreshToken, sessionLists[0][0].refreshToken)
assertEquals(newAccessToken, sessionLists[1][0].accessToken)
assertEquals(newRefreshToken, sessionLists[1][0].refreshToken)
}
@Test
fun `on onSessionForceLogout`() = runBlockingTest {
setupUpdateListsFlow()
accountManager.onSessionForceLogout(session1)
val stateLists = accountManager.onAccountStateChanged().toList()
assertEquals(2, stateLists.size)
assertEquals(account1.state, stateLists[0].state)
assertEquals(AccountState.Disabled, stateLists[1].state)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.ForceLogout, sessionStateLists[1].sessionState)
}
@Test
fun `on onHumanVerificationNeeded success`() = runBlockingTest {
setupUpdateListsFlow()
val humanVerificationDetails = HumanVerificationDetails(
verificationMethods = listOf(VerificationMethod.EMAIL),
captchaVerificationToken = null
)
coEvery { accountRepository.getAccount(any<SessionId>()) } returns flowOf(
account1,
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
account1.copy(sessionState = SessionState.HumanVerificationSuccess)
)
val result = accountManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
assertEquals(SessionListener.HumanVerificationResult.Success, result)
}
@Test
fun `on onHumanVerificationNeeded failed`() = runBlockingTest {
setupUpdateListsFlow()
val humanVerificationDetails = HumanVerificationDetails(
verificationMethods = listOf(VerificationMethod.EMAIL),
captchaVerificationToken = null
)
coEvery { accountRepository.getAccount(any<SessionId>()) } returns flowOf(
account1,
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
account1.copy(sessionState = SessionState.HumanVerificationFailed)
)
val result = accountManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
assertEquals(SessionListener.HumanVerificationResult.Failure, result)
}
}

View File

@ -0,0 +1,193 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.accountmanager.data
import io.mockk.MockKAnnotations
import io.mockk.coEvery
import io.mockk.every
import io.mockk.impl.annotations.RelaxedMockK
import io.mockk.slot
import kotlinx.coroutines.flow.flowOf
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.account.domain.repository.AccountRepository
import me.proton.core.auth.domain.repository.AuthRepository
import me.proton.core.domain.arch.DataResult
import me.proton.core.domain.arch.ResponseSource
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationHeaders
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
class RepositoryMocks(
private val session: Session,
private val account: Account
) {
@RelaxedMockK
lateinit var accountRepository: AccountRepository
@RelaxedMockK
lateinit var authRepository: AuthRepository
private val flowOfAccountLists = mutableListOf<List<Account>>()
private val flowOfSessionLists = mutableListOf<List<Session>>()
fun init() {
MockKAnnotations.init(this)
}
@Suppress("LongMethod")
fun setupAccountRepository() {
val userIdSlot = slot<UserId>()
val sessionIdSlot = slot<SessionId>()
val accountStateSlot = slot<AccountState>()
val sessionStateSlot = slot<SessionState>()
val updatedScopesSlot = slot<List<String>>()
val tokenTypeSlot = slot<String>()
val tokenCodeSlot = slot<String>()
val accessTokenSlot = slot<String>()
val refreshTokenSlot = slot<String>()
// Initial state.
flowOfAccountLists.clear()
flowOfSessionLists.clear()
flowOfAccountLists.add(listOf(account))
flowOfSessionLists.add(listOf(session))
// For each updateAccountState -> emit a new updated List<Account> from getAccounts().
coEvery { accountRepository.updateAccountState(capture(userIdSlot), capture(accountStateSlot)) } answers {
flowOfAccountLists.add(
listOf(
flowOfAccountLists.last().first { it.userId == userIdSlot.captured }.copy(
userId = userIdSlot.captured,
state = accountStateSlot.captured
)
)
)
}
coEvery { accountRepository.updateAccountState(capture(sessionIdSlot), capture(accountStateSlot)) } answers {
flowOfAccountLists.add(
listOf(
flowOfAccountLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
state = accountStateSlot.captured
)
)
)
}
// For each updateSessionState -> emit a new updated List<Account> from getAccounts().
coEvery { accountRepository.updateSessionState(capture(sessionIdSlot), capture(sessionStateSlot)) } answers {
flowOfAccountLists.add(
listOf(
flowOfAccountLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
sessionState = sessionStateSlot.captured
)
)
)
}
// For each updateSessionScopes -> emit a new updated List<Session> from getSessions().
coEvery { accountRepository.updateSessionScopes(capture(sessionIdSlot), capture(updatedScopesSlot)) } answers {
flowOfSessionLists.add(
listOf(
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
scopes = updatedScopesSlot.captured
)
)
)
}
// For each updateSessionHeaders -> emit a new updated List<Session> from getSessions().
coEvery {
accountRepository.updateSessionHeaders(
capture(sessionIdSlot),
capture(tokenTypeSlot),
capture(tokenCodeSlot)
)
} answers {
flowOfSessionLists.add(
listOf(
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
sessionId = sessionIdSlot.captured,
headers = HumanVerificationHeaders(
tokenTypeSlot.captured,
tokenCodeSlot.captured
)
)
)
)
}
// For each updateSessionToken -> emit a new updated List<Session> from getSessions().
coEvery {
accountRepository.updateSessionToken(
capture(sessionIdSlot),
capture(accessTokenSlot),
capture(refreshTokenSlot)
)
} answers {
flowOfSessionLists.add(
listOf(
session.copy(
sessionId = sessionIdSlot.captured,
accessToken = accessTokenSlot.captured,
refreshToken = refreshTokenSlot.captured
)
)
)
}
// Emit last state with same id if exist.
coEvery { accountRepository.getAccountOrNull(capture(userIdSlot)) } answers {
flowOfAccountLists.last().firstOrNull { it.userId == userIdSlot.captured }
}
coEvery { accountRepository.getAccountOrNull(capture(sessionIdSlot)) } answers {
flowOfAccountLists.last().firstOrNull { it.sessionId == sessionIdSlot.captured }
}
// Emit all state with same id if exist.
coEvery { accountRepository.getAccount(capture(sessionIdSlot)) } answers {
val filteredLists = flowOfAccountLists.map { list ->
list.firstOrNull { it.sessionId == sessionIdSlot.captured }
}
flowOf(*filteredLists.toTypedArray())
}
// Finally, emit all flow of Lists.
every { accountRepository.getAccounts() } answers {
flowOf(*flowOfAccountLists.toTypedArray())
}
every { accountRepository.getSessions() } answers {
flowOf(*flowOfSessionLists.toTypedArray())
}
}
fun setupAuthRepository() {
val sessionIdSlot = slot<SessionId>()
// Assume revokeSession is done successfully.
coEvery { authRepository.revokeSession(capture(sessionIdSlot)) } answers {
DataResult.Success(ResponseSource.Remote, true)
}
}
}

View File

@ -0,0 +1,159 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.accountmanager.data
import io.mockk.coEvery
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runBlockingTest
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.domain.entity.Product
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.humanverification.HumanVerificationHeaders
import me.proton.core.network.domain.humanverification.VerificationMethod
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionListener
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
class SessionManagerImplTest {
private lateinit var accountManager: AccountManagerImpl
private lateinit var sessionManager: SessionManagerImpl
private val session1 = Session(
sessionId = SessionId("session1"),
accessToken = "accessToken",
refreshToken = "refreshToken",
scopes = listOf("full", "calendar", "mail"),
headers = HumanVerificationHeaders("tokenType", "tokenCode")
)
private val account1 = Account(
userId = UserId("user1"),
username = "username",
email = "test@example.com",
state = AccountState.Ready,
sessionId = session1.sessionId,
sessionState = SessionState.Authenticated
)
private val mocks = RepositoryMocks(session1, account1)
@Before
fun beforeEveryTest() {
mocks.init()
accountManager = AccountManagerImpl(Product.Calendar, mocks.accountRepository, mocks.authRepository)
sessionManager = SessionManagerImpl(mocks.accountRepository)
}
@Test
fun `on onSessionTokenRefreshed`() = runBlockingTest {
mocks.setupAccountRepository()
val newAccessToken = "newAccessToken"
val newRefreshToken = "newRefreshToken"
sessionManager.onSessionTokenRefreshed(
session1.refreshWith(
accessToken = newAccessToken,
refreshToken = newRefreshToken
)
)
val sessionLists = accountManager.getSessions().toList()
assertEquals(2, sessionLists.size)
assertEquals(session1.accessToken, sessionLists[0][0].accessToken)
assertEquals(session1.refreshToken, sessionLists[0][0].refreshToken)
assertEquals(newAccessToken, sessionLists[1][0].accessToken)
assertEquals(newRefreshToken, sessionLists[1][0].refreshToken)
}
@Test
fun `on onSessionForceLogout`() = runBlockingTest {
mocks.setupAccountRepository()
sessionManager.onSessionForceLogout(session1)
val stateLists = accountManager.onAccountStateChanged().toList()
assertEquals(2, stateLists.size)
assertEquals(account1.state, stateLists[0].state)
assertEquals(AccountState.Disabled, stateLists[1].state)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.ForceLogout, sessionStateLists[1].sessionState)
}
@Test
fun `on onHumanVerificationNeeded success`() = runBlockingTest {
mocks.setupAccountRepository()
val humanVerificationDetails = HumanVerificationDetails(
verificationMethods = listOf(VerificationMethod.EMAIL),
captchaVerificationToken = null
)
coEvery { mocks.accountRepository.getAccount(any<SessionId>()) } returns flowOf(
account1,
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
account1.copy(sessionState = SessionState.HumanVerificationSuccess)
)
val result = sessionManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
assertEquals(SessionListener.HumanVerificationResult.Success, result)
}
@Test
fun `on onHumanVerificationNeeded failed`() = runBlockingTest {
mocks.setupAccountRepository()
val humanVerificationDetails = HumanVerificationDetails(
verificationMethods = listOf(VerificationMethod.EMAIL),
captchaVerificationToken = null
)
coEvery { mocks.accountRepository.getAccount(any<SessionId>()) } returns flowOf(
account1,
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
account1.copy(sessionState = SessionState.HumanVerificationFailed)
)
val result = sessionManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
val sessionStateLists = accountManager.onSessionStateChanged().toList()
assertEquals(2, sessionStateLists.size)
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
assertEquals(SessionListener.HumanVerificationResult.Failure, result)
}
}

View File

@ -19,9 +19,11 @@
package me.proton.core.accountmanager.domain
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
@ -36,3 +38,6 @@ fun AccountManager.getPrimaryAccount(): Flow<Account?> =
getPrimaryUserId().flatMapLatest { userId ->
userId?.let { getAccount(it) } ?: flowOf(null)
}
fun AccountManager.getAccounts(state: AccountState): Flow<List<Account>> =
getAccounts().map { list -> list.filter { it.state == state } }.distinctUntilChanged()

View File

@ -16,20 +16,9 @@
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.auth.presentation.ui
package me.proton.core.accountmanager.domain
import androidx.databinding.ViewDataBinding
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.core.network.domain.session.SessionListener
import me.proton.core.network.domain.session.SessionProvider
/**
* Bridge between authentication activities and the interface.
*
* @author Dino Kadrikj.
*/
interface AuthActivityComponent<DB : ViewDataBinding> : AuthActivity {
/**
* Sets and initializes the authentication activity that want to implement [AuthActivity].
*/
fun initializeAuth(protonAuthActivity: ProtonActivity<DB>)
}
interface SessionManager : SessionProvider, SessionListener

View File

@ -91,7 +91,6 @@ class AccountManagerObserver(
onAccountRemovedListener = block
}
internal fun setOnSessionHumanVerificationNeeded(block: suspend (Account) -> Unit) {
onSessionHumanVerificationNeededListener = block
}

View File

@ -30,6 +30,7 @@ android()
dependencies {
implementation(
project(Module.kotlinUtil),
project(Module.data),
project(Module.domain),
project(Module.network),

View File

@ -22,8 +22,8 @@ import androidx.room.Dao
import androidx.room.Query
import kotlinx.coroutines.flow.Flow
import me.proton.core.account.data.entity.SessionEntity
import me.proton.core.data.db.BaseDao
import me.proton.core.data.crypto.EncryptedString
import me.proton.core.data.db.BaseDao
import me.proton.core.domain.entity.Product
@Dao
@ -36,17 +36,20 @@ abstract class SessionDao : BaseDao<SessionEntity>() {
abstract fun findBySessionId(sessionId: String): Flow<SessionEntity?>
@Query("SELECT * FROM SessionEntity WHERE sessionId = :sessionId")
abstract fun get(sessionId: String): SessionEntity?
abstract suspend fun get(sessionId: String): SessionEntity?
@Query("SELECT sessionId FROM SessionEntity WHERE userId = :userId")
abstract suspend fun getSessionId(userId: String): String?
@Query("DELETE FROM SessionEntity WHERE sessionId = :sessionId")
abstract suspend fun delete(sessionId: String)
@Query("UPDATE SessionEntity SET scopes = :scopes WHERE sessionId = :sessionId")
abstract fun updateScopes(sessionId: String, scopes: String)
abstract suspend fun updateScopes(sessionId: String, scopes: String)
@Query("UPDATE SessionEntity SET humanHeaderTokenType = :tokenType, humanHeaderTokenCode = :tokenCode WHERE sessionId = :sessionId")
abstract fun updateHeaders(sessionId: String, tokenType: EncryptedString?, tokenCode: EncryptedString?)
abstract suspend fun updateHeaders(sessionId: String, tokenType: EncryptedString?, tokenCode: EncryptedString?)
@Query("UPDATE SessionEntity SET accessToken = :accessToken, refreshToken = :refreshToken WHERE sessionId = :sessionId")
abstract fun updateToken(sessionId: String, accessToken: EncryptedString, refreshToken: EncryptedString)
abstract suspend fun updateToken(sessionId: String, accessToken: EncryptedString, refreshToken: EncryptedString)
}

View File

@ -30,13 +30,16 @@ import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.account.domain.repository.AccountRepository
import me.proton.core.data.db.CommonConverters
import me.proton.core.data.crypto.StringCrypto
import me.proton.core.data.crypto.encrypt
import me.proton.core.data.db.CommonConverters
import me.proton.core.domain.entity.Product
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import me.proton.core.util.kotlin.exhaustive
import java.util.concurrent.ConcurrentHashMap
class AccountRepositoryImpl(
private val product: Product,
@ -48,6 +51,8 @@ class AccountRepositoryImpl(
private val sessionDao = db.sessionDao()
private val accountMetadataDao = db.accountMetadataDao()
private val humanVerificationDetails: ConcurrentHashMap<SessionId, HumanVerificationDetails?> = ConcurrentHashMap()
private suspend fun updateAccountMetadata(userId: UserId) =
accountMetadataDao.insertOrUpdate(
AccountMetadataEntity(
@ -91,9 +96,12 @@ class AccountRepositoryImpl(
.map { it?.toSession(stringCrypto) }
.distinctUntilChanged()
override fun getSessionOrNull(sessionId: SessionId): Session? =
override suspend fun getSessionOrNull(sessionId: SessionId): Session? =
sessionDao.get(sessionId.id)?.toSession(stringCrypto)
override suspend fun getSessionIdOrNull(userId: UserId): SessionId? =
sessionDao.getSessionId(userId.id)?.let { SessionId(it) }
override suspend fun createOrUpdateAccountSession(account: Account, session: Session) {
require(session.isValid()) {
"Session is not valid: $session\n.At least sessionId.id, accessToken and refreshToken must be valid."
@ -102,7 +110,7 @@ class AccountRepositoryImpl(
db.inTransaction {
accountDao.insertOrUpdate(
account.copy(
state = AccountState.Initializing,
state = AccountState.NotReady,
sessionId = null,
sessionState = null
).toAccountEntity()
@ -133,22 +141,19 @@ class AccountRepositoryImpl(
when (state) {
AccountState.Ready -> updateAccountMetadata(userId)
AccountState.Disabled,
AccountState.Removed -> deleteAccountMetadata(userId)
AccountState.Added,
AccountState.Initializing,
AccountState.Removed,
AccountState.NotReady,
AccountState.TwoPassModeNeeded,
AccountState.TwoPassModeSuccess,
AccountState.TwoPassModeFailed -> Unit
}
AccountState.TwoPassModeFailed -> deleteAccountMetadata(userId)
}.exhaustive
accountDao.updateAccountState(userId.id, state)
}
}
override suspend fun updateAccountState(sessionId: SessionId, state: AccountState) {
db.inTransaction {
getAccountOrNull(sessionId)?.let {
updateAccountState(it.userId, state)
}
getAccountOrNull(sessionId)?.let {
updateAccountState(it.userId, state)
}
}
@ -177,4 +182,11 @@ class AccountRepositoryImpl(
updateAccountMetadata(userId)
}
}
override suspend fun getHumanVerificationDetails(id: SessionId): HumanVerificationDetails? =
humanVerificationDetails[id]
override suspend fun setHumanVerificationDetails(id: SessionId, details: HumanVerificationDetails?) {
humanVerificationDetails[id] = details
}
}

View File

@ -29,3 +29,11 @@ data class Account(
val sessionId: SessionId?,
val sessionState: SessionState?
)
fun Account.isReady() = state == AccountState.Ready
fun Account.isDisabled() = state == AccountState.Disabled
fun Account.isTwoPassModeNeeded() = state == AccountState.TwoPassModeNeeded
fun Account.isAuthenticated() = sessionState == SessionState.Authenticated
fun Account.isSecondFactorNeeded() = sessionState == SessionState.SecondFactorNeeded
fun Account.isHumanVerificationNeeded() = sessionState == SessionState.HumanVerificationNeeded

View File

@ -19,17 +19,10 @@
package me.proton.core.account.domain.entity
enum class AccountState {
/**
* First state emitted after adding a new [Account], it is not yet [Ready] to use.
*
* Note: Usually followed by [Initializing] and/or [Ready].
*/
Added,
/**
* State emitted if this [Account] need more step(s) to be [Ready] to use.
*/
Initializing,
NotReady,
/**
* A two pass mode is needed.

View File

@ -23,6 +23,7 @@ import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
@ -65,7 +66,12 @@ interface AccountRepository {
/**
* Get [Session], by sessionId.
*/
fun getSessionOrNull(sessionId: SessionId): Session?
suspend fun getSessionOrNull(sessionId: SessionId): Session?
/**
* Get [SessionId], by userId.
*/
suspend fun getSessionIdOrNull(userId: UserId): SessionId?
/**
* Create or update an [Account], locally.
@ -121,4 +127,14 @@ interface AccountRepository {
* Set the primary [UserId].
*/
suspend fun setAsPrimary(userId: UserId)
/**
* Get [HumanVerificationDetails], if exist, by sessionId.
*/
suspend fun getHumanVerificationDetails(id: SessionId): HumanVerificationDetails?
/**
* Set [HumanVerificationDetails], by sessionId.
*/
suspend fun setHumanVerificationDetails(id: SessionId, details: HumanVerificationDetails?)
}

View File

@ -28,10 +28,12 @@ import me.proton.core.auth.data.entity.SecondFactorResponse
import me.proton.core.auth.data.entity.UserResponse
import me.proton.core.network.data.protonApi.BaseRetrofitApi
import me.proton.core.network.data.protonApi.GenericResponse
import me.proton.core.network.domain.TimeoutOverride
import retrofit2.http.Body
import retrofit2.http.DELETE
import retrofit2.http.GET
import retrofit2.http.POST
import retrofit2.http.Tag
interface AuthenticationApi : BaseRetrofitApi {
@ -45,7 +47,7 @@ interface AuthenticationApi : BaseRetrofitApi {
suspend fun performSecondFactor(@Body request: SecondFactorRequest): SecondFactorResponse
@DELETE("auth")
suspend fun revokeSession(): GenericResponse
suspend fun revokeSession(@Tag timeout: TimeoutOverride): GenericResponse
@GET("users")
suspend fun getUser(): UserResponse

View File

@ -30,12 +30,12 @@ import me.proton.core.auth.domain.entity.SecondFactorProof
import me.proton.core.auth.domain.entity.SessionInfo
import me.proton.core.auth.domain.entity.User
import me.proton.core.auth.domain.repository.AuthRepository
import me.proton.core.data.arch.ApiResultMapper
import me.proton.core.data.arch.toDataResponse
import me.proton.core.domain.arch.DataResult
import me.proton.core.network.data.ApiProvider
import me.proton.core.network.data.ResponseCodes
import me.proton.core.network.domain.TimeoutOverride
import me.proton.core.network.domain.session.SessionId
import me.proton.core.util.kotlin.invoke
/**
* Implementation of the [AuthRepository].
@ -44,8 +44,7 @@ import me.proton.core.util.kotlin.invoke
* @author Dino Kadrikj.
*/
class AuthRepositoryImpl(
private val provider: ApiProvider,
private val apiResultMapper: ApiResultMapper = ApiResultMapper()
private val provider: ApiProvider
) : AuthRepository {
/**
@ -60,13 +59,10 @@ class AuthRepositoryImpl(
override suspend fun getLoginInfo(
username: String,
clientSecret: String
): DataResult<LoginInfo> =
apiResultMapper {
provider.get<AuthenticationApi>().invoke {
val request = LoginInfoRequest(username, clientSecret)
getLoginInfo(request).toLoginInfo(username)
}.toDataResponse()
}
): DataResult<LoginInfo> = provider.get<AuthenticationApi>().invoke {
val request = LoginInfoRequest(username, clientSecret)
getLoginInfo(request).toLoginInfo(username)
}.toDataResponse()
/**
* Performs the login request to the API to try to get a valid Access Token and Session for the Account/username.
@ -85,13 +81,10 @@ class AuthRepositoryImpl(
clientEphemeral: String,
clientProof: String,
srpSession: String
): DataResult<SessionInfo> =
apiResultMapper {
provider.get<AuthenticationApi>().invoke {
val request = LoginRequest(username, clientSecret, clientEphemeral, clientProof, srpSession)
performLogin(request).toSessionInfo(username)
}.toDataResponse()
}
): DataResult<SessionInfo> = provider.get<AuthenticationApi>().invoke {
val request = LoginRequest(username, clientSecret, clientEphemeral, clientProof, srpSession)
performLogin(request).toSessionInfo(username)
}.toDataResponse()
/**
* Performs the second factor request for the Accounts that have second factor enabled.
@ -104,23 +97,21 @@ class AuthRepositoryImpl(
override suspend fun performSecondFactor(
sessionId: SessionId,
secondFactorProof: SecondFactorProof
): DataResult<ScopeInfo> = apiResultMapper {
provider.get<AuthenticationApi>(sessionId).invoke {
val request = when (secondFactorProof) {
is SecondFactorProof.SecondFactorCode -> SecondFactorRequest(
secondFactorCode = secondFactorProof.code
): DataResult<ScopeInfo> = provider.get<AuthenticationApi>(sessionId).invoke {
val request = when (secondFactorProof) {
is SecondFactorProof.SecondFactorCode -> SecondFactorRequest(
secondFactorCode = secondFactorProof.code
)
is SecondFactorProof.SecondFactorSignature -> SecondFactorRequest(
universalTwoFactorRequest = UniversalTwoFactorRequest(
keyHandle = secondFactorProof.keyHandle,
clientData = secondFactorProof.clientData,
signatureData = secondFactorProof.signatureData
)
is SecondFactorProof.SecondFactorSignature -> SecondFactorRequest(
universalTwoFactorRequest = UniversalTwoFactorRequest(
keyHandle = secondFactorProof.keyHandle,
clientData = secondFactorProof.clientData,
signatureData = secondFactorProof.signatureData
)
)
}
performSecondFactor(request).toScopeInfo()
}.toDataResponse()
}
)
}
performSecondFactor(request).toScopeInfo()
}.toDataResponse()
/**
* Revokes the session for the user. In particular this is practically logging out the user from the backend.
@ -130,11 +121,15 @@ class AuthRepositoryImpl(
* @return boolean result of the logout/session revoking operation.
*/
override suspend fun revokeSession(sessionId: SessionId): DataResult<Boolean> =
apiResultMapper {
provider.get<AuthenticationApi>(sessionId).invoke {
revokeSession().code.isSuccessResponse()
}.toDataResponse()
}
provider.get<AuthenticationApi>(sessionId).invoke(true) {
revokeSession(
TimeoutOverride(
connectionTimeoutSeconds = 1,
readTimeoutSeconds = 1,
writeTimeoutSeconds = 1
)
).code.isSuccessResponse()
}.toDataResponse()
/**
* Fetches the full user details from the API.
@ -144,12 +139,9 @@ class AuthRepositoryImpl(
* @return [User] object with full user details.
*/
override suspend fun getUser(sessionId: SessionId): DataResult<User> =
apiResultMapper {
provider.get<AuthenticationApi>(sessionId).invoke {
val userResponse = getUser()
userResponse.user.toUser()
}.toDataResponse()
}
provider.get<AuthenticationApi>(sessionId).invoke {
getUser().user.toUser()
}.toDataResponse()
/**
* Fetches the user-keys salts from the API.
@ -159,11 +151,9 @@ class AuthRepositoryImpl(
* @return [KeySalts] containing salts for all user keys.
*/
override suspend fun getSalts(sessionId: SessionId): DataResult<KeySalts> =
apiResultMapper {
provider.get<AuthenticationApi>(sessionId).invoke {
getSalts().toKeySalts()
}.toDataResponse()
}
provider.get<AuthenticationApi>(sessionId).invoke {
getSalts().toKeySalts()
}.toDataResponse()
}
internal fun Int.isSuccessResponse(): Boolean = this == ResponseCodes.OK

View File

@ -36,6 +36,7 @@ import me.proton.core.network.data.di.ApiFactory
import me.proton.core.network.domain.ApiManager
import me.proton.core.network.domain.ApiResult
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionProvider
import org.junit.Before
import org.junit.Test
import java.net.ConnectException
@ -49,6 +50,7 @@ import kotlin.test.assertTrue
class AuthRepositoryImplTest {
// region mocks
private val sessionProvider = mockk<SessionProvider>(relaxed = true)
private val apiFactory = mockk<ApiFactory>(relaxed = true)
private val apiManager = mockk<ApiManager<AuthenticationApi>>(relaxed = true)
private lateinit var apiProvider: ApiProvider
@ -82,7 +84,8 @@ class AuthRepositoryImplTest {
@Before
fun beforeEveryTest() {
// GIVEN
apiProvider = ApiProvider(apiFactory)
coEvery { sessionProvider.getSessionId(any()) } returns SessionId(testSessionId)
apiProvider = ApiProvider(apiFactory, sessionProvider)
every { apiFactory.create(interfaceClass = AuthenticationApi::class) } returns apiManager
every { apiFactory.create(SessionId(testSessionId), interfaceClass = AuthenticationApi::class) } returns apiManager
repository = AuthRepositoryImpl(apiProvider)

View File

@ -67,7 +67,7 @@ interface AuthRepository {
suspend fun getSalts(sessionId: SessionId): DataResult<KeySalts>
/**
* Perform Two Factor for the Login process for a given [SessionId].
* Revoke session for a given [SessionId].
*/
suspend fun revokeSession(sessionId: SessionId): DataResult<Boolean>

View File

@ -46,7 +46,7 @@ class PerformSecondFactor @Inject constructor(
data class Success(
val sessionId: SessionId,
val scopeInfo: ScopeInfo,
val isMailboxLoginNeeded: Boolean = false
val isTwoPassModeNeeded: Boolean
) : SecondFactorState()
sealed class Error : SecondFactorState() {
@ -61,7 +61,8 @@ class PerformSecondFactor @Inject constructor(
*/
operator fun invoke(
sessionId: SessionId,
secondFactorCode: String
secondFactorCode: String,
isTwoPassModeNeeded: Boolean = false
): Flow<SecondFactorState> = flow {
if (secondFactorCode.isEmpty()) {
@ -73,11 +74,11 @@ class PerformSecondFactor @Inject constructor(
authRepository.performSecondFactor(
sessionId,
SecondFactorProof.SecondFactorCode(secondFactorCode)
SecondFactorProof.SecondFactorCode(secondFactorCode),
).onFailure { errorMessage, _ ->
emit(SecondFactorState.Error.Message(errorMessage))
}.onSuccess { scopeInfo ->
emit(SecondFactorState.Success(sessionId, scopeInfo))
emit(SecondFactorState.Success(sessionId, scopeInfo, isTwoPassModeNeeded))
}
}
}

View File

@ -20,87 +20,126 @@ package me.proton.core.auth.presentation
import androidx.activity.ComponentActivity
import androidx.activity.result.ActivityResultLauncher
import androidx.lifecycle.lifecycleScope
import kotlinx.coroutines.launch
import me.proton.core.auth.domain.AccountWorkflowHandler
import me.proton.core.auth.presentation.entity.ScopeResult
import me.proton.core.auth.presentation.entity.SecondFactorInput
import me.proton.core.auth.presentation.entity.SessionResult
import me.proton.core.auth.presentation.entity.UserResult
import me.proton.core.auth.presentation.ui.StartLogin
import me.proton.core.auth.presentation.ui.StartMailboxLogin
import me.proton.core.auth.presentation.ui.StartSecondFactor
import me.proton.core.auth.presentation.ui.StartTwoPassMode
import me.proton.core.humanverification.presentation.entity.HumanVerificationInput
import me.proton.core.humanverification.presentation.entity.HumanVerificationResult
import me.proton.core.humanverification.presentation.ui.StartHumanVerification
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.session.SessionId
import javax.inject.Inject
class AuthOrchestrator {
class AuthOrchestrator @Inject constructor(
private val accountWorkflowHandler: AccountWorkflowHandler
) {
// region result launchers
private var loginWorkflowLauncher: ActivityResultLauncher<List<String>>? = null
private var secondFactorWorkflowLauncher: ActivityResultLauncher<SessionId>? = null
private var mailboxWorkflowLauncher: ActivityResultLauncher<SessionId>? = null
private var secondFactorWorkflowLauncher: ActivityResultLauncher<SecondFactorInput>? = null
private var twoPassModeWorkflowLauncher: ActivityResultLauncher<SessionId>? = null
private var humanWorkflowLauncher: ActivityResultLauncher<HumanVerificationInput>? = null
// endregion
private var onUserResultListener: (result: UserResult?) -> Unit = {}
private var onSessionResultListener: (result: SessionResult?) -> Unit = {}
private var onScopeResultListener: (result: ScopeResult?) -> Unit = {}
private var onHumanVerificationResultListener: (result: HumanVerificationResult?) -> Unit = {}
fun setOnUserResult(block: (result: UserResult?) -> Unit) {
onUserResultListener = block
}
fun setOnSessionResult(block: (result: SessionResult?) -> Unit) {
onSessionResultListener = block
}
fun setOnScopeResult(block: (result: ScopeResult?) -> Unit) {
onScopeResultListener = block
}
fun setOnHumanVerificationResult(block: (result: HumanVerificationResult?) -> Unit) {
onHumanVerificationResultListener = block
}
// region private module functions
private fun registerLoginWorkflowLauncher(
context: ComponentActivity,
onSessionResult: (result: SessionResult?) -> Unit = {}
context: ComponentActivity
): ActivityResultLauncher<List<String>> =
context.registerForActivityResult(StartLogin()) { result ->
context.registerForActivityResult(
StartLogin()
) { result ->
result?.let {
if (it.isSecondFactorNeeded) {
startSecondFactorWorkflow(SessionId(it.sessionId))
} else if (it.isMailboxLoginNeeded) {
startMailboxLoginWorkflow(SessionId(it.sessionId))
startSecondFactorWorkflow(SecondFactorInput(it.sessionId, it.isTwoPassModeNeeded))
} else if (it.isTwoPassModeNeeded) {
startTwoPassModeWorkflow(SessionId(it.sessionId))
}
onSessionResult(it)
}
onSessionResultListener(result)
}
private fun registerMailboxLoginWorkflowLauncher(
context: ComponentActivity,
onUserResult: (result: UserResult?) -> Unit = {}
private fun registerTwoPassModeWorkflowLauncher(
context: ComponentActivity
): ActivityResultLauncher<SessionId> =
context.registerForActivityResult(StartMailboxLogin()) {
onUserResult(it)
context.registerForActivityResult(
StartTwoPassMode()
) {
onUserResultListener(it)
}
private fun registerSecondFactorWorkflow(
context: ComponentActivity,
onScopeResult: (result: ScopeResult?) -> Unit = {}
): ActivityResultLauncher<SessionId> =
context.registerForActivityResult(StartSecondFactor()) { result ->
context: ComponentActivity
): ActivityResultLauncher<SecondFactorInput> =
context.registerForActivityResult(
StartSecondFactor()
) { result ->
result?.let {
if (it.isMailboxLoginNeeded) {
startMailboxLoginWorkflow(SessionId(it.sessionId))
if (it.isTwoPassModeNeeded) {
startTwoPassModeWorkflow(SessionId(it.sessionId))
}
onScopeResult(it)
}
onScopeResultListener(result)
}
private fun registerHumanVerificationWorkflow(
context: ComponentActivity,
onHumanVerificationResult: (result: HumanVerificationResult?) -> Unit = {}
context: ComponentActivity
): ActivityResultLauncher<HumanVerificationInput> =
context.registerForActivityResult(StartHumanVerification()) {
onHumanVerificationResult(it)
context.registerForActivityResult(
StartHumanVerification()
) { result ->
if (result != null) {
context.lifecycleScope.launch {
if (!result.tokenType.isNullOrBlank() && !result.tokenCode.isNullOrBlank()) {
accountWorkflowHandler.handleHumanVerificationSuccess(
sessionId = SessionId(result.sessionId),
tokenType = result.tokenType!!,
tokenCode = result.tokenCode!!
)
} else {
accountWorkflowHandler.handleHumanVerificationFailed(
sessionId = SessionId(result.sessionId)
)
}
}
}
onHumanVerificationResultListener(result)
}
/**
* Start a Second Factor workflow.
*/
private fun startSecondFactorWorkflow(input: SessionId) {
private fun startSecondFactorWorkflow(input: SecondFactorInput) {
secondFactorWorkflowLauncher?.launch(input)
?: throw IllegalStateException("You must call register before any start workflow function!")
}
/**
* Start a MailboxLogin workflow.
*/
private fun startMailboxLoginWorkflow(input: SessionId) {
mailboxWorkflowLauncher?.launch(input)
?: throw IllegalStateException("You must call register before any start workflow function!")
}
// endregion
// region public API
@ -110,10 +149,10 @@ class AuthOrchestrator {
* Note: This function have to be called [ComponentActivity.onCreate]] before [ComponentActivity.onResume].
*/
fun register(context: ComponentActivity) {
loginWorkflowLauncher ?: run { loginWorkflowLauncher = registerLoginWorkflowLauncher(context) }
humanWorkflowLauncher ?: run { humanWorkflowLauncher = registerHumanVerificationWorkflow(context) }
secondFactorWorkflowLauncher ?: run { secondFactorWorkflowLauncher = registerSecondFactorWorkflow(context) }
mailboxWorkflowLauncher ?: run { mailboxWorkflowLauncher = registerMailboxLoginWorkflowLauncher(context) }
loginWorkflowLauncher = registerLoginWorkflowLauncher(context)
humanWorkflowLauncher = registerHumanVerificationWorkflow(context)
secondFactorWorkflowLauncher = registerSecondFactorWorkflow(context)
twoPassModeWorkflowLauncher = registerTwoPassModeWorkflowLauncher(context)
}
/**
@ -124,6 +163,14 @@ class AuthOrchestrator {
?: throw IllegalStateException("You must call register before any start workflow function!")
}
/**
* Start a TwoPassMode workflow.
*/
fun startTwoPassModeWorkflow(input: SessionId) {
twoPassModeWorkflowLauncher?.launch(input)
?: throw IllegalStateException("You must call register before any start workflow function!")
}
/**
* Start a Human Verification workflow.
*/
@ -138,3 +185,31 @@ class AuthOrchestrator {
}
// endregion
}
fun AuthOrchestrator.onUserResult(
block: (result: UserResult?) -> Unit
): AuthOrchestrator {
setOnUserResult { block(it) }
return this
}
fun AuthOrchestrator.onScopeResult(
block: (result: ScopeResult?) -> Unit
): AuthOrchestrator {
setOnScopeResult { block(it) }
return this
}
fun AuthOrchestrator.onSessionResult(
block: (result: SessionResult?) -> Unit
): AuthOrchestrator {
setOnSessionResult { block(it) }
return this
}
fun AuthOrchestrator.onHumanVerificationResult(
block: (result: HumanVerificationResult?) -> Unit
): AuthOrchestrator {
setOnHumanVerificationResult { block(it) }
return this
}

View File

@ -21,16 +21,15 @@ package me.proton.core.auth.presentation.entity
import android.os.Parcelable
import kotlinx.android.parcel.Parcelize
import me.proton.core.auth.domain.entity.ScopeInfo
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.session.SessionId
@Parcelize
data class ScopeResult(
val sessionId: String,
val scopes: List<String>,
val isMailboxLoginNeeded: Boolean = false
val isTwoPassModeNeeded: Boolean = false
) : Parcelable {
constructor(sessionId: SessionId, scopeInfo: ScopeInfo, isMailboxLoginNeeded: Boolean = false)
: this(sessionId.id, scopeInfo.scopes, isMailboxLoginNeeded)
constructor(sessionId: SessionId, scopeInfo: ScopeInfo, isMailboxLoginNeeded: Boolean = false) :
this(sessionId.id, scopeInfo.scopes, isMailboxLoginNeeded)
}

View File

@ -0,0 +1,28 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.auth.presentation.entity
import android.os.Parcelable
import kotlinx.android.parcel.Parcelize
@Parcelize
data class SecondFactorInput(
val sessionId: String,
val isTwoPassModeNeeded: Boolean
) : Parcelable

View File

@ -44,7 +44,7 @@ data class SessionResult(
) : Parcelable {
@IgnoredOnParcel
val isMailboxLoginNeeded = passwordMode == 2
val isTwoPassModeNeeded = passwordMode == 2
companion object {

View File

@ -23,9 +23,9 @@ import android.content.Context
import android.content.Intent
import androidx.activity.result.contract.ActivityResultContract
import me.proton.core.auth.presentation.entity.ScopeResult
import me.proton.core.auth.presentation.entity.SecondFactorInput
import me.proton.core.auth.presentation.entity.SessionResult
import me.proton.core.auth.presentation.entity.UserResult
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.session.SessionId
class StartLogin : ActivityResultContract<List<String>, SessionResult?>() {
@ -41,11 +41,11 @@ class StartLogin : ActivityResultContract<List<String>, SessionResult?>() {
}
}
class StartSecondFactor : ActivityResultContract<SessionId, ScopeResult?>() {
class StartSecondFactor : ActivityResultContract<SecondFactorInput, ScopeResult?>() {
override fun createIntent(context: Context, sessionId: SessionId) =
override fun createIntent(context: Context, inupt: SecondFactorInput) =
Intent(context, SecondFactorActivity::class.java).apply {
putExtra(SecondFactorActivity.ARG_SESSION_ID, sessionId.id)
putExtra(SecondFactorActivity.ARG_SECOND_FACTOR_INPUT, inupt)
}
override fun parseResult(resultCode: Int, result: Intent?): ScopeResult? {
@ -54,11 +54,11 @@ class StartSecondFactor : ActivityResultContract<SessionId, ScopeResult?>() {
}
}
class StartMailboxLogin : ActivityResultContract<SessionId, UserResult?>() {
class StartTwoPassMode : ActivityResultContract<SessionId, UserResult?>() {
override fun createIntent(context: Context, sessionId: SessionId) =
override fun createIntent(context: Context, inupt: SessionId) =
Intent(context, MailboxLoginActivity::class.java).apply {
putExtra(MailboxLoginActivity.ARG_SESSION_ID, sessionId.id)
putExtra(MailboxLoginActivity.ARG_SESSION_ID, inupt.id)
}
override fun parseResult(resultCode: Int, result: Intent?): UserResult? {

View File

@ -18,20 +18,30 @@
package me.proton.core.auth.presentation.ui
/**
* Interface common for all authentication activities.
*
* @author Dino Kadrikj.
*/
interface AuthActivity {
import android.os.Build
import android.os.Bundle
import android.view.View
import androidx.databinding.ViewDataBinding
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.errorSnack
import me.proton.core.auth.presentation.R
/**
* Instructs the activity to show loading animation (custom for eacch activity).
*/
fun showLoading(loading: Boolean)
abstract class AuthActivity<DB : ViewDataBinding> : ProtonActivity<DB>() {
/**
* Provide default implementation for error UI.
*/
fun showError(message: String?)
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
window.decorView.systemUiVisibility = View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR
}
}
open fun showLoading(loading: Boolean) {
// No op
}
open fun showError(message: String?) {
showLoading(false)
binding.root.errorSnack(message = message ?: getString(R.string.auth_login_general_error))
}
}

View File

@ -1,55 +0,0 @@
/*
* Copyright (c) 2020 Proton Technologies AG
* This file is part of Proton Technologies AG and ProtonCore.
*
* ProtonCore is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonCore is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
*/
package me.proton.core.auth.presentation.ui
import android.os.Build
import android.view.View
import androidx.databinding.ViewDataBinding
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.errorSnack
import me.proton.core.auth.presentation.R
/**
* Delegate class implementing the [AuthActivity] interface.
*
* @author Dino Kadrikj.
*/
class AuthActivityDelegate<DB : ViewDataBinding> : AuthActivityComponent<DB> {
/** A reference to the Activity that will handle the rotation */
private lateinit var activity : ProtonActivity<DB>
override fun initializeAuth(protonAuthActivity: ProtonActivity<DB>) {
activity = protonAuthActivity
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
activity.window.decorView.systemUiVisibility = View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR
}
}
override fun showLoading(loading: Boolean) {
// noop
}
override fun showError(message: String?) {
showLoading(false)
activity.binding.root.errorSnack(message = message ?: activity.getString(R.string.auth_login_general_error))
}
}

View File

@ -19,7 +19,6 @@
package me.proton.core.auth.presentation.ui
import android.os.Bundle
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.onClick
import me.proton.android.core.presentation.utils.openBrowserLink
import me.proton.core.auth.presentation.R
@ -29,12 +28,11 @@ import me.proton.core.auth.presentation.databinding.ActivityAuthHelpBinding
* Authentication help Activity which offers common authentication problems help.
* @author Dino Kadrikj.
*/
class AuthHelpActivity : ProtonActivity<ActivityAuthHelpBinding>(), AuthActivityComponent<ActivityAuthHelpBinding> by AuthActivityDelegate() {
class AuthHelpActivity : AuthActivity<ActivityAuthHelpBinding>() {
override fun layoutId(): Int = R.layout.activity_auth_help
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
initializeAuth(this)
binding.apply {
closeButton.onClick {
finish()
@ -54,8 +52,4 @@ class AuthHelpActivity : ProtonActivity<ActivityAuthHelpBinding>(), AuthActivity
}
}
}
override fun showLoading(loading: Boolean) {
// no-operation
}
}

View File

@ -23,7 +23,6 @@ import android.content.Intent
import android.os.Bundle
import androidx.activity.viewModels
import dagger.hilt.android.AndroidEntryPoint
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.hideKeyboard
import me.proton.android.core.presentation.utils.onClick
import me.proton.android.core.presentation.utils.onFailure
@ -32,7 +31,6 @@ import me.proton.android.core.presentation.utils.validatePassword
import me.proton.android.core.presentation.utils.validateUsername
import me.proton.core.auth.domain.entity.SessionInfo
import me.proton.core.auth.domain.usecase.PerformLogin
import me.proton.core.auth.presentation.AuthOrchestrator
import me.proton.core.auth.presentation.R
import me.proton.core.auth.presentation.databinding.ActivityLoginBinding
import me.proton.core.auth.presentation.entity.SessionResult
@ -43,18 +41,14 @@ import me.proton.core.util.kotlin.exhaustive
* Login Activity which allows users to Login to any Proton client application.
*/
@AndroidEntryPoint
class LoginActivity : ProtonActivity<ActivityLoginBinding>(),
AuthActivityComponent<ActivityLoginBinding> by AuthActivityDelegate() {
class LoginActivity : AuthActivity<ActivityLoginBinding>() {
private val viewModel by viewModels<LoginViewModel>()
private val authOrchestrator = AuthOrchestrator()
override fun layoutId(): Int = R.layout.activity_login
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
initializeAuth(this)
authOrchestrator.register(this)
binding.apply {
closeButton.onClick {

View File

@ -23,7 +23,6 @@ import android.content.Intent
import android.os.Bundle
import androidx.activity.viewModels
import dagger.hilt.android.AndroidEntryPoint
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.hideKeyboard
import me.proton.android.core.presentation.utils.onClick
import me.proton.android.core.presentation.utils.onFailure
@ -45,11 +44,10 @@ import me.proton.core.util.kotlin.exhaustive
* mailbox).
*/
@AndroidEntryPoint
class MailboxLoginActivity : ProtonActivity<ActivityMailboxLoginBinding>(),
AuthActivityComponent<ActivityMailboxLoginBinding> by AuthActivityDelegate() {
class MailboxLoginActivity : AuthActivity<ActivityMailboxLoginBinding>() {
private val sessionId: SessionId by lazy {
intent?.extras?.get(ARG_SESSION_ID) as SessionId
SessionId(requireNotNull(intent?.extras?.getString(ARG_SESSION_ID)))
}
private val viewModel by viewModels<MailboxLoginViewModel>()
@ -58,10 +56,10 @@ class MailboxLoginActivity : ProtonActivity<ActivityMailboxLoginBinding>(),
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
initializeAuth(this)
binding.apply {
closeButton.onClick {
finish()
onBackPressed()
}
forgotPasswordButton.onClick {
@ -93,6 +91,12 @@ class MailboxLoginActivity : ProtonActivity<ActivityMailboxLoginBinding>(),
}
}
override fun onBackPressed() {
viewModel.stopMailboxLoginFlow(sessionId).invokeOnCompletion {
finish()
}
}
/**
* Invoked on successful completed mailbox login operation.
*/

View File

@ -24,7 +24,6 @@ import android.os.Bundle
import android.text.InputType
import androidx.activity.viewModels
import dagger.hilt.android.AndroidEntryPoint
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.hideKeyboard
import me.proton.android.core.presentation.utils.onClick
import me.proton.android.core.presentation.utils.onFailure
@ -35,6 +34,7 @@ import me.proton.core.auth.domain.usecase.PerformSecondFactor
import me.proton.core.auth.presentation.R
import me.proton.core.auth.presentation.databinding.Activity2faBinding
import me.proton.core.auth.presentation.entity.ScopeResult
import me.proton.core.auth.presentation.entity.SecondFactorInput
import me.proton.core.auth.presentation.viewmodel.SecondFactorViewModel
import me.proton.core.network.domain.session.SessionId
import me.proton.core.util.kotlin.exhaustive
@ -45,15 +45,10 @@ import me.proton.core.util.kotlin.exhaustive
* Optional, only shown for accounts with 2FA login enabled.
*/
@AndroidEntryPoint
class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
AuthActivityComponent<Activity2faBinding> by AuthActivityDelegate() {
class SecondFactorActivity : AuthActivity<Activity2faBinding>() {
private val sessionId: SessionId by lazy {
intent?.extras?.get(ARG_SESSION_ID) as SessionId
}
private val twoPassMode: Boolean by lazy {
intent?.extras?.get(ARG_TWO_PASS_MODE) as Boolean
private val input: SecondFactorInput by lazy {
requireNotNull(intent?.extras?.getParcelable(ARG_SECOND_FACTOR_INPUT))
}
// initial mode is the second factor input mode.
@ -65,10 +60,10 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
initializeAuth(this)
binding.apply {
closeButton.onClick {
finish()
onBackPressed()
}
recoveryCodeButton.onClick {
@ -87,7 +82,7 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
is PerformSecondFactor.SecondFactorState.Success -> onSuccess(
it.sessionId,
it.scopeInfo,
it.isMailboxLoginNeeded
it.isTwoPassModeNeeded
)
is PerformSecondFactor.SecondFactorState.Error.Message -> onError(false, it.message)
is PerformSecondFactor.SecondFactorState.Error.EmptyCredentials -> {
@ -111,20 +106,29 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
secondFactorInput.validate()
.onFailure { secondFactorInput.setInputError() }
.onSuccess { secondFactorCode ->
viewModel.startSecondFactorFlow(sessionId, secondFactorCode, twoPassMode)
viewModel.startSecondFactorFlow(
sessionId = SessionId(input.sessionId),
secondFactorCode = secondFactorCode,
isTwoPassModeNeeded = input.isTwoPassModeNeeded
)
}
}
}
override fun onBackPressed() {
viewModel.stopSecondFactorFlow(SessionId(input.sessionId)).invokeOnCompletion {
finish()
}
}
/**
* Invoked on successful completed mailbox login operation.
*/
private fun onSuccess(sessionId: SessionId, scopeInfo: ScopeInfo, isMailboxLoginNeeded: Boolean) {
val intent =
Intent().putExtra(
ARG_SCOPE_RESULT,
ScopeResult(sessionId, scopeInfo, isMailboxLoginNeeded)
)
private fun onSuccess(sessionId: SessionId, scopeInfo: ScopeInfo, isTwoPassModeNeeded: Boolean) {
val intent = Intent().putExtra(
ARG_SCOPE_RESULT,
ScopeResult(sessionId, scopeInfo, isTwoPassModeNeeded)
)
setResult(Activity.RESULT_OK, intent)
finish()
}
@ -173,8 +177,7 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
}
companion object {
const val ARG_SESSION_ID = "arg.sessionId"
const val ARG_TWO_PASS_MODE = "arg.twoPassMode"
const val ARG_SECOND_FACTOR_INPUT = "arg.secondFactorInput"
const val ARG_SCOPE_RESULT = "arg.scopeResult"
}
}

View File

@ -54,16 +54,14 @@ class LoginViewModel @ViewModelInject constructor(
username: String,
password: ByteArray
) {
performLogin(username, password)
.onEach {
if (it is PerformLogin.LoginState.Success) {
// on success result, contact account manager
onSuccess(it)
}
// inform the view for each state change
loginState.post(it)
performLogin(username, password).onEach {
if (it is PerformLogin.LoginState.Success) {
// on success result, contact account manager
onSuccess(it)
}
.launchIn(viewModelScope)
// inform the view for each state change
loginState.post(it)
}.launchIn(viewModelScope)
}
private suspend fun onSuccess(success: PerformLogin.LoginState.Success) {

View File

@ -20,15 +20,14 @@ package me.proton.core.auth.presentation.viewmodel
import androidx.hilt.lifecycle.ViewModelInject
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import me.proton.android.core.presentation.viewmodel.ProtonViewModel
import me.proton.core.auth.domain.AccountWorkflowHandler
import me.proton.core.auth.domain.usecase.PerformMailboxLogin
import me.proton.core.domain.entity.UserId
import me.proton.core.network.domain.session.SessionId
import me.proton.core.util.kotlin.DispatcherProvider
import studio.forface.viewstatestore.ViewStateStore
import studio.forface.viewstatestore.ViewStateStoreScope
@ -49,14 +48,15 @@ class MailboxLoginViewModel @ViewModelInject constructor(
sessionId: SessionId,
password: ByteArray
) {
performMailboxLogin(sessionId, password)
.onEach {
if (it is PerformMailboxLogin.MailboxLoginState.Success) {
accountWorkflowHandler.handleTwoPassModeSuccess(sessionId)
} else if (it is PerformMailboxLogin.MailboxLoginState.Error) {
accountWorkflowHandler.handleTwoPassModeFailed(sessionId)
}
mailboxLoginState.post(it)
}.launchIn(viewModelScope)
performMailboxLogin(sessionId, password).onEach {
if (it is PerformMailboxLogin.MailboxLoginState.Success) {
accountWorkflowHandler.handleTwoPassModeSuccess(sessionId)
}
mailboxLoginState.post(it)
}.launchIn(viewModelScope)
}
fun stopMailboxLoginFlow(
sessionId: SessionId
): Job = viewModelScope.launch { accountWorkflowHandler.handleTwoPassModeFailed(sessionId) }
}

View File

@ -20,14 +20,14 @@ package me.proton.core.auth.presentation.viewmodel
import androidx.hilt.lifecycle.ViewModelInject
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import me.proton.android.core.presentation.viewmodel.ProtonViewModel
import me.proton.core.auth.domain.AccountWorkflowHandler
import me.proton.core.auth.domain.usecase.PerformSecondFactor
import me.proton.core.network.domain.session.SessionId
import me.proton.core.util.kotlin.DispatcherProvider
import studio.forface.viewstatestore.ViewStateStore
import studio.forface.viewstatestore.ViewStateStoreScope
@ -44,27 +44,20 @@ class SecondFactorViewModel @ViewModelInject constructor(
fun startSecondFactorFlow(
sessionId: SessionId,
secondFactorCode: String,
isMailboxLoginNeeded: Boolean
isTwoPassModeNeeded: Boolean = false
) {
performSecondFactor(sessionId, secondFactorCode)
.onEach {
when (it) {
is PerformSecondFactor.SecondFactorState.Success -> {
secondFactorState.post(it.copy(isMailboxLoginNeeded = isMailboxLoginNeeded))
accountWorkflowHandler.handleSecondFactorSuccess(
sessionId = sessionId,
updatedScopes = it.scopeInfo.scopes
)
}
is PerformSecondFactor.SecondFactorState.Error -> {
secondFactorState.post(it)
accountWorkflowHandler.handleSecondFactorFailed(sessionId)
}
else -> {
secondFactorState.post(it)
}
}
performSecondFactor(sessionId, secondFactorCode, isTwoPassModeNeeded).onEach {
if (it is PerformSecondFactor.SecondFactorState.Success) {
accountWorkflowHandler.handleSecondFactorSuccess(
sessionId = sessionId,
updatedScopes = it.scopeInfo.scopes
)
}
.launchIn(viewModelScope)
secondFactorState.post(it)
}.launchIn(viewModelScope)
}
fun stopSecondFactorFlow(
sessionId: SessionId
): Job = viewModelScope.launch { accountWorkflowHandler.handleSecondFactorFailed(sessionId) }
}

View File

@ -26,9 +26,10 @@ import io.mockk.slot
import io.mockk.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.flowOf
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.auth.domain.AccountWorkflowHandler
import me.proton.core.auth.domain.crypto.SrpProofProvider
import me.proton.core.auth.domain.entity.Account
import me.proton.core.auth.domain.entity.SessionInfo
import me.proton.core.auth.domain.repository.AuthRepository
import me.proton.core.auth.domain.usecase.PerformLogin
@ -232,7 +233,7 @@ class LoginViewModelTest : ArchTest, CoroutinesTest {
val account = accountArgument.captured
val session = sessionArgument.captured
assertNotNull(account)
assertTrue(account.isTwoPassModeNeeded)
assertEquals(AccountState.TwoPassModeNeeded, account.state)
assertEquals(testSessionId, session.sessionId.id)
}
}

View File

@ -123,18 +123,12 @@ class MailboxLoginViewModelTest : ArchTest, CoroutinesTest {
}
@Test
fun `failed mailbox login invokes failed on account manager`() = coroutinesTest {
// GIVEN
coEvery { useCase.invoke(SessionId(testSessionId), testPassword.toByteArray()) } returns flowOf(
PerformMailboxLogin.MailboxLoginState.Processing,
PerformMailboxLogin.MailboxLoginState.Error.Message("test error")
)
fun `stop mailbox login invokes failed on account manager`() = coroutinesTest {
// WHEN
viewModel.startMailboxLoginFlow(SessionId(testSessionId), testPassword.toByteArray())
viewModel.stopMailboxLoginFlow(SessionId(testSessionId))
// THEN
val arguments = slot<SessionId>()
coVerify(exactly = 1) { accountManager.handleTwoPassModeFailed(capture(arguments)) }
coVerify(exactly = 0) { accountManager.handleTwoPassModeSuccess(any()) }
assertEquals(testSessionId, arguments.captured.id)
}
}

View File

@ -66,7 +66,7 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
val isMailboxLoginNeeded = false
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode) } returns flowOf(
PerformSecondFactor.SecondFactorState.Processing,
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo)
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo, isMailboxLoginNeeded)
)
val observer = mockk<(PerformSecondFactor.SecondFactorState) -> Unit>(relaxed = true)
viewModel.secondFactorState.observeDataForever(observer)
@ -103,7 +103,7 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
val isMailboxLoginNeeded = false
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode) } returns flowOf(
PerformSecondFactor.SecondFactorState.Processing,
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo)
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo, isMailboxLoginNeeded)
)
val observer = mockk<(PerformSecondFactor.SecondFactorState) -> Unit>(relaxed = true)
viewModel.secondFactorState.observeDataForever(observer)
@ -119,7 +119,7 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
val successState = arguments[1]
assertTrue(processingState is PerformSecondFactor.SecondFactorState.Processing)
assertTrue(successState is PerformSecondFactor.SecondFactorState.Success)
assertFalse(successState.isMailboxLoginNeeded)
assertFalse(successState.isTwoPassModeNeeded)
assertEquals(SessionId(testSessionId), accountManagerArguments.captured)
}
@ -127,9 +127,9 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
fun `submit 2fa two pass mode flow states are handled correctly`() = coroutinesTest {
// GIVEN
val isMailboxLoginNeeded = true
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode) } returns flowOf(
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode, isMailboxLoginNeeded) } returns flowOf(
PerformSecondFactor.SecondFactorState.Processing,
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo)
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo, isMailboxLoginNeeded)
)
val observer = mockk<(PerformSecondFactor.SecondFactorState) -> Unit>(relaxed = true)
viewModel.secondFactorState.observeDataForever(observer)
@ -145,7 +145,17 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
val successState = arguments[1]
assertTrue(processingState is PerformSecondFactor.SecondFactorState.Processing)
assertTrue(successState is PerformSecondFactor.SecondFactorState.Success)
assertTrue(successState.isMailboxLoginNeeded)
assertTrue(successState.isTwoPassModeNeeded)
assertEquals(SessionId(testSessionId), accountManagerArguments.captured)
}
@Test
fun `stop 2fa invokes failed on account manager`() = coroutinesTest {
// WHEN
viewModel.stopSecondFactorFlow(SessionId(testSessionId))
// THEN
val arguments = slot<SessionId>()
coVerify(exactly = 1) { accountManager.handleSecondFactorFailed(capture(arguments)) }
coVerify(exactly = 0) { accountManager.handleSecondFactorSuccess(any(), any()) }
}
}

View File

@ -80,8 +80,8 @@ object ApplicationModule {
@Provides
@Singleton
fun provideApiProvider(apiFactory: ApiFactory): ApiProvider =
ApiProvider(apiFactory)
fun provideApiProvider(apiFactory: ApiFactory, sessionProvider: SessionProvider): ApiProvider =
ApiProvider(apiFactory, sessionProvider)
@Provides
@Singleton
@ -105,11 +105,4 @@ object ApplicationModule {
@Provides
@ClientSecret
fun provideClientSecret(): String = ""
@Provides
fun provideDispatcherProvider() = object : DispatcherProvider {
override val Io = Dispatchers.IO
override val Comp = Dispatchers.Default
override val Main = Dispatchers.Main
}
}

View File

@ -18,37 +18,31 @@
package me.proton.android.core.coreexample
import android.annotation.SuppressLint
import android.content.Intent
import android.os.Bundle
import android.widget.Button
import androidx.lifecycle.lifecycleScope
import dagger.hilt.android.AndroidEntryPoint
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import me.proton.android.core.coreexample.databinding.ActivityMainBinding
import me.proton.android.core.coreexample.ui.CustomViewsActivity
import me.proton.android.core.presentation.ui.ProtonActivity
import me.proton.android.core.presentation.utils.onClick
import me.proton.core.account.domain.entity.Account
import me.proton.core.account.domain.entity.AccountState
import me.proton.core.account.domain.entity.SessionState
import me.proton.core.accountmanager.domain.AccountManager
import me.proton.core.accountmanager.presentation.observe
import me.proton.core.accountmanager.presentation.onAccountAdded
import me.proton.core.accountmanager.presentation.onAccountDisabled
import me.proton.core.accountmanager.presentation.onAccountInitializing
import me.proton.core.accountmanager.presentation.onAccountReady
import me.proton.core.accountmanager.presentation.onAccountRemoved
import me.proton.core.accountmanager.presentation.onAccountTwoPassModeFailed
import me.proton.core.accountmanager.presentation.onSessionAuthenticated
import me.proton.core.accountmanager.presentation.onSessionForceLogout
import me.proton.core.accountmanager.presentation.onSessionHumanVerificationFailed
import me.proton.core.accountmanager.domain.getPrimaryAccount
import me.proton.core.auth.presentation.AuthOrchestrator
import me.proton.core.domain.entity.UserId
import me.proton.core.auth.presentation.onHumanVerificationResult
import me.proton.core.auth.presentation.onScopeResult
import me.proton.core.auth.presentation.onSessionResult
import me.proton.core.auth.presentation.onUserResult
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.humanverification.VerificationMethod
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import java.util.UUID
import javax.inject.Inject
@AndroidEntryPoint
@ -57,79 +51,80 @@ class MainActivity : ProtonActivity<ActivityMainBinding>() {
@Inject
lateinit var accountManager: AccountManager
@Inject
lateinit var authOrchestrator: AuthOrchestrator
override fun layoutId(): Int = R.layout.activity_main
private val authWorkflowLauncher = AuthOrchestrator()
@SuppressLint("SetTextI18n")
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
authWorkflowLauncher.register(this)
authOrchestrator.register(this)
authOrchestrator
.onUserResult { }
.onScopeResult { }
.onSessionResult { }
.onHumanVerificationResult { }
binding.humanVerification.onClick {
authWorkflowLauncher.startHumanVerificationWorkflow(
SessionId("sessionId"),
HumanVerificationDetails(
listOf(
VerificationMethod.CAPTCHA,
VerificationMethod.EMAIL,
VerificationMethod.PHONE
with(binding) {
humanVerification.onClick {
authOrchestrator.startHumanVerificationWorkflow(
SessionId("sessionId"),
HumanVerificationDetails(
listOf(
VerificationMethod.CAPTCHA,
VerificationMethod.EMAIL,
VerificationMethod.PHONE
)
)
)
)
}
customViews.onClick { startActivity(Intent(this@MainActivity, CustomViewsActivity::class.java)) }
login.onClick { authOrchestrator.startLoginWorkflow() }
}
binding.customViews.onClick {
startActivity(Intent(this, CustomViewsActivity::class.java))
}
accountManager.getPrimaryAccount().onEach { primary ->
binding.primaryAccountText.text = "Primary: ${primary?.username}"
}.launchIn(lifecycleScope)
binding.login.onClick {
authWorkflowLauncher.startLoginWorkflow()
}
accountManager.getAccounts().onEach { accounts ->
if (accounts.isEmpty()) authOrchestrator.startLoginWorkflow()
accountManager.getPrimaryUserId()
.onEach { userId ->
if (userId == null) {
val userId = UserId(UUID.randomUUID().toString())
val sessionId = SessionId(UUID.randomUUID().toString())
val session = Session(
sessionId = sessionId,
accessToken = "accessToken",
refreshToken = "refreshToken",
headers = null,
scopes = listOf()
)
accountManager.addAccount(
Account(
userId,
"username",
"example@example.com",
AccountState.Ready,
sessionId,
SessionState.Authenticated
),
session
)
}
}.launchIn(lifecycleScope)
accountManager.getSessions().onEach { }.launchIn(lifecycleScope)
accountManager.observe(lifecycleScope)
.onAccountAdded { }
.onAccountDisabled { }
.onAccountInitializing { }
.onAccountReady { }
.onAccountRemoved { }
.onAccountTwoPassModeFailed { }
.onSessionAuthenticated { }
.onSessionForceLogout { }
.onSessionHumanVerificationFailed { }
binding.accountsLayout.removeAllViews()
accounts.forEach { account ->
binding.accountsLayout.addView(
Button(this@MainActivity).apply {
text = "${account.username} -> ${account.state}/${account.sessionState}"
onClick {
lifecycleScope.launch {
when (account.state) {
AccountState.Ready ->
accountManager.disableAccount(account.userId)
AccountState.Disabled ->
accountManager.removeAccount(account.userId)
AccountState.NotReady,
AccountState.TwoPassModeNeeded,
AccountState.TwoPassModeFailed ->
when (account.sessionState) {
SessionState.SecondFactorNeeded,
SessionState.SecondFactorFailed ->
accountManager.disableAccount(account.userId)
SessionState.Authenticated ->
authOrchestrator.startTwoPassModeWorkflow(account.sessionId!!)
else -> Unit
}
else -> Unit
}
}
}
}
)
}
}.launchIn(lifecycleScope)
accountManager.onHumanVerificationNeeded().onEach { (account, details) ->
account.sessionId?.let {
authWorkflowLauncher.startHumanVerificationWorkflow(it, details)
}
authOrchestrator.startHumanVerificationWorkflow(account.sessionId!!, details)
}.launchIn(lifecycleScope)
}
}

View File

@ -55,6 +55,43 @@
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@id/customViews" />
<LinearLayout
android:id="@+id/titleLayout"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginTop="@dimen/default_gap"
android:orientation="horizontal"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@id/login">
<TextView
android:id="@+id/accountsText"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="Accounts:" />
<TextView
android:id="@+id/primaryAccountText"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_weight="1"
android:gravity="end"
android:text="-" />
</LinearLayout>
<LinearLayout
android:id="@+id/accountsLayout"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginTop="@dimen/default_top_margin"
android:orientation="vertical"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@id/titleLayout" />
</androidx.constraintlayout.widget.ConstraintLayout>
</layout>

View File

@ -21,20 +21,17 @@ package me.proton.core.data.arch
import me.proton.core.domain.arch.DataResult
import me.proton.core.domain.arch.ResponseSource
import me.proton.core.network.domain.ApiResult
import me.proton.core.util.kotlin.Invokable
import me.proton.core.util.kotlin.exhaustive
class ApiResultMapper : Invokable {
fun <T> ApiResult<T>.toDataResponse(): DataResult<T> = when (this) {
is ApiResult.Success -> DataResult.Success(value, ResponseSource.Remote)
is ApiResult.Error.Http -> {
DataResult.Error.Message(
message = proton?.error ?: message,
source = ResponseSource.Remote,
code = proton?.code ?: 0 // 0 means no code is present
)
}
is ApiResult.Error.Parse -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
is ApiResult.Error.Connection -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
}.exhaustive
}
fun <T> ApiResult<T>.toDataResponse(): DataResult<T> = when (this) {
is ApiResult.Success -> DataResult.Success(value, ResponseSource.Remote)
is ApiResult.Error.Http -> {
DataResult.Error.Message(
message = proton?.error ?: message,
source = ResponseSource.Remote,
code = proton?.code ?: 0 // 0 means no code is present
)
}
is ApiResult.Error.Parse -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
is ApiResult.Error.Connection -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
}.exhaustive

View File

@ -31,6 +31,7 @@ import me.proton.core.network.data.protonApi.GenericResponse
import me.proton.core.network.domain.ApiManager
import me.proton.core.network.domain.ApiResult
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionProvider
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
@ -49,6 +50,8 @@ class HumanVerificationRemoteRepositoryImplTest {
private val errorResponse = "test error response"
private val errorResponseCode = 422
@RelaxedMockK
private lateinit var sessionProvider: SessionProvider
@RelaxedMockK
private lateinit var apiFactory: ApiFactory
private lateinit var apiProvider: ApiProvider
@ -59,7 +62,7 @@ class HumanVerificationRemoteRepositoryImplTest {
@Before
fun before() {
MockKAnnotations.init(this)
apiProvider = ApiProvider(apiFactory)
apiProvider = ApiProvider(apiFactory, sessionProvider)
every { apiFactory.create(sessionId, HumanVerificationApi::class) } returns apiManager
}

View File

@ -20,7 +20,6 @@ package me.proton.core.humanverification.presentation.entity
import android.os.Parcelable
import kotlinx.android.parcel.Parcelize
import me.proton.core.network.domain.session.SessionId
/**
* Human Verification result entity.
@ -28,6 +27,6 @@ import me.proton.core.network.domain.session.SessionId
@Parcelize
data class HumanVerificationResult(
val sessionId: String,
val tokenType: String,
val tokenCode: String
val tokenType: String?,
val tokenCode: String?
) : Parcelable

View File

@ -49,8 +49,8 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
companion object {
private const val ARG_SESSION_ID = "arg.sessionId"
const val ARG_VERIFICATION_OPTIONS = "arg.verification-options"
private const val ARG_CAPTCHA_TOKEN = "arg.captcha-token"
const val ARG_VERIFICATION_OPTIONS = "arg.verification-options"
const val ARG_DESTINATION = "arg.destination"
const val ARG_TOKEN_CODE = "arg.token-code"
const val ARG_TOKEN_TYPE = "arg.token-type"
@ -65,7 +65,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
* @param captchaToken if the API returns it, otherwise null
*/
operator fun invoke(
sessionId: SessionId,
sessionId: String,
availableVerificationMethods: List<String>,
captchaToken: String?
) = HumanVerificationDialogFragment().apply {
@ -81,7 +81,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
private lateinit var resultListener: OnResultListener
private val sessionId: SessionId by lazy {
requireArguments().get(ARG_SESSION_ID) as SessionId
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
}
private val captchaToken: String? by lazy {
@ -187,13 +187,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
private fun onClose(tokenType: String? = null, tokenCode: String? = null) {
if (!tokenType.isNullOrEmpty() && !tokenCode.isNullOrEmpty()) {
resultListener.setResult(
HumanVerificationResult(
sessionId.id,
tokenType,
tokenCode
)
)
resultListener.setResult(HumanVerificationResult(sessionId.id, tokenType, tokenCode))
dismissAllowingStateLoss()
return
}
@ -205,7 +199,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
if (backStackEntryCount >= 1) {
popBackStack()
} else {
resultListener.setResult(null)
resultListener.setResult(HumanVerificationResult(sessionId.id, null, null))
dismissAllowingStateLoss()
}
}

View File

@ -43,8 +43,7 @@ import me.proton.core.network.domain.session.SessionId
* @author Dino Kadrikj.
*/
@AndroidEntryPoint
internal class HumanVerificationCaptchaFragment :
ProtonFragment<FragmentHumanVerificationCaptchaBinding>() {
internal class HumanVerificationCaptchaFragment : ProtonFragment<FragmentHumanVerificationCaptchaBinding>() {
companion object {
private const val ARG_SESSION_ID = "arg.sessionId"
@ -52,7 +51,7 @@ internal class HumanVerificationCaptchaFragment :
private const val MAX_PROGRESS = 100
operator fun invoke(
sessionId: SessionId,
sessionId: String,
urlToken: String,
host: String
) = HumanVerificationCaptchaFragment().apply {
@ -65,7 +64,7 @@ internal class HumanVerificationCaptchaFragment :
}
private val sessionId: SessionId by lazy {
requireArguments().get(ARG_SESSION_ID) as SessionId
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
}
private val host: String by lazy {

View File

@ -49,7 +49,7 @@ internal class HumanVerificationEmailFragment : ProtonFragment<FragmentHumanVeri
private const val ARG_RECOVERY_EMAIL = "arg.recoveryemail"
operator fun invoke(
sessionId: SessionId,
sessionId: String,
token: String,
recoveryEmailAddress: String? = null
) = HumanVerificationEmailFragment().apply {
@ -72,7 +72,7 @@ internal class HumanVerificationEmailFragment : ProtonFragment<FragmentHumanVeri
}
private val sessionId: SessionId by lazy {
requireArguments().get(ARG_SESSION_ID) as SessionId
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
}
private val recoveryEmailAddress: String? by lazy {

View File

@ -53,7 +53,7 @@ class HumanVerificationEnterCodeFragment :
private const val ARG_TOKEN_TYPE = "arg.enter-code-token-type"
operator fun invoke(
sessionId: SessionId,
sessionId: String,
tokenType: TokenType,
destination: String?
) = HumanVerificationEnterCodeFragment().apply {
@ -68,7 +68,7 @@ class HumanVerificationEnterCodeFragment :
private val viewModel by viewModels<HumanVerificationEnterCodeViewModel>()
private val sessionId: SessionId by lazy {
requireArguments().get(ARG_SESSION_ID) as SessionId
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
}
private val destination: String? by lazy {

View File

@ -53,7 +53,10 @@ internal class HumanVerificationSMSFragment :
internal const val KEY_COUNTRY_SELECTED = "key.country_selected"
internal const val BUNDLE_KEY_COUNTRY = "bundle.country"
operator fun invoke(sessionId: SessionId, token: String) = HumanVerificationSMSFragment().apply {
operator fun invoke(
sessionId: String,
token: String
) = HumanVerificationSMSFragment().apply {
arguments = bundleOf(
ARG_SESSION_ID to sessionId,
ARG_URL_TOKEN to token
@ -64,7 +67,7 @@ internal class HumanVerificationSMSFragment :
private val viewModel by viewModels<HumanVerificationSMSViewModel>()
private val sessionId: SessionId by lazy {
requireArguments().get(ARG_SESSION_ID) as SessionId
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
}
private val humanVerificationBase by lazy {

View File

@ -56,7 +56,7 @@ fun FragmentManager.showHumanVerification(
largeLayout: Boolean
) {
val newFragment = HumanVerificationDialogFragment(sessionId, availableVerificationMethods, captchaToken)
val newFragment = HumanVerificationDialogFragment(sessionId.id, availableVerificationMethods, captchaToken)
if (largeLayout) {
// For large screens (tablets), we show the fragment as a dialog
newFragment.show(this, TAG_HUMAN_VERIFICATION_DIALOG)
@ -79,7 +79,7 @@ internal fun FragmentManager.showHumanVerificationCaptchaContent(
token: String?,
host: String = HOST_DEFAULT
): Fragment {
val captchaFragment = HumanVerificationCaptchaFragment(sessionId, token ?: TOKEN_DEFAULT, host)
val captchaFragment = HumanVerificationCaptchaFragment(sessionId.id, token ?: TOKEN_DEFAULT, host)
inTransaction {
setCustomAnimations(0, 0)
replace(containerId, captchaFragment)
@ -92,7 +92,7 @@ internal fun FragmentManager.showHumanVerificationEmailContent(
sessionId: SessionId,
token: String = TOKEN_DEFAULT
) {
val emailFragment = HumanVerificationEmailFragment(sessionId, token)
val emailFragment = HumanVerificationEmailFragment(sessionId.id, token)
inTransaction {
setCustomAnimations(0, 0)
replace(containerId, emailFragment)
@ -104,7 +104,7 @@ internal fun FragmentManager.showHumanVerificationSMSContent(
containerId: Int = android.R.id.content,
token: String = TOKEN_DEFAULT
) {
val smsFragment = HumanVerificationSMSFragment(sessionId, token)
val smsFragment = HumanVerificationSMSFragment(sessionId.id, token)
inTransaction {
setCustomAnimations(0, 0)
replace(containerId, smsFragment)
@ -116,7 +116,7 @@ internal fun FragmentManager.showEnterCode(
tokenType: TokenType,
destination: String?
) {
val enterCodeFragment = HumanVerificationEnterCodeFragment(sessionId, tokenType, destination)
val enterCodeFragment = HumanVerificationEnterCodeFragment(sessionId.id, tokenType, destination)
inTransaction {
setCustomAnimations(0, 0)
add(enterCodeFragment, TAG_HUMAN_VERIFICATION_ENTER_CODE)

View File

@ -36,6 +36,7 @@ dependencies {
project(Module.kotlinUtil),
project(Module.sharedPreferencesUtil),
project(Module.networkDomain),
project(Module.domain),
// Kotlin
`kotlin-jdk7`,

View File

@ -18,22 +18,31 @@
package me.proton.core.network.data
import me.proton.core.domain.entity.UserId
import me.proton.core.network.data.di.ApiFactory
import me.proton.core.network.data.protonApi.BaseRetrofitApi
import me.proton.core.network.domain.ApiManager
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionProvider
import java.lang.ref.Reference
import java.lang.ref.WeakReference
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentMap
/**
* Provide [ApiManager] instance bound to a specific [SessionId].
*/
class ApiProvider(
val apiFactory: ApiFactory
val apiFactory: ApiFactory,
val sessionProvider: SessionProvider
) {
val instances: ConcurrentHashMap<String, ConcurrentHashMap<String, WeakReference<ApiManager<*>>>> =
val instances: ConcurrentHashMap<String, ConcurrentHashMap<String, Reference<ApiManager<*>>>> =
ConcurrentHashMap()
suspend inline fun <reified Api : BaseRetrofitApi> get(
userId: UserId
): ApiManager<out Api> = get(sessionProvider.getSessionId(userId))
inline fun <reified Api : BaseRetrofitApi> get(
sessionId: SessionId? = null
): ApiManager<out Api> {
@ -44,7 +53,9 @@ class ApiProvider(
val className = Api::class.java.name
return instances
.getOrPut(sessionName) { ConcurrentHashMap() }
.getOrPut(className) { WeakReference(apiFactory.create(sessionId, Api::class)) }
.get() as ApiManager<out Api>
.getOrPutWeakRef(className) { apiFactory.create(sessionId, Api::class) } as ApiManager<out Api>
}
fun <K, V> ConcurrentMap<K, Reference<V>>.getOrPutWeakRef(key: K, defaultValue: () -> V): V =
this[key]?.get() ?: defaultValue().apply { put(key, WeakReference(this)) }
}

View File

@ -17,6 +17,7 @@
*/
package me.proton.core.network.data
import kotlinx.coroutines.runBlocking
import me.proton.core.network.data.di.Constants
import me.proton.core.network.data.protonApi.BaseRetrofitApi
import me.proton.core.network.data.protonApi.ProtonErrorData
@ -97,15 +98,13 @@ internal class ProtonApiBackend<Api : BaseRetrofitApi>(
}
private fun handleTimeoutTag(chain: Interceptor.Chain): Interceptor.Chain {
val tag = chain.request().tag()
return if (tag is TimeoutOverride) {
val tag = chain.request().tag(TimeoutOverride::class.java)
return tag?.let {
chain
.withConnectTimeout(tag.connectionTimeoutSeconds, TimeUnit.SECONDS)
.withReadTimeout(tag.readTimeoutSeconds, TimeUnit.SECONDS)
.withWriteTimeout(tag.writeTimeoutSeconds, TimeUnit.SECONDS)
} else {
chain
}
} ?: chain
}
private fun prepareHeaders(original: Request): Request.Builder {
@ -118,7 +117,7 @@ internal class ProtonApiBackend<Api : BaseRetrofitApi>(
request.header("Accept", "application/vnd.protonmail.v1+json")
}
sessionId?.let { sessionProvider.getSession(it) }?.let { session ->
sessionId?.let { runBlocking { sessionProvider.getSession(it) } }?.let { session ->
session.headers?.let {
request.header("x-pm-human-verification-token-type", it.tokenType)
request.header("x-pm-human-verification-token", it.tokenCode)

View File

@ -100,7 +100,8 @@ internal class ApiManagerTests {
apiClient = MockApiClient()
session = MockSession.getDefault()
every { sessionProvider.getSession(any()) } returns session
coEvery { sessionProvider.getSessionId(any()) } returns session.sessionId
coEvery { sessionProvider.getSession(any()) } returns session
networkManager = MockNetworkManager()
networkManager.networkStatus = NetworkStatus.Unmetered

View File

@ -19,6 +19,7 @@
package me.proton.core.network.data
import io.mockk.MockKAnnotations
import io.mockk.coEvery
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.mockk
@ -127,7 +128,8 @@ internal class HumanVerificationTests {
prefs = MockNetworkPrefs()
session = MockSession.getDefault()
every { sessionProvider.getSession(any()) } returns session
coEvery { sessionProvider.getSessionId(any()) } returns session.sessionId
coEvery { sessionProvider.getSession(any()) } returns session
apiFactory =
ApiFactory(
@ -218,7 +220,7 @@ internal class HumanVerificationTests {
)
)
every { sessionProvider.getSession(any()) } returns MockSession.getWithHeader(
coEvery { sessionProvider.getSession(any()) } returns MockSession.getWithHeader(
humanVerificationHeaders
)

View File

@ -58,10 +58,14 @@ internal class NetworkManagerTests {
networkManager.networkStatus = NetworkStatus.Unmetered
flow2.cancel()
assertEquals(listOf(NetworkStatus.Unmetered, NetworkStatus.Metered, NetworkStatus.Disconnected),
collectedStates1.toList())
assertEquals(listOf(NetworkStatus.Metered, NetworkStatus.Disconnected, NetworkStatus.Unmetered),
collectedStates2.toList())
assertEquals(
listOf(NetworkStatus.Unmetered, NetworkStatus.Metered, NetworkStatus.Disconnected),
collectedStates1.toList()
)
assertEquals(
listOf(NetworkStatus.Metered, NetworkStatus.Disconnected, NetworkStatus.Unmetered),
collectedStates2.toList()
)
assertFalse(networkManager.registered)
}

View File

@ -18,6 +18,7 @@
package me.proton.core.network.data
import io.mockk.MockKAnnotations
import io.mockk.coEvery
import io.mockk.every
import io.mockk.impl.annotations.MockK
import kotlinx.coroutines.CoroutineScope
@ -92,7 +93,8 @@ internal class ProtonApiBackendTests {
prefs = MockNetworkPrefs()
session = MockSession.getDefault()
every { sessionProvider.getSession(any()) } returns session
coEvery { sessionProvider.getSessionId(any()) } returns session.sessionId
coEvery { sessionProvider.getSession(any()) } returns session
apiFactory = ApiFactory(
"https://example.com/",

View File

@ -31,6 +31,7 @@ dependencies {
implementation(
project(Module.kotlinUtil),
project(Module.domain),
// Kotlin
`kotlin-jdk7`,

View File

@ -18,6 +18,16 @@
package me.proton.core.network.domain.session
import me.proton.core.domain.entity.UserId
interface SessionProvider {
fun getSession(sessionId: SessionId): Session?
/**
* Get [Session], if exist, by sessionId.
*/
suspend fun getSession(sessionId: SessionId): Session?
/**
* Get [SessionId], if exist, by userId.
*/
suspend fun getSessionId(userId: UserId): SessionId?
}

View File

@ -27,10 +27,12 @@ import kotlinx.coroutines.test.runBlockingTest
import me.proton.core.network.domain.handlers.HumanVerificationHandler
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
import me.proton.core.network.domain.humanverification.VerificationMethod
import me.proton.core.network.domain.session.Session
import me.proton.core.network.domain.session.SessionId
import me.proton.core.network.domain.session.SessionListener
import me.proton.core.network.domain.session.SessionProvider
import org.junit.Test
import kotlin.test.BeforeTest
import kotlin.test.assertNotNull
/**
@ -39,12 +41,18 @@ import kotlin.test.assertNotNull
class HumanVerificationHandlerTest {
private val sessionId: SessionId = SessionId("id")
private val session = mockk<Session>(relaxed = true)
private val sessionListener = mockk<SessionListener>(relaxed = true)
private val sessionProvider = mockk<SessionProvider>(relaxed = true)
val scope = CoroutineScope(TestCoroutineDispatcher())
val apiBackend = mockk<ApiBackend<Any>>()
@BeforeTest
fun beforeTest() {
coEvery { sessionProvider.getSession(any()) } returns session
}
@Test
fun `test human verification called`() = runBlockingTest {
val humanVerificationDetails =
@ -59,7 +67,12 @@ class HumanVerificationHandlerTest {
)
)
coEvery { sessionListener.onHumanVerificationNeeded(any(), any()) } returns SessionListener.HumanVerificationResult.Success
coEvery {
sessionListener.onHumanVerificationNeeded(
any(),
any()
)
} returns SessionListener.HumanVerificationResult.Success
coEvery { apiBackend.invoke<Any>(any()) } returns ApiResult.Success("test")
val humanVerificationHandler =