Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stripecardscan] Use ml-core for CardScan #6409

Merged
merged 4 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ interface InterpreterWrapper {
)

fun run(input: Any, output: Any)

fun close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import org.tensorflow.lite.Interpreter
import java.io.File
import java.nio.ByteBuffer

@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
class InterpreterWrapperImpl(file: File, options: InterpreterOptionsWrapper) : InterpreterWrapper {
private val interpreter: Interpreter = Interpreter(file, options.toInterpreterOptions())
class InterpreterWrapperImpl : InterpreterWrapper {
private val interpreter: Interpreter

constructor(byteBuffer: ByteBuffer, options: InterpreterOptionsWrapper) {
interpreter = Interpreter(byteBuffer, options.toInterpreterOptions())
}

constructor(file: File, options: InterpreterOptionsWrapper) {
interpreter = Interpreter(file, options.toInterpreterOptions())
}

override fun runForMultipleInputsOutputs(inputs: Array<Any>, outputs: Map<Int, Any>) {
interpreter.runForMultipleInputsOutputs(inputs, outputs)
Expand All @@ -17,6 +26,10 @@ class InterpreterWrapperImpl(file: File, options: InterpreterOptionsWrapper) : I
override fun run(input: Any, output: Any) {
interpreter.run(input, output)
}

override fun close() {
interpreter.close()
}
}

private fun InterpreterOptionsWrapper.toInterpreterOptions(): Interpreter.Options {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class InterpreterWrapperImpl(file: File, options: InterpreterOptionsWrapper) : I
override fun run(input: Any, output: Any) {
interpreter.run(input, output)
}

override fun close() {
interpreter.close()
}
}

private fun InterpreterOptionsWrapper.toInterpreterOptions(): Interpreter.Options {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@ import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import org.tensorflow.lite.InterpreterApi
import java.io.File
import java.nio.ByteBuffer

@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
class InterpreterWrapperImpl(file: File, options: InterpreterOptionsWrapper) : InterpreterWrapper {
private val interpreter: InterpreterApi =
InterpreterApi.create(file, options.toInterpreterApiOptions())
class InterpreterWrapperImpl : InterpreterWrapper {
private val interpreter: InterpreterApi

constructor(byteBuffer: ByteBuffer, options: InterpreterOptionsWrapper) {
interpreter = InterpreterApi.create(byteBuffer, options.toInterpreterApiOptions())
}

constructor(file: File, options: InterpreterOptionsWrapper) {
interpreter = InterpreterApi.create(file, options.toInterpreterApiOptions())
}

override fun runForMultipleInputsOutputs(inputs: Array<Any>, outputs: Map<Int, Any>) {
interpreter.runForMultipleInputsOutputs(inputs, outputs)
Expand All @@ -18,6 +26,10 @@ class InterpreterWrapperImpl(file: File, options: InterpreterOptionsWrapper) : I
override fun run(input: Any, output: Any) {
interpreter.run(input, output)
}

override fun close() {
interpreter.close()
}
}

private fun InterpreterOptionsWrapper.toInterpreterApiOptions(): InterpreterApi.Options {
Expand Down
12 changes: 12 additions & 0 deletions stripecardscan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ See the `stripecardscan-example` directory for an example application that you c
}
```

## Use TFLite in Google play to reduce binary size

CardScan Android SDK uses a portable TFLite runtime to execute machine learning models, if your application is released through Google play, you could instead use the Google play runtime, this would reduce the SDK size by ~400kb.

To do so, configure your app's dependency on stripecardscan as follows.
```
implementation('com.stripe:stripecardscan:$stripeVersion') {
exclude group: 'com.stripe', module: 'ml-core-cardscan' // exclude the cardscan-specific portable tflite runtime
}
implementation('com.stripe:ml-core-googleplay:$stripeVersion') // include the google play tflite runtime
```

# Credit Card OCR

Add `CardScanSheet` in your activity or fragment where you want to invoke the verification flow
Expand Down
4 changes: 1 addition & 3 deletions stripecardscan/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ dependencies {
api project(":stripe-core")
implementation project(":camera-core")

// If a user wants to use their own TFLite implementation, they can exclude this dependency
// explicitly in their gradle dependency.
implementation project(":stripecardscan-tflite")
implementation project(":ml-core:cardscan")

implementation "androidx.appcompat:appcompat:$androidxAppcompatVersion"
implementation "androidx.constraintlayout:constraintlayout:$androidxConstraintlayoutVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@ import android.content.Context
import android.util.Log
import com.stripe.android.camera.framework.Analyzer
import com.stripe.android.camera.framework.AnalyzerFactory
import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import com.stripe.android.mlcore.impl.InterpreterWrapperImpl
import com.stripe.android.stripecardscan.framework.FetchedData
import com.stripe.android.stripecardscan.framework.Loader
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.nnapi.NnApiDelegate
import java.io.Closeable
import java.nio.ByteBuffer

/**
* A TensorFlowLite analyzer uses an [Interpreter] to analyze data.
*/
internal abstract class TensorFlowLiteAnalyzer<Input, MLInput, Output, MLOutput>(
private val tfInterpreter: Interpreter,
private val delegate: NnApiDelegate? = null
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is removed as it's never set

private val tfInterpreter: InterpreterWrapper,
) : Analyzer<Input, Any, Output>, Closeable {

protected abstract suspend fun interpretMLOutput(data: Input, mlOutput: MLOutput): Output

protected abstract suspend fun transformData(data: Input): MLInput

protected abstract suspend fun executeInference(
tfInterpreter: Interpreter,
tfInterpreter: InterpreterWrapper,
data: MLInput
): MLOutput

Expand All @@ -40,7 +40,6 @@ internal abstract class TensorFlowLiteAnalyzer<Input, MLInput, Output, MLOutput>

override fun close() {
tfInterpreter.close()
delegate?.close()
}
}

Expand All @@ -55,19 +54,22 @@ internal abstract class TFLAnalyzerFactory<
private val context: Context,
private val fetchedModel: FetchedData
) : AnalyzerFactory<Input, Any, Output, AnalyzerType> {
protected abstract val tfOptions: Interpreter.Options
protected abstract val tfOptions: InterpreterOptionsWrapper

private val loader by lazy { Loader(context) }

private val loadModelMutex = Mutex()

private var loadedModel: ByteBuffer? = null

protected suspend fun createInterpreter(): Interpreter? =
// protected suspend fun createInterpreter(): Interpreter? =
// createInterpreter(fetchedModel)

protected suspend fun createInterpreter(): InterpreterWrapper? =
createInterpreter(fetchedModel)

private suspend fun createInterpreter(fetchedModel: FetchedData): Interpreter? = try {
loadModel(fetchedModel)?.let { Interpreter(it, tfOptions) }
private suspend fun createInterpreter(fetchedModel: FetchedData): InterpreterWrapper? = try {
loadModel(fetchedModel)?.let { InterpreterWrapperImpl(it, tfOptions) }
} catch (t: Throwable) {
Log.e(
LOG_TAG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import android.util.Size
import com.stripe.android.camera.framework.image.cropCameraPreviewToSquare
import com.stripe.android.camera.framework.image.hasOpenGl31
import com.stripe.android.camera.framework.image.scale
import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import com.stripe.android.stripecardscan.framework.FetchedData
import com.stripe.android.stripecardscan.framework.image.MLImage
import com.stripe.android.stripecardscan.framework.image.toMLImage
import com.stripe.android.stripecardscan.framework.ml.TFLAnalyzerFactory
import com.stripe.android.stripecardscan.framework.ml.TensorFlowLiteAnalyzer
import com.stripe.android.stripecardscan.framework.util.indexOfMax
import org.tensorflow.lite.Interpreter
import java.nio.ByteBuffer
import kotlin.math.max

Expand All @@ -22,7 +23,7 @@ private val TRAINED_IMAGE_SIZE = Size(224, 224)
/** model returns whether or not there is a card present */
private const val NUM_CLASS = 3

internal class CardDetect private constructor(interpreter: Interpreter) :
internal class CardDetect private constructor(interpreter: InterpreterWrapper) :
TensorFlowLiteAnalyzer<
CardDetect.Input,
ByteBuffer,
Expand Down Expand Up @@ -97,7 +98,7 @@ internal class CardDetect private constructor(interpreter: Interpreter) :
data.cardDetectImage.getData()

override suspend fun executeInference(
tfInterpreter: Interpreter,
tfInterpreter: InterpreterWrapper,
data: ByteBuffer
): Array<FloatArray> {
val mlOutput = arrayOf(FloatArray(NUM_CLASS))
Expand All @@ -118,10 +119,10 @@ internal class CardDetect private constructor(interpreter: Interpreter) :
private const val DEFAULT_THREADS = 4
}

override val tfOptions: Interpreter.Options = Interpreter
.Options()
.setUseNNAPI(USE_GPU && hasOpenGl31(context))
.setNumThreads(threads)
override val tfOptions = InterpreterOptionsWrapper.Builder()
.useNNAPI(USE_GPU && hasOpenGl31(context.applicationContext))
.numThreads(threads)
.build()

override suspend fun newInstance(): CardDetect? =
createInterpreter()?.let { CardDetect(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import androidx.annotation.VisibleForTesting
import com.stripe.android.camera.framework.image.cropCameraPreviewToViewFinder
import com.stripe.android.camera.framework.image.hasOpenGl31
import com.stripe.android.camera.framework.image.scale
import com.stripe.android.mlcore.base.InterpreterOptionsWrapper
import com.stripe.android.mlcore.base.InterpreterWrapper
import com.stripe.android.stripecardscan.framework.FetchedData
import com.stripe.android.stripecardscan.framework.image.MLImage
import com.stripe.android.stripecardscan.framework.image.toMLImage
Expand All @@ -23,7 +25,6 @@ import com.stripe.android.stripecardscan.payment.ml.ssd.combinePriors
import com.stripe.android.stripecardscan.payment.ml.ssd.determineLayoutAndFilter
import com.stripe.android.stripecardscan.payment.ml.ssd.extractPredictions
import com.stripe.android.stripecardscan.payment.ml.ssd.rearrangeOCRArray
import org.tensorflow.lite.Interpreter
import java.nio.ByteBuffer

/** Training images are normalized with mean 127.5 and std 128.5. */
Expand Down Expand Up @@ -89,7 +90,7 @@ private val PRIORS = combinePriors(SSDOcr.Factory.TRAINED_IMAGE_SIZE)
/**
* This model performs SSD OCR recognition on a card.
*/
internal class SSDOcr private constructor(interpreter: Interpreter) :
internal class SSDOcr private constructor(interpreter: InterpreterWrapper) :
TensorFlowLiteAnalyzer<
SSDOcr.Input,
Array<ByteBuffer>,
Expand Down Expand Up @@ -178,15 +179,16 @@ internal class SSDOcr private constructor(interpreter: Interpreter) :
}

override suspend fun executeInference(
tfInterpreter: Interpreter,
tfInterpreter: InterpreterWrapper,
data: Array<ByteBuffer>
): Map<Int, Array<FloatArray>> {
val mlOutput = mapOf(
0 to arrayOf(FloatArray(NUM_CLASS)),
1 to arrayOf(FloatArray(NUM_LOC))
)

tfInterpreter.runForMultipleInputsOutputs(data, mlOutput)
@Suppress("UNCHECKED_CAST")
tfInterpreter.runForMultipleInputsOutputs(data as Array<Any>, mlOutput)
return mlOutput
}

Expand All @@ -205,10 +207,10 @@ internal class SSDOcr private constructor(interpreter: Interpreter) :
val TRAINED_IMAGE_SIZE = Size(600, 375)
}

override val tfOptions: Interpreter.Options = Interpreter
.Options()
.setUseNNAPI(USE_GPU && hasOpenGl31(context.applicationContext))
.setNumThreads(threads)
override val tfOptions = InterpreterOptionsWrapper.Builder()
.useNNAPI(USE_GPU && hasOpenGl31(context.applicationContext))
.numThreads(threads)
.build()

override suspend fun newInstance(): SSDOcr? = createInterpreter()?.let { SSDOcr(it) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import android.view.ViewGroup
import android.view.WindowManager
import androidx.activity.addCallback
import androidx.appcompat.app.AlertDialog
import androidx.lifecycle.lifecycleScope
import com.stripe.android.camera.CameraAdapter
import com.stripe.android.camera.CameraPermissionCheckingActivity
import com.stripe.android.camera.CameraPreviewImage
import com.stripe.android.camera.DefaultCameraErrorListener
import com.stripe.android.camera.framework.Stats
import com.stripe.android.mlcore.impl.InterpreterInitializerImpl
import com.stripe.android.stripecardscan.R
import com.stripe.android.stripecardscan.camera.getCameraAdapter
import kotlinx.coroutines.CoroutineScope
Expand Down Expand Up @@ -249,7 +251,17 @@ internal abstract class ScanActivity : CameraPermissionCheckingActivity(), Corou
onSupportsMultipleCameras(it)
}

launch { onCameraStreamAvailable(cameraAdapter.getImageStream()) }
lifecycleScope.launch(Dispatchers.IO) {
InterpreterInitializerImpl.initialize(
context = this@ScanActivity,
onSuccess = {
lifecycleScope.launch { onCameraStreamAvailable(cameraAdapter.getImageStream()) }
},
onFailure = {
scanFailure(it)
}
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.RestrictTo
import androidx.core.content.ContextCompat
import androidx.fragment.app.Fragment
import androidx.lifecycle.lifecycleScope
import com.stripe.android.camera.CameraAdapter
import com.stripe.android.camera.CameraPreviewImage
import com.stripe.android.camera.DefaultCameraErrorListener
import com.stripe.android.camera.framework.Stats
import com.stripe.android.core.storage.StorageFactory
import com.stripe.android.mlcore.impl.InterpreterInitializerImpl
import com.stripe.android.stripecardscan.R
import com.stripe.android.stripecardscan.camera.getCameraAdapter
import kotlinx.coroutines.CoroutineScope
Expand Down Expand Up @@ -236,7 +238,17 @@ abstract class ScanFragment : Fragment(), CoroutineScope {
onSupportsMultipleCameras(it)
}

launch { onCameraStreamAvailable(cameraAdapter.getImageStream()) }
lifecycleScope.launch(Dispatchers.IO) {
InterpreterInitializerImpl.initialize(
context = requireContext(),
onSuccess = {
lifecycleScope.launch { onCameraStreamAvailable(cameraAdapter.getImageStream()) }
},
onFailure = {
scanFailure(it)
}
)
}
}

/**
Expand Down