diff --git a/example/src/main/java/com/stripe/example/ExampleApplication.kt b/example/src/main/java/com/stripe/example/ExampleApplication.kt index 87a29af4e58..dae4de34eb1 100644 --- a/example/src/main/java/com/stripe/example/ExampleApplication.kt +++ b/example/src/main/java/com/stripe/example/ExampleApplication.kt @@ -4,9 +4,8 @@ import android.os.StrictMode import androidx.multidex.MultiDexApplication import com.facebook.stetho.Stetho import com.stripe.android.CustomerSession -import com.stripe.android.EphemeralKeyProvider -import com.stripe.android.EphemeralKeyUpdateListener import com.stripe.android.PaymentConfiguration +import com.stripe.example.service.ExampleEphemeralKeyProvider import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -38,12 +37,10 @@ class ExampleApplication : MultiDexApplication() { Stetho.initializeWithDefaults(this@ExampleApplication) } - // initialize empty CustomerSession - CustomerSession.initCustomerSession(this, - object : EphemeralKeyProvider { - override fun createEphemeralKey(apiVersion: String, keyUpdateListener: EphemeralKeyUpdateListener) { - } - } + CustomerSession.initCustomerSession( + this, + ExampleEphemeralKeyProvider(this), + false ) } } diff --git a/example/src/main/java/com/stripe/example/activity/CustomerSessionActivity.kt b/example/src/main/java/com/stripe/example/activity/CustomerSessionActivity.kt index 1d326d8b0de..1117eb08850 100644 --- a/example/src/main/java/com/stripe/example/activity/CustomerSessionActivity.kt +++ b/example/src/main/java/com/stripe/example/activity/CustomerSessionActivity.kt @@ -11,7 +11,6 @@ import com.stripe.android.model.Customer import com.stripe.android.model.PaymentMethod import com.stripe.android.view.PaymentMethodsActivityStarter import com.stripe.example.R -import com.stripe.example.service.ExampleEphemeralKeyProvider import kotlinx.android.synthetic.main.activity_customer_session.* /** @@ -28,11 +27,6 @@ class CustomerSessionActivity : AppCompatActivity() { super.onCreate(savedInstanceState) setContentView(R.layout.activity_customer_session) setTitle(R.string.customer_payment_data_example) - CustomerSession.initCustomerSession( - this, - ExampleEphemeralKeyProvider(this), - false - ) progress_bar.visibility = View.VISIBLE CustomerSession.getInstance().retrieveCurrentCustomer( diff --git a/example/src/main/java/com/stripe/example/activity/FragmentExamplesActivity.kt b/example/src/main/java/com/stripe/example/activity/FragmentExamplesActivity.kt index 9356c9e8eca..a11894e1361 100644 --- a/example/src/main/java/com/stripe/example/activity/FragmentExamplesActivity.kt +++ b/example/src/main/java/com/stripe/example/activity/FragmentExamplesActivity.kt @@ -28,7 +28,6 @@ import com.stripe.example.R import com.stripe.example.StripeFactory import com.stripe.example.module.BackendApiFactory import com.stripe.example.service.BackendApi -import com.stripe.example.service.ExampleEphemeralKeyProvider import io.reactivex.android.schedulers.AndroidSchedulers import io.reactivex.disposables.CompositeDisposable import io.reactivex.schedulers.Schedulers @@ -50,8 +49,8 @@ class FragmentExamplesActivity : AppCompatActivity() { val newFragment = LauncherFragment() - val ft = supportFragmentManager.beginTransaction() - ft.add(R.id.root, newFragment, LauncherFragment::class.java.simpleName) + supportFragmentManager.beginTransaction() + .replace(R.id.root, newFragment, LauncherFragment::class.java.simpleName) .commit() } @@ -275,10 +274,6 @@ class FragmentExamplesActivity : AppCompatActivity() { } private fun createCustomerSession(): CustomerSession { - CustomerSession.initCustomerSession( - requireContext(), - ExampleEphemeralKeyProvider(requireContext()) - ) val customerSession = CustomerSession.getInstance() customerSession.retrieveCurrentCustomer( object : CustomerSession.CustomerRetrievalListener { diff --git a/example/src/main/java/com/stripe/example/activity/PaymentSessionActivity.kt b/example/src/main/java/com/stripe/example/activity/PaymentSessionActivity.kt index 381e0567cc0..6047016218f 100644 --- a/example/src/main/java/com/stripe/example/activity/PaymentSessionActivity.kt +++ b/example/src/main/java/com/stripe/example/activity/PaymentSessionActivity.kt @@ -20,7 +20,6 @@ import com.stripe.android.view.BillingAddressFields import com.stripe.android.view.PaymentUtils import com.stripe.android.view.ShippingInfoWidget import com.stripe.example.R -import com.stripe.example.service.ExampleEphemeralKeyProvider import kotlinx.android.synthetic.main.activity_payment_session.* import java.util.Currency import java.util.Locale @@ -45,9 +44,7 @@ class PaymentSessionActivity : AppCompatActivity() { super.onCreate(savedInstanceState) setContentView(R.layout.activity_payment_session) - progress_bar.visibility = View.VISIBLE - - paymentSession = createPaymentSession(savedInstanceState) + paymentSession = createPaymentSession(savedInstanceState == null) btn_select_payment_method.setOnClickListener { paymentSession.presentPaymentMethodSelection() @@ -57,21 +54,17 @@ class PaymentSessionActivity : AppCompatActivity() { } } - private fun createCustomerSession(): CustomerSession { - CustomerSession.initCustomerSession( - this, - ExampleEphemeralKeyProvider(this), - false - ) - return CustomerSession.getInstance() - } - private fun createPaymentSession( - savedInstanceState: Bundle?, - shouldPrefetchCustomer: Boolean = true + shouldPrefetchCustomer: Boolean = false ): PaymentSession { + if (shouldPrefetchCustomer) { + disableUi() + } else { + enableUi() + } + // CustomerSession only needs to be initialized once per app. - val customerSession = createCustomerSession() + val customerSession = CustomerSession.getInstance() val paymentSession = PaymentSession( activity = this, @@ -96,8 +89,7 @@ class PaymentSessionActivity : AppCompatActivity() { .build() ) paymentSession.init( - listener = PaymentSessionListenerImpl(this, customerSession), - savedInstanceState = savedInstanceState + listener = PaymentSessionListenerImpl(this, customerSession) ) paymentSession.setCartTotal(2000L) @@ -159,22 +151,12 @@ class PaymentSessionActivity : AppCompatActivity() { paymentSession.handlePaymentData(requestCode, resultCode, data ?: Intent()) } - override fun onDestroy() { - paymentSession.onDestroy() - super.onDestroy() - } - - override fun onSaveInstanceState(outState: Bundle) { - super.onSaveInstanceState(outState) - paymentSession.savePaymentSessionInstanceState(outState) - } - private fun onPaymentSessionDataChanged( customerSession: CustomerSession, data: PaymentSessionData ) { paymentSessionData = data - progress_bar.visibility = View.VISIBLE + disableUi() customerSession.retrieveCurrentCustomer( PaymentSessionChangeCustomerRetrievalListener(this) ) @@ -186,6 +168,12 @@ class PaymentSessionActivity : AppCompatActivity() { btn_start_payment_flow.isEnabled = true } + private fun disableUi() { + progress_bar.visibility = View.VISIBLE + btn_select_payment_method.isEnabled = false + btn_start_payment_flow.isEnabled = false + } + private fun onCustomerRetrieved() { enableUi() @@ -239,10 +227,10 @@ class PaymentSessionActivity : AppCompatActivity() { BackgroundTaskTracker.onStart() } - listenerActivity?.progress_bar?.visibility = if (isCommunicating) { - View.VISIBLE + if (isCommunicating) { + listenerActivity?.disableUi() } else { - View.INVISIBLE + listenerActivity?.enableUi() } } diff --git a/stripe/src/main/java/com/stripe/android/PaymentSession.kt b/stripe/src/main/java/com/stripe/android/PaymentSession.kt index 379c45eed33..6d5c6715648 100644 --- a/stripe/src/main/java/com/stripe/android/PaymentSession.kt +++ b/stripe/src/main/java/com/stripe/android/PaymentSession.kt @@ -1,13 +1,20 @@ package com.stripe.android import android.app.Activity +import android.app.Application import android.content.Context import android.content.Intent -import android.os.Bundle +import androidx.activity.ComponentActivity import androidx.annotation.IntRange import androidx.annotation.VisibleForTesting import androidx.fragment.app.Fragment -import com.stripe.android.model.Customer +import androidx.lifecycle.Lifecycle +import androidx.lifecycle.LifecycleObserver +import androidx.lifecycle.LifecycleOwner +import androidx.lifecycle.Observer +import androidx.lifecycle.OnLifecycleEvent +import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.ViewModelStoreOwner import com.stripe.android.model.PaymentMethod import com.stripe.android.view.ActivityStarter import com.stripe.android.view.PaymentFlowActivity @@ -21,38 +28,65 @@ import java.lang.ref.WeakReference */ class PaymentSession @VisibleForTesting internal constructor( private val context: Context, + application: Application, + viewModelStoreOwner: ViewModelStoreOwner, + private val lifecycleOwner: LifecycleOwner, private val config: PaymentSessionConfig, - private val customerSession: CustomerSession, + customerSession: CustomerSession, private val paymentMethodsActivityStarter: ActivityStarter, private val paymentFlowActivityStarter: ActivityStarter, - private val paymentSessionPrefs: PaymentSessionPrefs, paymentSessionData: PaymentSessionData = PaymentSessionData(config) ) { + internal val viewModel: PaymentSessionViewModel = + ViewModelProvider( + viewModelStoreOwner, + PaymentSessionViewModel.Factory( + application, + paymentSessionData, + customerSession + ) + )[PaymentSessionViewModel::class.java] + /** * @return the data associated with the instance of this class. */ - var paymentSessionData: PaymentSessionData = paymentSessionData - private set - private var paymentSessionListener: PaymentSessionListener? = null + val paymentSessionData: PaymentSessionData + get() = viewModel.paymentSessionData + + @JvmSynthetic + internal var listener: PaymentSessionListener? = null + + private val lifecycleObserver = object : LifecycleObserver { + @OnLifecycleEvent(Lifecycle.Event.ON_DESTROY) + fun onDestroy() { + listener = null + } + } + + init { + lifecycleOwner.lifecycle.addObserver(lifecycleObserver) + } /** * Create a PaymentSession attached to the given host Activity. * - * @param activity an `Activity` from which to launch other Stripe Activities. This + * @param activity a `ComponentActivity` from which to launch other Stripe Activities. This * Activity will receive results in * `Activity#onActivityResult(int, int, Intent)` that should be * passed back to this session. * @param config a [PaymentSessionConfig] that configures this [PaymentSession] instance */ - constructor(activity: Activity, config: PaymentSessionConfig) : this( + constructor(activity: ComponentActivity, config: PaymentSessionConfig) : this( activity.applicationContext, + activity.application, + activity, + activity, config, CustomerSession.getInstance(), PaymentMethodsActivityStarter(activity), - PaymentFlowActivityStarter(activity, config), - PaymentSessionPrefs.create(activity) + PaymentFlowActivityStarter(activity, config) ) /** @@ -64,19 +98,21 @@ class PaymentSession @VisibleForTesting internal constructor( * @param config a [PaymentSessionConfig] that configures this [PaymentSession] instance */ constructor(fragment: Fragment, config: PaymentSessionConfig) : this( - fragment.requireContext().applicationContext, + fragment.requireActivity().applicationContext, + fragment.requireActivity().application, + fragment, + fragment, config, CustomerSession.getInstance(), PaymentMethodsActivityStarter(fragment), - PaymentFlowActivityStarter(fragment, config), - PaymentSessionPrefs.create(fragment.requireActivity()) + PaymentFlowActivityStarter(fragment, config) ) /** * Notify this payment session that it is complete */ fun onCompleted() { - customerSession.resetUsageTokens() + viewModel.onCompleted() } /** @@ -115,8 +151,8 @@ class PaymentSession @VisibleForTesting internal constructor( val paymentSessionData = data.getParcelableExtra(EXTRA_PAYMENT_SESSION_DATA) ?: this.paymentSessionData - this.paymentSessionData = paymentSessionData - paymentSessionListener?.onPaymentSessionDataChanged(paymentSessionData) + viewModel.paymentSessionData = paymentSessionData + listener?.onPaymentSessionDataChanged(paymentSessionData) return true } else -> { @@ -137,49 +173,37 @@ class PaymentSession @VisibleForTesting internal constructor( } private fun dispatchUpdates() { - paymentSessionListener?.onPaymentSessionDataChanged(paymentSessionData) - paymentSessionListener?.onCommunicatingStateChanged(false) + listener?.onPaymentSessionDataChanged(paymentSessionData) + listener?.onCommunicatingStateChanged(false) } private fun persistPaymentMethodResult( paymentMethod: PaymentMethod?, useGooglePay: Boolean ) { - customerSession.cachedCustomer?.id?.let { customerId -> - paymentSessionPrefs.saveSelectedPaymentMethodId(customerId, paymentMethod?.id) - } - paymentSessionData = paymentSessionData.copy( - paymentMethod = paymentMethod, - useGooglePay = useGooglePay - ) + viewModel.persistPaymentMethodResult(paymentMethod, useGooglePay) } /** - * Initialize the PaymentSession with a [PaymentSessionListener] to be notified of - * data changes. + * Initialize the [PaymentSession] with a [PaymentSessionListener] to be notified of + * data changes. The reference to the [listener] will be released when the host (i.e. + * `Activity` or `Fragment`) is destroyed. + * + * If the [PaymentSessionConfig.shouldPrefetchCustomer] is true, a new Customer instance + * will be fetched. Otherwise, the [listener] will be immediately called with the current + * [paymentSessionData]. * - * @param listener a [PaymentSessionListener] that will receive notifications of changes - * in payment session status, including networking status - * @param savedInstanceState a `Bundle` containing the saved state of a - * PaymentSession that was stored in [savePaymentSessionInstanceState] + * @param listener a [PaymentSessionListener] */ - @JvmOverloads fun init( - listener: PaymentSessionListener, - savedInstanceState: Bundle? = null + listener: PaymentSessionListener ) { - if (savedInstanceState == null) { - customerSession.resetUsageTokens() - } - customerSession.addProductUsageTokenIfValid(TOKEN_PAYMENT_SESSION) - - paymentSessionListener = listener - - paymentSessionData = savedInstanceState?.getParcelable(STATE_PAYMENT_SESSION_DATA) - ?: PaymentSessionData(config) + this.listener = listener if (config.shouldPrefetchCustomer) { fetchCustomer() + } else { + dispatchUpdates() } } @@ -217,24 +241,7 @@ class PaymentSession @VisibleForTesting internal constructor( @VisibleForTesting internal fun getSelectedPaymentMethodId(userSelectedPaymentMethodId: String?): String? { - return userSelectedPaymentMethodId - ?: if (paymentSessionData.paymentMethod != null) { - paymentSessionData.paymentMethod?.id - } else { - customerSession.cachedCustomer?.id?.let { customerId -> - paymentSessionPrefs.getSelectedPaymentMethodId(customerId) - } - } - } - - /** - * Save the data associated with this PaymentSession. This should be called in the host's - * `onSaveInstanceState(Bundle)` method. - * - * @param outState the host activity's outgoing `Bundle` - */ - fun savePaymentSessionInstanceState(outState: Bundle) { - outState.putParcelable(STATE_PAYMENT_SESSION_DATA, paymentSessionData) + return viewModel.getSelectedPaymentMethodId(userSelectedPaymentMethodId) } /** @@ -244,7 +251,7 @@ class PaymentSession @VisibleForTesting internal constructor( * a customer's cart */ fun setCartTotal(@IntRange(from = 0) cartTotal: Long) { - paymentSessionData = paymentSessionData.copy(cartTotal = cartTotal) + viewModel.updateCartTotal(cartTotal) } /** @@ -261,31 +268,20 @@ class PaymentSession @VisibleForTesting internal constructor( ) } - /** - * Should be called during the host `Activity`'s onDestroy to detach listeners. - */ - fun onDestroy() { - paymentSessionListener = null - } - private fun fetchCustomer() { - paymentSessionListener?.onCommunicatingStateChanged(true) + listener?.onCommunicatingStateChanged(true) - customerSession.retrieveCurrentCustomer( - object : CustomerSession.CustomerRetrievalListener { - override fun onCustomerRetrieved(customer: Customer) { + viewModel.fetchCustomer().observe(lifecycleOwner, Observer { + when (it) { + PaymentSessionViewModel.FetchCustomerResult.Success -> { dispatchUpdates() } - - override fun onError( - errorCode: Int, - errorMessage: String, - stripeError: StripeError? - ) { - paymentSessionListener?.onError(errorCode, errorMessage) - paymentSessionListener?.onCommunicatingStateChanged(false) + is PaymentSessionViewModel.FetchCustomerResult.Error -> { + listener?.onError(it.errorCode, it.errorMessage) + listener?.onCommunicatingStateChanged(false) } - }) + } + }) } /** @@ -334,8 +330,6 @@ class PaymentSession @VisibleForTesting internal constructor( internal const val EXTRA_PAYMENT_SESSION_DATA: String = "extra_payment_session_data" - private const val STATE_PAYMENT_SESSION_DATA: String = "state_payment_session_data" - private val VALID_REQUEST_CODES = setOf( PaymentMethodsActivityStarter.REQUEST_CODE, PaymentFlowActivityStarter.REQUEST_CODE diff --git a/stripe/src/main/java/com/stripe/android/PaymentSessionData.kt b/stripe/src/main/java/com/stripe/android/PaymentSessionData.kt index 21f6ce81b9a..1489bfd1afa 100644 --- a/stripe/src/main/java/com/stripe/android/PaymentSessionData.kt +++ b/stripe/src/main/java/com/stripe/android/PaymentSessionData.kt @@ -52,7 +52,7 @@ data class PaymentSessionData internal constructor( */ val isPaymentReadyToCharge: Boolean get() = - (paymentMethod != null || useGooglePay) && config != null && + (paymentMethod != null || useGooglePay) && (!config.isShippingInfoRequired || shippingInformation != null) && (!config.isShippingMethodRequired || shippingMethod != null) } diff --git a/stripe/src/main/java/com/stripe/android/PaymentSessionViewModel.kt b/stripe/src/main/java/com/stripe/android/PaymentSessionViewModel.kt new file mode 100644 index 00000000000..bd822029bf0 --- /dev/null +++ b/stripe/src/main/java/com/stripe/android/PaymentSessionViewModel.kt @@ -0,0 +1,106 @@ +package com.stripe.android + +import android.app.Application +import androidx.annotation.IntRange +import androidx.lifecycle.AndroidViewModel +import androidx.lifecycle.LiveData +import androidx.lifecycle.MutableLiveData +import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider +import com.stripe.android.model.Customer +import com.stripe.android.model.PaymentMethod + +internal class PaymentSessionViewModel( + application: Application, + var paymentSessionData: PaymentSessionData, + private val customerSession: CustomerSession, + private val paymentSessionPrefs: PaymentSessionPrefs = + PaymentSessionPrefs.create(application.applicationContext) +) : AndroidViewModel(application) { + init { + customerSession.resetUsageTokens() + customerSession.addProductUsageTokenIfValid(PaymentSession.TOKEN_PAYMENT_SESSION) + } + + @JvmSynthetic + fun updateCartTotal(@IntRange(from = 0) cartTotal: Long) { + paymentSessionData = paymentSessionData.copy(cartTotal = cartTotal) + } + + @JvmSynthetic + fun persistPaymentMethodResult( + paymentMethod: PaymentMethod?, + useGooglePay: Boolean + ) { + customerSession.cachedCustomer?.id?.let { customerId -> + paymentSessionPrefs.saveSelectedPaymentMethodId(customerId, paymentMethod?.id) + } + paymentSessionData = paymentSessionData.copy( + paymentMethod = paymentMethod, + useGooglePay = useGooglePay + ) + } + + @JvmSynthetic + fun onCompleted() { + customerSession.resetUsageTokens() + } + + @JvmSynthetic + fun fetchCustomer(): LiveData { + val resultData: MutableLiveData = MutableLiveData() + customerSession.retrieveCurrentCustomer( + object : CustomerSession.CustomerRetrievalListener { + override fun onCustomerRetrieved(customer: Customer) { + resultData.value = FetchCustomerResult.Success + } + + override fun onError( + errorCode: Int, + errorMessage: String, + stripeError: StripeError? + ) { + resultData.value = FetchCustomerResult.Error( + errorCode, errorMessage, stripeError + ) + } + } + ) + return resultData + } + + @JvmSynthetic + fun getSelectedPaymentMethodId(userSelectedPaymentMethodId: String? = null): String? { + return userSelectedPaymentMethodId + ?: if (paymentSessionData.paymentMethod != null) { + paymentSessionData.paymentMethod?.id + } else { + customerSession.cachedCustomer?.id?.let { customerId -> + paymentSessionPrefs.getSelectedPaymentMethodId(customerId) + } + } + } + + sealed class FetchCustomerResult { + object Success : FetchCustomerResult() + class Error( + val errorCode: Int, + val errorMessage: String, + val stripeError: StripeError? + ) : FetchCustomerResult() + } + + internal class Factory( + private val application: Application, + private val paymentSessionData: PaymentSessionData, + private val customerSession: CustomerSession + ) : ViewModelProvider.Factory { + override fun create(modelClass: Class): T { + return PaymentSessionViewModel( + application, + paymentSessionData, + customerSession + ) as T + } + } +} diff --git a/stripe/src/test/java/com/stripe/android/PaymentSessionTest.kt b/stripe/src/test/java/com/stripe/android/PaymentSessionTest.kt index 6f48e63ca68..eb4c4456c77 100644 --- a/stripe/src/test/java/com/stripe/android/PaymentSessionTest.kt +++ b/stripe/src/test/java/com/stripe/android/PaymentSessionTest.kt @@ -1,11 +1,15 @@ package com.stripe.android -import android.app.Activity import android.app.Activity.RESULT_CANCELED import android.app.Activity.RESULT_OK import android.content.Context import android.content.Intent -import android.os.Bundle +import androidx.activity.ComponentActivity +import androidx.lifecycle.Lifecycle +import androidx.lifecycle.LifecycleOwner +import androidx.lifecycle.ViewModelStore +import androidx.lifecycle.ViewModelStoreOwner +import androidx.test.core.app.ActivityScenario import androidx.test.core.app.ApplicationProvider import com.google.common.truth.Truth.assertThat import com.nhaarman.mockitokotlin2.KArgumentCaptor @@ -32,10 +36,10 @@ import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFalse +import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue import org.junit.runner.RunWith -import org.mockito.Mockito.`when` import org.mockito.Mockito.doAnswer import org.mockito.Mockito.mock import org.mockito.Mockito.never @@ -58,7 +62,6 @@ class PaymentSessionTest { ActivityStarter = mock() private val paymentFlowActivityStarter: ActivityStarter = mock() - private val paymentSessionPrefs: PaymentSessionPrefs = mock() private val paymentSessionDataArgumentCaptor: KArgumentCaptor = argumentCaptor() @@ -73,21 +76,23 @@ class PaymentSessionTest { invocation.getArgument(0).run() null }.`when`(threadPoolExecutor).execute(any()) + + CustomerSession.instance = createCustomerSession() } @Test fun init_addsPaymentSessionToken_andFetchesCustomer() { - val customerSession = createCustomerSession() - CustomerSession.instance = customerSession - createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) paymentSession.init(paymentSessionListener) - assertTrue(customerSession.productUsageTokens - .contains(PaymentSession.TOKEN_PAYMENT_SESSION)) + assertEquals( + setOf(PaymentSession.TOKEN_PAYMENT_SESSION), + CustomerSession.getInstance().productUsageTokens + ) - verify(paymentSessionListener).onCommunicatingStateChanged(eq(true)) + verify(paymentSessionListener) + .onCommunicatingStateChanged(eq(true)) } } @@ -95,7 +100,6 @@ class PaymentSessionTest { fun init_whenEphemeralKeyProviderContinues_fetchesCustomerAndNotifiesListener() { ephemeralKeyProvider .setNextRawEphemeralKey(EphemeralKeyFixtures.FIRST_JSON) - CustomerSession.instance = createCustomerSession() createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) @@ -111,8 +115,6 @@ class PaymentSessionTest { @Test fun handlePaymentData_whenPaymentMethodSelected_notifiesListenerAndFetchesCustomer() { - CustomerSession.instance = createCustomerSession() - createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) paymentSession.init(paymentSessionListener) @@ -138,8 +140,6 @@ class PaymentSessionTest { @Test fun handlePaymentData_whenGooglePaySelected_notifiesListenerAndFetchesCustomer() { - CustomerSession.instance = createCustomerSession() - createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) paymentSession.init(paymentSessionListener) @@ -165,8 +165,6 @@ class PaymentSessionTest { @Test fun selectPaymentMethod_launchesPaymentMethodsActivityWithLog() { - CustomerSession.instance = createCustomerSession() - createActivity { activity -> val paymentSession = PaymentSession(activity, DEFAULT_CONFIG) paymentSession.init(paymentSessionListener) @@ -189,8 +187,6 @@ class PaymentSessionTest { @Test fun presentPaymentMethodSelection_withShouldRequirePostalCode_shouldPassInIntent() { - CustomerSession.instance = createCustomerSession() - createActivity { activity -> val paymentSession = PaymentSession( activity, @@ -218,13 +214,6 @@ class PaymentSessionTest { } } - @Test - fun getSelectedPaymentMethodId_whenPrefsNotSet_returnsNull() { - `when`(customerSession.cachedCustomer).thenReturn(FIRST_CUSTOMER) - CustomerSession.instance = customerSession - assertNull(createPaymentSession().getSelectedPaymentMethodId(null)) - } - @Test fun getSelectedPaymentMethodId_whenHasPaymentSessionData_returnsExpectedId() { val paymentSession = createPaymentSession( @@ -238,19 +227,6 @@ class PaymentSessionTest { ) } - @Test - fun getSelectedPaymentMethodId_whenHasPrefsSet_returnsExpectedId() { - val customerId = requireNotNull(FIRST_CUSTOMER.id) - `when`(paymentSessionPrefs.getSelectedPaymentMethodId(customerId)) - .thenReturn("pm_12345") - - `when`(customerSession.cachedCustomer).thenReturn(FIRST_CUSTOMER) - CustomerSession.instance = customerSession - - assertEquals("pm_12345", - createPaymentSession().getSelectedPaymentMethodId(null)) - } - @Test fun getSelectedPaymentMethodId_whenHasUserSpecifiedPaymentMethod_returnsExpectedId() { val paymentSession = createPaymentSession( @@ -264,59 +240,45 @@ class PaymentSessionTest { @Test fun init_withoutSavedState_clearsLoggingTokensAndStartsWithPaymentSession() { - val customerSession = createCustomerSession() - CustomerSession.instance = customerSession + val customerSession = CustomerSession.getInstance() customerSession .addProductUsageTokenIfValid(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY) - assertEquals(1, customerSession.productUsageTokens.size) + assertEquals( + setOf(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY), + customerSession.productUsageTokens + ) createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) paymentSession.init(paymentSessionListener) // The init removes PaymentMethodsActivity, but then adds PaymentSession - val loggingTokens = customerSession.productUsageTokens - assertEquals(1, loggingTokens.size) - assertFalse(loggingTokens.contains(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY)) - assertTrue(loggingTokens.contains(PaymentSession.TOKEN_PAYMENT_SESSION)) - } - } - - @Test - fun init_withSavedStateBundle_doesNotClearLoggingTokens() { - val customerSession = createCustomerSession() - CustomerSession.instance = customerSession - customerSession - .addProductUsageTokenIfValid(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY) - assertEquals(1, customerSession.productUsageTokens.size) - - createActivity { - val paymentSession = PaymentSession(it, DEFAULT_CONFIG) - // If it is given any saved state at all, the tokens are not cleared out. - paymentSession.init(paymentSessionListener, Bundle()) - - val loggingTokens = customerSession.productUsageTokens - assertEquals(2, loggingTokens.size) - assertTrue(loggingTokens.contains(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY)) - assertTrue(loggingTokens.contains(PaymentSession.TOKEN_PAYMENT_SESSION)) + assertEquals( + setOf(PaymentSession.TOKEN_PAYMENT_SESSION), + customerSession.productUsageTokens + ) } } @Test fun completePayment_withLoggedActions_clearsLoggingTokensAndSetsResult() { - val customerSession = createCustomerSession() - CustomerSession.instance = customerSession + val customerSession = CustomerSession.getInstance() customerSession .addProductUsageTokenIfValid(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY) - assertEquals(1, customerSession.productUsageTokens.size) + assertEquals( + setOf(PaymentMethodsActivity.TOKEN_PAYMENT_METHODS_ACTIVITY), + customerSession.productUsageTokens + ) createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) // If it is given any saved state at all, the tokens are not cleared out. - paymentSession.init(paymentSessionListener, Bundle()) + paymentSession.init(paymentSessionListener) - val loggingTokens = customerSession.productUsageTokens - assertEquals(2, loggingTokens.size) + assertEquals( + setOf(PaymentSession.TOKEN_PAYMENT_SESSION), + customerSession.productUsageTokens + ) reset(paymentSessionListener) @@ -328,7 +290,6 @@ class PaymentSessionTest { @Test fun init_withSavedState_setsPaymentSessionData() { ephemeralKeyProvider.setNextRawEphemeralKey(EphemeralKeyFixtures.FIRST_JSON) - CustomerSession.instance = createCustomerSession() createActivity { val paymentSession = PaymentSession(it, DEFAULT_CONFIG) @@ -337,14 +298,12 @@ class PaymentSessionTest { verify(paymentSessionListener) .onPaymentSessionDataChanged(paymentSessionDataArgumentCaptor.capture()) - val bundle = Bundle() - paymentSession.savePaymentSessionInstanceState(bundle) val firstPaymentSessionData = paymentSessionDataArgumentCaptor.firstValue val secondListener = mock(PaymentSession.PaymentSessionListener::class.java) - paymentSession.init(secondListener, bundle) + paymentSession.init(secondListener) verify(secondListener) .onPaymentSessionDataChanged(paymentSessionDataArgumentCaptor.capture()) @@ -358,9 +317,11 @@ class PaymentSessionTest { @Test fun handlePaymentData_withInvalidRequestCode_aborts() { - val paymentSession = createPaymentSession() - assertFalse(paymentSession.handlePaymentData(-1, RESULT_CANCELED, Intent())) - verify(customerSession, never()).retrieveCurrentCustomer(any()) + createActivity { + val paymentSession = PaymentSession(it, DEFAULT_CONFIG) + assertFalse(paymentSession.handlePaymentData(-1, RESULT_CANCELED, Intent())) + verify(customerSession, never()).retrieveCurrentCustomer(any()) + } } @Test @@ -379,17 +340,32 @@ class PaymentSessionTest { verify(customerSession).retrieveCurrentCustomer(any()) } + @Test + fun onDestroy_shouldReleaseListener() { + createActivityScenario { activityScenario -> + activityScenario.onActivity { activity -> + val paymentSession = PaymentSession(activity, DEFAULT_CONFIG) + paymentSession.init(paymentSessionListener) + assertNotNull(paymentSession.listener) + activityScenario.moveToState(Lifecycle.State.DESTROYED) + assertNull(paymentSession.listener) + } + } + } + private fun createPaymentSession( config: PaymentSessionConfig = DEFAULT_CONFIG, paymentSessionData: PaymentSessionData = PaymentSessionData(config) ): PaymentSession { return PaymentSession( context, + ApplicationProvider.getApplicationContext(), + ViewModelStoreOwner { ViewModelStore() }, + LifecycleOwner { mock() }, config, customerSession, paymentMethodsActivityStarter, paymentFlowActivityStarter, - paymentSessionPrefs, paymentSessionData ) } @@ -408,13 +384,17 @@ class PaymentSessionTest { ) } - private fun createActivity(callback: (Activity) -> Unit) { - // start an arbitrary Activity + private fun createActivity(callback: (ComponentActivity) -> Unit) { + // start an arbitrary ComponentActivity + createActivityScenario { it.onActivity(callback) } + } + + private fun createActivityScenario( + callback: (ActivityScenario) -> Unit + ) { activityScenarioFactory.create( PaymentMethodsActivityStarter.Args.DEFAULT - ).use { activityScenario -> - activityScenario.onActivity(callback) - } + ).use(callback) } private class FakeStripeRepository : AbsFakeStripeRepository() { diff --git a/stripe/src/test/java/com/stripe/android/PaymentSessionViewModelTest.kt b/stripe/src/test/java/com/stripe/android/PaymentSessionViewModelTest.kt new file mode 100644 index 00000000000..58f2658b45d --- /dev/null +++ b/stripe/src/test/java/com/stripe/android/PaymentSessionViewModelTest.kt @@ -0,0 +1,70 @@ +package com.stripe.android + +import androidx.test.core.app.ApplicationProvider +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.verify +import com.stripe.android.model.Customer +import com.stripe.android.model.CustomerFixtures +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import org.junit.runner.RunWith +import org.mockito.Mockito +import org.robolectric.RobolectricTestRunner + +@RunWith(RobolectricTestRunner::class) +class PaymentSessionViewModelTest { + private val customerSession: CustomerSession = mock() + private val paymentSessionPrefs: PaymentSessionPrefs = mock() + + private val viewModel: PaymentSessionViewModel by lazy { + PaymentSessionViewModel( + ApplicationProvider.getApplicationContext(), + PaymentSessionFixtures.PAYMENT_SESSION_DATA, + customerSession, + paymentSessionPrefs + ) + } + + @Test + fun init_shouldUpdateProductUsage() { + viewModel.paymentSessionData + + verify(customerSession).resetUsageTokens() + verify(customerSession).addProductUsageTokenIfValid( + PaymentSession.TOKEN_PAYMENT_SESSION + ) + } + + @Test + fun updateCartTotal_shouldUpdatePaymentSessionData() { + viewModel.updateCartTotal(5000) + assertEquals( + 5000, + viewModel.paymentSessionData.cartTotal + ) + } + + @Test + fun getSelectedPaymentMethodId_whenPrefsNotSet_returnsNull() { + Mockito.`when`(customerSession.cachedCustomer) + .thenReturn(FIRST_CUSTOMER) + assertNull(viewModel.getSelectedPaymentMethodId(null)) + } + + @Test + fun getSelectedPaymentMethodId_whenHasPrefsSet_returnsExpectedId() { + val customerId = requireNotNull(FIRST_CUSTOMER.id) + Mockito.`when`(paymentSessionPrefs.getSelectedPaymentMethodId(customerId)) + .thenReturn("pm_12345") + + Mockito.`when`(customerSession.cachedCustomer).thenReturn(FIRST_CUSTOMER) + CustomerSession.instance = customerSession + + assertEquals("pm_12345", viewModel.getSelectedPaymentMethodId()) + } + + private companion object { + private val FIRST_CUSTOMER = CustomerFixtures.CUSTOMER + } +}