Skip to content

Commit

Permalink
Fix #7347 - Catch errors loading tflite models in Identity SDK (#9812)
Browse files Browse the repository at this point in the history
* Validates that cached models are valid, exits flow with IdentityVerificationSheet.VerificationFlowResult.Failed if not
* Handle error If an invalid model is still somehow still loaded, preventing an app crash
  • Loading branch information
kentwilliams-stripe authored Dec 20, 2024
1 parent 75d9647 commit dd48576
Show file tree
Hide file tree
Showing 18 changed files with 117 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ interface ScanFlow<Parameters, DataType> {
* @param lifecycleOwner: The activity that owns this flow. The flow will pause if the activity
* is paused
* @param coroutineScope: The coroutine scope used to run async tasks for this flow
* @param errorHandler: A handler to report errors to
*/
fun startFlow(
context: Context,
imageStream: Flow<DataType>,
viewFinder: Rect,
lifecycleOwner: LifecycleOwner,
coroutineScope: CoroutineScope,
parameters: Parameters
parameters: Parameters,
errorHandler: (e: Exception) -> Unit
)

/**
Expand Down
2 changes: 1 addition & 1 deletion identity/detekt-baseline.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<ID>LongMethod:DebugScreen.kt$@Composable internal fun CompleteWithTestDataSection( onClickSubmit: (CompleteOption) -> Unit )</ID>
<ID>LongMethod:DebugScreen.kt$@Composable internal fun DebugScreen( navController: NavController, identityViewModel: IdentityViewModel, verificationFlowFinishable: VerificationFlowFinishable )</ID>
<ID>LongMethod:DocWarmupScreen.kt$@Composable internal fun DocWarmupView( documentSelectPage: VerificationPageStaticContentDocumentSelectPage, onContinueClick: () -> Unit )</ID>
<ID>LongMethod:DocumentScanScreen.kt$@Composable internal fun DocumentScanScreen( navController: NavController, identityViewModel: IdentityViewModel, documentScanViewModel: DocumentScanViewModel )</ID>
<ID>LongMethod:DocumentScanScreen.kt$@Composable internal fun DocumentScanScreen( navController: NavController, identityViewModel: IdentityViewModel, documentScanViewModel: DocumentScanViewModel, )</ID>
<ID>LongMethod:DocumentScanScreen.kt$@Composable private fun DocumentCaptureScreen( documentScannerState: IdentityScanViewModel.State, @StringRes feedback: Int, targetScanType: IdentityScanState.ScanType?, identityScanViewModel: IdentityScanViewModel, identityViewModel: IdentityViewModel, lifecycleOwner: LifecycleOwner, cameraManager: IdentityCameraManager, onContinueClick: () -> Unit )</ID>
<ID>LongMethod:ErrorScreen.kt$@Composable internal fun ErrorScreen( identityViewModel: IdentityViewModel, title: String, modifier: Modifier = Modifier, message1: String? = null, message2: String? = null, topButton: ErrorScreenButton? = null, bottomButton: ErrorScreenButton? = null, )</ID>
<ID>LongMethod:IDNumberSection.kt$@Composable internal fun IDNumberSection( enabled: Boolean, idNumberCountries: List&lt;Country>, countryNotListedText: String, navController: NavController, onIdNumberCollected: (Resource&lt;IdNumberParam>) -> Unit )</ID>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import com.stripe.android.identity.navigation.navigateToFinalErrorScreen
import com.stripe.android.identity.ui.IdentityTheme
import com.stripe.android.identity.ui.IdentityTopBarState
import com.stripe.android.identity.viewmodel.IdentityViewModel
import kotlinx.coroutines.launch
import javax.inject.Inject
import javax.inject.Provider
import kotlin.coroutines.CoroutineContext
Expand All @@ -64,7 +63,8 @@ internal class IdentityActivity :
{ application },
{ uiContext },
{ workContext },
{ subcomponent }
{ subcomponent },
{ finishWithResult(it) }
)

private val starterArgs: IdentityVerificationSheetContract.Args by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import com.stripe.android.identity.networking.models.VerificationPage
import com.stripe.android.identity.states.IdentityScanState
import com.stripe.android.identity.states.LaplacianBlurDetector
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File

/**
Expand Down Expand Up @@ -80,7 +82,8 @@ internal class IdentityScanFlow(
viewFinder: Rect,
lifecycleOwner: LifecycleOwner,
coroutineScope: CoroutineScope,
parameters: IdentityScanState.ScanType
parameters: IdentityScanState.ScanType,
errorHandler: (e: Exception) -> Unit,
) {
coroutineScope.launch {
if (canceled) {
Expand All @@ -94,26 +97,34 @@ internal class IdentityScanFlow(

requireNotNull(aggregator).bindToLifecycle(lifecycleOwner)

analyzerPool =
AnalyzerPool.of(
if (parameters == IdentityScanState.ScanType.SELFIE) {
FaceDetectorAnalyzer.Factory(
requireNotNull(faceDetectorModelFile) {
"Failed to initialize FaceDetectorAnalyzer, " +
"faceDetectorModelFile is null"
},
modelPerformanceTracker
)
} else {
IDDetectorAnalyzer.Factory(
idDetectorModelFile,
verificationPage.documentCapture.models.idDetectorMinScore,
modelPerformanceTracker,
laplacianBlurDetector,
identityAnalyticsRequestFactory,
)
}
)
try {
analyzerPool =
AnalyzerPool.of(
if (parameters == IdentityScanState.ScanType.SELFIE) {
FaceDetectorAnalyzer.Factory(
requireNotNull(faceDetectorModelFile) {
"Failed to initialize FaceDetectorAnalyzer, " +
"faceDetectorModelFile is null"
},
modelPerformanceTracker
)
} else {
IDDetectorAnalyzer.Factory(
idDetectorModelFile,
verificationPage.documentCapture.models.idDetectorMinScore,
modelPerformanceTracker,
laplacianBlurDetector,
identityAnalyticsRequestFactory,
)
}
)
} catch (e: IllegalStateException) {
withContext(Dispatchers.Main) {
errorHandler(e)
}

return@launch
}

loop = ProcessBoundAnalyzerLoop(
analyzerPool = requireNotNull(analyzerPool),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ internal object DocWarmupDestination : IdentityTopLevelDestination(

override val destinationRoute = ROUTE
}

private const val DOC_WARMUP = "DocWarmup"
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ internal fun IdentityNavGraph(
DocumentScanScreen(
navController = navController,
identityViewModel = identityViewModel,
documentScanViewModel = documentScanViewModel
documentScanViewModel = documentScanViewModel,
)
}
screen(SelfieWarmupDestination.ROUTE) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.stripe.android.identity.networking

import com.stripe.android.identity.utils.IdentityIO
import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.impl.InterpreterWrapperImpl
import java.io.File
import javax.inject.Inject

Expand All @@ -11,11 +13,29 @@ internal class DefaultIdentityModelFetcher @Inject constructor(
override suspend fun fetchIdentityModel(modelUrl: String): File {
// Use the filename as a look up key
identityIO.createTFLiteFile(modelUrl).let { tfliteFile ->
return if (tfliteFile.exists()) {
return if (tfliteFile.exists() && validateModel(tfliteFile)) {
tfliteFile
} else {
identityRepository.downloadModel(modelUrl)
identityRepository.downloadModel(modelUrl).also {
if (!validateModel(tfliteFile)) {
throw IllegalStateException("Invalid TFLite model, likely a corrupted download")
}
}
}
}
}

private fun validateModel(modelFile: File): Boolean {
// Try to load the model file
@Suppress("SwallowedException")
return try {
InterpreterWrapperImpl(
modelFile,
InterpreterOptionsWrapper.Builder().build()
)
true
} catch (e: IllegalStateException) {
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ internal const val VIEW_FINDER_ASPECT_RATIO = 1f
internal fun DocumentScanScreen(
navController: NavController,
identityViewModel: IdentityViewModel,
documentScanViewModel: DocumentScanViewModel
documentScanViewModel: DocumentScanViewModel,
) {
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import androidx.lifecycle.viewModelScope
import androidx.lifecycle.viewmodel.CreationExtras
import com.stripe.android.core.utils.requireApplication
import com.stripe.android.identity.R
import com.stripe.android.identity.VerificationFlowFinishable
import com.stripe.android.identity.analytics.FPSTracker
import com.stripe.android.identity.analytics.IdentityAnalyticsRequestFactory
import com.stripe.android.identity.analytics.ModelPerformanceTracker
Expand All @@ -26,13 +27,15 @@ internal class DocumentScanViewModel(
override val fpsTracker: FPSTracker,
override val identityAnalyticsRequestFactory: IdentityAnalyticsRequestFactory,
modelPerformanceTracker: ModelPerformanceTracker,
laplacianBlurDetector: LaplacianBlurDetector
laplacianBlurDetector: LaplacianBlurDetector,
verificationFlowFinishable: VerificationFlowFinishable
) : IdentityScanViewModel(
applicationContext,
fpsTracker,
identityAnalyticsRequestFactory,
modelPerformanceTracker,
laplacianBlurDetector
laplacianBlurDetector,
verificationFlowFinishable
) {

@OptIn(FlowPreview::class)
Expand All @@ -48,7 +51,6 @@ internal class DocumentScanViewModel(
R.string.stripe_position_id_back
}
}

is State.Scanned -> R.string.stripe_scanned
is State.Scanning -> {
when (scannerState.scanState) {
Expand Down Expand Up @@ -87,6 +89,7 @@ internal class DocumentScanViewModel(
}

internal class DocumentScanViewModelFactory @Inject constructor(
private val verificationFlowFinishable: VerificationFlowFinishable,
private val modelPerformanceTracker: ModelPerformanceTracker,
private val laplacianBlurDetector: LaplacianBlurDetector,
private val fpsTracker: FPSTracker,
Expand All @@ -99,7 +102,8 @@ internal class DocumentScanViewModel(
fpsTracker,
identityAnalyticsRequestFactory,
modelPerformanceTracker,
laplacianBlurDetector
laplacianBlurDetector,
verificationFlowFinishable
) as T
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import android.app.Application
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.viewModelScope
import com.stripe.android.camera.scanui.util.asRect
import com.stripe.android.identity.IdentityVerificationSheet
import com.stripe.android.identity.VerificationFlowFinishable
import com.stripe.android.identity.analytics.FPSTracker
import com.stripe.android.identity.analytics.IdentityAnalyticsRequestFactory
import com.stripe.android.identity.analytics.ModelPerformanceTracker
Expand All @@ -22,7 +24,8 @@ internal abstract class IdentityScanViewModel(
open val fpsTracker: FPSTracker,
open val identityAnalyticsRequestFactory: IdentityAnalyticsRequestFactory,
modelPerformanceTracker: ModelPerformanceTracker,
laplacianBlurDetector: LaplacianBlurDetector
laplacianBlurDetector: LaplacianBlurDetector,
private val verificationFlowFinishable: VerificationFlowFinishable
) :
CameraViewModel(
modelPerformanceTracker,
Expand Down Expand Up @@ -133,7 +136,12 @@ internal abstract class IdentityScanViewModel(
viewFinder = cameraManager.requireCameraView().viewFinderWindowView.asRect(),
lifecycleOwner = lifecycleOwner,
coroutineScope = viewModelScope,
parameters = scanType
parameters = scanType,
errorHandler = { e ->
verificationFlowFinishable.finishWithResult(
IdentityVerificationSheet.VerificationFlowResult.Failed(e)
)
}
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.stripe.android.camera.framework.image.longerEdge
import com.stripe.android.core.injection.IOContext
import com.stripe.android.core.injection.UIContext
import com.stripe.android.core.model.StripeFilePurpose
import com.stripe.android.identity.IdentityVerificationSheet
import com.stripe.android.identity.IdentityVerificationSheetContract
import com.stripe.android.identity.analytics.AnalyticsState
import com.stripe.android.identity.analytics.IdentityAnalyticsRequestFactory
Expand Down Expand Up @@ -100,7 +101,7 @@ import kotlin.coroutines.CoroutineContext
/**
* ViewModel hosted by IdentityActivity, shared across fragments.
*/
internal class IdentityViewModel constructor(
internal class IdentityViewModel(
application: Application,
internal val verificationArgs: IdentityVerificationSheetContract.Args,
internal val identityRepository: IdentityRepository,
Expand All @@ -112,7 +113,8 @@ internal class IdentityViewModel constructor(
private val tfLiteInitializer: InterpreterInitializer,
private val savedStateHandle: SavedStateHandle,
@UIContext internal val uiContext: CoroutineContext,
@IOContext internal val workContext: CoroutineContext
@IOContext internal val workContext: CoroutineContext,
private val finishWithResult: (IdentityVerificationSheet.VerificationFlowResult) -> Unit
) : AndroidViewModel(application) {

/**
Expand Down Expand Up @@ -878,14 +880,10 @@ internal class IdentityViewModel constructor(
(
"sessionID: ${verificationArgs.verificationSessionId} and ephemeralKey: " +
verificationArgs.ephemeralKeySecret
).let { msg ->
_verificationPage.postValue(
Resource.error(
msg,
IllegalStateException(msg, it)
)
)
}
.let { msg ->
_verificationPage.postValue(Resource.error(msg, IllegalStateException(msg, it)))
}
}
)
}
Expand Down Expand Up @@ -1022,6 +1020,9 @@ internal class IdentityViewModel constructor(
it
)
)

// Exit with failure
finishWithResult(IdentityVerificationSheet.VerificationFlowResult.Failed(it))
}
)
}
Expand Down Expand Up @@ -1763,7 +1764,8 @@ internal class IdentityViewModel constructor(
private val applicationSupplier: () -> Application,
private val uiContextSupplier: () -> CoroutineContext,
private val workContextSupplier: () -> CoroutineContext,
private val subcomponentSupplier: () -> IdentityActivitySubcomponent
private val subcomponentSupplier: () -> IdentityActivitySubcomponent,
private val finishWithResult: (IdentityVerificationSheet.VerificationFlowResult) -> Unit,
) : ViewModelProvider.Factory {

@Suppress("UNCHECKED_CAST")
Expand All @@ -1783,7 +1785,8 @@ internal class IdentityViewModel constructor(
subcomponent.tfLiteInitializer,
savedStateHandle,
uiContextSupplier(),
workContextSupplier()
workContextSupplier(),
finishWithResult
) as T
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import androidx.lifecycle.viewModelScope
import androidx.lifecycle.viewmodel.CreationExtras
import com.stripe.android.core.utils.requireApplication
import com.stripe.android.identity.R
import com.stripe.android.identity.VerificationFlowFinishable
import com.stripe.android.identity.analytics.FPSTracker
import com.stripe.android.identity.analytics.IdentityAnalyticsRequestFactory
import com.stripe.android.identity.analytics.ModelPerformanceTracker
Expand All @@ -24,13 +25,15 @@ internal class SelfieScanViewModel(
override val fpsTracker: FPSTracker,
override val identityAnalyticsRequestFactory: IdentityAnalyticsRequestFactory,
modelPerformanceTracker: ModelPerformanceTracker,
laplacianBlurDetector: LaplacianBlurDetector
laplacianBlurDetector: LaplacianBlurDetector,
private val verificationFlowFinishable: VerificationFlowFinishable
) : IdentityScanViewModel(
applicationContext,
fpsTracker,
identityAnalyticsRequestFactory,
modelPerformanceTracker,
laplacianBlurDetector
laplacianBlurDetector,
verificationFlowFinishable
) {

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down Expand Up @@ -58,6 +61,7 @@ internal class SelfieScanViewModel(
)

internal class SelfieScanViewModelFactory @Inject constructor(
private val verificationFlowFinishable: VerificationFlowFinishable,
private val modelPerformanceTracker: ModelPerformanceTracker,
private val laplacianBlurDetector: LaplacianBlurDetector,
private val fpsTracker: FPSTracker,
Expand All @@ -70,7 +74,8 @@ internal class SelfieScanViewModel(
fpsTracker,
identityAnalyticsRequestFactory,
modelPerformanceTracker,
laplacianBlurDetector
laplacianBlurDetector,
verificationFlowFinishable
) as T
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ internal class IdentityScanViewModelTest {
mockFpsTracker,
mockIdentityAnalyticsRequestFactory,
mock(),
mock(),
mock()
) {
override val scanFeedback = MutableStateFlow(null)
Expand Down
Loading

0 comments on commit dd48576

Please sign in to comment.