Skip to content

Commit

Permalink
Refactor getServingInfo tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caiocamatta-stripe committed Feb 27, 2024
1 parent acc83c2 commit 354067a
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.mockito.stubbing.Answer
import org.mockito.{Answers, ArgumentCaptor}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import org.junit.Assert.assertSame

import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext, Future}
Expand Down Expand Up @@ -149,41 +148,44 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper {
actualRequest.get.keys shouldBe query.keyMapping.get
}

// updateServingInfo() is called when the batch response is from the KV store.
@Test
def test_getServingInfo_ShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val spiedFetcherBase = spy(baseFetcher)
val oldServingInfo = mock[GroupByServingInfoParsed]
val updatedServingInfo = mock[GroupByServingInfoParsed]
val oldServingInfo: GroupByServingInfoParsed = mock[GroupByServingInfoParsed]
val updatedServingInfo: GroupByServingInfoParsed = mock[GroupByServingInfoParsed]
when(fetcherBase.updateServingInfo(any(), any())).thenReturn(updatedServingInfo)

// Prepare a KV store response
val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L)))
val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess)
doReturn(updatedServingInfo).when(spiedFetcherBase).updateServingInfo(any(), any())

// updateServingInfo is called
val result = spiedFetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)
assertSame(result, updatedServingInfo)
verify(spiedFetcherBase).updateServingInfo(any(), any())
val result = fetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)

// updateServingInfo() is called
verify(fetcherBase).updateServingInfo(any(), any())
result shouldEqual updatedServingInfo
}

// If a batch response is cached, the serving info should be refreshed. This is needed to prevent
// the serving info from becoming stale if all the requests are cached.
@Test
def test_getServingInfo_ShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val spiedFetcherBase = spy(baseFetcher)
val oldServingInfo = mock[GroupByServingInfoParsed]
val metaData = mock[MetaData]
val groupByOpsMock = mock[GroupByOps]
val ttlCacheMock: TTLCache[String, Try[GroupByServingInfoParsed]] =
mock[TTLCache[String, Try[GroupByServingInfoParsed]]]
when(fetcherBase.getGroupByServingInfo).thenReturn(ttlCacheMock)
val groupByServingInfoParsedMock: GroupByServingInfoParsed = mock[GroupByServingInfoParsed]
when(ttlCacheMock.refresh(any[String])).thenReturn(Success(groupByServingInfoParsedMock))
val metaDataMock: MetaData = mock[MetaData]
val groupByOpsMock: GroupByOps = mock[GroupByOps]
when(groupByOpsMock.metaData).thenReturn(metaDataMock)
when(groupByServingInfoParsedMock.groupByOps).thenReturn(groupByOpsMock)

val cachedBatchResponses = BatchResponses(mock[FinalBatchIr])
val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]]
doReturn(ttlCache).when(spiedFetcherBase).getGroupByServingInfo
doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String])
metaData.name = "test"
groupByOpsMock.metaData = metaData
when(oldServingInfo.groupByOps).thenReturn(groupByOpsMock)
val result = fetcherBase.getServingInfo(groupByServingInfoParsedMock, cachedBatchResponses)

// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
val result = spiedFetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)
assertSame(result, oldServingInfo)
verify(ttlCache).refresh(any())
verify(spiedFetcherBase, never()).updateServingInfo(any(), any())
result shouldEqual groupByServingInfoParsedMock
verify(ttlCacheMock).refresh(any[String])
verify(fetcherBase, never()).updateServingInfo(any(), any())
}
}

0 comments on commit 354067a

Please sign in to comment.