Skip to content

Commit

Permalink
Merge pull request #234 from G12-Wanderwave/bugfix/unify-coroutine-sc…
Browse files Browse the repository at this point in the history
…opes

Unify usage of coroutine scope across project
  • Loading branch information
yzueger authored May 3, 2024
2 parents 008fde9 + 616cc8f commit 2146ce3
Show file tree
Hide file tree
Showing 17 changed files with 119 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ import com.google.android.gms.tasks.Task
import com.google.firebase.auth.AuthResult
import com.google.firebase.auth.FirebaseAuth
import io.mockk.called
import io.mockk.coEvery
import io.mockk.every
import io.mockk.impl.annotations.RelaxedMockK
import io.mockk.junit4.MockKRule
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import okhttp3.Call
import okhttp3.OkHttpClient
import org.junit.Before
Expand All @@ -40,10 +44,13 @@ class AuthenticationControllerTest {
AuthenticationUserData(
"testid", "testemail", "testDisplayName", "https://example.com/testphoto.jpg")

@OptIn(ExperimentalCoroutinesApi::class)
@Before
fun setup() {
val testDispatcher = UnconfinedTestDispatcher(TestCoroutineScheduler())
authenticationController =
AuthenticationController(mockFirebaseAuth, mockHttpClient, mockTokenRepository)
AuthenticationController(
mockFirebaseAuth, mockHttpClient, mockTokenRepository, testDispatcher)
}

fun setupDummyUserSignedIn() {
Expand Down Expand Up @@ -71,14 +78,14 @@ class AuthenticationControllerTest {
.trimIndent()
}
}
every { mockTokenRepository.setAuthToken(any(), any(), any()) } returns Unit
every {
coEvery { mockTokenRepository.setAuthToken(any(), any(), any()) } returns Unit
coEvery {
mockTokenRepository.getAuthToken(AuthTokenRepository.AuthTokenType.FIREBASE_TOKEN)
} returns "testtoken-firebase"
every {
coEvery {
mockTokenRepository.getAuthToken(AuthTokenRepository.AuthTokenType.SPOTIFY_ACCESS_TOKEN)
} returns "testtoken-spotify-access"
every {
coEvery {
mockTokenRepository.getAuthToken(AuthTokenRepository.AuthTokenType.SPOTIFY_REFRESH_TOKEN)
} returns "testtoken-spotify-refresh"
}
Expand All @@ -95,9 +102,12 @@ class AuthenticationControllerTest {

@Test
fun canSignOut() = runBlocking {
assert(authenticationController.isSignedIn())
every { mockFirebaseAuth.signOut() } returns Unit
authenticationController.deauthenticate()
verify { mockFirebaseAuth.signOut() }
every { mockFirebaseAuth.currentUser } returns null
assert(!authenticationController.isSignedIn())
}

@Test
Expand Down Expand Up @@ -163,5 +173,28 @@ class AuthenticationControllerTest {
val result = authenticationController.refreshTokenIfNecessary()
verify { mockFirebaseAuth.signInWithCustomToken("testtoken-firebase") }
assert(result)

every { mockFirebaseAuth.currentUser } returns mockFirebaseUser
val result2 = authenticationController.refreshTokenIfNecessary()
assert(result2)
verify { mockFirebaseAuth.signInWithCustomToken(any()) wasNot called }
}

@Test
fun failureCases() = runBlocking {
setupDummyUserSignedIn()
val call = mockk<Call>()
every { mockHttpClient.newCall(any()) } returns call
every { call.execute() } returns mockk { every { body } returns null }
every { mockFirebaseAuth.currentUser } returns null

assert(!authenticationController.refreshTokenIfNecessary())
verify { call.execute() }

coEvery { mockTokenRepository.getAuthToken(any()) } returns null

assert(!authenticationController.refreshTokenIfNecessary())

assert(!authenticationController.authenticate("%invalidcode%").first())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import junit.framework.TestCase.fail
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.withTimeout
import org.junit.Before
import org.junit.Rule
Expand Down Expand Up @@ -87,7 +89,9 @@ public class BeaconConnectionTest {
every { firestore.collection(any()) } returns collectionReference

// Pass the mock Firestore instance to your BeaconConnection
beaconConnection = BeaconConnection(firestore, trackConnection, profileConnection)
val testDispatcher = UnconfinedTestDispatcher(TestCoroutineScheduler())
beaconConnection =
BeaconConnection(firestore, trackConnection, profileConnection, testDispatcher)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import ch.epfl.cs311.wanderwave.model.data.Location
import ch.epfl.cs311.wanderwave.model.data.Profile
import ch.epfl.cs311.wanderwave.model.data.ProfileTrackAssociation
import ch.epfl.cs311.wanderwave.model.data.Track
import ch.epfl.cs311.wanderwave.model.remote.BeaconConnection
import ch.epfl.cs311.wanderwave.model.remote.ProfileConnection
import ch.epfl.cs311.wanderwave.model.remote.TrackConnection
import com.google.android.gms.maps.model.LatLng
Expand All @@ -25,7 +24,6 @@ import org.junit.Test
class DataClassesTest {
// Testing of all the data classes, I think it's better to test them all together
@get:Rule val mockkRule = MockKRule(this)
private lateinit var beaconConnection: BeaconConnection
private lateinit var trackConnection: TrackConnection
private lateinit var profileConnection: ProfileConnection

Expand All @@ -38,8 +36,6 @@ class DataClassesTest {
trackConnection = mockk<TrackConnection>(relaxed = true)
profileConnection = mockk<ProfileConnection>(relaxed = true)

beaconConnection =
BeaconConnection(trackConnection = trackConnection, profileConnection = profileConnection)
// Set up the document mock to return some tracks
every { document.id } returns "someId"
every { document["title"] } returns "someTitle"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit4.MockKRule
import io.mockk.verify
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import org.junit.Before
import org.junit.Rule
import org.junit.Test
Expand All @@ -31,7 +34,9 @@ class LocalAuthTokenRepositoryTest {
every { mockDatabase.authTokenDao() } returns mockAuthTokenDao
every { mockAuthTokenDao.setAuthToken(any()) } returns Unit
every { mockAuthTokenDao.deleteAuthToken(any()) } returns Unit
localAuthTokenRepository = LocalAuthTokenRepository(mockDatabase)

val testDispatcher = UnconfinedTestDispatcher(TestCoroutineScheduler())
localAuthTokenRepository = LocalAuthTokenRepository(mockDatabase, testDispatcher)

val now = System.currentTimeMillis() / 1000 + 3600

Expand All @@ -57,7 +62,7 @@ class LocalAuthTokenRepositoryTest {
}

@Test
fun canSetTokens() {
fun canSetTokens() = runBlocking {
localAuthTokenRepository.setAuthToken(
AuthTokenRepository.AuthTokenType.FIREBASE_TOKEN, "firebaseToken", 123L)

Expand Down Expand Up @@ -91,7 +96,7 @@ class LocalAuthTokenRepositoryTest {
}

@Test
fun canGetTokens() {
fun canGetTokens() = runBlocking {
val firebaseToken =
localAuthTokenRepository.getAuthToken(AuthTokenRepository.AuthTokenType.FIREBASE_TOKEN)
assert(firebaseToken == "firebaseToken")
Expand All @@ -108,7 +113,7 @@ class LocalAuthTokenRepositoryTest {
}

@Test
fun canDeleteTokens() {
fun canDeleteTokens() = runBlocking {
localAuthTokenRepository.deleteAuthToken(AuthTokenRepository.AuthTokenType.FIREBASE_TOKEN)
verify { mockAuthTokenDao.deleteAuthToken(AuthTokenRepository.AuthTokenType.FIREBASE_TOKEN.id) }

Expand All @@ -125,7 +130,7 @@ class LocalAuthTokenRepositoryTest {
}

@Test
fun doNotGetExpiredToken() {
fun doNotGetExpiredToken() = runBlocking {
every {
mockAuthTokenDao.getAuthToken(AuthTokenRepository.AuthTokenType.SPOTIFY_REFRESH_TOKEN.id)
} returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ProfileViewModelTest {
}

// Trigger the operations that will cause the song lists to be populated
viewModel.retrieveTracks(this)
viewModel.retrieveTracks()
// Wait for the job to complete which includes Flow collection
job.join()

Expand Down Expand Up @@ -140,8 +140,8 @@ class ProfileViewModelTest {
every {
spotifyController.getAllChildren(ListItem("id", "title", null, "subtitle", "", false, true))
} returns flowOf(listOf(expectedListItem))
viewModel.retrieveAndAddSubsection(this)
viewModel.retrieveChild(expectedListItem, this)
viewModel.retrieveAndAddSubsection()
viewModel.retrieveChild(expectedListItem)
advanceUntilIdle() // Ensure all coroutines are completed

// val result = viewModel.spotifySubsectionList.first() // Safely access the first item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import javax.inject.Singleton
import kotlinx.coroutines.Dispatchers
import okhttp3.OkHttpClient

@Module
Expand All @@ -21,6 +22,7 @@ object AuthenticationModule {
httpClient: OkHttpClient,
authenticationRepository: AuthTokenRepository
): AuthenticationController {
return AuthenticationController(Firebase.auth, httpClient, authenticationRepository)
return AuthenticationController(
Firebase.auth, httpClient, authenticationRepository, Dispatchers.IO)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import javax.inject.Singleton
import kotlinx.coroutines.Dispatchers

@Module
@InstallIn(SingletonComponent::class)
Expand All @@ -22,7 +23,9 @@ object ConnectionModule {
@Singleton
fun provideBeaconRepository(@ApplicationContext context: Context): BeaconRepository {
return BeaconConnection(
trackConnection = TrackConnection(), profileConnection = ProfileConnection())
trackConnection = TrackConnection(),
profileConnection = ProfileConnection(),
ioDispatcher = Dispatchers.IO)
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ object DatabaseModule {
@Singleton
fun provideAppDatabase(@ApplicationContext context: Context): AppDatabase {
return Room.databaseBuilder(context.applicationContext, AppDatabase::class.java, "app_database")
.allowMainThreadQueries()
.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import javax.inject.Singleton
import kotlinx.coroutines.Dispatchers

@Module
@InstallIn(SingletonComponent::class)
Expand All @@ -16,6 +17,6 @@ object RepositoryModule {
@Provides
@Singleton
fun provideAuthTokenRepository(appDatabase: AppDatabase): AuthTokenRepository {
return LocalAuthTokenRepository(appDatabase)
return LocalAuthTokenRepository(appDatabase, Dispatchers.IO)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package ch.epfl.cs311.wanderwave.model.auth
import ch.epfl.cs311.wanderwave.model.repository.AuthTokenRepository
import com.google.firebase.auth.FirebaseAuth
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.tasks.await
Expand All @@ -18,7 +18,8 @@ class AuthenticationController
constructor(
private val auth: FirebaseAuth,
private val httpClient: OkHttpClient,
private val tokenRepository: AuthTokenRepository
private val tokenRepository: AuthTokenRepository,
private val ioDispatcher: CoroutineDispatcher
) {

private val AUTH_SERVICE_URL = "https://us-central1-wanderwave-95743.cloudfunctions.net"
Expand Down Expand Up @@ -47,7 +48,7 @@ constructor(

suspend fun refreshTokenIfNecessary(): Boolean {
if (auth.currentUser == null) {
return withContext(Dispatchers.IO) { refreshSpotifyToken() }
return refreshSpotifyToken()
}
return true
}
Expand All @@ -66,7 +67,9 @@ constructor(
.post("code=$authenticationCode".toRequestBody())
.build()

val responseJson = httpClient.newCall(request).execute().body?.string() ?: return false
val responseJson =
withContext(ioDispatcher) { httpClient.newCall(request).execute().body?.string() }
?: return false
return storeAndUseNewTokens(responseJson)
}

Expand Down Expand Up @@ -104,7 +107,9 @@ constructor(
.post("refresh_token=$refreshToken".toRequestBody())
.build()

val responseJson = httpClient.newCall(request).execute().body?.string() ?: return false
val responseJson =
withContext(ioDispatcher) { httpClient.newCall(request).execute().body?.string() }
?: return false
return storeAndUseNewTokens(responseJson)
}

Expand All @@ -119,12 +124,4 @@ constructor(
fun deauthenticate() {
auth.signOut()
}

private data class TokenResponse(
val accessToken: String,
val refreshToken: String?,
val firebaseToken: String
)

private data class State(val isSignedIn: Boolean = false)
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
package ch.epfl.cs311.wanderwave.model.localDb

import ch.epfl.cs311.wanderwave.model.repository.AuthTokenRepository
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.withContext

class LocalAuthTokenRepository(database: AppDatabase) : AuthTokenRepository {
class LocalAuthTokenRepository(
private val database: AppDatabase,
private val ioDispatcher: CoroutineDispatcher
) : AuthTokenRepository {

private val authTokenDao = database.authTokenDao()

override fun getAuthToken(tokenType: AuthTokenRepository.AuthTokenType): String? {
return authTokenDao.getAuthToken(tokenType.id)?.let { authTokenEntity ->
if (authTokenEntity.expirationDate > System.currentTimeMillis() / 1000) {
authTokenEntity.token
} else {
authTokenDao.deleteAuthToken(tokenType.id)
null
override suspend fun getAuthToken(tokenType: AuthTokenRepository.AuthTokenType): String? {
return withContext(ioDispatcher) {
authTokenDao.getAuthToken(tokenType.id)?.let { authTokenEntity ->
if (authTokenEntity.expirationDate > System.currentTimeMillis() / 1000) {
authTokenEntity.token
} else {
authTokenDao.deleteAuthToken(tokenType.id)
null
}
}
}
}

override fun setAuthToken(
override suspend fun setAuthToken(
tokenType: AuthTokenRepository.AuthTokenType,
token: String,
expirationTime: Long
) {
authTokenDao.setAuthToken(AuthTokenEntity(token, expirationTime, tokenType.id))
withContext(ioDispatcher) {
authTokenDao.setAuthToken(AuthTokenEntity(token, expirationTime, tokenType.id))
}
}

override fun deleteAuthToken(tokenType: AuthTokenRepository.AuthTokenType) {
authTokenDao.deleteAuthToken(tokenType.id)
override suspend fun deleteAuthToken(tokenType: AuthTokenRepository.AuthTokenType) {
withContext(ioDispatcher) { authTokenDao.deleteAuthToken(tokenType.id) }
}
}
Loading

0 comments on commit 2146ce3

Please sign in to comment.