Skip to content

Commit

Permalink
Refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko committed Nov 26, 2024
1 parent e171b44 commit 2ea154e
Showing 1 changed file with 20 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.comet.opik.api.resources.utils.resources.TraceResourceClient;
import com.comet.opik.domain.FeedbackScoreMapper;
import com.comet.opik.domain.SpanType;
import com.comet.opik.domain.cost.ModelPrice;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.podam.PodamFactoryUtils;
import com.comet.opik.utils.JsonUtils;
Expand Down Expand Up @@ -3226,14 +3227,12 @@ void getTraceWithUsage() {

@ParameterizedTest
@MethodSource
void getTraceWithCost(BigDecimal spanExpectedCost, String model) {
BigDecimal traceExpectedCost = spanExpectedCost == null ? null : spanExpectedCost.multiply(BigDecimal.valueOf(5));
void getTraceWithCost(String model) {
var projectName = RandomStringUtils.randomAlphanumeric(10);
var trace = factory.manufacturePojo(Trace.class)
.toBuilder()
.id(null)
.projectName(projectName)
.usage(Map.of("completion_tokens", 200 * 5L, "prompt_tokens", 300 * 5L, "total_tokens", 4 * 5L))
.feedbackScores(null)
.build();
var id = create(trace, API_KEY, TEST_WORKSPACE);
Expand All @@ -3242,24 +3241,36 @@ void getTraceWithCost(BigDecimal spanExpectedCost, String model) {
.map(spanInStream -> spanInStream.toBuilder()
.projectName(projectName)
.traceId(id)
.usage(Map.of("completion_tokens", 200, "prompt_tokens", 300, "total_tokens", 4))
.usage(Map.of("completion_tokens", Math.abs(factory.manufacturePojo(Integer.class)),
"prompt_tokens", Math.abs(factory.manufacturePojo(Integer.class))))
.model(model)
.build())
.collect(Collectors.toList());

var usage = spans.stream()
.flatMap(span -> span.usage().entrySet().stream())
.map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), Long.valueOf(entry.getValue())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Long::sum));

BigDecimal traceExpectedCost = spans.stream()
.map(span -> ModelPrice.fromString(span.model()).calculateCost(span.usage()))
.reduce(BigDecimal.ZERO, BigDecimal::add);

batchCreateSpansAndAssert(spans, API_KEY, TEST_WORKSPACE);

var projectId = getProjectId(projectName, TEST_WORKSPACE, API_KEY);
trace = trace.toBuilder().id(id).build();
trace = trace.toBuilder().id(id).usage(usage).build();
Trace createdTrace = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE);
assertThat(createdTrace.totalEstimatedCost()).isEqualTo(traceExpectedCost);
assertThat(traceExpectedCost.compareTo(BigDecimal.ZERO) == 0 ?
createdTrace.totalEstimatedCost() == null :
traceExpectedCost.compareTo(createdTrace.totalEstimatedCost()) == 0)
.isEqualTo(true);
}

static Stream<Arguments> getTraceWithCost() {
return Stream.of(
Arguments.of(new BigDecimal("0.00070000"), "gpt-3.5-turbo-1106"),
Arguments.of(null, "unknown-model"),
Arguments.of(null, null));
Arguments.of("gpt-3.5-turbo-1106"),
Arguments.of("unknown-model"));
}

@Test
Expand Down

0 comments on commit 2ea154e

Please sign in to comment.