Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(e2ei): respect E2EI during login and MLS client creation (WPB-5851) #2633

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.android.di

import com.wire.kalium.logic.CoreLogic
import com.wire.kalium.logic.data.user.UserId
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject

class ObserveIfE2EIRequiredDuringLoginUseCaseProvider @AssistedInject constructor(
@KaliumCoreLogic private val coreLogic: CoreLogic,
@Assisted
private val userId: UserId
) {
suspend fun observeIfE2EIIsRequiredDuringLogin() = coreLogic.getSessionScope(userId).observeIfE2EIRequiredDuringLogin()

@AssistedFactory
interface Factory {
fun create(userId: UserId): ObserveIfE2EIRequiredDuringLoginUseCaseProvider
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.asset.DeleteAssetUseCase
import com.wire.kalium.logic.feature.asset.GetAssetSizeLimitUseCase
import com.wire.kalium.logic.feature.asset.GetAvatarAssetUseCase
import com.wire.kalium.logic.feature.client.FinalizeMLSClientAfterE2EIEnrollment
import com.wire.kalium.logic.feature.conversation.GetAllContactsNotInConversationUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
Expand Down Expand Up @@ -117,6 +118,11 @@ class UserModule {
fun provideEnrollE2EIUseCase(userScope: UserScope): EnrollE2EIUseCase =
userScope.enrollE2EI

@ViewModelScoped
@Provides
fun provideFinalizeMLSClientAfterE2EIEnrollmentUseCase(userScope: UserScope): FinalizeMLSClientAfterE2EIEnrollment =
userScope.finalizeMLSClientAfterE2EIEnrollment

@ViewModelScoped
@Provides
fun provideObserveTypingIndicatorEnabled(userScope: UserScope): ObserveTypingIndicatorEnabledUseCase =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ class GetE2EICertificateUseCase @Inject constructor(
private lateinit var initialEnrollmentResult: E2EIEnrollmentResult.Initialized
lateinit var enrollmentResultHandler: (Either<E2EIFailure, E2EIEnrollmentResult>) -> Unit

operator fun invoke(context: Context, enrollmentResultHandler: (Either<CoreFailure, E2EIEnrollmentResult>) -> Unit) {
operator fun invoke(
context: Context,
isNewClient: Boolean,
enrollmentResultHandler: (Either<CoreFailure, E2EIEnrollmentResult>) -> Unit
) {
this.enrollmentResultHandler = enrollmentResultHandler
scope.launch {
enrollE2EI.initialEnrollment().fold({
enrollE2EI.initialEnrollment(isNewClientRegistration = isNewClient).fold({
enrollmentResultHandler(Either.Left(it))
}, {
if (it is E2EIEnrollmentResult.Initialized) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import androidx.activity.result.ActivityResultRegistry
import androidx.activity.result.contract.ActivityResultContracts
import com.wire.android.appLogger
import com.wire.android.util.deeplink.DeepLinkProcessor
import com.wire.android.util.removeQueryParams
import kotlinx.serialization.json.JsonObject
import net.openid.appauth.AppAuthConfiguration
import net.openid.appauth.AuthState
Expand All @@ -44,6 +45,7 @@ import net.openid.appauth.browser.VersionedBrowserMatcher
import net.openid.appauth.connectivity.ConnectionBuilder
import org.json.JSONObject
import java.net.HttpURLConnection
import java.net.URI
import java.net.URL
import java.security.MessageDigest
import java.security.SecureRandom
Expand Down Expand Up @@ -119,7 +121,7 @@ class OAuthUseCase(context: Context, private val authUrl: String, private val cl
handleActivityResult(result, resultHandler)
}
AuthorizationServiceConfiguration.fetchFromUrl(
Uri.parse(authUrl.plus(IDP_CONFIGURATION_PATH)),
Uri.parse(URI(authUrl).removeQueryParams().toString().plus(IDP_CONFIGURATION_PATH)),
{ configuration, ex ->
if (ex == null) {
authServiceConfig = configuration!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MigrateClientsDataUseCase @Inject constructor(
private val scalaUserDBProvider: ScalaUserDatabaseProvider,
private val userDataStoreProvider: UserDataStoreProvider
) {
@Suppress("ReturnCount")
@Suppress("ReturnCount", "ComplexMethod")
suspend operator fun invoke(userId: UserId, isFederated: Boolean): Either<CoreFailure, Unit> =
scalaUserDBProvider.clientDAO(userId.value).flatMap { clientDAO ->
val clientId = clientDAO.clientInfo()?.clientId?.let { ClientId(it) }
Expand Down Expand Up @@ -103,6 +103,19 @@ class MigrateClientsDataUseCase @Inject constructor(
userDataStoreProvider.getOrCreate(userId).setInitialSyncCompleted()
}
}

is RegisterClientResult.E2EICertificateRequired ->
withTimeoutOrNull(SYNC_START_TIMEOUT) {
syncManager.waitUntilStartedOrFailure()
}.let {
it ?: Either.Left(NetworkFailure.NoNetworkConnection(null))
}.flatMap {
syncManager.waitUntilLiveOrFailure()
.onSuccess {
userDataStoreProvider.getOrCreate(userId).setInitialSyncCompleted()
TODO() // TODO: ask question about this!
}
}
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions app/src/main/kotlin/com/wire/android/ui/WireActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import com.wire.android.ui.common.topappbar.CommonTopAppBar
import com.wire.android.ui.common.topappbar.CommonTopAppBarViewModel
import com.wire.android.ui.common.visbility.rememberVisibilityState
import com.wire.android.ui.destinations.ConversationScreenDestination
import com.wire.android.ui.destinations.E2EIEnrollmentScreenDestination
import com.wire.android.ui.destinations.E2eiCertificateDetailsScreenDestination
import com.wire.android.ui.destinations.HomeScreenDestination
import com.wire.android.ui.destinations.ImportMediaScreenDestination
Expand Down Expand Up @@ -166,9 +167,9 @@ class WireActivity : AppCompatActivity() {
val startDestination = when (viewModel.initialAppState) {
InitialAppState.NOT_MIGRATED -> MigrationScreenDestination
InitialAppState.NOT_LOGGED_IN -> WelcomeScreenDestination
InitialAppState.LOGGED_IN -> HomeScreenDestination
}

InitialAppState.ENROLL_E2EI -> E2EIEnrollmentScreenDestination
InitialAppState.LOGGED_IN -> HomeScreenDestination
}
appLogger.i("$TAG composable content")
setComposableContent(startDestination) {
appLogger.i("$TAG splash hide")
Expand Down
27 changes: 26 additions & 1 deletion app/src/main/kotlin/com/wire/android/ui/WireActivityViewModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.wire.android.appLogger
import com.wire.android.datastore.GlobalDataStore
import com.wire.android.di.AuthServerConfigProvider
import com.wire.android.di.KaliumCoreLogic
import com.wire.android.di.ObserveIfE2EIRequiredDuringLoginUseCaseProvider
import com.wire.android.di.ObserveScreenshotCensoringConfigUseCaseProvider
import com.wire.android.di.ObserveSyncStateUseCaseProvider
import com.wire.android.feature.AccountSwitchUseCase
Expand Down Expand Up @@ -111,6 +112,7 @@ class WireActivityViewModel @Inject constructor(
private val currentScreenManager: CurrentScreenManager,
private val observeScreenshotCensoringConfigUseCaseProviderFactory: ObserveScreenshotCensoringConfigUseCaseProvider.Factory,
private val globalDataStore: GlobalDataStore,
private val observeIfE2EIRequiredDuringLoginUseCaseProviderFactory: ObserveIfE2EIRequiredDuringLoginUseCaseProvider.Factory
) : ViewModel() {

var globalAppState: GlobalAppState by mutableStateOf(GlobalAppState())
Expand Down Expand Up @@ -143,12 +145,16 @@ class WireActivityViewModel @Inject constructor(
private val _observeSyncFlowState: MutableStateFlow<SyncState?> = MutableStateFlow(null)
val observeSyncFlowState: StateFlow<SyncState?> = _observeSyncFlowState

private val _observeE2EIState: MutableStateFlow<Boolean?> = MutableStateFlow(null)
private val observeE2EIState: StateFlow<Boolean?> = _observeE2EIState

init {
observeSyncState()
observeUpdateAppState()
observeNewClientState()
observeScreenshotCensoringConfigState()
observeAppThemeState()
observerE2EIState()
}

private fun observeAppThemeState() {
Expand All @@ -161,6 +167,18 @@ class WireActivityViewModel @Inject constructor(
}
}

fun observerE2EIState() {
viewModelScope.launch(dispatchers.io()) {
observeUserId
.flatMapLatest {
it?.let { observeIfE2EIRequiredDuringLoginUseCaseProviderFactory.create(it).observeIfE2EIIsRequiredDuringLogin() }
?: flowOf(null)
}
.distinctUntilChanged()
.collect { _observeE2EIState.emit(it) }
}
}

private fun observeSyncState() {
viewModelScope.launch(dispatchers.io()) {
observeUserId
Expand Down Expand Up @@ -234,6 +252,7 @@ class WireActivityViewModel @Inject constructor(
get() = when {
shouldMigrate() -> InitialAppState.NOT_MIGRATED
shouldLogIn() -> InitialAppState.NOT_LOGGED_IN
blockedByE2EI() -> InitialAppState.ENROLL_E2EI
else -> InitialAppState.LOGGED_IN
}

Expand Down Expand Up @@ -264,8 +283,10 @@ class WireActivityViewModel @Inject constructor(
// to handle the deepLinks above user needs to be Logged in
// do nothing, already handled by initialAppState
}

result is DeepLinkResult.JoinConversation ->
onConversationInviteDeepLink(result.code, result.key, result.domain, onOpenConversation)

result != null -> onResult(result)
result is DeepLinkResult.Unknown -> appLogger.e("unknown deeplink result $result")
}
Expand Down Expand Up @@ -413,6 +434,10 @@ class WireActivityViewModel @Inject constructor(

fun shouldLogIn(): Boolean = !hasValidCurrentSession()

fun blockedByE2EI(): Boolean {
return observeE2EIState.value == true
}

private fun hasValidCurrentSession(): Boolean = runBlocking {
// TODO: the usage of currentSessionFlow is a temporary solution, it should be replaced with a proper solution
currentSessionFlow().first().let {
Expand Down Expand Up @@ -532,5 +557,5 @@ data class GlobalAppState(
)

enum class InitialAppState {
NOT_MIGRATED, NOT_LOGGED_IN, LOGGED_IN
NOT_MIGRATED, NOT_LOGGED_IN, LOGGED_IN, ENROLL_E2EI
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ class CreateAccountCodeViewModel @Inject constructor(
is RegisterClientResult.Success -> {
onSuccess()
}

is RegisterClientResult.E2EICertificateRequired -> {
// TODO
onSuccess()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ data class Device(
mlsPublicKeys = client.mlsPublicKeys,
e2eiCertificateStatus = e2eiCertificateStatus
)

fun updateFromClient(client: Client): Device = copy(
name = client.displayName(),
clientId = client.id,
registrationTime = client.registrationTime?.toIsoDateTimeString(),
lastActiveInWholeWeeks = client.lastActiveInWholeWeeks(),
isValid = client.isValid,
isVerifiedProteus = client.isVerified,
mlsPublicKeys = client.mlsPublicKeys,
)

fun updateE2EICertificateStatus(e2eiCertificateStatus: CertificateStatus): Device = copy(
e2eiCertificateStatus = e2eiCertificateStatus
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import com.wire.android.ui.common.textfield.clearAutofillTree
import com.wire.android.ui.common.topappbar.NavigationIconType
import com.wire.android.ui.common.topappbar.WireCenterAlignedTopAppBar
import com.wire.android.ui.common.visbility.rememberVisibilityState
import com.wire.android.ui.destinations.E2EIEnrollmentScreenDestination
import com.wire.android.ui.destinations.HomeScreenDestination
import com.wire.android.ui.destinations.InitialSyncScreenDestination
import com.wire.android.ui.destinations.RemoveDeviceScreenDestination
Expand All @@ -81,11 +82,14 @@ fun RegisterDeviceScreen(navigator: Navigator) {
is RegisterDeviceFlowState.Success -> {
navigator.navigate(
NavigationCommand(
destination = if (flowState.initialSyncCompleted) HomeScreenDestination else InitialSyncScreenDestination,
destination = if (flowState.isE2EIRequired) E2EIEnrollmentScreenDestination
else if (flowState.initialSyncCompleted) HomeScreenDestination
else InitialSyncScreenDestination,
backStackMode = BackStackMode.CLEAR_WHOLE
)
)
}

is RegisterDeviceFlowState.TooManyDevices -> navigator.navigate(NavigationCommand(RemoveDeviceScreenDestination))
else ->
RegisterDeviceContent(
Expand Down Expand Up @@ -189,6 +193,7 @@ private fun PasswordTextField(state: RegisterDeviceState, onPasswordChange: (Tex
state = when (state.flowState) {
is RegisterDeviceFlowState.Error.InvalidCredentialsError ->
WireTextFieldState.Error(stringResource(id = R.string.remove_device_invalid_password))

else -> WireTextFieldState.Default
},
imeAction = ImeAction.Done,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,26 @@ package com.wire.android.ui.authentication.devices.register

import androidx.compose.ui.text.input.TextFieldValue
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.UserId

data class RegisterDeviceState(
val password: TextFieldValue = TextFieldValue(""),
val continueEnabled: Boolean = false,
val flowState: RegisterDeviceFlowState = RegisterDeviceFlowState.Default
)

sealed class RegisterDeviceFlowState {
object Default : RegisterDeviceFlowState()
object Loading : RegisterDeviceFlowState()
object TooManyDevices : RegisterDeviceFlowState()
data class Success(val initialSyncCompleted: Boolean) : RegisterDeviceFlowState()
data class Success(
val initialSyncCompleted: Boolean,
val isE2EIRequired: Boolean,
val clientId: ClientId,
val userId: UserId? = null
) : RegisterDeviceFlowState()

sealed class Error : RegisterDeviceFlowState() {
object InvalidCredentialsError : Error()
data class GenericError(val coreFailure: CoreFailure) : Error()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,25 @@ class RegisterDeviceViewModel @Inject constructor(
)) {
is RegisterClientResult.Failure.TooManyClients ->
updateFlowState(RegisterDeviceFlowState.TooManyDevices)

is RegisterClientResult.Success ->
updateFlowState(RegisterDeviceFlowState.Success(userDataStore.initialSyncCompleted.first()))
updateFlowState(
RegisterDeviceFlowState.Success(
userDataStore.initialSyncCompleted.first(),
false,
registerDeviceResult.client.id
)
)

is RegisterClientResult.E2EICertificateRequired ->
updateFlowState(
RegisterDeviceFlowState.Success(
userDataStore.initialSyncCompleted.first(),
true,
registerDeviceResult.client.id,
registerDeviceResult.userId
)
)

is RegisterClientResult.Failure.Generic -> state = state.copy(
continueEnabled = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import com.wire.android.ui.common.divider.WireDivider
import com.wire.android.ui.common.rememberTopBarElevationState
import com.wire.android.ui.common.textfield.clearAutofillTree
import com.wire.android.ui.common.visbility.rememberVisibilityState
import com.wire.android.ui.destinations.E2EIEnrollmentScreenDestination
import com.wire.android.ui.destinations.HomeScreenDestination
import com.wire.android.ui.destinations.InitialSyncScreenDestination
import com.wire.android.util.dialogErrorStrings
Expand All @@ -73,9 +74,11 @@ fun RemoveDeviceScreen(navigator: Navigator) {
val state: RemoveDeviceState = viewModel.state
val clearSessionState: ClearSessionState = clearSessionViewModel.state

fun navigateAfterSuccess(initialSyncCompleted: Boolean) = navigator.navigate(
fun navigateAfterSuccess(initialSyncCompleted: Boolean, isE2EIRequired: Boolean) = navigator.navigate(
NavigationCommand(
destination = if (initialSyncCompleted) HomeScreenDestination else InitialSyncScreenDestination,
destination = if (isE2EIRequired) E2EIEnrollmentScreenDestination
else if (initialSyncCompleted) HomeScreenDestination
else InitialSyncScreenDestination,
backStackMode = BackStackMode.CLEAR_WHOLE
)
)
Expand All @@ -84,9 +87,9 @@ fun RemoveDeviceScreen(navigator: Navigator) {
RemoveDeviceContent(
state = state,
clearSessionState = clearSessionState,
onItemClicked = { viewModel.onItemClicked(it) { navigateAfterSuccess(it) } },
onItemClicked = { viewModel.onItemClicked(it, ::navigateAfterSuccess) },
onPasswordChange = viewModel::onPasswordChange,
onRemoveConfirm = { viewModel.onRemoveConfirmed { navigateAfterSuccess(it) } },
onRemoveConfirm = { viewModel.onRemoveConfirmed(::navigateAfterSuccess) },
onDialogDismiss = viewModel::onDialogDismissed,
onErrorDialogDismiss = viewModel::clearDeleteClientError,
onBackButtonClicked = clearSessionViewModel::onBackButtonClicked,
Expand Down
Loading
Loading