Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-470] Project Metrics - add cost #733

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ public enum MetricType {
TRACE_COUNT,
TOKEN_USAGE,
DURATION,
COST,
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@ImplementedBy(ProjectMetricsDAOImpl.class)
public interface ProjectMetricsDAO {
String NAME_TRACES = "traces";
String NAME_COST = "cost";

@Builder
record Entry(String name, Instant time, Number value) {
Expand All @@ -42,6 +43,7 @@ record Entry(String name, Instant time, Number value) {
Mono<List<Entry>> getTraceCount(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
Mono<List<Entry>> getFeedbackScores(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
Mono<List<Entry>> getTokenUsage(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
Mono<List<Entry>> getCost(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
}

@Slf4j
Expand Down Expand Up @@ -125,6 +127,30 @@ TO parseDateTimeBestEffort(:end_time)
STEP <step>;
""";

private static final String GET_COST = """
WITH spans_dedup AS (
SELECT t.start_time AS start_time,
s.total_estimated_cost AS value
FROM spans s
JOIN traces t ON spans.trace_id = t.id
WHERE project_id = :project_id
AND workspace_id = :workspace_id
ORDER BY s.id DESC, s.last_updated_at DESC
LIMIT 1 BY s.id
)
SELECT <bucket> AS bucket,
nullIf(sum(value), 0) AS value
FROM spans_dedup
WHERE start_time >= parseDateTime64BestEffort(:start_time, 9)
AND start_time \\<= parseDateTime64BestEffort(:end_time, 9)
GROUP BY bucket
ORDER BY bucket
WITH FILL
FROM <fill_from>
TO parseDateTimeBestEffort(:end_time)
STEP <step>;
""";

@Override
public Mono<List<Entry>> getTraceCount(@NonNull UUID projectId, @NonNull ProjectMetricRequest request) {
return template.nonTransaction(connection -> getMetric(projectId, request, connection,
Expand Down Expand Up @@ -156,6 +182,17 @@ public Mono<List<Entry>> getTokenUsage(@NonNull UUID projectId, @NonNull Project
.collectList());
}

@Override
public Mono<List<Entry>> getCost(@NonNull UUID projectId, @NonNull ProjectMetricRequest request) {
return template.nonTransaction(connection -> getMetric(projectId, request, connection,
GET_COST, "cost")
.flatMapMany(result -> rowToDataPoint(
result,
row -> NAME_COST,
row -> row.get("value", BigDecimal.class)))
.collectList());
}

private Mono<? extends Result> getMetric(
UUID projectId, ProjectMetricRequest request, Connection connection, String query, String segmentName) {
var template = new ST(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ private Optional<BiFunction<UUID, ProjectMetricRequest, Mono<List<ProjectMetrics
.of(
MetricType.TRACE_COUNT, projectMetricsDAO::getTraceCount,
MetricType.FEEDBACK_SCORES, projectMetricsDAO::getFeedbackScores,
MetricType.TOKEN_USAGE, projectMetricsDAO::getTokenUsage);
MetricType.TOKEN_USAGE, projectMetricsDAO::getTokenUsage,
MetricType.COST, projectMetricsDAO::getCost);

return Optional.ofNullable(HANDLER_BY_TYPE.get(metricType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.comet.opik.api.resources.utils.resources.TraceResourceClient;
import com.comet.opik.domain.ProjectMetricsDAO;
import com.comet.opik.domain.ProjectMetricsService;
import com.comet.opik.domain.cost.ModelPrice;
import com.comet.opik.infrastructure.DatabaseAnalyticsFactory;
import com.comet.opik.podam.PodamFactoryUtils;
import com.github.tomakehurst.wiremock.client.WireMock;
Expand Down Expand Up @@ -574,6 +575,87 @@ private Map<String, Long> createSpans(
}
}

@Nested
@DisplayName("Cost")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class CostTest {
@ParameterizedTest
@EnumSource(TimeInterval.class)
void happyPath(TimeInterval interval) {
// setup
mockTargetWorkspace();

Instant marker = getIntervalStart(interval);
String projectName = RandomStringUtils.randomAlphabetic(10);
var projectId = projectResourceClient.createProject(projectName, API_KEY, WORKSPACE_NAME);

var costMinus3 = Map.of(ProjectMetricsDAO.NAME_COST,
createSpans(projectName, subtract(marker, TIME_BUCKET_3, interval)));
var costMinus1 = Map.of(ProjectMetricsDAO.NAME_COST,
createSpans(projectName, subtract(marker, TIME_BUCKET_1, interval)));
var costCurrent = Map.of(ProjectMetricsDAO.NAME_COST, createSpans(projectName, marker));

getMetricsAndAssert(projectId, ProjectMetricRequest.builder()
.metricType(MetricType.COST)
.interval(interval)
.intervalStart(subtract(marker, TIME_BUCKET_4, interval))
.intervalEnd(Instant.now())
.build(), marker, List.of(ProjectMetricsDAO.NAME_COST), BigDecimal.class, costMinus3, costMinus1,
costCurrent);
}

@ParameterizedTest
@EnumSource(TimeInterval.class)
void emptyData(TimeInterval interval) {
// setup
mockTargetWorkspace();

Instant marker = getIntervalStart(interval);
String projectName = RandomStringUtils.randomAlphabetic(10);
var projectId = projectResourceClient.createProject(projectName, API_KEY, WORKSPACE_NAME);
Map<String, BigDecimal> empty = new HashMap<>() {
{
put(ProjectMetricsDAO.NAME_COST, null);
}
};

getMetricsAndAssert(projectId, ProjectMetricRequest.builder()
.metricType(MetricType.COST)
.interval(interval)
.intervalStart(subtract(marker, TIME_BUCKET_4, interval))
.intervalEnd(Instant.now())
.build(), marker, List.of(ProjectMetricsDAO.NAME_COST), BigDecimal.class, empty, empty, empty);
}

private BigDecimal createSpans(
String projectName, Instant marker) {
var MODEL_NAME = "gpt-3.5-turbo";

List<Trace> traces = IntStream.range(0, 5)
.mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder()
.projectName(projectName)
.startTime(marker.plusSeconds(i))
.build())
.toList();
traceResourceClient.batchCreateTraces(traces, API_KEY, WORKSPACE_NAME);

List<Span> spans = traces.stream()
.map(trace -> factory.manufacturePojo(Span.class).toBuilder()
.projectName(projectName)
.model(MODEL_NAME)
.usage(Map.of(
"prompt_tokens", RANDOM.nextInt(),
"completion_tokens", RANDOM.nextInt()))
.traceId(trace.id())
.build())
.toList();

spanResourceClient.batchCreateSpans(spans, API_KEY, WORKSPACE_NAME);
return spans.stream().map(span -> ModelPrice.fromString(MODEL_NAME).calculateCost(span.usage()))
.reduce(BigDecimal.ZERO, BigDecimal::add);
}
}

private <T extends Number> ProjectMetricResponse<T> getProjectMetrics(
UUID projectId, ProjectMetricRequest request, Class<T> aClass) {
try (var response = client.target(URL_TEMPLATE.formatted(baseURI, projectId))
Expand Down