Integrated AccountManager with modules.
This commit is contained in:
parent
755dc1b8a0
commit
bf33818a5a
|
@ -28,9 +28,11 @@ import dagger.hilt.android.qualifiers.ApplicationContext
|
|||
import me.proton.core.account.data.repository.AccountRepositoryImpl
|
||||
import me.proton.core.account.domain.repository.AccountRepository
|
||||
import me.proton.core.accountmanager.data.AccountManagerImpl
|
||||
import me.proton.core.accountmanager.data.SessionManagerImpl
|
||||
import me.proton.core.accountmanager.data.db.AccountManagerDatabase
|
||||
import me.proton.core.accountmanager.domain.AccountManager
|
||||
import me.proton.core.auth.domain.AccountWorkflowHandler
|
||||
import me.proton.core.auth.domain.repository.AuthRepository
|
||||
import me.proton.core.data.crypto.KeyStoreStringCrypto
|
||||
import me.proton.core.data.crypto.StringCrypto
|
||||
import me.proton.core.domain.entity.Product
|
||||
|
@ -63,8 +65,19 @@ object AccountManagerModule {
|
|||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideAccountManagerImpl(product: Product, accountRepository: AccountRepository): AccountManagerImpl =
|
||||
AccountManagerImpl(product, accountRepository)
|
||||
fun provideAccountManagerImpl(
|
||||
product: Product,
|
||||
accountRepository: AccountRepository,
|
||||
authRepository: AuthRepository
|
||||
): AccountManagerImpl =
|
||||
AccountManagerImpl(product, accountRepository, authRepository)
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideSessionManagerImpl(
|
||||
accountRepository: AccountRepository
|
||||
): SessionManagerImpl =
|
||||
SessionManagerImpl(accountRepository)
|
||||
}
|
||||
|
||||
@Module
|
||||
|
@ -78,8 +91,8 @@ interface AccountManagerBindModule {
|
|||
fun bindAccountWorkflowHandler(accountManagerImpl: AccountManagerImpl): AccountWorkflowHandler
|
||||
|
||||
@Binds
|
||||
fun bindSessionProvider(accountManagerImpl: AccountManagerImpl): SessionProvider
|
||||
fun bindSessionProvider(sessionManagerImpl: SessionManagerImpl): SessionProvider
|
||||
|
||||
@Binds
|
||||
fun bindSessionListener(accountManagerImpl: AccountManagerImpl): SessionListener
|
||||
fun bindSessionListener(sessionManagerImpl: SessionManagerImpl): SessionListener
|
||||
}
|
||||
|
|
|
@ -21,34 +21,32 @@ package me.proton.core.accountmanager.data
|
|||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.distinctUntilChanged
|
||||
import kotlinx.coroutines.flow.filter
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.map
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.account.domain.entity.isReady
|
||||
import me.proton.core.account.domain.entity.isSecondFactorNeeded
|
||||
import me.proton.core.account.domain.repository.AccountRepository
|
||||
import me.proton.core.accountmanager.domain.AccountManager
|
||||
import me.proton.core.accountmanager.domain.onSessionStateChanged
|
||||
import me.proton.core.auth.domain.AccountWorkflowHandler
|
||||
import me.proton.core.auth.domain.repository.AuthRepository
|
||||
import me.proton.core.domain.arch.extension.onEntityChanged
|
||||
import me.proton.core.domain.entity.Product
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionListener
|
||||
import me.proton.core.network.domain.session.SessionProvider
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
class AccountManagerImpl constructor(
|
||||
product: Product,
|
||||
private val accountRepository: AccountRepository
|
||||
) : AccountManager(product), AccountWorkflowHandler, SessionProvider, SessionListener {
|
||||
|
||||
private val humanVerificationDetails: ConcurrentHashMap<SessionId, HumanVerificationDetails?> = ConcurrentHashMap()
|
||||
private val accountRepository: AccountRepository,
|
||||
private val authRepository: AuthRepository
|
||||
) : AccountManager(product), AccountWorkflowHandler {
|
||||
|
||||
private suspend fun removeSession(sessionId: SessionId) {
|
||||
// TODO: Revoke session (fire and forget, need Auth module).
|
||||
authRepository.revokeSession(sessionId)
|
||||
accountRepository.deleteSession(sessionId)
|
||||
}
|
||||
|
||||
|
@ -100,7 +98,7 @@ class AccountManagerImpl constructor(
|
|||
|
||||
override fun onHumanVerificationNeeded(): Flow<Pair<Account, HumanVerificationDetails?>> =
|
||||
onSessionStateChanged(SessionState.HumanVerificationNeeded)
|
||||
.map { it to it.sessionId?.let { id -> humanVerificationDetails[id] } }
|
||||
.map { it to it.sessionId?.let { id -> accountRepository.getHumanVerificationDetails(id) } }
|
||||
|
||||
override fun getPrimaryUserId(): Flow<UserId?> =
|
||||
accountRepository.getPrimaryUserId()
|
||||
|
@ -111,30 +109,31 @@ class AccountManagerImpl constructor(
|
|||
// region AccountWorkflowHandler
|
||||
|
||||
override suspend fun handleSession(account: Account, session: Session) {
|
||||
val initializing = when {
|
||||
account.state == AccountState.TwoPassModeNeeded -> true
|
||||
account.sessionState == SessionState.SecondFactorNeeded -> true
|
||||
else -> false
|
||||
}
|
||||
val updatedAccount = if (initializing) account.copy(state = AccountState.Initializing) else account
|
||||
accountRepository.createOrUpdateAccountSession(updatedAccount, session)
|
||||
// Account state must be != Ready if SecondFactorNeeded.
|
||||
val state = if (account.isReady() && account.isSecondFactorNeeded()) AccountState.NotReady else account.state
|
||||
accountRepository.createOrUpdateAccountSession(account.copy(state = state), session)
|
||||
}
|
||||
|
||||
override suspend fun handleTwoPassModeSuccess(sessionId: SessionId) {
|
||||
accountRepository.updateAccountState(sessionId, AccountState.TwoPassModeSuccess)
|
||||
accountRepository.updateAccountState(sessionId, AccountState.Ready)
|
||||
accountRepository.getAccountOrNull(sessionId)?.let { account ->
|
||||
if (account.sessionState == SessionState.Authenticated)
|
||||
accountRepository.updateAccountState(sessionId, AccountState.Ready)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun handleTwoPassModeFailed(sessionId: SessionId) {
|
||||
accountRepository.updateAccountState(sessionId, AccountState.TwoPassModeFailed)
|
||||
disableAccount(sessionId)
|
||||
}
|
||||
|
||||
override suspend fun handleSecondFactorSuccess(sessionId: SessionId, updatedScopes: List<String>) {
|
||||
accountRepository.updateSessionScopes(sessionId, updatedScopes)
|
||||
accountRepository.updateSessionState(sessionId, SessionState.SecondFactorSuccess)
|
||||
accountRepository.updateSessionState(sessionId, SessionState.Authenticated)
|
||||
accountRepository.updateAccountState(sessionId, AccountState.Ready)
|
||||
accountRepository.getAccountOrNull(sessionId)?.let { account ->
|
||||
if (account.state != AccountState.TwoPassModeNeeded)
|
||||
accountRepository.updateAccountState(sessionId, AccountState.Ready)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun handleSecondFactorFailed(sessionId: SessionId) {
|
||||
|
@ -152,51 +151,7 @@ class AccountManagerImpl constructor(
|
|||
override suspend fun handleHumanVerificationFailed(sessionId: SessionId) {
|
||||
accountRepository.updateSessionHeaders(sessionId, null, null)
|
||||
accountRepository.updateSessionState(sessionId, SessionState.HumanVerificationFailed)
|
||||
disableAccount(sessionId)
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region SessionListener
|
||||
|
||||
override suspend fun onSessionTokenRefreshed(session: Session) {
|
||||
accountRepository.updateSessionToken(session.sessionId, session.accessToken, session.refreshToken)
|
||||
accountRepository.updateSessionState(session.sessionId, SessionState.Authenticated)
|
||||
}
|
||||
|
||||
override suspend fun onSessionForceLogout(session: Session) {
|
||||
accountRepository.updateSessionState(session.sessionId, SessionState.ForceLogout)
|
||||
accountRepository.getAccountOrNull(session.sessionId)?.let { account ->
|
||||
accountRepository.updateAccountState(account.userId, AccountState.Disabled)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun onHumanVerificationNeeded(
|
||||
session: Session,
|
||||
details: HumanVerificationDetails?
|
||||
): SessionListener.HumanVerificationResult {
|
||||
humanVerificationDetails[session.sessionId] = details
|
||||
accountRepository.updateSessionState(session.sessionId, SessionState.HumanVerificationNeeded)
|
||||
accountRepository.updateAccountState(session.sessionId, AccountState.Initializing)
|
||||
|
||||
// Wait for HumanVerification Success or Failure.
|
||||
val state = accountRepository.getAccount(session.sessionId)
|
||||
.map { it?.sessionState }
|
||||
.filter { it == SessionState.HumanVerificationSuccess || it == SessionState.HumanVerificationFailed }
|
||||
.first()
|
||||
|
||||
return when (state) {
|
||||
null -> SessionListener.HumanVerificationResult.Failure
|
||||
SessionState.HumanVerificationSuccess -> SessionListener.HumanVerificationResult.Success
|
||||
else -> SessionListener.HumanVerificationResult.Failure
|
||||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region SessionProvider
|
||||
|
||||
override fun getSession(sessionId: SessionId): Session? = accountRepository.getSessionOrNull(sessionId)
|
||||
|
||||
// endregion
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Proton Technologies AG
|
||||
* This file is part of Proton Technologies AG and ProtonCore.
|
||||
*
|
||||
* ProtonCore is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonCore is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package me.proton.core.accountmanager.data
|
||||
|
||||
import kotlinx.coroutines.flow.filter
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.map
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.account.domain.repository.AccountRepository
|
||||
import me.proton.core.accountmanager.domain.SessionManager
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionListener
|
||||
|
||||
class SessionManagerImpl(
|
||||
private val accountRepository: AccountRepository
|
||||
) : SessionManager {
|
||||
|
||||
// region SessionListener
|
||||
|
||||
override suspend fun onSessionTokenRefreshed(session: Session) {
|
||||
accountRepository.updateSessionToken(session.sessionId, session.accessToken, session.refreshToken)
|
||||
accountRepository.updateSessionState(session.sessionId, SessionState.Authenticated)
|
||||
}
|
||||
|
||||
override suspend fun onSessionForceLogout(session: Session) {
|
||||
accountRepository.updateSessionState(session.sessionId, SessionState.ForceLogout)
|
||||
accountRepository.getAccountOrNull(session.sessionId)?.let { account ->
|
||||
accountRepository.updateAccountState(account.userId, AccountState.Disabled)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun onHumanVerificationNeeded(
|
||||
session: Session,
|
||||
details: HumanVerificationDetails?
|
||||
): SessionListener.HumanVerificationResult {
|
||||
accountRepository.setHumanVerificationDetails(session.sessionId, details)
|
||||
accountRepository.updateAccountState(session.sessionId, AccountState.NotReady)
|
||||
accountRepository.updateSessionState(session.sessionId, SessionState.HumanVerificationNeeded)
|
||||
|
||||
// Wait for HumanVerification Success or Failure.
|
||||
val state = accountRepository.getAccount(session.sessionId)
|
||||
.map { it?.sessionState }
|
||||
.filter { it == SessionState.HumanVerificationSuccess || it == SessionState.HumanVerificationFailed }
|
||||
.first()
|
||||
|
||||
return when (state) {
|
||||
null -> SessionListener.HumanVerificationResult.Failure
|
||||
SessionState.HumanVerificationSuccess -> SessionListener.HumanVerificationResult.Success
|
||||
else -> SessionListener.HumanVerificationResult.Failure
|
||||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region SessionProvider
|
||||
|
||||
override suspend fun getSession(sessionId: SessionId): Session? =
|
||||
accountRepository.getSessionOrNull(sessionId)
|
||||
|
||||
override suspend fun getSessionId(userId: UserId): SessionId? =
|
||||
accountRepository.getSessionIdOrNull(userId)
|
||||
|
||||
// endregion
|
||||
}
|
|
@ -17,27 +17,19 @@
|
|||
*/
|
||||
package me.proton.core.accountmanager.data
|
||||
|
||||
import io.mockk.MockKAnnotations
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.RelaxedMockK
|
||||
import io.mockk.slot
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.test.runBlockingTest
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.account.domain.repository.AccountRepository
|
||||
import me.proton.core.domain.entity.Product
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationHeaders
|
||||
import me.proton.core.network.domain.humanverification.VerificationMethod
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionListener
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
@ -46,9 +38,6 @@ class AccountManagerImplTest {
|
|||
|
||||
private lateinit var accountManager: AccountManagerImpl
|
||||
|
||||
@RelaxedMockK
|
||||
private lateinit var accountRepository: AccountRepository
|
||||
|
||||
private val session1 = Session(
|
||||
sessionId = SessionId("session1"),
|
||||
accessToken = "accessToken",
|
||||
|
@ -66,145 +55,25 @@ class AccountManagerImplTest {
|
|||
sessionState = SessionState.Authenticated
|
||||
)
|
||||
|
||||
private val flowOfAccountLists = mutableListOf<List<Account>>()
|
||||
private val flowOfSessionLists = mutableListOf<List<Session>>()
|
||||
|
||||
@Suppress("LongMethod")
|
||||
private fun setupUpdateListsFlow() {
|
||||
val userIdSlot = slot<UserId>()
|
||||
val sessionIdSlot = slot<SessionId>()
|
||||
val accountStateSlot = slot<AccountState>()
|
||||
val sessionStateSlot = slot<SessionState>()
|
||||
val updatedScopesSlot = slot<List<String>>()
|
||||
val tokenTypeSlot = slot<String>()
|
||||
val tokenCodeSlot = slot<String>()
|
||||
val accessTokenSlot = slot<String>()
|
||||
val refreshTokenSlot = slot<String>()
|
||||
|
||||
// Initial state.
|
||||
flowOfAccountLists.clear()
|
||||
flowOfSessionLists.clear()
|
||||
flowOfAccountLists.add(listOf(account1))
|
||||
flowOfSessionLists.add(listOf(session1))
|
||||
|
||||
// For each updateAccountState -> emit a new updated List<Account> from getAccounts().
|
||||
coEvery { accountRepository.updateAccountState(capture(userIdSlot), capture(accountStateSlot)) } answers {
|
||||
flowOfAccountLists.add(
|
||||
listOf(
|
||||
flowOfAccountLists.last().first { it.userId == userIdSlot.captured }.copy(
|
||||
userId = userIdSlot.captured,
|
||||
state = accountStateSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionState -> emit a new updated List<Account> from getAccounts().
|
||||
coEvery { accountRepository.updateSessionState(capture(sessionIdSlot), capture(sessionStateSlot)) } answers {
|
||||
flowOfAccountLists.add(
|
||||
listOf(
|
||||
flowOfAccountLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
sessionState = sessionStateSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionScopes -> emit a new updated List<Session> from getSessions().
|
||||
coEvery { accountRepository.updateSessionScopes(capture(sessionIdSlot), capture(updatedScopesSlot)) } answers {
|
||||
flowOfSessionLists.add(
|
||||
listOf(
|
||||
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
scopes = updatedScopesSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionHeaders -> emit a new updated List<Session> from getSessions().
|
||||
coEvery {
|
||||
accountRepository.updateSessionHeaders(
|
||||
capture(sessionIdSlot),
|
||||
capture(tokenTypeSlot),
|
||||
capture(tokenCodeSlot)
|
||||
)
|
||||
} answers {
|
||||
flowOfSessionLists.add(
|
||||
listOf(
|
||||
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
headers = HumanVerificationHeaders(
|
||||
tokenTypeSlot.captured,
|
||||
tokenCodeSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionToken -> emit a new updated List<Session> from getSessions().
|
||||
coEvery {
|
||||
accountRepository.updateSessionToken(
|
||||
capture(sessionIdSlot),
|
||||
capture(accessTokenSlot),
|
||||
capture(refreshTokenSlot)
|
||||
)
|
||||
} answers {
|
||||
flowOfSessionLists.add(
|
||||
listOf(
|
||||
session1.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
accessToken = accessTokenSlot.captured,
|
||||
refreshToken = refreshTokenSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// Emit last state with same id if exist.
|
||||
coEvery { accountRepository.getAccountOrNull(capture(userIdSlot)) } answers {
|
||||
flowOfAccountLists.last().firstOrNull { it.userId == userIdSlot.captured }
|
||||
}
|
||||
coEvery { accountRepository.getAccountOrNull(capture(sessionIdSlot)) } answers {
|
||||
flowOfAccountLists.last().firstOrNull { it.sessionId == sessionIdSlot.captured }
|
||||
}
|
||||
|
||||
// Emit all state with same id if exist.
|
||||
coEvery { accountRepository.getAccount(capture(sessionIdSlot)) } answers {
|
||||
val filteredLists = flowOfAccountLists.map { list ->
|
||||
list.firstOrNull { it.sessionId == sessionIdSlot.captured }
|
||||
}
|
||||
flowOf(*filteredLists.toTypedArray())
|
||||
}
|
||||
|
||||
// Finally, emit all flow of Lists.
|
||||
every { accountRepository.getAccounts() } answers {
|
||||
flowOf(*flowOfAccountLists.toTypedArray())
|
||||
}
|
||||
every { accountRepository.getSessions() } answers {
|
||||
flowOf(*flowOfSessionLists.toTypedArray())
|
||||
}
|
||||
}
|
||||
private val mocks = RepositoryMocks(session1, account1)
|
||||
|
||||
@Before
|
||||
fun beforeEveryTest() {
|
||||
MockKAnnotations.init(this)
|
||||
mocks.init()
|
||||
|
||||
accountManager = AccountManagerImpl(Product.Calendar, accountRepository)
|
||||
accountManager = AccountManagerImpl(Product.Calendar, mocks.accountRepository, mocks.authRepository)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `add user with session`() = runBlockingTest {
|
||||
accountManager.addAccount(account1, session1)
|
||||
|
||||
coVerify(exactly = 1) { accountRepository.createOrUpdateAccountSession(any(), any()) }
|
||||
coVerify(exactly = 1) { mocks.accountRepository.createOrUpdateAccountSession(any(), any()) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on account state changed`() = runBlockingTest {
|
||||
every { accountRepository.getAccounts() } returns flowOf(
|
||||
every { mocks.accountRepository.getAccounts() } returns flowOf(
|
||||
listOf(account1),
|
||||
listOf(account1),
|
||||
listOf(account1.copy(state = AccountState.Disabled))
|
||||
|
@ -221,7 +90,7 @@ class AccountManagerImplTest {
|
|||
|
||||
@Test
|
||||
fun `on session state changed`() = runBlockingTest {
|
||||
every { accountRepository.getAccounts() } returns flowOf(
|
||||
every { mocks.accountRepository.getAccounts() } returns flowOf(
|
||||
listOf(account1),
|
||||
listOf(account1),
|
||||
listOf(account1.copy(sessionState = SessionState.ForceLogout))
|
||||
|
@ -238,9 +107,9 @@ class AccountManagerImplTest {
|
|||
|
||||
@Test
|
||||
fun `on handleTwoPassModeSuccess`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
accountManager.handleTwoPassModeSuccess(account1.userId)
|
||||
accountManager.handleTwoPassModeSuccess(account1.sessionId!!)
|
||||
|
||||
val stateLists = accountManager.onAccountStateChanged().toList()
|
||||
assertEquals(3, stateLists.size)
|
||||
|
@ -251,20 +120,19 @@ class AccountManagerImplTest {
|
|||
|
||||
@Test
|
||||
fun `on handleTwoPassModeFailed`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
accountManager.handleTwoPassModeFailed(account1.userId)
|
||||
accountManager.handleTwoPassModeFailed(account1.sessionId!!)
|
||||
|
||||
val stateLists = accountManager.onAccountStateChanged().toList()
|
||||
assertEquals(3, stateLists.size)
|
||||
assertEquals(2, stateLists.size)
|
||||
assertEquals(account1.state, stateLists[0].state)
|
||||
assertEquals(AccountState.TwoPassModeFailed, stateLists[1].state)
|
||||
assertEquals(AccountState.Disabled, stateLists[2].state)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on handleSecondFactorSuccess`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
val newScopes = listOf("scope1", "scope2")
|
||||
|
||||
|
@ -288,7 +156,8 @@ class AccountManagerImplTest {
|
|||
|
||||
@Test
|
||||
fun `on handleSecondFactorFailed`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
mocks.setupAccountRepository()
|
||||
mocks.setupAuthRepository()
|
||||
|
||||
accountManager.handleSecondFactorFailed(session1.sessionId)
|
||||
|
||||
|
@ -305,7 +174,7 @@ class AccountManagerImplTest {
|
|||
|
||||
@Test
|
||||
fun `on handleHumanVerificationSuccess`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
val tokenType = "newTokenType"
|
||||
val tokenCode = "newTokenCode"
|
||||
|
@ -331,107 +200,17 @@ class AccountManagerImplTest {
|
|||
|
||||
@Test
|
||||
fun `on handleHumanVerificationFailed`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
accountManager.handleHumanVerificationFailed(session1.sessionId)
|
||||
|
||||
val stateLists = accountManager.onAccountStateChanged().toList()
|
||||
assertEquals(2, stateLists.size)
|
||||
assertEquals(1, stateLists.size)
|
||||
assertEquals(account1.state, stateLists[0].state)
|
||||
assertEquals(AccountState.Disabled, stateLists[1].state)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.HumanVerificationFailed, sessionStateLists[1].sessionState)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onSessionTokenRefreshed`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
|
||||
val newAccessToken = "newAccessToken"
|
||||
val newRefreshToken = "newRefreshToken"
|
||||
|
||||
accountManager.onSessionTokenRefreshed(
|
||||
session1.refreshWith(
|
||||
accessToken = newAccessToken,
|
||||
refreshToken = newRefreshToken
|
||||
)
|
||||
)
|
||||
|
||||
val sessionLists = accountManager.getSessions().toList()
|
||||
assertEquals(2, sessionLists.size)
|
||||
assertEquals(session1.accessToken, sessionLists[0][0].accessToken)
|
||||
assertEquals(session1.refreshToken, sessionLists[0][0].refreshToken)
|
||||
assertEquals(newAccessToken, sessionLists[1][0].accessToken)
|
||||
assertEquals(newRefreshToken, sessionLists[1][0].refreshToken)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onSessionForceLogout`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
|
||||
accountManager.onSessionForceLogout(session1)
|
||||
|
||||
val stateLists = accountManager.onAccountStateChanged().toList()
|
||||
assertEquals(2, stateLists.size)
|
||||
assertEquals(account1.state, stateLists[0].state)
|
||||
assertEquals(AccountState.Disabled, stateLists[1].state)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.ForceLogout, sessionStateLists[1].sessionState)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onHumanVerificationNeeded success`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
|
||||
val humanVerificationDetails = HumanVerificationDetails(
|
||||
verificationMethods = listOf(VerificationMethod.EMAIL),
|
||||
captchaVerificationToken = null
|
||||
)
|
||||
|
||||
coEvery { accountRepository.getAccount(any<SessionId>()) } returns flowOf(
|
||||
account1,
|
||||
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
|
||||
account1.copy(sessionState = SessionState.HumanVerificationSuccess)
|
||||
)
|
||||
|
||||
val result = accountManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
|
||||
|
||||
assertEquals(SessionListener.HumanVerificationResult.Success, result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onHumanVerificationNeeded failed`() = runBlockingTest {
|
||||
setupUpdateListsFlow()
|
||||
|
||||
val humanVerificationDetails = HumanVerificationDetails(
|
||||
verificationMethods = listOf(VerificationMethod.EMAIL),
|
||||
captchaVerificationToken = null
|
||||
)
|
||||
|
||||
coEvery { accountRepository.getAccount(any<SessionId>()) } returns flowOf(
|
||||
account1,
|
||||
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
|
||||
account1.copy(sessionState = SessionState.HumanVerificationFailed)
|
||||
)
|
||||
|
||||
val result = accountManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
|
||||
|
||||
assertEquals(SessionListener.HumanVerificationResult.Failure, result)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Proton Technologies AG
|
||||
* This file is part of Proton Technologies AG and ProtonCore.
|
||||
*
|
||||
* ProtonCore is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonCore is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package me.proton.core.accountmanager.data
|
||||
|
||||
import io.mockk.MockKAnnotations
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.RelaxedMockK
|
||||
import io.mockk.slot
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.account.domain.repository.AccountRepository
|
||||
import me.proton.core.auth.domain.repository.AuthRepository
|
||||
import me.proton.core.domain.arch.DataResult
|
||||
import me.proton.core.domain.arch.ResponseSource
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationHeaders
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
|
||||
class RepositoryMocks(
|
||||
private val session: Session,
|
||||
private val account: Account
|
||||
) {
|
||||
|
||||
@RelaxedMockK
|
||||
lateinit var accountRepository: AccountRepository
|
||||
|
||||
@RelaxedMockK
|
||||
lateinit var authRepository: AuthRepository
|
||||
|
||||
private val flowOfAccountLists = mutableListOf<List<Account>>()
|
||||
private val flowOfSessionLists = mutableListOf<List<Session>>()
|
||||
|
||||
fun init() {
|
||||
MockKAnnotations.init(this)
|
||||
}
|
||||
|
||||
@Suppress("LongMethod")
|
||||
fun setupAccountRepository() {
|
||||
val userIdSlot = slot<UserId>()
|
||||
val sessionIdSlot = slot<SessionId>()
|
||||
val accountStateSlot = slot<AccountState>()
|
||||
val sessionStateSlot = slot<SessionState>()
|
||||
val updatedScopesSlot = slot<List<String>>()
|
||||
val tokenTypeSlot = slot<String>()
|
||||
val tokenCodeSlot = slot<String>()
|
||||
val accessTokenSlot = slot<String>()
|
||||
val refreshTokenSlot = slot<String>()
|
||||
|
||||
// Initial state.
|
||||
flowOfAccountLists.clear()
|
||||
flowOfSessionLists.clear()
|
||||
flowOfAccountLists.add(listOf(account))
|
||||
flowOfSessionLists.add(listOf(session))
|
||||
|
||||
// For each updateAccountState -> emit a new updated List<Account> from getAccounts().
|
||||
coEvery { accountRepository.updateAccountState(capture(userIdSlot), capture(accountStateSlot)) } answers {
|
||||
flowOfAccountLists.add(
|
||||
listOf(
|
||||
flowOfAccountLists.last().first { it.userId == userIdSlot.captured }.copy(
|
||||
userId = userIdSlot.captured,
|
||||
state = accountStateSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
coEvery { accountRepository.updateAccountState(capture(sessionIdSlot), capture(accountStateSlot)) } answers {
|
||||
flowOfAccountLists.add(
|
||||
listOf(
|
||||
flowOfAccountLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
state = accountStateSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionState -> emit a new updated List<Account> from getAccounts().
|
||||
coEvery { accountRepository.updateSessionState(capture(sessionIdSlot), capture(sessionStateSlot)) } answers {
|
||||
flowOfAccountLists.add(
|
||||
listOf(
|
||||
flowOfAccountLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
sessionState = sessionStateSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionScopes -> emit a new updated List<Session> from getSessions().
|
||||
coEvery { accountRepository.updateSessionScopes(capture(sessionIdSlot), capture(updatedScopesSlot)) } answers {
|
||||
flowOfSessionLists.add(
|
||||
listOf(
|
||||
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
scopes = updatedScopesSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionHeaders -> emit a new updated List<Session> from getSessions().
|
||||
coEvery {
|
||||
accountRepository.updateSessionHeaders(
|
||||
capture(sessionIdSlot),
|
||||
capture(tokenTypeSlot),
|
||||
capture(tokenCodeSlot)
|
||||
)
|
||||
} answers {
|
||||
flowOfSessionLists.add(
|
||||
listOf(
|
||||
flowOfSessionLists.last().first { it.sessionId == sessionIdSlot.captured }.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
headers = HumanVerificationHeaders(
|
||||
tokenTypeSlot.captured,
|
||||
tokenCodeSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// For each updateSessionToken -> emit a new updated List<Session> from getSessions().
|
||||
coEvery {
|
||||
accountRepository.updateSessionToken(
|
||||
capture(sessionIdSlot),
|
||||
capture(accessTokenSlot),
|
||||
capture(refreshTokenSlot)
|
||||
)
|
||||
} answers {
|
||||
flowOfSessionLists.add(
|
||||
listOf(
|
||||
session.copy(
|
||||
sessionId = sessionIdSlot.captured,
|
||||
accessToken = accessTokenSlot.captured,
|
||||
refreshToken = refreshTokenSlot.captured
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// Emit last state with same id if exist.
|
||||
coEvery { accountRepository.getAccountOrNull(capture(userIdSlot)) } answers {
|
||||
flowOfAccountLists.last().firstOrNull { it.userId == userIdSlot.captured }
|
||||
}
|
||||
coEvery { accountRepository.getAccountOrNull(capture(sessionIdSlot)) } answers {
|
||||
flowOfAccountLists.last().firstOrNull { it.sessionId == sessionIdSlot.captured }
|
||||
}
|
||||
|
||||
// Emit all state with same id if exist.
|
||||
coEvery { accountRepository.getAccount(capture(sessionIdSlot)) } answers {
|
||||
val filteredLists = flowOfAccountLists.map { list ->
|
||||
list.firstOrNull { it.sessionId == sessionIdSlot.captured }
|
||||
}
|
||||
flowOf(*filteredLists.toTypedArray())
|
||||
}
|
||||
|
||||
// Finally, emit all flow of Lists.
|
||||
every { accountRepository.getAccounts() } answers {
|
||||
flowOf(*flowOfAccountLists.toTypedArray())
|
||||
}
|
||||
every { accountRepository.getSessions() } answers {
|
||||
flowOf(*flowOfSessionLists.toTypedArray())
|
||||
}
|
||||
}
|
||||
|
||||
fun setupAuthRepository() {
|
||||
val sessionIdSlot = slot<SessionId>()
|
||||
|
||||
// Assume revokeSession is done successfully.
|
||||
coEvery { authRepository.revokeSession(capture(sessionIdSlot)) } answers {
|
||||
DataResult.Success(ResponseSource.Remote, true)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Proton Technologies AG
|
||||
* This file is part of Proton Technologies AG and ProtonCore.
|
||||
*
|
||||
* ProtonCore is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonCore is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package me.proton.core.accountmanager.data
|
||||
|
||||
import io.mockk.coEvery
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.test.runBlockingTest
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.domain.entity.Product
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationHeaders
|
||||
import me.proton.core.network.domain.humanverification.VerificationMethod
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionListener
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class SessionManagerImplTest {
|
||||
|
||||
private lateinit var accountManager: AccountManagerImpl
|
||||
private lateinit var sessionManager: SessionManagerImpl
|
||||
|
||||
private val session1 = Session(
|
||||
sessionId = SessionId("session1"),
|
||||
accessToken = "accessToken",
|
||||
refreshToken = "refreshToken",
|
||||
scopes = listOf("full", "calendar", "mail"),
|
||||
headers = HumanVerificationHeaders("tokenType", "tokenCode")
|
||||
)
|
||||
|
||||
private val account1 = Account(
|
||||
userId = UserId("user1"),
|
||||
username = "username",
|
||||
email = "test@example.com",
|
||||
state = AccountState.Ready,
|
||||
sessionId = session1.sessionId,
|
||||
sessionState = SessionState.Authenticated
|
||||
)
|
||||
|
||||
private val mocks = RepositoryMocks(session1, account1)
|
||||
|
||||
@Before
|
||||
fun beforeEveryTest() {
|
||||
mocks.init()
|
||||
|
||||
accountManager = AccountManagerImpl(Product.Calendar, mocks.accountRepository, mocks.authRepository)
|
||||
sessionManager = SessionManagerImpl(mocks.accountRepository)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onSessionTokenRefreshed`() = runBlockingTest {
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
val newAccessToken = "newAccessToken"
|
||||
val newRefreshToken = "newRefreshToken"
|
||||
|
||||
sessionManager.onSessionTokenRefreshed(
|
||||
session1.refreshWith(
|
||||
accessToken = newAccessToken,
|
||||
refreshToken = newRefreshToken
|
||||
)
|
||||
)
|
||||
|
||||
val sessionLists = accountManager.getSessions().toList()
|
||||
assertEquals(2, sessionLists.size)
|
||||
assertEquals(session1.accessToken, sessionLists[0][0].accessToken)
|
||||
assertEquals(session1.refreshToken, sessionLists[0][0].refreshToken)
|
||||
assertEquals(newAccessToken, sessionLists[1][0].accessToken)
|
||||
assertEquals(newRefreshToken, sessionLists[1][0].refreshToken)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onSessionForceLogout`() = runBlockingTest {
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
sessionManager.onSessionForceLogout(session1)
|
||||
|
||||
val stateLists = accountManager.onAccountStateChanged().toList()
|
||||
assertEquals(2, stateLists.size)
|
||||
assertEquals(account1.state, stateLists[0].state)
|
||||
assertEquals(AccountState.Disabled, stateLists[1].state)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.ForceLogout, sessionStateLists[1].sessionState)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onHumanVerificationNeeded success`() = runBlockingTest {
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
val humanVerificationDetails = HumanVerificationDetails(
|
||||
verificationMethods = listOf(VerificationMethod.EMAIL),
|
||||
captchaVerificationToken = null
|
||||
)
|
||||
|
||||
coEvery { mocks.accountRepository.getAccount(any<SessionId>()) } returns flowOf(
|
||||
account1,
|
||||
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
|
||||
account1.copy(sessionState = SessionState.HumanVerificationSuccess)
|
||||
)
|
||||
|
||||
val result = sessionManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
|
||||
|
||||
assertEquals(SessionListener.HumanVerificationResult.Success, result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `on onHumanVerificationNeeded failed`() = runBlockingTest {
|
||||
mocks.setupAccountRepository()
|
||||
|
||||
val humanVerificationDetails = HumanVerificationDetails(
|
||||
verificationMethods = listOf(VerificationMethod.EMAIL),
|
||||
captchaVerificationToken = null
|
||||
)
|
||||
|
||||
coEvery { mocks.accountRepository.getAccount(any<SessionId>()) } returns flowOf(
|
||||
account1,
|
||||
account1.copy(sessionState = SessionState.HumanVerificationNeeded),
|
||||
account1.copy(sessionState = SessionState.HumanVerificationFailed)
|
||||
)
|
||||
|
||||
val result = sessionManager.onHumanVerificationNeeded(session1, humanVerificationDetails)
|
||||
|
||||
val sessionStateLists = accountManager.onSessionStateChanged().toList()
|
||||
assertEquals(2, sessionStateLists.size)
|
||||
assertEquals(account1.sessionState, sessionStateLists[0].sessionState)
|
||||
assertEquals(SessionState.HumanVerificationNeeded, sessionStateLists[1].sessionState)
|
||||
|
||||
assertEquals(SessionListener.HumanVerificationResult.Failure, result)
|
||||
}
|
||||
}
|
|
@ -19,9 +19,11 @@
|
|||
package me.proton.core.accountmanager.domain
|
||||
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.distinctUntilChanged
|
||||
import kotlinx.coroutines.flow.filter
|
||||
import kotlinx.coroutines.flow.flatMapLatest
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.flow.map
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
|
@ -36,3 +38,6 @@ fun AccountManager.getPrimaryAccount(): Flow<Account?> =
|
|||
getPrimaryUserId().flatMapLatest { userId ->
|
||||
userId?.let { getAccount(it) } ?: flowOf(null)
|
||||
}
|
||||
|
||||
fun AccountManager.getAccounts(state: AccountState): Flow<List<Account>> =
|
||||
getAccounts().map { list -> list.filter { it.state == state } }.distinctUntilChanged()
|
||||
|
|
|
@ -16,20 +16,9 @@
|
|||
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package me.proton.core.auth.presentation.ui
|
||||
package me.proton.core.accountmanager.domain
|
||||
|
||||
import androidx.databinding.ViewDataBinding
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.core.network.domain.session.SessionListener
|
||||
import me.proton.core.network.domain.session.SessionProvider
|
||||
|
||||
/**
|
||||
* Bridge between authentication activities and the interface.
|
||||
*
|
||||
* @author Dino Kadrikj.
|
||||
*/
|
||||
interface AuthActivityComponent<DB : ViewDataBinding> : AuthActivity {
|
||||
|
||||
/**
|
||||
* Sets and initializes the authentication activity that want to implement [AuthActivity].
|
||||
*/
|
||||
fun initializeAuth(protonAuthActivity: ProtonActivity<DB>)
|
||||
}
|
||||
interface SessionManager : SessionProvider, SessionListener
|
|
@ -91,7 +91,6 @@ class AccountManagerObserver(
|
|||
onAccountRemovedListener = block
|
||||
}
|
||||
|
||||
|
||||
internal fun setOnSessionHumanVerificationNeeded(block: suspend (Account) -> Unit) {
|
||||
onSessionHumanVerificationNeededListener = block
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ android()
|
|||
dependencies {
|
||||
|
||||
implementation(
|
||||
project(Module.kotlinUtil),
|
||||
project(Module.data),
|
||||
project(Module.domain),
|
||||
project(Module.network),
|
||||
|
|
|
@ -22,8 +22,8 @@ import androidx.room.Dao
|
|||
import androidx.room.Query
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import me.proton.core.account.data.entity.SessionEntity
|
||||
import me.proton.core.data.db.BaseDao
|
||||
import me.proton.core.data.crypto.EncryptedString
|
||||
import me.proton.core.data.db.BaseDao
|
||||
import me.proton.core.domain.entity.Product
|
||||
|
||||
@Dao
|
||||
|
@ -36,17 +36,20 @@ abstract class SessionDao : BaseDao<SessionEntity>() {
|
|||
abstract fun findBySessionId(sessionId: String): Flow<SessionEntity?>
|
||||
|
||||
@Query("SELECT * FROM SessionEntity WHERE sessionId = :sessionId")
|
||||
abstract fun get(sessionId: String): SessionEntity?
|
||||
abstract suspend fun get(sessionId: String): SessionEntity?
|
||||
|
||||
@Query("SELECT sessionId FROM SessionEntity WHERE userId = :userId")
|
||||
abstract suspend fun getSessionId(userId: String): String?
|
||||
|
||||
@Query("DELETE FROM SessionEntity WHERE sessionId = :sessionId")
|
||||
abstract suspend fun delete(sessionId: String)
|
||||
|
||||
@Query("UPDATE SessionEntity SET scopes = :scopes WHERE sessionId = :sessionId")
|
||||
abstract fun updateScopes(sessionId: String, scopes: String)
|
||||
abstract suspend fun updateScopes(sessionId: String, scopes: String)
|
||||
|
||||
@Query("UPDATE SessionEntity SET humanHeaderTokenType = :tokenType, humanHeaderTokenCode = :tokenCode WHERE sessionId = :sessionId")
|
||||
abstract fun updateHeaders(sessionId: String, tokenType: EncryptedString?, tokenCode: EncryptedString?)
|
||||
abstract suspend fun updateHeaders(sessionId: String, tokenType: EncryptedString?, tokenCode: EncryptedString?)
|
||||
|
||||
@Query("UPDATE SessionEntity SET accessToken = :accessToken, refreshToken = :refreshToken WHERE sessionId = :sessionId")
|
||||
abstract fun updateToken(sessionId: String, accessToken: EncryptedString, refreshToken: EncryptedString)
|
||||
abstract suspend fun updateToken(sessionId: String, accessToken: EncryptedString, refreshToken: EncryptedString)
|
||||
}
|
||||
|
|
|
@ -30,13 +30,16 @@ import me.proton.core.account.domain.entity.Account
|
|||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.account.domain.repository.AccountRepository
|
||||
import me.proton.core.data.db.CommonConverters
|
||||
import me.proton.core.data.crypto.StringCrypto
|
||||
import me.proton.core.data.crypto.encrypt
|
||||
import me.proton.core.data.db.CommonConverters
|
||||
import me.proton.core.domain.entity.Product
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.util.kotlin.exhaustive
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
class AccountRepositoryImpl(
|
||||
private val product: Product,
|
||||
|
@ -48,6 +51,8 @@ class AccountRepositoryImpl(
|
|||
private val sessionDao = db.sessionDao()
|
||||
private val accountMetadataDao = db.accountMetadataDao()
|
||||
|
||||
private val humanVerificationDetails: ConcurrentHashMap<SessionId, HumanVerificationDetails?> = ConcurrentHashMap()
|
||||
|
||||
private suspend fun updateAccountMetadata(userId: UserId) =
|
||||
accountMetadataDao.insertOrUpdate(
|
||||
AccountMetadataEntity(
|
||||
|
@ -91,9 +96,12 @@ class AccountRepositoryImpl(
|
|||
.map { it?.toSession(stringCrypto) }
|
||||
.distinctUntilChanged()
|
||||
|
||||
override fun getSessionOrNull(sessionId: SessionId): Session? =
|
||||
override suspend fun getSessionOrNull(sessionId: SessionId): Session? =
|
||||
sessionDao.get(sessionId.id)?.toSession(stringCrypto)
|
||||
|
||||
override suspend fun getSessionIdOrNull(userId: UserId): SessionId? =
|
||||
sessionDao.getSessionId(userId.id)?.let { SessionId(it) }
|
||||
|
||||
override suspend fun createOrUpdateAccountSession(account: Account, session: Session) {
|
||||
require(session.isValid()) {
|
||||
"Session is not valid: $session\n.At least sessionId.id, accessToken and refreshToken must be valid."
|
||||
|
@ -102,7 +110,7 @@ class AccountRepositoryImpl(
|
|||
db.inTransaction {
|
||||
accountDao.insertOrUpdate(
|
||||
account.copy(
|
||||
state = AccountState.Initializing,
|
||||
state = AccountState.NotReady,
|
||||
sessionId = null,
|
||||
sessionState = null
|
||||
).toAccountEntity()
|
||||
|
@ -133,22 +141,19 @@ class AccountRepositoryImpl(
|
|||
when (state) {
|
||||
AccountState.Ready -> updateAccountMetadata(userId)
|
||||
AccountState.Disabled,
|
||||
AccountState.Removed -> deleteAccountMetadata(userId)
|
||||
AccountState.Added,
|
||||
AccountState.Initializing,
|
||||
AccountState.Removed,
|
||||
AccountState.NotReady,
|
||||
AccountState.TwoPassModeNeeded,
|
||||
AccountState.TwoPassModeSuccess,
|
||||
AccountState.TwoPassModeFailed -> Unit
|
||||
}
|
||||
AccountState.TwoPassModeFailed -> deleteAccountMetadata(userId)
|
||||
}.exhaustive
|
||||
accountDao.updateAccountState(userId.id, state)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun updateAccountState(sessionId: SessionId, state: AccountState) {
|
||||
db.inTransaction {
|
||||
getAccountOrNull(sessionId)?.let {
|
||||
updateAccountState(it.userId, state)
|
||||
}
|
||||
getAccountOrNull(sessionId)?.let {
|
||||
updateAccountState(it.userId, state)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -177,4 +182,11 @@ class AccountRepositoryImpl(
|
|||
updateAccountMetadata(userId)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun getHumanVerificationDetails(id: SessionId): HumanVerificationDetails? =
|
||||
humanVerificationDetails[id]
|
||||
|
||||
override suspend fun setHumanVerificationDetails(id: SessionId, details: HumanVerificationDetails?) {
|
||||
humanVerificationDetails[id] = details
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,3 +29,11 @@ data class Account(
|
|||
val sessionId: SessionId?,
|
||||
val sessionState: SessionState?
|
||||
)
|
||||
|
||||
fun Account.isReady() = state == AccountState.Ready
|
||||
fun Account.isDisabled() = state == AccountState.Disabled
|
||||
fun Account.isTwoPassModeNeeded() = state == AccountState.TwoPassModeNeeded
|
||||
|
||||
fun Account.isAuthenticated() = sessionState == SessionState.Authenticated
|
||||
fun Account.isSecondFactorNeeded() = sessionState == SessionState.SecondFactorNeeded
|
||||
fun Account.isHumanVerificationNeeded() = sessionState == SessionState.HumanVerificationNeeded
|
||||
|
|
|
@ -19,17 +19,10 @@
|
|||
package me.proton.core.account.domain.entity
|
||||
|
||||
enum class AccountState {
|
||||
/**
|
||||
* First state emitted after adding a new [Account], it is not yet [Ready] to use.
|
||||
*
|
||||
* Note: Usually followed by [Initializing] and/or [Ready].
|
||||
*/
|
||||
Added,
|
||||
|
||||
/**
|
||||
* State emitted if this [Account] need more step(s) to be [Ready] to use.
|
||||
*/
|
||||
Initializing,
|
||||
NotReady,
|
||||
|
||||
/**
|
||||
* A two pass mode is needed.
|
||||
|
|
|
@ -23,6 +23,7 @@ import me.proton.core.account.domain.entity.Account
|
|||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
|
||||
|
@ -65,7 +66,12 @@ interface AccountRepository {
|
|||
/**
|
||||
* Get [Session], by sessionId.
|
||||
*/
|
||||
fun getSessionOrNull(sessionId: SessionId): Session?
|
||||
suspend fun getSessionOrNull(sessionId: SessionId): Session?
|
||||
|
||||
/**
|
||||
* Get [SessionId], by userId.
|
||||
*/
|
||||
suspend fun getSessionIdOrNull(userId: UserId): SessionId?
|
||||
|
||||
/**
|
||||
* Create or update an [Account], locally.
|
||||
|
@ -121,4 +127,14 @@ interface AccountRepository {
|
|||
* Set the primary [UserId].
|
||||
*/
|
||||
suspend fun setAsPrimary(userId: UserId)
|
||||
|
||||
/**
|
||||
* Get [HumanVerificationDetails], if exist, by sessionId.
|
||||
*/
|
||||
suspend fun getHumanVerificationDetails(id: SessionId): HumanVerificationDetails?
|
||||
|
||||
/**
|
||||
* Set [HumanVerificationDetails], by sessionId.
|
||||
*/
|
||||
suspend fun setHumanVerificationDetails(id: SessionId, details: HumanVerificationDetails?)
|
||||
}
|
||||
|
|
|
@ -28,10 +28,12 @@ import me.proton.core.auth.data.entity.SecondFactorResponse
|
|||
import me.proton.core.auth.data.entity.UserResponse
|
||||
import me.proton.core.network.data.protonApi.BaseRetrofitApi
|
||||
import me.proton.core.network.data.protonApi.GenericResponse
|
||||
import me.proton.core.network.domain.TimeoutOverride
|
||||
import retrofit2.http.Body
|
||||
import retrofit2.http.DELETE
|
||||
import retrofit2.http.GET
|
||||
import retrofit2.http.POST
|
||||
import retrofit2.http.Tag
|
||||
|
||||
interface AuthenticationApi : BaseRetrofitApi {
|
||||
|
||||
|
@ -45,7 +47,7 @@ interface AuthenticationApi : BaseRetrofitApi {
|
|||
suspend fun performSecondFactor(@Body request: SecondFactorRequest): SecondFactorResponse
|
||||
|
||||
@DELETE("auth")
|
||||
suspend fun revokeSession(): GenericResponse
|
||||
suspend fun revokeSession(@Tag timeout: TimeoutOverride): GenericResponse
|
||||
|
||||
@GET("users")
|
||||
suspend fun getUser(): UserResponse
|
||||
|
|
|
@ -30,12 +30,12 @@ import me.proton.core.auth.domain.entity.SecondFactorProof
|
|||
import me.proton.core.auth.domain.entity.SessionInfo
|
||||
import me.proton.core.auth.domain.entity.User
|
||||
import me.proton.core.auth.domain.repository.AuthRepository
|
||||
import me.proton.core.data.arch.ApiResultMapper
|
||||
import me.proton.core.data.arch.toDataResponse
|
||||
import me.proton.core.domain.arch.DataResult
|
||||
import me.proton.core.network.data.ApiProvider
|
||||
import me.proton.core.network.data.ResponseCodes
|
||||
import me.proton.core.network.domain.TimeoutOverride
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.util.kotlin.invoke
|
||||
|
||||
/**
|
||||
* Implementation of the [AuthRepository].
|
||||
|
@ -44,8 +44,7 @@ import me.proton.core.util.kotlin.invoke
|
|||
* @author Dino Kadrikj.
|
||||
*/
|
||||
class AuthRepositoryImpl(
|
||||
private val provider: ApiProvider,
|
||||
private val apiResultMapper: ApiResultMapper = ApiResultMapper()
|
||||
private val provider: ApiProvider
|
||||
) : AuthRepository {
|
||||
|
||||
/**
|
||||
|
@ -60,13 +59,10 @@ class AuthRepositoryImpl(
|
|||
override suspend fun getLoginInfo(
|
||||
username: String,
|
||||
clientSecret: String
|
||||
): DataResult<LoginInfo> =
|
||||
apiResultMapper {
|
||||
provider.get<AuthenticationApi>().invoke {
|
||||
val request = LoginInfoRequest(username, clientSecret)
|
||||
getLoginInfo(request).toLoginInfo(username)
|
||||
}.toDataResponse()
|
||||
}
|
||||
): DataResult<LoginInfo> = provider.get<AuthenticationApi>().invoke {
|
||||
val request = LoginInfoRequest(username, clientSecret)
|
||||
getLoginInfo(request).toLoginInfo(username)
|
||||
}.toDataResponse()
|
||||
|
||||
/**
|
||||
* Performs the login request to the API to try to get a valid Access Token and Session for the Account/username.
|
||||
|
@ -85,13 +81,10 @@ class AuthRepositoryImpl(
|
|||
clientEphemeral: String,
|
||||
clientProof: String,
|
||||
srpSession: String
|
||||
): DataResult<SessionInfo> =
|
||||
apiResultMapper {
|
||||
provider.get<AuthenticationApi>().invoke {
|
||||
val request = LoginRequest(username, clientSecret, clientEphemeral, clientProof, srpSession)
|
||||
performLogin(request).toSessionInfo(username)
|
||||
}.toDataResponse()
|
||||
}
|
||||
): DataResult<SessionInfo> = provider.get<AuthenticationApi>().invoke {
|
||||
val request = LoginRequest(username, clientSecret, clientEphemeral, clientProof, srpSession)
|
||||
performLogin(request).toSessionInfo(username)
|
||||
}.toDataResponse()
|
||||
|
||||
/**
|
||||
* Performs the second factor request for the Accounts that have second factor enabled.
|
||||
|
@ -104,23 +97,21 @@ class AuthRepositoryImpl(
|
|||
override suspend fun performSecondFactor(
|
||||
sessionId: SessionId,
|
||||
secondFactorProof: SecondFactorProof
|
||||
): DataResult<ScopeInfo> = apiResultMapper {
|
||||
provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
val request = when (secondFactorProof) {
|
||||
is SecondFactorProof.SecondFactorCode -> SecondFactorRequest(
|
||||
secondFactorCode = secondFactorProof.code
|
||||
): DataResult<ScopeInfo> = provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
val request = when (secondFactorProof) {
|
||||
is SecondFactorProof.SecondFactorCode -> SecondFactorRequest(
|
||||
secondFactorCode = secondFactorProof.code
|
||||
)
|
||||
is SecondFactorProof.SecondFactorSignature -> SecondFactorRequest(
|
||||
universalTwoFactorRequest = UniversalTwoFactorRequest(
|
||||
keyHandle = secondFactorProof.keyHandle,
|
||||
clientData = secondFactorProof.clientData,
|
||||
signatureData = secondFactorProof.signatureData
|
||||
)
|
||||
is SecondFactorProof.SecondFactorSignature -> SecondFactorRequest(
|
||||
universalTwoFactorRequest = UniversalTwoFactorRequest(
|
||||
keyHandle = secondFactorProof.keyHandle,
|
||||
clientData = secondFactorProof.clientData,
|
||||
signatureData = secondFactorProof.signatureData
|
||||
)
|
||||
)
|
||||
}
|
||||
performSecondFactor(request).toScopeInfo()
|
||||
}.toDataResponse()
|
||||
}
|
||||
)
|
||||
}
|
||||
performSecondFactor(request).toScopeInfo()
|
||||
}.toDataResponse()
|
||||
|
||||
/**
|
||||
* Revokes the session for the user. In particular this is practically logging out the user from the backend.
|
||||
|
@ -130,11 +121,15 @@ class AuthRepositoryImpl(
|
|||
* @return boolean result of the logout/session revoking operation.
|
||||
*/
|
||||
override suspend fun revokeSession(sessionId: SessionId): DataResult<Boolean> =
|
||||
apiResultMapper {
|
||||
provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
revokeSession().code.isSuccessResponse()
|
||||
}.toDataResponse()
|
||||
}
|
||||
provider.get<AuthenticationApi>(sessionId).invoke(true) {
|
||||
revokeSession(
|
||||
TimeoutOverride(
|
||||
connectionTimeoutSeconds = 1,
|
||||
readTimeoutSeconds = 1,
|
||||
writeTimeoutSeconds = 1
|
||||
)
|
||||
).code.isSuccessResponse()
|
||||
}.toDataResponse()
|
||||
|
||||
/**
|
||||
* Fetches the full user details from the API.
|
||||
|
@ -144,12 +139,9 @@ class AuthRepositoryImpl(
|
|||
* @return [User] object with full user details.
|
||||
*/
|
||||
override suspend fun getUser(sessionId: SessionId): DataResult<User> =
|
||||
apiResultMapper {
|
||||
provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
val userResponse = getUser()
|
||||
userResponse.user.toUser()
|
||||
}.toDataResponse()
|
||||
}
|
||||
provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
getUser().user.toUser()
|
||||
}.toDataResponse()
|
||||
|
||||
/**
|
||||
* Fetches the user-keys salts from the API.
|
||||
|
@ -159,11 +151,9 @@ class AuthRepositoryImpl(
|
|||
* @return [KeySalts] containing salts for all user keys.
|
||||
*/
|
||||
override suspend fun getSalts(sessionId: SessionId): DataResult<KeySalts> =
|
||||
apiResultMapper {
|
||||
provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
getSalts().toKeySalts()
|
||||
}.toDataResponse()
|
||||
}
|
||||
provider.get<AuthenticationApi>(sessionId).invoke {
|
||||
getSalts().toKeySalts()
|
||||
}.toDataResponse()
|
||||
}
|
||||
|
||||
internal fun Int.isSuccessResponse(): Boolean = this == ResponseCodes.OK
|
||||
|
|
|
@ -36,6 +36,7 @@ import me.proton.core.network.data.di.ApiFactory
|
|||
import me.proton.core.network.domain.ApiManager
|
||||
import me.proton.core.network.domain.ApiResult
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionProvider
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.net.ConnectException
|
||||
|
@ -49,6 +50,7 @@ import kotlin.test.assertTrue
|
|||
class AuthRepositoryImplTest {
|
||||
|
||||
// region mocks
|
||||
private val sessionProvider = mockk<SessionProvider>(relaxed = true)
|
||||
private val apiFactory = mockk<ApiFactory>(relaxed = true)
|
||||
private val apiManager = mockk<ApiManager<AuthenticationApi>>(relaxed = true)
|
||||
private lateinit var apiProvider: ApiProvider
|
||||
|
@ -82,7 +84,8 @@ class AuthRepositoryImplTest {
|
|||
@Before
|
||||
fun beforeEveryTest() {
|
||||
// GIVEN
|
||||
apiProvider = ApiProvider(apiFactory)
|
||||
coEvery { sessionProvider.getSessionId(any()) } returns SessionId(testSessionId)
|
||||
apiProvider = ApiProvider(apiFactory, sessionProvider)
|
||||
every { apiFactory.create(interfaceClass = AuthenticationApi::class) } returns apiManager
|
||||
every { apiFactory.create(SessionId(testSessionId), interfaceClass = AuthenticationApi::class) } returns apiManager
|
||||
repository = AuthRepositoryImpl(apiProvider)
|
||||
|
|
|
@ -67,7 +67,7 @@ interface AuthRepository {
|
|||
suspend fun getSalts(sessionId: SessionId): DataResult<KeySalts>
|
||||
|
||||
/**
|
||||
* Perform Two Factor for the Login process for a given [SessionId].
|
||||
* Revoke session for a given [SessionId].
|
||||
*/
|
||||
suspend fun revokeSession(sessionId: SessionId): DataResult<Boolean>
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class PerformSecondFactor @Inject constructor(
|
|||
data class Success(
|
||||
val sessionId: SessionId,
|
||||
val scopeInfo: ScopeInfo,
|
||||
val isMailboxLoginNeeded: Boolean = false
|
||||
val isTwoPassModeNeeded: Boolean
|
||||
) : SecondFactorState()
|
||||
|
||||
sealed class Error : SecondFactorState() {
|
||||
|
@ -61,7 +61,8 @@ class PerformSecondFactor @Inject constructor(
|
|||
*/
|
||||
operator fun invoke(
|
||||
sessionId: SessionId,
|
||||
secondFactorCode: String
|
||||
secondFactorCode: String,
|
||||
isTwoPassModeNeeded: Boolean = false
|
||||
): Flow<SecondFactorState> = flow {
|
||||
|
||||
if (secondFactorCode.isEmpty()) {
|
||||
|
@ -73,11 +74,11 @@ class PerformSecondFactor @Inject constructor(
|
|||
|
||||
authRepository.performSecondFactor(
|
||||
sessionId,
|
||||
SecondFactorProof.SecondFactorCode(secondFactorCode)
|
||||
SecondFactorProof.SecondFactorCode(secondFactorCode),
|
||||
).onFailure { errorMessage, _ ->
|
||||
emit(SecondFactorState.Error.Message(errorMessage))
|
||||
}.onSuccess { scopeInfo ->
|
||||
emit(SecondFactorState.Success(sessionId, scopeInfo))
|
||||
emit(SecondFactorState.Success(sessionId, scopeInfo, isTwoPassModeNeeded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,87 +20,126 @@ package me.proton.core.auth.presentation
|
|||
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.result.ActivityResultLauncher
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import kotlinx.coroutines.launch
|
||||
import me.proton.core.auth.domain.AccountWorkflowHandler
|
||||
import me.proton.core.auth.presentation.entity.ScopeResult
|
||||
import me.proton.core.auth.presentation.entity.SecondFactorInput
|
||||
import me.proton.core.auth.presentation.entity.SessionResult
|
||||
import me.proton.core.auth.presentation.entity.UserResult
|
||||
import me.proton.core.auth.presentation.ui.StartLogin
|
||||
import me.proton.core.auth.presentation.ui.StartMailboxLogin
|
||||
import me.proton.core.auth.presentation.ui.StartSecondFactor
|
||||
import me.proton.core.auth.presentation.ui.StartTwoPassMode
|
||||
import me.proton.core.humanverification.presentation.entity.HumanVerificationInput
|
||||
import me.proton.core.humanverification.presentation.entity.HumanVerificationResult
|
||||
import me.proton.core.humanverification.presentation.ui.StartHumanVerification
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import javax.inject.Inject
|
||||
|
||||
class AuthOrchestrator {
|
||||
class AuthOrchestrator @Inject constructor(
|
||||
private val accountWorkflowHandler: AccountWorkflowHandler
|
||||
) {
|
||||
|
||||
// region result launchers
|
||||
private var loginWorkflowLauncher: ActivityResultLauncher<List<String>>? = null
|
||||
private var secondFactorWorkflowLauncher: ActivityResultLauncher<SessionId>? = null
|
||||
private var mailboxWorkflowLauncher: ActivityResultLauncher<SessionId>? = null
|
||||
private var secondFactorWorkflowLauncher: ActivityResultLauncher<SecondFactorInput>? = null
|
||||
private var twoPassModeWorkflowLauncher: ActivityResultLauncher<SessionId>? = null
|
||||
private var humanWorkflowLauncher: ActivityResultLauncher<HumanVerificationInput>? = null
|
||||
// endregion
|
||||
|
||||
private var onUserResultListener: (result: UserResult?) -> Unit = {}
|
||||
private var onSessionResultListener: (result: SessionResult?) -> Unit = {}
|
||||
private var onScopeResultListener: (result: ScopeResult?) -> Unit = {}
|
||||
private var onHumanVerificationResultListener: (result: HumanVerificationResult?) -> Unit = {}
|
||||
|
||||
fun setOnUserResult(block: (result: UserResult?) -> Unit) {
|
||||
onUserResultListener = block
|
||||
}
|
||||
|
||||
fun setOnSessionResult(block: (result: SessionResult?) -> Unit) {
|
||||
onSessionResultListener = block
|
||||
}
|
||||
|
||||
fun setOnScopeResult(block: (result: ScopeResult?) -> Unit) {
|
||||
onScopeResultListener = block
|
||||
}
|
||||
|
||||
fun setOnHumanVerificationResult(block: (result: HumanVerificationResult?) -> Unit) {
|
||||
onHumanVerificationResultListener = block
|
||||
}
|
||||
|
||||
// region private module functions
|
||||
private fun registerLoginWorkflowLauncher(
|
||||
context: ComponentActivity,
|
||||
onSessionResult: (result: SessionResult?) -> Unit = {}
|
||||
context: ComponentActivity
|
||||
): ActivityResultLauncher<List<String>> =
|
||||
context.registerForActivityResult(StartLogin()) { result ->
|
||||
context.registerForActivityResult(
|
||||
StartLogin()
|
||||
) { result ->
|
||||
result?.let {
|
||||
if (it.isSecondFactorNeeded) {
|
||||
startSecondFactorWorkflow(SessionId(it.sessionId))
|
||||
} else if (it.isMailboxLoginNeeded) {
|
||||
startMailboxLoginWorkflow(SessionId(it.sessionId))
|
||||
startSecondFactorWorkflow(SecondFactorInput(it.sessionId, it.isTwoPassModeNeeded))
|
||||
} else if (it.isTwoPassModeNeeded) {
|
||||
startTwoPassModeWorkflow(SessionId(it.sessionId))
|
||||
}
|
||||
onSessionResult(it)
|
||||
}
|
||||
onSessionResultListener(result)
|
||||
}
|
||||
|
||||
private fun registerMailboxLoginWorkflowLauncher(
|
||||
context: ComponentActivity,
|
||||
onUserResult: (result: UserResult?) -> Unit = {}
|
||||
private fun registerTwoPassModeWorkflowLauncher(
|
||||
context: ComponentActivity
|
||||
): ActivityResultLauncher<SessionId> =
|
||||
context.registerForActivityResult(StartMailboxLogin()) {
|
||||
onUserResult(it)
|
||||
context.registerForActivityResult(
|
||||
StartTwoPassMode()
|
||||
) {
|
||||
onUserResultListener(it)
|
||||
}
|
||||
|
||||
private fun registerSecondFactorWorkflow(
|
||||
context: ComponentActivity,
|
||||
onScopeResult: (result: ScopeResult?) -> Unit = {}
|
||||
): ActivityResultLauncher<SessionId> =
|
||||
context.registerForActivityResult(StartSecondFactor()) { result ->
|
||||
context: ComponentActivity
|
||||
): ActivityResultLauncher<SecondFactorInput> =
|
||||
context.registerForActivityResult(
|
||||
StartSecondFactor()
|
||||
) { result ->
|
||||
result?.let {
|
||||
if (it.isMailboxLoginNeeded) {
|
||||
startMailboxLoginWorkflow(SessionId(it.sessionId))
|
||||
if (it.isTwoPassModeNeeded) {
|
||||
startTwoPassModeWorkflow(SessionId(it.sessionId))
|
||||
}
|
||||
onScopeResult(it)
|
||||
}
|
||||
onScopeResultListener(result)
|
||||
}
|
||||
|
||||
private fun registerHumanVerificationWorkflow(
|
||||
context: ComponentActivity,
|
||||
onHumanVerificationResult: (result: HumanVerificationResult?) -> Unit = {}
|
||||
context: ComponentActivity
|
||||
): ActivityResultLauncher<HumanVerificationInput> =
|
||||
context.registerForActivityResult(StartHumanVerification()) {
|
||||
onHumanVerificationResult(it)
|
||||
context.registerForActivityResult(
|
||||
StartHumanVerification()
|
||||
) { result ->
|
||||
if (result != null) {
|
||||
context.lifecycleScope.launch {
|
||||
if (!result.tokenType.isNullOrBlank() && !result.tokenCode.isNullOrBlank()) {
|
||||
accountWorkflowHandler.handleHumanVerificationSuccess(
|
||||
sessionId = SessionId(result.sessionId),
|
||||
tokenType = result.tokenType!!,
|
||||
tokenCode = result.tokenCode!!
|
||||
)
|
||||
} else {
|
||||
accountWorkflowHandler.handleHumanVerificationFailed(
|
||||
sessionId = SessionId(result.sessionId)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
onHumanVerificationResultListener(result)
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a Second Factor workflow.
|
||||
*/
|
||||
private fun startSecondFactorWorkflow(input: SessionId) {
|
||||
private fun startSecondFactorWorkflow(input: SecondFactorInput) {
|
||||
secondFactorWorkflowLauncher?.launch(input)
|
||||
?: throw IllegalStateException("You must call register before any start workflow function!")
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a MailboxLogin workflow.
|
||||
*/
|
||||
private fun startMailboxLoginWorkflow(input: SessionId) {
|
||||
mailboxWorkflowLauncher?.launch(input)
|
||||
?: throw IllegalStateException("You must call register before any start workflow function!")
|
||||
}
|
||||
// endregion
|
||||
|
||||
// region public API
|
||||
|
@ -110,10 +149,10 @@ class AuthOrchestrator {
|
|||
* Note: This function have to be called [ComponentActivity.onCreate]] before [ComponentActivity.onResume].
|
||||
*/
|
||||
fun register(context: ComponentActivity) {
|
||||
loginWorkflowLauncher ?: run { loginWorkflowLauncher = registerLoginWorkflowLauncher(context) }
|
||||
humanWorkflowLauncher ?: run { humanWorkflowLauncher = registerHumanVerificationWorkflow(context) }
|
||||
secondFactorWorkflowLauncher ?: run { secondFactorWorkflowLauncher = registerSecondFactorWorkflow(context) }
|
||||
mailboxWorkflowLauncher ?: run { mailboxWorkflowLauncher = registerMailboxLoginWorkflowLauncher(context) }
|
||||
loginWorkflowLauncher = registerLoginWorkflowLauncher(context)
|
||||
humanWorkflowLauncher = registerHumanVerificationWorkflow(context)
|
||||
secondFactorWorkflowLauncher = registerSecondFactorWorkflow(context)
|
||||
twoPassModeWorkflowLauncher = registerTwoPassModeWorkflowLauncher(context)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -124,6 +163,14 @@ class AuthOrchestrator {
|
|||
?: throw IllegalStateException("You must call register before any start workflow function!")
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a TwoPassMode workflow.
|
||||
*/
|
||||
fun startTwoPassModeWorkflow(input: SessionId) {
|
||||
twoPassModeWorkflowLauncher?.launch(input)
|
||||
?: throw IllegalStateException("You must call register before any start workflow function!")
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a Human Verification workflow.
|
||||
*/
|
||||
|
@ -138,3 +185,31 @@ class AuthOrchestrator {
|
|||
}
|
||||
// endregion
|
||||
}
|
||||
|
||||
fun AuthOrchestrator.onUserResult(
|
||||
block: (result: UserResult?) -> Unit
|
||||
): AuthOrchestrator {
|
||||
setOnUserResult { block(it) }
|
||||
return this
|
||||
}
|
||||
|
||||
fun AuthOrchestrator.onScopeResult(
|
||||
block: (result: ScopeResult?) -> Unit
|
||||
): AuthOrchestrator {
|
||||
setOnScopeResult { block(it) }
|
||||
return this
|
||||
}
|
||||
|
||||
fun AuthOrchestrator.onSessionResult(
|
||||
block: (result: SessionResult?) -> Unit
|
||||
): AuthOrchestrator {
|
||||
setOnSessionResult { block(it) }
|
||||
return this
|
||||
}
|
||||
|
||||
fun AuthOrchestrator.onHumanVerificationResult(
|
||||
block: (result: HumanVerificationResult?) -> Unit
|
||||
): AuthOrchestrator {
|
||||
setOnHumanVerificationResult { block(it) }
|
||||
return this
|
||||
}
|
||||
|
|
|
@ -21,16 +21,15 @@ package me.proton.core.auth.presentation.entity
|
|||
import android.os.Parcelable
|
||||
import kotlinx.android.parcel.Parcelize
|
||||
import me.proton.core.auth.domain.entity.ScopeInfo
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
|
||||
@Parcelize
|
||||
data class ScopeResult(
|
||||
val sessionId: String,
|
||||
val scopes: List<String>,
|
||||
val isMailboxLoginNeeded: Boolean = false
|
||||
val isTwoPassModeNeeded: Boolean = false
|
||||
) : Parcelable {
|
||||
|
||||
constructor(sessionId: SessionId, scopeInfo: ScopeInfo, isMailboxLoginNeeded: Boolean = false)
|
||||
: this(sessionId.id, scopeInfo.scopes, isMailboxLoginNeeded)
|
||||
constructor(sessionId: SessionId, scopeInfo: ScopeInfo, isMailboxLoginNeeded: Boolean = false) :
|
||||
this(sessionId.id, scopeInfo.scopes, isMailboxLoginNeeded)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Proton Technologies AG
|
||||
* This file is part of Proton Technologies AG and ProtonCore.
|
||||
*
|
||||
* ProtonCore is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonCore is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package me.proton.core.auth.presentation.entity
|
||||
|
||||
import android.os.Parcelable
|
||||
import kotlinx.android.parcel.Parcelize
|
||||
|
||||
@Parcelize
|
||||
data class SecondFactorInput(
|
||||
val sessionId: String,
|
||||
val isTwoPassModeNeeded: Boolean
|
||||
) : Parcelable
|
|
@ -44,7 +44,7 @@ data class SessionResult(
|
|||
) : Parcelable {
|
||||
|
||||
@IgnoredOnParcel
|
||||
val isMailboxLoginNeeded = passwordMode == 2
|
||||
val isTwoPassModeNeeded = passwordMode == 2
|
||||
|
||||
companion object {
|
||||
|
||||
|
|
|
@ -23,9 +23,9 @@ import android.content.Context
|
|||
import android.content.Intent
|
||||
import androidx.activity.result.contract.ActivityResultContract
|
||||
import me.proton.core.auth.presentation.entity.ScopeResult
|
||||
import me.proton.core.auth.presentation.entity.SecondFactorInput
|
||||
import me.proton.core.auth.presentation.entity.SessionResult
|
||||
import me.proton.core.auth.presentation.entity.UserResult
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
|
||||
class StartLogin : ActivityResultContract<List<String>, SessionResult?>() {
|
||||
|
@ -41,11 +41,11 @@ class StartLogin : ActivityResultContract<List<String>, SessionResult?>() {
|
|||
}
|
||||
}
|
||||
|
||||
class StartSecondFactor : ActivityResultContract<SessionId, ScopeResult?>() {
|
||||
class StartSecondFactor : ActivityResultContract<SecondFactorInput, ScopeResult?>() {
|
||||
|
||||
override fun createIntent(context: Context, sessionId: SessionId) =
|
||||
override fun createIntent(context: Context, inupt: SecondFactorInput) =
|
||||
Intent(context, SecondFactorActivity::class.java).apply {
|
||||
putExtra(SecondFactorActivity.ARG_SESSION_ID, sessionId.id)
|
||||
putExtra(SecondFactorActivity.ARG_SECOND_FACTOR_INPUT, inupt)
|
||||
}
|
||||
|
||||
override fun parseResult(resultCode: Int, result: Intent?): ScopeResult? {
|
||||
|
@ -54,11 +54,11 @@ class StartSecondFactor : ActivityResultContract<SessionId, ScopeResult?>() {
|
|||
}
|
||||
}
|
||||
|
||||
class StartMailboxLogin : ActivityResultContract<SessionId, UserResult?>() {
|
||||
class StartTwoPassMode : ActivityResultContract<SessionId, UserResult?>() {
|
||||
|
||||
override fun createIntent(context: Context, sessionId: SessionId) =
|
||||
override fun createIntent(context: Context, inupt: SessionId) =
|
||||
Intent(context, MailboxLoginActivity::class.java).apply {
|
||||
putExtra(MailboxLoginActivity.ARG_SESSION_ID, sessionId.id)
|
||||
putExtra(MailboxLoginActivity.ARG_SESSION_ID, inupt.id)
|
||||
}
|
||||
|
||||
override fun parseResult(resultCode: Int, result: Intent?): UserResult? {
|
||||
|
|
|
@ -18,20 +18,30 @@
|
|||
|
||||
package me.proton.core.auth.presentation.ui
|
||||
|
||||
/**
|
||||
* Interface common for all authentication activities.
|
||||
*
|
||||
* @author Dino Kadrikj.
|
||||
*/
|
||||
interface AuthActivity {
|
||||
import android.os.Build
|
||||
import android.os.Bundle
|
||||
import android.view.View
|
||||
import androidx.databinding.ViewDataBinding
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.errorSnack
|
||||
import me.proton.core.auth.presentation.R
|
||||
|
||||
/**
|
||||
* Instructs the activity to show loading animation (custom for eacch activity).
|
||||
*/
|
||||
fun showLoading(loading: Boolean)
|
||||
abstract class AuthActivity<DB : ViewDataBinding> : ProtonActivity<DB>() {
|
||||
|
||||
/**
|
||||
* Provide default implementation for error UI.
|
||||
*/
|
||||
fun showError(message: String?)
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||
window.decorView.systemUiVisibility = View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR
|
||||
}
|
||||
}
|
||||
|
||||
open fun showLoading(loading: Boolean) {
|
||||
// No op
|
||||
}
|
||||
|
||||
open fun showError(message: String?) {
|
||||
showLoading(false)
|
||||
binding.root.errorSnack(message = message ?: getString(R.string.auth_login_general_error))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Proton Technologies AG
|
||||
* This file is part of Proton Technologies AG and ProtonCore.
|
||||
*
|
||||
* ProtonCore is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonCore is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonCore. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package me.proton.core.auth.presentation.ui
|
||||
|
||||
import android.os.Build
|
||||
import android.view.View
|
||||
import androidx.databinding.ViewDataBinding
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.errorSnack
|
||||
import me.proton.core.auth.presentation.R
|
||||
|
||||
/**
|
||||
* Delegate class implementing the [AuthActivity] interface.
|
||||
*
|
||||
* @author Dino Kadrikj.
|
||||
*/
|
||||
class AuthActivityDelegate<DB : ViewDataBinding> : AuthActivityComponent<DB> {
|
||||
|
||||
/** A reference to the Activity that will handle the rotation */
|
||||
private lateinit var activity : ProtonActivity<DB>
|
||||
|
||||
override fun initializeAuth(protonAuthActivity: ProtonActivity<DB>) {
|
||||
activity = protonAuthActivity
|
||||
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||
activity.window.decorView.systemUiVisibility = View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR
|
||||
}
|
||||
}
|
||||
|
||||
override fun showLoading(loading: Boolean) {
|
||||
// noop
|
||||
}
|
||||
|
||||
override fun showError(message: String?) {
|
||||
showLoading(false)
|
||||
activity.binding.root.errorSnack(message = message ?: activity.getString(R.string.auth_login_general_error))
|
||||
}
|
||||
|
||||
}
|
|
@ -19,7 +19,6 @@
|
|||
package me.proton.core.auth.presentation.ui
|
||||
|
||||
import android.os.Bundle
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.onClick
|
||||
import me.proton.android.core.presentation.utils.openBrowserLink
|
||||
import me.proton.core.auth.presentation.R
|
||||
|
@ -29,12 +28,11 @@ import me.proton.core.auth.presentation.databinding.ActivityAuthHelpBinding
|
|||
* Authentication help Activity which offers common authentication problems help.
|
||||
* @author Dino Kadrikj.
|
||||
*/
|
||||
class AuthHelpActivity : ProtonActivity<ActivityAuthHelpBinding>(), AuthActivityComponent<ActivityAuthHelpBinding> by AuthActivityDelegate() {
|
||||
class AuthHelpActivity : AuthActivity<ActivityAuthHelpBinding>() {
|
||||
override fun layoutId(): Int = R.layout.activity_auth_help
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
initializeAuth(this)
|
||||
binding.apply {
|
||||
closeButton.onClick {
|
||||
finish()
|
||||
|
@ -54,8 +52,4 @@ class AuthHelpActivity : ProtonActivity<ActivityAuthHelpBinding>(), AuthActivity
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun showLoading(loading: Boolean) {
|
||||
// no-operation
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ import android.content.Intent
|
|||
import android.os.Bundle
|
||||
import androidx.activity.viewModels
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.hideKeyboard
|
||||
import me.proton.android.core.presentation.utils.onClick
|
||||
import me.proton.android.core.presentation.utils.onFailure
|
||||
|
@ -32,7 +31,6 @@ import me.proton.android.core.presentation.utils.validatePassword
|
|||
import me.proton.android.core.presentation.utils.validateUsername
|
||||
import me.proton.core.auth.domain.entity.SessionInfo
|
||||
import me.proton.core.auth.domain.usecase.PerformLogin
|
||||
import me.proton.core.auth.presentation.AuthOrchestrator
|
||||
import me.proton.core.auth.presentation.R
|
||||
import me.proton.core.auth.presentation.databinding.ActivityLoginBinding
|
||||
import me.proton.core.auth.presentation.entity.SessionResult
|
||||
|
@ -43,18 +41,14 @@ import me.proton.core.util.kotlin.exhaustive
|
|||
* Login Activity which allows users to Login to any Proton client application.
|
||||
*/
|
||||
@AndroidEntryPoint
|
||||
class LoginActivity : ProtonActivity<ActivityLoginBinding>(),
|
||||
AuthActivityComponent<ActivityLoginBinding> by AuthActivityDelegate() {
|
||||
class LoginActivity : AuthActivity<ActivityLoginBinding>() {
|
||||
|
||||
private val viewModel by viewModels<LoginViewModel>()
|
||||
private val authOrchestrator = AuthOrchestrator()
|
||||
|
||||
override fun layoutId(): Int = R.layout.activity_login
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
initializeAuth(this)
|
||||
authOrchestrator.register(this)
|
||||
|
||||
binding.apply {
|
||||
closeButton.onClick {
|
||||
|
|
|
@ -23,7 +23,6 @@ import android.content.Intent
|
|||
import android.os.Bundle
|
||||
import androidx.activity.viewModels
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.hideKeyboard
|
||||
import me.proton.android.core.presentation.utils.onClick
|
||||
import me.proton.android.core.presentation.utils.onFailure
|
||||
|
@ -45,11 +44,10 @@ import me.proton.core.util.kotlin.exhaustive
|
|||
* mailbox).
|
||||
*/
|
||||
@AndroidEntryPoint
|
||||
class MailboxLoginActivity : ProtonActivity<ActivityMailboxLoginBinding>(),
|
||||
AuthActivityComponent<ActivityMailboxLoginBinding> by AuthActivityDelegate() {
|
||||
class MailboxLoginActivity : AuthActivity<ActivityMailboxLoginBinding>() {
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
intent?.extras?.get(ARG_SESSION_ID) as SessionId
|
||||
SessionId(requireNotNull(intent?.extras?.getString(ARG_SESSION_ID)))
|
||||
}
|
||||
|
||||
private val viewModel by viewModels<MailboxLoginViewModel>()
|
||||
|
@ -58,10 +56,10 @@ class MailboxLoginActivity : ProtonActivity<ActivityMailboxLoginBinding>(),
|
|||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
initializeAuth(this)
|
||||
|
||||
binding.apply {
|
||||
closeButton.onClick {
|
||||
finish()
|
||||
onBackPressed()
|
||||
}
|
||||
|
||||
forgotPasswordButton.onClick {
|
||||
|
@ -93,6 +91,12 @@ class MailboxLoginActivity : ProtonActivity<ActivityMailboxLoginBinding>(),
|
|||
}
|
||||
}
|
||||
|
||||
override fun onBackPressed() {
|
||||
viewModel.stopMailboxLoginFlow(sessionId).invokeOnCompletion {
|
||||
finish()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoked on successful completed mailbox login operation.
|
||||
*/
|
||||
|
|
|
@ -24,7 +24,6 @@ import android.os.Bundle
|
|||
import android.text.InputType
|
||||
import androidx.activity.viewModels
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.hideKeyboard
|
||||
import me.proton.android.core.presentation.utils.onClick
|
||||
import me.proton.android.core.presentation.utils.onFailure
|
||||
|
@ -35,6 +34,7 @@ import me.proton.core.auth.domain.usecase.PerformSecondFactor
|
|||
import me.proton.core.auth.presentation.R
|
||||
import me.proton.core.auth.presentation.databinding.Activity2faBinding
|
||||
import me.proton.core.auth.presentation.entity.ScopeResult
|
||||
import me.proton.core.auth.presentation.entity.SecondFactorInput
|
||||
import me.proton.core.auth.presentation.viewmodel.SecondFactorViewModel
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.util.kotlin.exhaustive
|
||||
|
@ -45,15 +45,10 @@ import me.proton.core.util.kotlin.exhaustive
|
|||
* Optional, only shown for accounts with 2FA login enabled.
|
||||
*/
|
||||
@AndroidEntryPoint
|
||||
class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
|
||||
AuthActivityComponent<Activity2faBinding> by AuthActivityDelegate() {
|
||||
class SecondFactorActivity : AuthActivity<Activity2faBinding>() {
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
intent?.extras?.get(ARG_SESSION_ID) as SessionId
|
||||
}
|
||||
|
||||
private val twoPassMode: Boolean by lazy {
|
||||
intent?.extras?.get(ARG_TWO_PASS_MODE) as Boolean
|
||||
private val input: SecondFactorInput by lazy {
|
||||
requireNotNull(intent?.extras?.getParcelable(ARG_SECOND_FACTOR_INPUT))
|
||||
}
|
||||
|
||||
// initial mode is the second factor input mode.
|
||||
|
@ -65,10 +60,10 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
|
|||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
initializeAuth(this)
|
||||
|
||||
binding.apply {
|
||||
closeButton.onClick {
|
||||
finish()
|
||||
onBackPressed()
|
||||
}
|
||||
|
||||
recoveryCodeButton.onClick {
|
||||
|
@ -87,7 +82,7 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
|
|||
is PerformSecondFactor.SecondFactorState.Success -> onSuccess(
|
||||
it.sessionId,
|
||||
it.scopeInfo,
|
||||
it.isMailboxLoginNeeded
|
||||
it.isTwoPassModeNeeded
|
||||
)
|
||||
is PerformSecondFactor.SecondFactorState.Error.Message -> onError(false, it.message)
|
||||
is PerformSecondFactor.SecondFactorState.Error.EmptyCredentials -> {
|
||||
|
@ -111,20 +106,29 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
|
|||
secondFactorInput.validate()
|
||||
.onFailure { secondFactorInput.setInputError() }
|
||||
.onSuccess { secondFactorCode ->
|
||||
viewModel.startSecondFactorFlow(sessionId, secondFactorCode, twoPassMode)
|
||||
viewModel.startSecondFactorFlow(
|
||||
sessionId = SessionId(input.sessionId),
|
||||
secondFactorCode = secondFactorCode,
|
||||
isTwoPassModeNeeded = input.isTwoPassModeNeeded
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun onBackPressed() {
|
||||
viewModel.stopSecondFactorFlow(SessionId(input.sessionId)).invokeOnCompletion {
|
||||
finish()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoked on successful completed mailbox login operation.
|
||||
*/
|
||||
private fun onSuccess(sessionId: SessionId, scopeInfo: ScopeInfo, isMailboxLoginNeeded: Boolean) {
|
||||
val intent =
|
||||
Intent().putExtra(
|
||||
ARG_SCOPE_RESULT,
|
||||
ScopeResult(sessionId, scopeInfo, isMailboxLoginNeeded)
|
||||
)
|
||||
private fun onSuccess(sessionId: SessionId, scopeInfo: ScopeInfo, isTwoPassModeNeeded: Boolean) {
|
||||
val intent = Intent().putExtra(
|
||||
ARG_SCOPE_RESULT,
|
||||
ScopeResult(sessionId, scopeInfo, isTwoPassModeNeeded)
|
||||
)
|
||||
setResult(Activity.RESULT_OK, intent)
|
||||
finish()
|
||||
}
|
||||
|
@ -173,8 +177,7 @@ class SecondFactorActivity : ProtonActivity<Activity2faBinding>(),
|
|||
}
|
||||
|
||||
companion object {
|
||||
const val ARG_SESSION_ID = "arg.sessionId"
|
||||
const val ARG_TWO_PASS_MODE = "arg.twoPassMode"
|
||||
const val ARG_SECOND_FACTOR_INPUT = "arg.secondFactorInput"
|
||||
const val ARG_SCOPE_RESULT = "arg.scopeResult"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,16 +54,14 @@ class LoginViewModel @ViewModelInject constructor(
|
|||
username: String,
|
||||
password: ByteArray
|
||||
) {
|
||||
performLogin(username, password)
|
||||
.onEach {
|
||||
if (it is PerformLogin.LoginState.Success) {
|
||||
// on success result, contact account manager
|
||||
onSuccess(it)
|
||||
}
|
||||
// inform the view for each state change
|
||||
loginState.post(it)
|
||||
performLogin(username, password).onEach {
|
||||
if (it is PerformLogin.LoginState.Success) {
|
||||
// on success result, contact account manager
|
||||
onSuccess(it)
|
||||
}
|
||||
.launchIn(viewModelScope)
|
||||
// inform the view for each state change
|
||||
loginState.post(it)
|
||||
}.launchIn(viewModelScope)
|
||||
}
|
||||
|
||||
private suspend fun onSuccess(success: PerformLogin.LoginState.Success) {
|
||||
|
|
|
@ -20,15 +20,14 @@ package me.proton.core.auth.presentation.viewmodel
|
|||
|
||||
import androidx.hilt.lifecycle.ViewModelInject
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import kotlinx.coroutines.flow.flowOn
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.launchIn
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.launch
|
||||
import me.proton.android.core.presentation.viewmodel.ProtonViewModel
|
||||
import me.proton.core.auth.domain.AccountWorkflowHandler
|
||||
import me.proton.core.auth.domain.usecase.PerformMailboxLogin
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.util.kotlin.DispatcherProvider
|
||||
import studio.forface.viewstatestore.ViewStateStore
|
||||
import studio.forface.viewstatestore.ViewStateStoreScope
|
||||
|
||||
|
@ -49,14 +48,15 @@ class MailboxLoginViewModel @ViewModelInject constructor(
|
|||
sessionId: SessionId,
|
||||
password: ByteArray
|
||||
) {
|
||||
performMailboxLogin(sessionId, password)
|
||||
.onEach {
|
||||
if (it is PerformMailboxLogin.MailboxLoginState.Success) {
|
||||
accountWorkflowHandler.handleTwoPassModeSuccess(sessionId)
|
||||
} else if (it is PerformMailboxLogin.MailboxLoginState.Error) {
|
||||
accountWorkflowHandler.handleTwoPassModeFailed(sessionId)
|
||||
}
|
||||
mailboxLoginState.post(it)
|
||||
}.launchIn(viewModelScope)
|
||||
performMailboxLogin(sessionId, password).onEach {
|
||||
if (it is PerformMailboxLogin.MailboxLoginState.Success) {
|
||||
accountWorkflowHandler.handleTwoPassModeSuccess(sessionId)
|
||||
}
|
||||
mailboxLoginState.post(it)
|
||||
}.launchIn(viewModelScope)
|
||||
}
|
||||
|
||||
fun stopMailboxLoginFlow(
|
||||
sessionId: SessionId
|
||||
): Job = viewModelScope.launch { accountWorkflowHandler.handleTwoPassModeFailed(sessionId) }
|
||||
}
|
||||
|
|
|
@ -20,14 +20,14 @@ package me.proton.core.auth.presentation.viewmodel
|
|||
|
||||
import androidx.hilt.lifecycle.ViewModelInject
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import kotlinx.coroutines.flow.flowOn
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.launchIn
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.launch
|
||||
import me.proton.android.core.presentation.viewmodel.ProtonViewModel
|
||||
import me.proton.core.auth.domain.AccountWorkflowHandler
|
||||
import me.proton.core.auth.domain.usecase.PerformSecondFactor
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.util.kotlin.DispatcherProvider
|
||||
import studio.forface.viewstatestore.ViewStateStore
|
||||
import studio.forface.viewstatestore.ViewStateStoreScope
|
||||
|
||||
|
@ -44,27 +44,20 @@ class SecondFactorViewModel @ViewModelInject constructor(
|
|||
fun startSecondFactorFlow(
|
||||
sessionId: SessionId,
|
||||
secondFactorCode: String,
|
||||
isMailboxLoginNeeded: Boolean
|
||||
isTwoPassModeNeeded: Boolean = false
|
||||
) {
|
||||
performSecondFactor(sessionId, secondFactorCode)
|
||||
.onEach {
|
||||
when (it) {
|
||||
is PerformSecondFactor.SecondFactorState.Success -> {
|
||||
secondFactorState.post(it.copy(isMailboxLoginNeeded = isMailboxLoginNeeded))
|
||||
accountWorkflowHandler.handleSecondFactorSuccess(
|
||||
sessionId = sessionId,
|
||||
updatedScopes = it.scopeInfo.scopes
|
||||
)
|
||||
}
|
||||
is PerformSecondFactor.SecondFactorState.Error -> {
|
||||
secondFactorState.post(it)
|
||||
accountWorkflowHandler.handleSecondFactorFailed(sessionId)
|
||||
}
|
||||
else -> {
|
||||
secondFactorState.post(it)
|
||||
}
|
||||
}
|
||||
performSecondFactor(sessionId, secondFactorCode, isTwoPassModeNeeded).onEach {
|
||||
if (it is PerformSecondFactor.SecondFactorState.Success) {
|
||||
accountWorkflowHandler.handleSecondFactorSuccess(
|
||||
sessionId = sessionId,
|
||||
updatedScopes = it.scopeInfo.scopes
|
||||
)
|
||||
}
|
||||
.launchIn(viewModelScope)
|
||||
secondFactorState.post(it)
|
||||
}.launchIn(viewModelScope)
|
||||
}
|
||||
|
||||
fun stopSecondFactorFlow(
|
||||
sessionId: SessionId
|
||||
): Job = viewModelScope.launch { accountWorkflowHandler.handleSecondFactorFailed(sessionId) }
|
||||
}
|
||||
|
|
|
@ -26,9 +26,10 @@ import io.mockk.slot
|
|||
import io.mockk.verify
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.auth.domain.AccountWorkflowHandler
|
||||
import me.proton.core.auth.domain.crypto.SrpProofProvider
|
||||
import me.proton.core.auth.domain.entity.Account
|
||||
import me.proton.core.auth.domain.entity.SessionInfo
|
||||
import me.proton.core.auth.domain.repository.AuthRepository
|
||||
import me.proton.core.auth.domain.usecase.PerformLogin
|
||||
|
@ -232,7 +233,7 @@ class LoginViewModelTest : ArchTest, CoroutinesTest {
|
|||
val account = accountArgument.captured
|
||||
val session = sessionArgument.captured
|
||||
assertNotNull(account)
|
||||
assertTrue(account.isTwoPassModeNeeded)
|
||||
assertEquals(AccountState.TwoPassModeNeeded, account.state)
|
||||
assertEquals(testSessionId, session.sessionId.id)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,18 +123,12 @@ class MailboxLoginViewModelTest : ArchTest, CoroutinesTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
fun `failed mailbox login invokes failed on account manager`() = coroutinesTest {
|
||||
// GIVEN
|
||||
coEvery { useCase.invoke(SessionId(testSessionId), testPassword.toByteArray()) } returns flowOf(
|
||||
PerformMailboxLogin.MailboxLoginState.Processing,
|
||||
PerformMailboxLogin.MailboxLoginState.Error.Message("test error")
|
||||
)
|
||||
fun `stop mailbox login invokes failed on account manager`() = coroutinesTest {
|
||||
// WHEN
|
||||
viewModel.startMailboxLoginFlow(SessionId(testSessionId), testPassword.toByteArray())
|
||||
viewModel.stopMailboxLoginFlow(SessionId(testSessionId))
|
||||
// THEN
|
||||
val arguments = slot<SessionId>()
|
||||
coVerify(exactly = 1) { accountManager.handleTwoPassModeFailed(capture(arguments)) }
|
||||
coVerify(exactly = 0) { accountManager.handleTwoPassModeSuccess(any()) }
|
||||
assertEquals(testSessionId, arguments.captured.id)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
|
|||
val isMailboxLoginNeeded = false
|
||||
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode) } returns flowOf(
|
||||
PerformSecondFactor.SecondFactorState.Processing,
|
||||
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo)
|
||||
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo, isMailboxLoginNeeded)
|
||||
)
|
||||
val observer = mockk<(PerformSecondFactor.SecondFactorState) -> Unit>(relaxed = true)
|
||||
viewModel.secondFactorState.observeDataForever(observer)
|
||||
|
@ -103,7 +103,7 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
|
|||
val isMailboxLoginNeeded = false
|
||||
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode) } returns flowOf(
|
||||
PerformSecondFactor.SecondFactorState.Processing,
|
||||
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo)
|
||||
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo, isMailboxLoginNeeded)
|
||||
)
|
||||
val observer = mockk<(PerformSecondFactor.SecondFactorState) -> Unit>(relaxed = true)
|
||||
viewModel.secondFactorState.observeDataForever(observer)
|
||||
|
@ -119,7 +119,7 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
|
|||
val successState = arguments[1]
|
||||
assertTrue(processingState is PerformSecondFactor.SecondFactorState.Processing)
|
||||
assertTrue(successState is PerformSecondFactor.SecondFactorState.Success)
|
||||
assertFalse(successState.isMailboxLoginNeeded)
|
||||
assertFalse(successState.isTwoPassModeNeeded)
|
||||
assertEquals(SessionId(testSessionId), accountManagerArguments.captured)
|
||||
}
|
||||
|
||||
|
@ -127,9 +127,9 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
|
|||
fun `submit 2fa two pass mode flow states are handled correctly`() = coroutinesTest {
|
||||
// GIVEN
|
||||
val isMailboxLoginNeeded = true
|
||||
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode) } returns flowOf(
|
||||
coEvery { useCase.invoke(SessionId(testSessionId), testSecondFactorCode, isMailboxLoginNeeded) } returns flowOf(
|
||||
PerformSecondFactor.SecondFactorState.Processing,
|
||||
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo)
|
||||
PerformSecondFactor.SecondFactorState.Success(SessionId(testSessionId), testScopeInfo, isMailboxLoginNeeded)
|
||||
)
|
||||
val observer = mockk<(PerformSecondFactor.SecondFactorState) -> Unit>(relaxed = true)
|
||||
viewModel.secondFactorState.observeDataForever(observer)
|
||||
|
@ -145,7 +145,17 @@ class SecondFactorViewModelTest : ArchTest, CoroutinesTest {
|
|||
val successState = arguments[1]
|
||||
assertTrue(processingState is PerformSecondFactor.SecondFactorState.Processing)
|
||||
assertTrue(successState is PerformSecondFactor.SecondFactorState.Success)
|
||||
assertTrue(successState.isMailboxLoginNeeded)
|
||||
assertTrue(successState.isTwoPassModeNeeded)
|
||||
assertEquals(SessionId(testSessionId), accountManagerArguments.captured)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `stop 2fa invokes failed on account manager`() = coroutinesTest {
|
||||
// WHEN
|
||||
viewModel.stopSecondFactorFlow(SessionId(testSessionId))
|
||||
// THEN
|
||||
val arguments = slot<SessionId>()
|
||||
coVerify(exactly = 1) { accountManager.handleSecondFactorFailed(capture(arguments)) }
|
||||
coVerify(exactly = 0) { accountManager.handleSecondFactorSuccess(any(), any()) }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,8 +80,8 @@ object ApplicationModule {
|
|||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideApiProvider(apiFactory: ApiFactory): ApiProvider =
|
||||
ApiProvider(apiFactory)
|
||||
fun provideApiProvider(apiFactory: ApiFactory, sessionProvider: SessionProvider): ApiProvider =
|
||||
ApiProvider(apiFactory, sessionProvider)
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
|
@ -105,11 +105,4 @@ object ApplicationModule {
|
|||
@Provides
|
||||
@ClientSecret
|
||||
fun provideClientSecret(): String = ""
|
||||
|
||||
@Provides
|
||||
fun provideDispatcherProvider() = object : DispatcherProvider {
|
||||
override val Io = Dispatchers.IO
|
||||
override val Comp = Dispatchers.Default
|
||||
override val Main = Dispatchers.Main
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,37 +18,31 @@
|
|||
|
||||
package me.proton.android.core.coreexample
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.content.Intent
|
||||
import android.os.Bundle
|
||||
import android.widget.Button
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import kotlinx.coroutines.flow.launchIn
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.launch
|
||||
import me.proton.android.core.coreexample.databinding.ActivityMainBinding
|
||||
import me.proton.android.core.coreexample.ui.CustomViewsActivity
|
||||
import me.proton.android.core.presentation.ui.ProtonActivity
|
||||
import me.proton.android.core.presentation.utils.onClick
|
||||
import me.proton.core.account.domain.entity.Account
|
||||
import me.proton.core.account.domain.entity.AccountState
|
||||
import me.proton.core.account.domain.entity.SessionState
|
||||
import me.proton.core.accountmanager.domain.AccountManager
|
||||
import me.proton.core.accountmanager.presentation.observe
|
||||
import me.proton.core.accountmanager.presentation.onAccountAdded
|
||||
import me.proton.core.accountmanager.presentation.onAccountDisabled
|
||||
import me.proton.core.accountmanager.presentation.onAccountInitializing
|
||||
import me.proton.core.accountmanager.presentation.onAccountReady
|
||||
import me.proton.core.accountmanager.presentation.onAccountRemoved
|
||||
import me.proton.core.accountmanager.presentation.onAccountTwoPassModeFailed
|
||||
import me.proton.core.accountmanager.presentation.onSessionAuthenticated
|
||||
import me.proton.core.accountmanager.presentation.onSessionForceLogout
|
||||
import me.proton.core.accountmanager.presentation.onSessionHumanVerificationFailed
|
||||
import me.proton.core.accountmanager.domain.getPrimaryAccount
|
||||
import me.proton.core.auth.presentation.AuthOrchestrator
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.auth.presentation.onHumanVerificationResult
|
||||
import me.proton.core.auth.presentation.onScopeResult
|
||||
import me.proton.core.auth.presentation.onSessionResult
|
||||
import me.proton.core.auth.presentation.onUserResult
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.humanverification.VerificationMethod
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import java.util.UUID
|
||||
import javax.inject.Inject
|
||||
|
||||
@AndroidEntryPoint
|
||||
|
@ -57,79 +51,80 @@ class MainActivity : ProtonActivity<ActivityMainBinding>() {
|
|||
@Inject
|
||||
lateinit var accountManager: AccountManager
|
||||
|
||||
@Inject
|
||||
lateinit var authOrchestrator: AuthOrchestrator
|
||||
|
||||
override fun layoutId(): Int = R.layout.activity_main
|
||||
|
||||
private val authWorkflowLauncher = AuthOrchestrator()
|
||||
|
||||
@SuppressLint("SetTextI18n")
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
|
||||
authWorkflowLauncher.register(this)
|
||||
authOrchestrator.register(this)
|
||||
authOrchestrator
|
||||
.onUserResult { }
|
||||
.onScopeResult { }
|
||||
.onSessionResult { }
|
||||
.onHumanVerificationResult { }
|
||||
|
||||
binding.humanVerification.onClick {
|
||||
authWorkflowLauncher.startHumanVerificationWorkflow(
|
||||
SessionId("sessionId"),
|
||||
HumanVerificationDetails(
|
||||
listOf(
|
||||
VerificationMethod.CAPTCHA,
|
||||
VerificationMethod.EMAIL,
|
||||
VerificationMethod.PHONE
|
||||
with(binding) {
|
||||
humanVerification.onClick {
|
||||
authOrchestrator.startHumanVerificationWorkflow(
|
||||
SessionId("sessionId"),
|
||||
HumanVerificationDetails(
|
||||
listOf(
|
||||
VerificationMethod.CAPTCHA,
|
||||
VerificationMethod.EMAIL,
|
||||
VerificationMethod.PHONE
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
customViews.onClick { startActivity(Intent(this@MainActivity, CustomViewsActivity::class.java)) }
|
||||
login.onClick { authOrchestrator.startLoginWorkflow() }
|
||||
}
|
||||
|
||||
binding.customViews.onClick {
|
||||
startActivity(Intent(this, CustomViewsActivity::class.java))
|
||||
}
|
||||
accountManager.getPrimaryAccount().onEach { primary ->
|
||||
binding.primaryAccountText.text = "Primary: ${primary?.username}"
|
||||
}.launchIn(lifecycleScope)
|
||||
|
||||
binding.login.onClick {
|
||||
authWorkflowLauncher.startLoginWorkflow()
|
||||
}
|
||||
accountManager.getAccounts().onEach { accounts ->
|
||||
if (accounts.isEmpty()) authOrchestrator.startLoginWorkflow()
|
||||
|
||||
accountManager.getPrimaryUserId()
|
||||
.onEach { userId ->
|
||||
if (userId == null) {
|
||||
val userId = UserId(UUID.randomUUID().toString())
|
||||
val sessionId = SessionId(UUID.randomUUID().toString())
|
||||
val session = Session(
|
||||
sessionId = sessionId,
|
||||
accessToken = "accessToken",
|
||||
refreshToken = "refreshToken",
|
||||
headers = null,
|
||||
scopes = listOf()
|
||||
)
|
||||
accountManager.addAccount(
|
||||
Account(
|
||||
userId,
|
||||
"username",
|
||||
"example@example.com",
|
||||
AccountState.Ready,
|
||||
sessionId,
|
||||
SessionState.Authenticated
|
||||
),
|
||||
session
|
||||
)
|
||||
}
|
||||
}.launchIn(lifecycleScope)
|
||||
|
||||
accountManager.getSessions().onEach { }.launchIn(lifecycleScope)
|
||||
|
||||
accountManager.observe(lifecycleScope)
|
||||
.onAccountAdded { }
|
||||
.onAccountDisabled { }
|
||||
.onAccountInitializing { }
|
||||
.onAccountReady { }
|
||||
.onAccountRemoved { }
|
||||
.onAccountTwoPassModeFailed { }
|
||||
.onSessionAuthenticated { }
|
||||
.onSessionForceLogout { }
|
||||
.onSessionHumanVerificationFailed { }
|
||||
binding.accountsLayout.removeAllViews()
|
||||
accounts.forEach { account ->
|
||||
binding.accountsLayout.addView(
|
||||
Button(this@MainActivity).apply {
|
||||
text = "${account.username} -> ${account.state}/${account.sessionState}"
|
||||
onClick {
|
||||
lifecycleScope.launch {
|
||||
when (account.state) {
|
||||
AccountState.Ready ->
|
||||
accountManager.disableAccount(account.userId)
|
||||
AccountState.Disabled ->
|
||||
accountManager.removeAccount(account.userId)
|
||||
AccountState.NotReady,
|
||||
AccountState.TwoPassModeNeeded,
|
||||
AccountState.TwoPassModeFailed ->
|
||||
when (account.sessionState) {
|
||||
SessionState.SecondFactorNeeded,
|
||||
SessionState.SecondFactorFailed ->
|
||||
accountManager.disableAccount(account.userId)
|
||||
SessionState.Authenticated ->
|
||||
authOrchestrator.startTwoPassModeWorkflow(account.sessionId!!)
|
||||
else -> Unit
|
||||
}
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}.launchIn(lifecycleScope)
|
||||
|
||||
accountManager.onHumanVerificationNeeded().onEach { (account, details) ->
|
||||
account.sessionId?.let {
|
||||
authWorkflowLauncher.startHumanVerificationWorkflow(it, details)
|
||||
}
|
||||
authOrchestrator.startHumanVerificationWorkflow(account.sessionId!!, details)
|
||||
}.launchIn(lifecycleScope)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,6 +55,43 @@
|
|||
app:layout_constraintStart_toStartOf="parent"
|
||||
app:layout_constraintTop_toBottomOf="@id/customViews" />
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/titleLayout"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginTop="@dimen/default_gap"
|
||||
android:orientation="horizontal"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintStart_toStartOf="parent"
|
||||
app:layout_constraintTop_toBottomOf="@id/login">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/accountsText"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_weight="1"
|
||||
android:text="Accounts:" />
|
||||
|
||||
<TextView
|
||||
android:id="@+id/primaryAccountText"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_weight="1"
|
||||
android:gravity="end"
|
||||
android:text="-" />
|
||||
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:id="@+id/accountsLayout"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginTop="@dimen/default_top_margin"
|
||||
android:orientation="vertical"
|
||||
app:layout_constraintEnd_toEndOf="parent"
|
||||
app:layout_constraintStart_toStartOf="parent"
|
||||
app:layout_constraintTop_toBottomOf="@id/titleLayout" />
|
||||
|
||||
</androidx.constraintlayout.widget.ConstraintLayout>
|
||||
|
||||
</layout>
|
||||
|
|
|
@ -21,20 +21,17 @@ package me.proton.core.data.arch
|
|||
import me.proton.core.domain.arch.DataResult
|
||||
import me.proton.core.domain.arch.ResponseSource
|
||||
import me.proton.core.network.domain.ApiResult
|
||||
import me.proton.core.util.kotlin.Invokable
|
||||
import me.proton.core.util.kotlin.exhaustive
|
||||
|
||||
class ApiResultMapper : Invokable {
|
||||
fun <T> ApiResult<T>.toDataResponse(): DataResult<T> = when (this) {
|
||||
is ApiResult.Success -> DataResult.Success(value, ResponseSource.Remote)
|
||||
is ApiResult.Error.Http -> {
|
||||
DataResult.Error.Message(
|
||||
message = proton?.error ?: message,
|
||||
source = ResponseSource.Remote,
|
||||
code = proton?.code ?: 0 // 0 means no code is present
|
||||
)
|
||||
}
|
||||
is ApiResult.Error.Parse -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
|
||||
is ApiResult.Error.Connection -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
|
||||
}.exhaustive
|
||||
}
|
||||
fun <T> ApiResult<T>.toDataResponse(): DataResult<T> = when (this) {
|
||||
is ApiResult.Success -> DataResult.Success(value, ResponseSource.Remote)
|
||||
is ApiResult.Error.Http -> {
|
||||
DataResult.Error.Message(
|
||||
message = proton?.error ?: message,
|
||||
source = ResponseSource.Remote,
|
||||
code = proton?.code ?: 0 // 0 means no code is present
|
||||
)
|
||||
}
|
||||
is ApiResult.Error.Parse -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
|
||||
is ApiResult.Error.Connection -> DataResult.Error.Message(cause?.message, ResponseSource.Remote)
|
||||
}.exhaustive
|
||||
|
|
|
@ -31,6 +31,7 @@ import me.proton.core.network.data.protonApi.GenericResponse
|
|||
import me.proton.core.network.domain.ApiManager
|
||||
import me.proton.core.network.domain.ApiResult
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionProvider
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
@ -49,6 +50,8 @@ class HumanVerificationRemoteRepositoryImplTest {
|
|||
private val errorResponse = "test error response"
|
||||
private val errorResponseCode = 422
|
||||
|
||||
@RelaxedMockK
|
||||
private lateinit var sessionProvider: SessionProvider
|
||||
@RelaxedMockK
|
||||
private lateinit var apiFactory: ApiFactory
|
||||
private lateinit var apiProvider: ApiProvider
|
||||
|
@ -59,7 +62,7 @@ class HumanVerificationRemoteRepositoryImplTest {
|
|||
@Before
|
||||
fun before() {
|
||||
MockKAnnotations.init(this)
|
||||
apiProvider = ApiProvider(apiFactory)
|
||||
apiProvider = ApiProvider(apiFactory, sessionProvider)
|
||||
every { apiFactory.create(sessionId, HumanVerificationApi::class) } returns apiManager
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ package me.proton.core.humanverification.presentation.entity
|
|||
|
||||
import android.os.Parcelable
|
||||
import kotlinx.android.parcel.Parcelize
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
|
||||
/**
|
||||
* Human Verification result entity.
|
||||
|
@ -28,6 +27,6 @@ import me.proton.core.network.domain.session.SessionId
|
|||
@Parcelize
|
||||
data class HumanVerificationResult(
|
||||
val sessionId: String,
|
||||
val tokenType: String,
|
||||
val tokenCode: String
|
||||
val tokenType: String?,
|
||||
val tokenCode: String?
|
||||
) : Parcelable
|
||||
|
|
|
@ -49,8 +49,8 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
|
|||
|
||||
companion object {
|
||||
private const val ARG_SESSION_ID = "arg.sessionId"
|
||||
const val ARG_VERIFICATION_OPTIONS = "arg.verification-options"
|
||||
private const val ARG_CAPTCHA_TOKEN = "arg.captcha-token"
|
||||
const val ARG_VERIFICATION_OPTIONS = "arg.verification-options"
|
||||
const val ARG_DESTINATION = "arg.destination"
|
||||
const val ARG_TOKEN_CODE = "arg.token-code"
|
||||
const val ARG_TOKEN_TYPE = "arg.token-type"
|
||||
|
@ -65,7 +65,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
|
|||
* @param captchaToken if the API returns it, otherwise null
|
||||
*/
|
||||
operator fun invoke(
|
||||
sessionId: SessionId,
|
||||
sessionId: String,
|
||||
availableVerificationMethods: List<String>,
|
||||
captchaToken: String?
|
||||
) = HumanVerificationDialogFragment().apply {
|
||||
|
@ -81,7 +81,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
|
|||
private lateinit var resultListener: OnResultListener
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
requireArguments().get(ARG_SESSION_ID) as SessionId
|
||||
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
|
||||
}
|
||||
|
||||
private val captchaToken: String? by lazy {
|
||||
|
@ -187,13 +187,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
|
|||
|
||||
private fun onClose(tokenType: String? = null, tokenCode: String? = null) {
|
||||
if (!tokenType.isNullOrEmpty() && !tokenCode.isNullOrEmpty()) {
|
||||
resultListener.setResult(
|
||||
HumanVerificationResult(
|
||||
sessionId.id,
|
||||
tokenType,
|
||||
tokenCode
|
||||
)
|
||||
)
|
||||
resultListener.setResult(HumanVerificationResult(sessionId.id, tokenType, tokenCode))
|
||||
dismissAllowingStateLoss()
|
||||
return
|
||||
}
|
||||
|
@ -205,7 +199,7 @@ class HumanVerificationDialogFragment : ProtonDialogFragment<DialogHumanVerifica
|
|||
if (backStackEntryCount >= 1) {
|
||||
popBackStack()
|
||||
} else {
|
||||
resultListener.setResult(null)
|
||||
resultListener.setResult(HumanVerificationResult(sessionId.id, null, null))
|
||||
dismissAllowingStateLoss()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,8 +43,7 @@ import me.proton.core.network.domain.session.SessionId
|
|||
* @author Dino Kadrikj.
|
||||
*/
|
||||
@AndroidEntryPoint
|
||||
internal class HumanVerificationCaptchaFragment :
|
||||
ProtonFragment<FragmentHumanVerificationCaptchaBinding>() {
|
||||
internal class HumanVerificationCaptchaFragment : ProtonFragment<FragmentHumanVerificationCaptchaBinding>() {
|
||||
|
||||
companion object {
|
||||
private const val ARG_SESSION_ID = "arg.sessionId"
|
||||
|
@ -52,7 +51,7 @@ internal class HumanVerificationCaptchaFragment :
|
|||
private const val MAX_PROGRESS = 100
|
||||
|
||||
operator fun invoke(
|
||||
sessionId: SessionId,
|
||||
sessionId: String,
|
||||
urlToken: String,
|
||||
host: String
|
||||
) = HumanVerificationCaptchaFragment().apply {
|
||||
|
@ -65,7 +64,7 @@ internal class HumanVerificationCaptchaFragment :
|
|||
}
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
requireArguments().get(ARG_SESSION_ID) as SessionId
|
||||
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
|
||||
}
|
||||
|
||||
private val host: String by lazy {
|
||||
|
|
|
@ -49,7 +49,7 @@ internal class HumanVerificationEmailFragment : ProtonFragment<FragmentHumanVeri
|
|||
private const val ARG_RECOVERY_EMAIL = "arg.recoveryemail"
|
||||
|
||||
operator fun invoke(
|
||||
sessionId: SessionId,
|
||||
sessionId: String,
|
||||
token: String,
|
||||
recoveryEmailAddress: String? = null
|
||||
) = HumanVerificationEmailFragment().apply {
|
||||
|
@ -72,7 +72,7 @@ internal class HumanVerificationEmailFragment : ProtonFragment<FragmentHumanVeri
|
|||
}
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
requireArguments().get(ARG_SESSION_ID) as SessionId
|
||||
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
|
||||
}
|
||||
|
||||
private val recoveryEmailAddress: String? by lazy {
|
||||
|
|
|
@ -53,7 +53,7 @@ class HumanVerificationEnterCodeFragment :
|
|||
private const val ARG_TOKEN_TYPE = "arg.enter-code-token-type"
|
||||
|
||||
operator fun invoke(
|
||||
sessionId: SessionId,
|
||||
sessionId: String,
|
||||
tokenType: TokenType,
|
||||
destination: String?
|
||||
) = HumanVerificationEnterCodeFragment().apply {
|
||||
|
@ -68,7 +68,7 @@ class HumanVerificationEnterCodeFragment :
|
|||
private val viewModel by viewModels<HumanVerificationEnterCodeViewModel>()
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
requireArguments().get(ARG_SESSION_ID) as SessionId
|
||||
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
|
||||
}
|
||||
|
||||
private val destination: String? by lazy {
|
||||
|
|
|
@ -53,7 +53,10 @@ internal class HumanVerificationSMSFragment :
|
|||
internal const val KEY_COUNTRY_SELECTED = "key.country_selected"
|
||||
internal const val BUNDLE_KEY_COUNTRY = "bundle.country"
|
||||
|
||||
operator fun invoke(sessionId: SessionId, token: String) = HumanVerificationSMSFragment().apply {
|
||||
operator fun invoke(
|
||||
sessionId: String,
|
||||
token: String
|
||||
) = HumanVerificationSMSFragment().apply {
|
||||
arguments = bundleOf(
|
||||
ARG_SESSION_ID to sessionId,
|
||||
ARG_URL_TOKEN to token
|
||||
|
@ -64,7 +67,7 @@ internal class HumanVerificationSMSFragment :
|
|||
private val viewModel by viewModels<HumanVerificationSMSViewModel>()
|
||||
|
||||
private val sessionId: SessionId by lazy {
|
||||
requireArguments().get(ARG_SESSION_ID) as SessionId
|
||||
SessionId(requireArguments().getString(ARG_SESSION_ID)!!)
|
||||
}
|
||||
|
||||
private val humanVerificationBase by lazy {
|
||||
|
|
|
@ -56,7 +56,7 @@ fun FragmentManager.showHumanVerification(
|
|||
largeLayout: Boolean
|
||||
) {
|
||||
|
||||
val newFragment = HumanVerificationDialogFragment(sessionId, availableVerificationMethods, captchaToken)
|
||||
val newFragment = HumanVerificationDialogFragment(sessionId.id, availableVerificationMethods, captchaToken)
|
||||
if (largeLayout) {
|
||||
// For large screens (tablets), we show the fragment as a dialog
|
||||
newFragment.show(this, TAG_HUMAN_VERIFICATION_DIALOG)
|
||||
|
@ -79,7 +79,7 @@ internal fun FragmentManager.showHumanVerificationCaptchaContent(
|
|||
token: String?,
|
||||
host: String = HOST_DEFAULT
|
||||
): Fragment {
|
||||
val captchaFragment = HumanVerificationCaptchaFragment(sessionId, token ?: TOKEN_DEFAULT, host)
|
||||
val captchaFragment = HumanVerificationCaptchaFragment(sessionId.id, token ?: TOKEN_DEFAULT, host)
|
||||
inTransaction {
|
||||
setCustomAnimations(0, 0)
|
||||
replace(containerId, captchaFragment)
|
||||
|
@ -92,7 +92,7 @@ internal fun FragmentManager.showHumanVerificationEmailContent(
|
|||
sessionId: SessionId,
|
||||
token: String = TOKEN_DEFAULT
|
||||
) {
|
||||
val emailFragment = HumanVerificationEmailFragment(sessionId, token)
|
||||
val emailFragment = HumanVerificationEmailFragment(sessionId.id, token)
|
||||
inTransaction {
|
||||
setCustomAnimations(0, 0)
|
||||
replace(containerId, emailFragment)
|
||||
|
@ -104,7 +104,7 @@ internal fun FragmentManager.showHumanVerificationSMSContent(
|
|||
containerId: Int = android.R.id.content,
|
||||
token: String = TOKEN_DEFAULT
|
||||
) {
|
||||
val smsFragment = HumanVerificationSMSFragment(sessionId, token)
|
||||
val smsFragment = HumanVerificationSMSFragment(sessionId.id, token)
|
||||
inTransaction {
|
||||
setCustomAnimations(0, 0)
|
||||
replace(containerId, smsFragment)
|
||||
|
@ -116,7 +116,7 @@ internal fun FragmentManager.showEnterCode(
|
|||
tokenType: TokenType,
|
||||
destination: String?
|
||||
) {
|
||||
val enterCodeFragment = HumanVerificationEnterCodeFragment(sessionId, tokenType, destination)
|
||||
val enterCodeFragment = HumanVerificationEnterCodeFragment(sessionId.id, tokenType, destination)
|
||||
inTransaction {
|
||||
setCustomAnimations(0, 0)
|
||||
add(enterCodeFragment, TAG_HUMAN_VERIFICATION_ENTER_CODE)
|
||||
|
|
|
@ -36,6 +36,7 @@ dependencies {
|
|||
project(Module.kotlinUtil),
|
||||
project(Module.sharedPreferencesUtil),
|
||||
project(Module.networkDomain),
|
||||
project(Module.domain),
|
||||
|
||||
// Kotlin
|
||||
`kotlin-jdk7`,
|
||||
|
|
|
@ -18,22 +18,31 @@
|
|||
|
||||
package me.proton.core.network.data
|
||||
|
||||
import me.proton.core.domain.entity.UserId
|
||||
import me.proton.core.network.data.di.ApiFactory
|
||||
import me.proton.core.network.data.protonApi.BaseRetrofitApi
|
||||
import me.proton.core.network.domain.ApiManager
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionProvider
|
||||
import java.lang.ref.Reference
|
||||
import java.lang.ref.WeakReference
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ConcurrentMap
|
||||
|
||||
/**
|
||||
* Provide [ApiManager] instance bound to a specific [SessionId].
|
||||
*/
|
||||
class ApiProvider(
|
||||
val apiFactory: ApiFactory
|
||||
val apiFactory: ApiFactory,
|
||||
val sessionProvider: SessionProvider
|
||||
) {
|
||||
val instances: ConcurrentHashMap<String, ConcurrentHashMap<String, WeakReference<ApiManager<*>>>> =
|
||||
val instances: ConcurrentHashMap<String, ConcurrentHashMap<String, Reference<ApiManager<*>>>> =
|
||||
ConcurrentHashMap()
|
||||
|
||||
suspend inline fun <reified Api : BaseRetrofitApi> get(
|
||||
userId: UserId
|
||||
): ApiManager<out Api> = get(sessionProvider.getSessionId(userId))
|
||||
|
||||
inline fun <reified Api : BaseRetrofitApi> get(
|
||||
sessionId: SessionId? = null
|
||||
): ApiManager<out Api> {
|
||||
|
@ -44,7 +53,9 @@ class ApiProvider(
|
|||
val className = Api::class.java.name
|
||||
return instances
|
||||
.getOrPut(sessionName) { ConcurrentHashMap() }
|
||||
.getOrPut(className) { WeakReference(apiFactory.create(sessionId, Api::class)) }
|
||||
.get() as ApiManager<out Api>
|
||||
.getOrPutWeakRef(className) { apiFactory.create(sessionId, Api::class) } as ApiManager<out Api>
|
||||
}
|
||||
|
||||
fun <K, V> ConcurrentMap<K, Reference<V>>.getOrPutWeakRef(key: K, defaultValue: () -> V): V =
|
||||
this[key]?.get() ?: defaultValue().apply { put(key, WeakReference(this)) }
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
*/
|
||||
package me.proton.core.network.data
|
||||
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import me.proton.core.network.data.di.Constants
|
||||
import me.proton.core.network.data.protonApi.BaseRetrofitApi
|
||||
import me.proton.core.network.data.protonApi.ProtonErrorData
|
||||
|
@ -97,15 +98,13 @@ internal class ProtonApiBackend<Api : BaseRetrofitApi>(
|
|||
}
|
||||
|
||||
private fun handleTimeoutTag(chain: Interceptor.Chain): Interceptor.Chain {
|
||||
val tag = chain.request().tag()
|
||||
return if (tag is TimeoutOverride) {
|
||||
val tag = chain.request().tag(TimeoutOverride::class.java)
|
||||
return tag?.let {
|
||||
chain
|
||||
.withConnectTimeout(tag.connectionTimeoutSeconds, TimeUnit.SECONDS)
|
||||
.withReadTimeout(tag.readTimeoutSeconds, TimeUnit.SECONDS)
|
||||
.withWriteTimeout(tag.writeTimeoutSeconds, TimeUnit.SECONDS)
|
||||
} else {
|
||||
chain
|
||||
}
|
||||
} ?: chain
|
||||
}
|
||||
|
||||
private fun prepareHeaders(original: Request): Request.Builder {
|
||||
|
@ -118,7 +117,7 @@ internal class ProtonApiBackend<Api : BaseRetrofitApi>(
|
|||
request.header("Accept", "application/vnd.protonmail.v1+json")
|
||||
}
|
||||
|
||||
sessionId?.let { sessionProvider.getSession(it) }?.let { session ->
|
||||
sessionId?.let { runBlocking { sessionProvider.getSession(it) } }?.let { session ->
|
||||
session.headers?.let {
|
||||
request.header("x-pm-human-verification-token-type", it.tokenType)
|
||||
request.header("x-pm-human-verification-token", it.tokenCode)
|
||||
|
|
|
@ -100,7 +100,8 @@ internal class ApiManagerTests {
|
|||
apiClient = MockApiClient()
|
||||
|
||||
session = MockSession.getDefault()
|
||||
every { sessionProvider.getSession(any()) } returns session
|
||||
coEvery { sessionProvider.getSessionId(any()) } returns session.sessionId
|
||||
coEvery { sessionProvider.getSession(any()) } returns session
|
||||
|
||||
networkManager = MockNetworkManager()
|
||||
networkManager.networkStatus = NetworkStatus.Unmetered
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
package me.proton.core.network.data
|
||||
|
||||
import io.mockk.MockKAnnotations
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.mockk
|
||||
|
@ -127,7 +128,8 @@ internal class HumanVerificationTests {
|
|||
prefs = MockNetworkPrefs()
|
||||
|
||||
session = MockSession.getDefault()
|
||||
every { sessionProvider.getSession(any()) } returns session
|
||||
coEvery { sessionProvider.getSessionId(any()) } returns session.sessionId
|
||||
coEvery { sessionProvider.getSession(any()) } returns session
|
||||
|
||||
apiFactory =
|
||||
ApiFactory(
|
||||
|
@ -218,7 +220,7 @@ internal class HumanVerificationTests {
|
|||
)
|
||||
)
|
||||
|
||||
every { sessionProvider.getSession(any()) } returns MockSession.getWithHeader(
|
||||
coEvery { sessionProvider.getSession(any()) } returns MockSession.getWithHeader(
|
||||
humanVerificationHeaders
|
||||
)
|
||||
|
||||
|
|
|
@ -58,10 +58,14 @@ internal class NetworkManagerTests {
|
|||
networkManager.networkStatus = NetworkStatus.Unmetered
|
||||
flow2.cancel()
|
||||
|
||||
assertEquals(listOf(NetworkStatus.Unmetered, NetworkStatus.Metered, NetworkStatus.Disconnected),
|
||||
collectedStates1.toList())
|
||||
assertEquals(listOf(NetworkStatus.Metered, NetworkStatus.Disconnected, NetworkStatus.Unmetered),
|
||||
collectedStates2.toList())
|
||||
assertEquals(
|
||||
listOf(NetworkStatus.Unmetered, NetworkStatus.Metered, NetworkStatus.Disconnected),
|
||||
collectedStates1.toList()
|
||||
)
|
||||
assertEquals(
|
||||
listOf(NetworkStatus.Metered, NetworkStatus.Disconnected, NetworkStatus.Unmetered),
|
||||
collectedStates2.toList()
|
||||
)
|
||||
|
||||
assertFalse(networkManager.registered)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package me.proton.core.network.data
|
||||
|
||||
import io.mockk.MockKAnnotations
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
|
@ -92,7 +93,8 @@ internal class ProtonApiBackendTests {
|
|||
prefs = MockNetworkPrefs()
|
||||
|
||||
session = MockSession.getDefault()
|
||||
every { sessionProvider.getSession(any()) } returns session
|
||||
coEvery { sessionProvider.getSessionId(any()) } returns session.sessionId
|
||||
coEvery { sessionProvider.getSession(any()) } returns session
|
||||
|
||||
apiFactory = ApiFactory(
|
||||
"https://example.com/",
|
||||
|
|
|
@ -31,6 +31,7 @@ dependencies {
|
|||
implementation(
|
||||
|
||||
project(Module.kotlinUtil),
|
||||
project(Module.domain),
|
||||
|
||||
// Kotlin
|
||||
`kotlin-jdk7`,
|
||||
|
|
|
@ -18,6 +18,16 @@
|
|||
|
||||
package me.proton.core.network.domain.session
|
||||
|
||||
import me.proton.core.domain.entity.UserId
|
||||
|
||||
interface SessionProvider {
|
||||
fun getSession(sessionId: SessionId): Session?
|
||||
/**
|
||||
* Get [Session], if exist, by sessionId.
|
||||
*/
|
||||
suspend fun getSession(sessionId: SessionId): Session?
|
||||
|
||||
/**
|
||||
* Get [SessionId], if exist, by userId.
|
||||
*/
|
||||
suspend fun getSessionId(userId: UserId): SessionId?
|
||||
}
|
||||
|
|
|
@ -27,10 +27,12 @@ import kotlinx.coroutines.test.runBlockingTest
|
|||
import me.proton.core.network.domain.handlers.HumanVerificationHandler
|
||||
import me.proton.core.network.domain.humanverification.HumanVerificationDetails
|
||||
import me.proton.core.network.domain.humanverification.VerificationMethod
|
||||
import me.proton.core.network.domain.session.Session
|
||||
import me.proton.core.network.domain.session.SessionId
|
||||
import me.proton.core.network.domain.session.SessionListener
|
||||
import me.proton.core.network.domain.session.SessionProvider
|
||||
import org.junit.Test
|
||||
import kotlin.test.BeforeTest
|
||||
import kotlin.test.assertNotNull
|
||||
|
||||
/**
|
||||
|
@ -39,12 +41,18 @@ import kotlin.test.assertNotNull
|
|||
class HumanVerificationHandlerTest {
|
||||
|
||||
private val sessionId: SessionId = SessionId("id")
|
||||
private val session = mockk<Session>(relaxed = true)
|
||||
private val sessionListener = mockk<SessionListener>(relaxed = true)
|
||||
private val sessionProvider = mockk<SessionProvider>(relaxed = true)
|
||||
|
||||
val scope = CoroutineScope(TestCoroutineDispatcher())
|
||||
val apiBackend = mockk<ApiBackend<Any>>()
|
||||
|
||||
@BeforeTest
|
||||
fun beforeTest() {
|
||||
coEvery { sessionProvider.getSession(any()) } returns session
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test human verification called`() = runBlockingTest {
|
||||
val humanVerificationDetails =
|
||||
|
@ -59,7 +67,12 @@ class HumanVerificationHandlerTest {
|
|||
)
|
||||
)
|
||||
|
||||
coEvery { sessionListener.onHumanVerificationNeeded(any(), any()) } returns SessionListener.HumanVerificationResult.Success
|
||||
coEvery {
|
||||
sessionListener.onHumanVerificationNeeded(
|
||||
any(),
|
||||
any()
|
||||
)
|
||||
} returns SessionListener.HumanVerificationResult.Success
|
||||
coEvery { apiBackend.invoke<Any>(any()) } returns ApiResult.Success("test")
|
||||
|
||||
val humanVerificationHandler =
|
||||
|
|
Loading…
Reference in New Issue