Skip to content

Commit

Permalink
[Identity] Use ml-core-default by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ccen-stripe committed Mar 16, 2023
1 parent f758cbc commit 794f4f2
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 16 deletions.
12 changes: 12 additions & 0 deletions identity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ If you intend to use this SDK with Stripe's Identity service, you must not modif

Get started with Stripe Identity's [Android integration guide](https://stripe.com/docs/identity/verify-identity-documents?platform=android) and [example project](../identity-example), or [📘 browse the SDK reference](https://stripe.dev/stripe-android/identity/index.html) for fine-grained documentation of all the classes and methods in the SDK.

### Use TFLite in Google play to reduce binary size

Identity Android SDK uses a portable TFLite runtime to execute machine learning models, if your application is released through Google play, you could instead use the Google play runtime, this would reduce the SDK size by ~1.2mb.

To do so, configure your app's dependency on stripe identity as follows.
```
implementation("com.stripe:identity:x.y.z") {
exclude group: 'com.stripe', module: 'ml-core-default' // exclude the default tflite runtime
}
implementation("com.stripe:ml-core-googleplay:x.y.z") // include the google play tflite runtime
```

### Example

[identity-example](../identity-example) – This example demonstrates how to capture your users' ID documents on Android and securely send them to Stripe Identity for identity verification.
6 changes: 1 addition & 5 deletions identity/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ dependencies {
implementation project(":camera-core")
implementation project(":stripe-core")
implementation project(":stripe-ui-core")

// vanilla tflite library
implementation "org.tensorflow:tensorflow-lite:2.11.0"
// support library to reshape image to the input model shape
implementation 'org.tensorflow:tensorflow-lite-support:0.4.3'
implementation project(":ml-core:default")

implementation "androidx.constraintlayout:constraintlayout:$androidxConstraintlayoutVersion"
implementation "androidx.activity:activity-ktx:$androidxActivityVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ internal class IdentityActivity :
super.onSaveInstanceState(outState)
outState.putBoolean(KEY_PRESENTED, true)
}

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
injectWithFallback(
Expand All @@ -120,6 +119,7 @@ internal class IdentityActivity :
.fallbackUrlLauncher(this)
.build()
identityViewModel.retrieveAndBufferVerificationPage()
identityViewModel.initializeTfLite()
identityViewModel.registerActivityResultCaller(this)
fallbackUrlLauncher = registerForActivityResult(
ActivityResultContracts.StartActivityForResult()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import com.stripe.android.identity.networking.IdentityRepository
import com.stripe.android.identity.utils.IdentityIO
import com.stripe.android.identity.utils.IdentityImageHandler
import com.stripe.android.identity.viewmodel.IdentityScanViewModel
import com.stripe.android.mlcore.base.InterpreterInitializer
import com.stripe.android.uicore.address.AddressRepository
import dagger.BindsInstance
import dagger.Subcomponent
Expand All @@ -31,6 +32,7 @@ internal interface IdentityActivitySubcomponent {
val verificationArgs: IdentityVerificationSheetContract.Args
val identityImageHandler: IdentityImageHandler
val addressRepository: AddressRepository
val tfLiteInitializer: InterpreterInitializer

@Subcomponent.Builder
interface Builder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import com.stripe.android.identity.networking.IdentityModelFetcher
import com.stripe.android.identity.networking.IdentityRepository
import com.stripe.android.identity.utils.DefaultIdentityIO
import com.stripe.android.identity.utils.IdentityIO
import com.stripe.android.mlcore.base.InterpreterInitializer
import com.stripe.android.mlcore.impl.InterpreterInitializerImpl
import dagger.Binds
import dagger.Module
import dagger.Provides
Expand Down Expand Up @@ -39,5 +41,9 @@ internal abstract class IdentityCommonModule {
@Provides
@Singleton
fun provideResources(context: Context): Resources = context.resources

@Provides
@Singleton
fun provideInterpreterInitializer(): InterpreterInitializer = InterpreterInitializerImpl
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import com.stripe.android.camera.framework.util.maxAspectRatioInSize
import com.stripe.android.identity.analytics.ModelPerformanceTracker
import com.stripe.android.identity.states.IdentityScanState
import com.stripe.android.identity.utils.roundToMaxDecimals
import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import com.stripe.android.mlcore.impl.InterpreterWrapperImpl
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
Expand All @@ -24,7 +26,12 @@ internal class FaceDetectorAnalyzer(
private val modelPerformanceTracker: ModelPerformanceTracker
) : Analyzer<AnalyzerInput, IdentityScanState, AnalyzerOutput> {

private val tfliteInterpreter = Interpreter(modelFile)
// private val tfliteInterpreter = Interpreter(modelFile)
// private val interpreterApi: InterpreterApi

private val interpreterApi: InterpreterWrapper = InterpreterWrapperImpl(
modelFile, InterpreterOptionsWrapper.Builder().build()
)

override suspend fun analyze(
data: AnalyzerInput,
Expand Down Expand Up @@ -56,15 +63,18 @@ internal class FaceDetectorAnalyzer(

val inferenceStat = modelPerformanceTracker.trackInference()
// inference - input: (1, 128, 128, 3), output: (1, 4), (1, 1)

val boundingBoxes = Array(1) { FloatArray(OUTPUT_BOUNDING_BOX_TENSOR_SIZE) }
val score = FloatArray(OUTPUT_SCORE_TENSOR_SIZE)
tfliteInterpreter.runForMultipleInputsOutputs(

interpreterApi.runForMultipleInputsOutputs(
arrayOf(tensorImage.buffer),
mapOf(
OUTPUT_BOUNDING_BOX_TENSOR_INDEX to boundingBoxes,
OUTPUT_SCORE_TENSOR_INDEX to score
)
)

inferenceStat.trackResult()

// FaceDetector outputs (left, top, right, bottom) with absolute value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import com.stripe.android.camera.framework.util.maxAspectRatioInSize
import com.stripe.android.identity.analytics.ModelPerformanceTracker
import com.stripe.android.identity.states.IdentityScanState
import com.stripe.android.identity.utils.roundToMaxDecimals
import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import com.stripe.android.mlcore.impl.InterpreterWrapperImpl
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
Expand All @@ -18,8 +20,6 @@ import java.io.File

/**
* Analyzer to run IDDetector.
*
* TODO(ccen): reimplement with ImageClassifier
*/
internal class IDDetectorAnalyzer(
modelFile: File,
Expand All @@ -28,7 +28,10 @@ internal class IDDetectorAnalyzer(
) :
Analyzer<AnalyzerInput, IdentityScanState, AnalyzerOutput> {

private val tfliteInterpreter = Interpreter(modelFile)
// private val interpreterApi: InterpreterApi
private val interpreterApi: InterpreterWrapper = InterpreterWrapperImpl(
modelFile, InterpreterOptionsWrapper.Builder().build()
)

override suspend fun analyze(
data: AnalyzerInput,
Expand Down Expand Up @@ -60,7 +63,7 @@ internal class IDDetectorAnalyzer(
// inference - input: (1, 224, 224, 3), output: (392, 4), (392, 4)
val boundingBoxes = Array(OUTPUT_SIZE) { FloatArray(OUTPUT_BOUNDING_BOX_TENSOR_SIZE) }
val categories = Array(OUTPUT_SIZE) { FloatArray(OUTPUT_CATEGORY_TENSOR_SIZE) }
tfliteInterpreter.runForMultipleInputsOutputs(
interpreterApi.runForMultipleInputsOutputs(
arrayOf(tensorImage.buffer),
mapOf(
OUTPUT_BOUNDING_BOX_TENSOR_INDEX to boundingBoxes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ import com.stripe.android.identity.states.IdentityScanState
import com.stripe.android.identity.ui.IndividualCollectedStates
import com.stripe.android.identity.utils.IdentityIO
import com.stripe.android.identity.utils.IdentityImageHandler
import com.stripe.android.mlcore.base.InterpreterInitializer
import com.stripe.android.uicore.address.AddressSchemaRepository
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
Expand All @@ -112,6 +113,7 @@ internal class IdentityViewModel constructor(
internal val screenTracker: ScreenTracker,
internal val imageHandler: IdentityImageHandler,
internal val addressSchemaRepository: AddressSchemaRepository,
private val tfLiteInitializer: InterpreterInitializer,
private val savedStateHandle: SavedStateHandle,
@UIContext internal val uiContext: CoroutineContext,
@IOContext internal val workContext: CoroutineContext
Expand Down Expand Up @@ -235,6 +237,21 @@ internal class IdentityViewModel constructor(
(SingleSideDocumentUploadState() to CollectedDataParam())
)

private val _isTfLiteInitialized: MutableLiveData<Boolean> = MutableLiveData(false)
val isTfLiteInitialized: LiveData<Boolean> = _isTfLiteInitialized

fun initializeTfLite() {
viewModelScope.launch(workContext) {
tfLiteInitializer.initialize(
getApplication(),
{
_isTfLiteInitialized.postValue(true)
},
{ throw IllegalStateException("Failed to initialize TFLite runtime: $it") }
)
}
}

/**
* Response for initial VerificationPage, used for building UI.
*/
Expand Down Expand Up @@ -274,6 +291,7 @@ internal class IdentityViewModel constructor(
private var idDetectorModel: File? = null
private var faceDetectorModel: File? = null
private var faceDetectorModelValueSet = false
private var isTfliteInitialized = false

init {
postValue(Resource.loading())
Expand Down Expand Up @@ -317,12 +335,18 @@ internal class IdentityViewModel constructor(
Status.IDLE -> {} // no-op
}
}
addSource(isTfLiteInitialized) { initialized ->
isTfliteInitialized = initialized
if (isTfliteInitialized) {
maybePostSuccess()
}
}
}

private fun maybePostSuccess() {
page?.let { page ->
idDetectorModel?.let { idDetectorModel ->
if (faceDetectorModelValueSet) {
if (isTfliteInitialized && faceDetectorModelValueSet) {
postValue(
Resource.success(
PageAndModelFiles(
Expand Down Expand Up @@ -1597,6 +1621,7 @@ internal class IdentityViewModel constructor(
subcomponent.screenTracker,
subcomponent.identityImageHandler,
subcomponent.addressRepository,
subcomponent.tfLiteInitializer,
savedStateHandle,
uiContextSupplier(),
workContextSupplier()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ import com.stripe.android.identity.states.IdentityScanState
import com.stripe.android.identity.utils.IdentityIO
import com.stripe.android.identity.viewmodel.IdentityViewModel.Companion.BACK
import com.stripe.android.identity.viewmodel.IdentityViewModel.Companion.FRONT
import com.stripe.android.mlcore.base.InterpreterInitializer
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TestRule
import org.junit.runner.RunWith
import org.mockito.kotlin.any
import org.mockito.kotlin.anyOrNull
import org.mockito.kotlin.argWhere
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.eq
import org.mockito.kotlin.mock
Expand All @@ -74,6 +78,7 @@ import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import org.robolectric.RobolectricTestRunner
import java.io.File
import kotlin.test.assertFailsWith

@RunWith(RobolectricTestRunner::class)
internal class IdentityViewModelTest {
Expand Down Expand Up @@ -123,7 +128,9 @@ internal class IdentityViewModelTest {
private val mockOnMissingFront = mock<() -> Unit>()
private val mockOnMissingBack = mock<() -> Unit>()
private val mockOnReadyToSubmit = mock<() -> Unit>()
private val mockTfLiteInitializer = mock<InterpreterInitializer>()

@OptIn(ExperimentalCoroutinesApi::class)
private val viewModel = IdentityViewModel(
ApplicationProvider.getApplicationContext(),
IdentityVerificationSheetContract.Args(
Expand All @@ -141,9 +148,10 @@ internal class IdentityViewModelTest {
mockScreenTracker,
mock(),
mock(),
mockTfLiteInitializer,
mockSavedStateHandle,
mock(),
mock()
UnconfinedTestDispatcher()
)

private fun mockUploadSuccess() = runBlocking {
Expand Down Expand Up @@ -523,6 +531,30 @@ internal class IdentityViewModelTest {
}
}

@Test
fun `verify tfLite initialization success`() {
viewModel.initializeTfLite()
val successCaptor = argumentCaptor<() -> Unit>()

verify(mockTfLiteInitializer).initialize(any(), successCaptor.capture(), any())

successCaptor.firstValue.invoke()

assertThat(viewModel.isTfLiteInitialized.value).isTrue()
}

@Test
fun `verify tfLite initialization failure`() {
viewModel.initializeTfLite()
val failureCaptor = argumentCaptor<(Exception) -> Unit>()

verify(mockTfLiteInitializer).initialize(any(), any(), failureCaptor.capture())

assertFailsWith<Exception> {
failureCaptor.firstValue.invoke(Exception())
}
}

private fun testPostVerificationPageDataAndMaybeNavigate(
verificationPageData: VerificationPageData,
targetTopLevelDestination: IdentityTopLevelDestination
Expand Down

0 comments on commit 794f4f2

Please sign in to comment.