From 18ae3cdd87cfb96827ca8f9c6906735dc11d6c3e Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Fri, 22 Nov 2024 18:36:07 +0100 Subject: [PATCH] [OPIK-441] Split feedback score name endpoints --- .../java/com/comet/opik/api/DataPoint.java | 3 +- .../api/metrics/ProjectMetricRequest.java | 3 +- .../api/metrics/ProjectMetricResponse.java | 4 +- .../queryparams/UUIDListParamConverter.java | 29 ++ .../UUIDListParamConverterProvider.java | 22 ++ .../v1/priv/ExperimentsResource.java | 29 ++ .../v1/priv/FeedbackScoreResource.java | 71 ----- .../api/resources/v1/priv/SpansResource.java | 33 ++ .../api/resources/v1/priv/TracesResource.java | 32 ++ .../comet/opik/domain/FeedbackScoreDAO.java | 106 ++++++- .../opik/domain/FeedbackScoreService.java | 20 +- .../opik/domain/ProjectMetricsService.java | 3 +- .../java/com/comet/opik/domain/SpanDAO.java | 3 +- .../comet/opik/domain/cost/ModelPrice.java | 81 +++-- .../opik/domain/cost/SpanCostCalculator.java | 3 +- .../resources/ProjectResourceClient.java | 19 ++ .../utils/resources/TraceResourceClient.java | 71 ++++- .../v1/priv/FeedbackScoreResourceTest.java | 3 +- .../v1/priv/ProjectMetricsResourceTest.java | 5 +- .../resources/v1/priv/TracesResourceTest.java | 297 +++++------------- 20 files changed, 503 insertions(+), 334 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverter.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverterProvider.java delete mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResource.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/DataPoint.java b/apps/opik-backend/src/main/java/com/comet/opik/api/DataPoint.java index ab1bfee407..a87988e694 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/DataPoint.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/DataPoint.java @@ -5,4 +5,5 @@ import java.time.Instant; @Builder(toBuilder = true) -public record DataPoint(Instant time, Number value) {} +public record DataPoint(Instant time, Number value) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricRequest.java b/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricRequest.java index 4c0ed5da17..b3a3ff4870 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricRequest.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricRequest.java @@ -16,4 +16,5 @@ public record ProjectMetricRequest( @NonNull MetricType metricType, @NonNull TimeInterval interval, Instant intervalStart, - Instant intervalEnd) {} + Instant intervalEnd) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricResponse.java b/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricResponse.java index 32ac97cb11..aeb4744ca3 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricResponse.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/metrics/ProjectMetricResponse.java @@ -7,7 +7,6 @@ import com.fasterxml.jackson.databind.annotation.JsonNaming; import lombok.Builder; -import java.time.Instant; import java.util.List; import java.util.UUID; @@ -27,5 +26,6 @@ public record ProjectMetricResponse( @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) - public record Results(String name, List data) {} + public record Results(String name, List data) { + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverter.java b/apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverter.java new file mode 100644 index 0000000000..53bb9bbfa7 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverter.java @@ -0,0 +1,29 @@ +package com.comet.opik.api.queryparams; + +import jakarta.ws.rs.ext.ParamConverter; + +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +public class UUIDListParamConverter implements ParamConverter> { + + @Override + public List fromString(String value) { + if (value == null || value.trim().isEmpty()) { + return List.of(); // Return an empty list if no value is provided + } + return Arrays.stream(value.split(",")) + .map(String::trim) + .map(UUID::fromString) + .collect(Collectors.toList()); + } + + @Override + public String toString(List value) { + return value.stream() + .map(UUID::toString) + .collect(Collectors.joining(",")); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverterProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverterProvider.java new file mode 100644 index 0000000000..9e9c004097 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/queryparams/UUIDListParamConverterProvider.java @@ -0,0 +1,22 @@ +package com.comet.opik.api.queryparams; + +import jakarta.ws.rs.ext.ParamConverter; +import jakarta.ws.rs.ext.ParamConverterProvider; +import jakarta.ws.rs.ext.Provider; + +import java.lang.reflect.Type; +import java.util.List; + +@Provider +public class UUIDListParamConverterProvider implements ParamConverterProvider { + + @Override + @SuppressWarnings("unchecked") + public ParamConverter getConverter(Class rawType, Type genericType, + java.lang.annotation.Annotation[] annotations) { + if (rawType.equals(List.class) && genericType.getTypeName().contains("UUID")) { + return (ParamConverter) new UUIDListParamConverter(); + } + return null; + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java index 225ea9c6d7..c2d02126a6 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java @@ -8,8 +8,11 @@ import com.comet.opik.api.ExperimentItemsDelete; import com.comet.opik.api.ExperimentSearchCriteria; import com.comet.opik.api.ExperimentsDelete; +import com.comet.opik.api.FeedbackDefinition; +import com.comet.opik.api.FeedbackScoreNames; import com.comet.opik.domain.ExperimentItemService; import com.comet.opik.domain.ExperimentService; +import com.comet.opik.domain.FeedbackScoreService; import com.comet.opik.domain.IdGenerator; import com.comet.opik.domain.Streamer; import com.comet.opik.infrastructure.auth.RequestContext; @@ -48,6 +51,7 @@ import lombok.extern.slf4j.Slf4j; import org.glassfish.jersey.server.ChunkedOutput; +import java.util.List; import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; @@ -66,6 +70,7 @@ public class ExperimentsResource { private final @NonNull ExperimentService experimentService; private final @NonNull ExperimentItemService experimentItemService; + private final @NonNull FeedbackScoreService feedbackScoreService; private final @NonNull Provider requestContext; private final @NonNull IdGenerator idGenerator; private final @NonNull Streamer streamer; @@ -241,4 +246,28 @@ public Response deleteExperimentItems( log.info("Deleted experiment items, count '{}'", request.ids().size()); return Response.noContent().build(); } + + @GET + @Path("/feedback-scores/names") + @Operation(operationId = "findFeedbackScoreNames", summary = "Find Feedback Score names", description = "Find Feedback Score names", responses = { + @ApiResponse(responseCode = "200", description = "Feedback Scores resource", content = @Content(array = @ArraySchema(schema = @Schema(implementation = String.class)))) + }) + @JsonView({FeedbackDefinition.View.Public.class}) + public Response findFeedbackScoreNames(@QueryParam("experiment_ids") List experimentIds) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Find feedback score names by experiment_ids '{}', on workspaceId '{}'", + experimentIds, workspaceId); + FeedbackScoreNames feedbackScoreNames = feedbackScoreService + .getExperimentsFeedbackScoreNames(experimentIds) + .map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList()) + .map(FeedbackScoreNames::new) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + log.info("Found feedback score names '{}' by experiment_ids '{}', on workspaceId '{}'", + feedbackScoreNames.scores().size(), experimentIds, workspaceId); + + return Response.ok(feedbackScoreNames).build(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResource.java deleted file mode 100644 index d06a3cc9e6..0000000000 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResource.java +++ /dev/null @@ -1,71 +0,0 @@ -package com.comet.opik.api.resources.v1.priv; - -import com.codahale.metrics.annotation.Timed; -import com.comet.opik.api.FeedbackDefinition; -import com.comet.opik.api.FeedbackScoreNames; -import com.comet.opik.domain.FeedbackScoreService; -import com.comet.opik.infrastructure.auth.RequestContext; -import com.fasterxml.jackson.annotation.JsonView; -import io.swagger.v3.oas.annotations.Operation; -import io.swagger.v3.oas.annotations.media.ArraySchema; -import io.swagger.v3.oas.annotations.media.Content; -import io.swagger.v3.oas.annotations.media.Schema; -import io.swagger.v3.oas.annotations.responses.ApiResponse; -import io.swagger.v3.oas.annotations.tags.Tag; -import jakarta.inject.Inject; -import jakarta.inject.Provider; -import jakarta.ws.rs.Consumes; -import jakarta.ws.rs.GET; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.Produces; -import jakarta.ws.rs.QueryParam; -import jakarta.ws.rs.core.MediaType; -import jakarta.ws.rs.core.Response; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; - -import java.util.UUID; - -import static com.comet.opik.api.FeedbackScoreNames.ScoreName; -import static com.comet.opik.utils.AsyncUtils.setRequestContext; - -@Path("/v1/private/feedback-scores") -@Produces(MediaType.APPLICATION_JSON) -@Consumes(MediaType.APPLICATION_JSON) -@Timed -@Slf4j -@RequiredArgsConstructor(onConstructor_ = @Inject) -@Tag(name = "Feedback-scores", description = "Feedback scores related resources") -public class FeedbackScoreResource { - - private final @NonNull FeedbackScoreService feedbackScoreService; - private final @NonNull Provider requestContext; - - @GET - @Path("/names") - @Operation(operationId = "findFeedbackScoreNames", summary = "Find Feedback Score names", description = "Find Feedback Score names", responses = { - @ApiResponse(responseCode = "200", description = "Feedback Scores resource", content = @Content(array = @ArraySchema(schema = @Schema(implementation = String.class)))) - }) - @JsonView({FeedbackDefinition.View.Public.class}) - public Response findFeedbackScoreNames( - @QueryParam("project_id") UUID projectId, - @QueryParam("with_experiments_only") boolean withExperimentsOnly) { - - String workspaceId = requestContext.get().getWorkspaceId(); - - log.info("Find feedback score names by project_id '{}' and with_experiments_only '{}', on workspaceId '{}'", - projectId, withExperimentsOnly, workspaceId); - FeedbackScoreNames feedbackScoreNames = feedbackScoreService - .getFeedbackScoreNames(projectId, withExperimentsOnly) - .map(names -> names.stream().map(ScoreName::new).toList()) - .map(FeedbackScoreNames::new) - .contextWrite(ctx -> setRequestContext(ctx, requestContext)) - .block(); - log.info("Found feedback score names by project_id '{}' and with_experiments_only '{}', on workspaceId '{}'", - projectId, withExperimentsOnly, workspaceId); - - return Response.ok(feedbackScoreNames).build(); - } - -} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java index 2ed94d5d53..0526d9035c 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java @@ -2,8 +2,10 @@ import com.codahale.metrics.annotation.Timed; import com.comet.opik.api.DeleteFeedbackScore; +import com.comet.opik.api.FeedbackDefinition; import com.comet.opik.api.FeedbackScore; import com.comet.opik.api.FeedbackScoreBatch; +import com.comet.opik.api.FeedbackScoreNames; import com.comet.opik.api.Span; import com.comet.opik.api.SpanBatch; import com.comet.opik.api.SpanSearchCriteria; @@ -19,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; @@ -29,6 +32,7 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.BadRequestException; import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.DELETE; @@ -258,4 +262,33 @@ public Response scoreBatchOfSpans( return Response.noContent().build(); } + @GET + @Path("/feedback-scores/names") + @Operation(operationId = "findFeedbackScoreNames", summary = "Find Feedback Score names", description = "Find Feedback Score names", responses = { + @ApiResponse(responseCode = "200", description = "Feedback Scores resource", content = @Content(array = @ArraySchema(schema = @Schema(implementation = String.class)))) + }) + @JsonView({FeedbackDefinition.View.Public.class}) + public Response findFeedbackScoreNames(@QueryParam("project_id") UUID projectId, + @QueryParam("type") SpanType type) { + + if (projectId == null) { + throw new BadRequestException("project_id must be provided"); + } + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Find feedback score names by project_id '{}', on workspaceId '{}'", + projectId, workspaceId); + FeedbackScoreNames feedbackScoreNames = feedbackScoreService + .getSpanFeedbackScoreNames(projectId, type) + .map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList()) + .map(FeedbackScoreNames::new) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + log.info("Found feedback score names '{}' by project_id '{}', on workspaceId '{}'", + feedbackScoreNames.scores().size(), projectId, workspaceId); + + return Response.ok(feedbackScoreNames).build(); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index f39527a670..86370b1b90 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -2,8 +2,10 @@ import com.codahale.metrics.annotation.Timed; import com.comet.opik.api.DeleteFeedbackScore; +import com.comet.opik.api.FeedbackDefinition; import com.comet.opik.api.FeedbackScore; import com.comet.opik.api.FeedbackScoreBatch; +import com.comet.opik.api.FeedbackScoreNames; import com.comet.opik.api.Trace; import com.comet.opik.api.Trace.TracePage; import com.comet.opik.api.TraceBatch; @@ -20,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; @@ -29,6 +32,7 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.BadRequestException; import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.DELETE; @@ -290,4 +294,32 @@ public Response scoreBatchOfTraces( return Response.noContent().build(); } + @GET + @Path("/feedback-scores/names") + @Operation(operationId = "findFeedbackScoreNames", summary = "Find Feedback Score names", description = "Find Feedback Score names", responses = { + @ApiResponse(responseCode = "200", description = "Feedback Scores resource", content = @Content(array = @ArraySchema(schema = @Schema(implementation = String.class)))) + }) + @JsonView({FeedbackDefinition.View.Public.class}) + public Response findFeedbackScoreNames(@QueryParam("project_id") UUID projectId) { + + if (projectId == null) { + throw new BadRequestException("project_id must be provided"); + } + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Find feedback score names by project_id '{}', on workspaceId '{}'", + projectId, workspaceId); + FeedbackScoreNames feedbackScoreNames = feedbackScoreService + .getTraceFeedbackScoreNames(projectId) + .map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList()) + .map(FeedbackScoreNames::new) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + log.info("Found feedback score names '{}' by project_id '{}', on workspaceId '{}'", + feedbackScoreNames.scores().size(), projectId, workspaceId); + + return Response.ok(feedbackScoreNames).build(); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreDAO.java index 35b8bc5a8b..4e6d20be63 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreDAO.java @@ -67,7 +67,11 @@ Mono scoreEntity(EntityType entityType, UUID entityId, FeedbackScore score Mono scoreBatchOf(EntityType entityType, List scores); - Mono> getFeedbackScoreNames(UUID projectId, boolean withExperimentsOnly); + Mono> getTraceFeedbackScoreNames(UUID projectId); + + Mono> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type); + + Mono> getExperimentsFeedbackScoreNames(List experimentIds); } @Singleton @@ -155,7 +159,7 @@ AND entity_id IN ( ; """; - private static final String SELECT_FEEDBACK_SCORE_NAMES = """ + private static final String SELECT_TRACE_FEEDBACK_SCORE_NAMES = """ SELECT distinct name FROM ( @@ -184,6 +188,9 @@ INNER JOIN ( trace_id FROM experiment_items WHERE workspace_id = :workspace_id + + AND experiment_id IN :experiment_ids + ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) ei ON e.id = ei.experiment_id @@ -196,6 +203,34 @@ INNER JOIN ( ; """; + private final static String SELECT_SPAN_FEEDBACK_SCORE_NAMES = """ + SELECT + distinct name + FROM ( + SELECT + name + FROM feedback_scores + WHERE workspace_id = :workspace_id + AND project_id = :project_id + + AND entity_id IN ( + SELECT + id + FROM spans + WHERE workspace_id = :workspace_id + AND project_id = :project_id + AND type = :type + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) + + AND entity_type = 'span' + ORDER BY entity_id DESC, last_updated_at DESC + LIMIT 1 BY entity_id, name + ) AS names + ; + """; + private final @NonNull TransactionTemplateAsync asyncTemplate; @Override @@ -363,16 +398,33 @@ public Mono deleteByEntityIds( @Override @WithSpan - public Mono> getFeedbackScoreNames(UUID projectId, boolean withExperimentsOnly) { + public Mono> getTraceFeedbackScoreNames(@NonNull UUID projectId) { return asyncTemplate.nonTransaction(connection -> { - ST template = new ST(SELECT_FEEDBACK_SCORE_NAMES); + ST template = new ST(SELECT_TRACE_FEEDBACK_SCORE_NAMES); - bindTemplateParam(projectId, withExperimentsOnly, template); + bindTemplateParam(projectId, false, null, template); var statement = connection.createStatement(template.render()); - bindStatementParam(projectId, statement); + bindStatementParam(projectId, null, statement); + + return getNames(statement); + }); + } + + @Override + @WithSpan + public Mono> getExperimentsFeedbackScoreNames(List experimentIds) { + return asyncTemplate.nonTransaction(connection -> { + + ST template = new ST(SELECT_TRACE_FEEDBACK_SCORE_NAMES); + + bindTemplateParam(null, true, experimentIds, template); + + var statement = connection.createStatement(template.render()); + + bindStatementParam(null, experimentIds, statement); return makeMonoContextAware(bindWorkspaceIdToMono(statement)) .flatMapMany(result -> result.map((row, rowMetadata) -> row.get("name", String.class))) @@ -381,18 +433,56 @@ public Mono> getFeedbackScoreNames(UUID projectId, boolean withExpe }); } - private void bindStatementParam(UUID projectId, Statement statement) { + private static Mono> getNames(Statement statement) { + return makeMonoContextAware(bindWorkspaceIdToMono(statement)) + .flatMapMany(result -> result.map((row, rowMetadata) -> row.get("name", String.class))) + .distinct() + .collect(Collectors.toList()); + } + + @Override + @WithSpan + public Mono> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type) { + return asyncTemplate.nonTransaction(connection -> { + + ST template = new ST(SELECT_SPAN_FEEDBACK_SCORE_NAMES); + + if (type != null) { + template.add("type", type.name()); + } + + var statement = connection.createStatement(template.render()); + + statement.bind("project_id", projectId); + + if (type != null) { + statement.bind("type", type.name()); + } + + return getNames(statement); + }); + } + + private void bindStatementParam(UUID projectId, List experimentIds, Statement statement) { if (projectId != null) { statement.bind("project_id", projectId); } + + if (CollectionUtils.isNotEmpty(experimentIds)) { + statement.bind("experiment_ids", experimentIds.toArray(UUID[]::new)); + } } - private void bindTemplateParam(UUID projectId, boolean withExperimentsOnly, ST template) { + private void bindTemplateParam(UUID projectId, boolean withExperimentsOnly, List experimentIds, ST template) { if (projectId != null) { template.add("project_id", projectId); } template.add("with_experiments_only", withExperimentsOnly); + + if (CollectionUtils.isNotEmpty(experimentIds)) { + template.add("experiment_ids", experimentIds); + } } private Mono cascadeSpanDelete(Set traceIds, Connection connection) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java index e67cedfae3..038a805bbf 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java @@ -45,7 +45,11 @@ public interface FeedbackScoreService { Mono deleteSpanScore(UUID id, String tag); Mono deleteTraceScore(UUID id, String tag); - Mono> getFeedbackScoreNames(UUID projectId, boolean withExperimentsOnly); + Mono> getTraceFeedbackScoreNames(UUID projectId); + + Mono> getSpanFeedbackScoreNames(UUID projectId, SpanType type); + + Mono> getExperimentsFeedbackScoreNames(List experimentIds); } @Slf4j @@ -231,8 +235,18 @@ public Mono deleteTraceScore(UUID id, String name) { } @Override - public Mono> getFeedbackScoreNames(UUID projectId, boolean withExperimentsOnly) { - return dao.getFeedbackScoreNames(projectId, withExperimentsOnly); + public Mono> getTraceFeedbackScoreNames(@NonNull UUID projectId) { + return dao.getTraceFeedbackScoreNames(projectId); + } + + @Override + public Mono> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type) { + return dao.getSpanFeedbackScoreNames(projectId, type); + } + + @Override + public Mono> getExperimentsFeedbackScoreNames(List experimentIds) { + return dao.getExperimentsFeedbackScoreNames(experimentIds); } private Mono failWithNotFound(String errorMessage) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java index a613d407ff..164c2d3842 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java @@ -1,6 +1,5 @@ package com.comet.opik.domain; -import com.comet.opik.api.DataPoint; import com.comet.opik.api.metrics.ProjectMetricRequest; import com.comet.opik.api.metrics.ProjectMetricResponse; import com.comet.opik.infrastructure.db.TransactionTemplateAsync; @@ -37,7 +36,7 @@ public Mono getProjectMetrics(UUID projectId, ProjectMetr validate(request); return template.nonTransaction(connection -> projectMetricsDAO.getTraceCount(projectId, request, - connection) + connection) .map(dataPoints -> ProjectMetricResponse.builder() .projectId(projectId) .metricType(request.metricType()) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index c94012e224..d42b905291 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -633,7 +633,8 @@ private Publisher insert(List spans, Connection connecti .bind("model" + i, span.model() != null ? span.model() : "") .bind("provider" + i, span.provider() != null ? span.provider() : "") .bind("total_estimated_cost" + i, estimatedCost.toString()) - .bind("total_estimated_cost_version" + i, estimatedCost.compareTo(ZERO_COST) > 0 ? ESTIMATED_COST_VERSION : "") + .bind("total_estimated_cost_version" + i, + estimatedCost.compareTo(ZERO_COST) > 0 ? ESTIMATED_COST_VERSION : "") .bind("tags" + i, span.tags() != null ? span.tags().toArray(String[]::new) : new String[]{}) .bind("created_by" + i, userName) .bind("last_updated_by" + i, userName); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/ModelPrice.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/ModelPrice.java index 012586474c..78aebb93f3 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/ModelPrice.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/ModelPrice.java @@ -12,36 +12,63 @@ @Getter public enum ModelPrice { gpt_4o("gpt-4o", new BigDecimal("0.0000025"), new BigDecimal("0.000010"), SpanCostCalculator::textGenerationCost), - gpt_4o_2024_08_06("gpt-4o-2024-08-06", new BigDecimal("0.0000025"), new BigDecimal("0.000010"), SpanCostCalculator::textGenerationCost), - gpt_4o_audio_preview("gpt-4o-audio-preview", new BigDecimal("0.0000025"), new BigDecimal("0.000010"), SpanCostCalculator::textGenerationCost), - gpt_4o_audio_preview_2024_10_01("gpt-4o-audio-preview-2024-10-01", new BigDecimal("0.0000025"), new BigDecimal("0.000010"), SpanCostCalculator::textGenerationCost), - gpt_4o_2024_05_13("gpt-4o-2024-05-13", new BigDecimal("0.000005"), new BigDecimal("0.000015"), SpanCostCalculator::textGenerationCost), - gpt_4o_mini("gpt-4o-mini", new BigDecimal("0.00000015"), new BigDecimal("0.00000060"), SpanCostCalculator::textGenerationCost), - gpt_4o_mini_2024_07_18("gpt-4o-mini-2024-07-18", new BigDecimal("0.00000015"), new BigDecimal("0.00000060"), SpanCostCalculator::textGenerationCost), - o1_preview("o1-preview", new BigDecimal("0.000015"), new BigDecimal("0.000060"), SpanCostCalculator::textGenerationCost), - o1_preview_2024_09_12("o1-preview-2024-09-12", new BigDecimal("0.000015"), new BigDecimal("0.000060"), SpanCostCalculator::textGenerationCost), + gpt_4o_2024_08_06("gpt-4o-2024-08-06", new BigDecimal("0.0000025"), new BigDecimal("0.000010"), + SpanCostCalculator::textGenerationCost), + gpt_4o_audio_preview("gpt-4o-audio-preview", new BigDecimal("0.0000025"), new BigDecimal("0.000010"), + SpanCostCalculator::textGenerationCost), + gpt_4o_audio_preview_2024_10_01("gpt-4o-audio-preview-2024-10-01", new BigDecimal("0.0000025"), + new BigDecimal("0.000010"), SpanCostCalculator::textGenerationCost), + gpt_4o_2024_05_13("gpt-4o-2024-05-13", new BigDecimal("0.000005"), new BigDecimal("0.000015"), + SpanCostCalculator::textGenerationCost), + gpt_4o_mini("gpt-4o-mini", new BigDecimal("0.00000015"), new BigDecimal("0.00000060"), + SpanCostCalculator::textGenerationCost), + gpt_4o_mini_2024_07_18("gpt-4o-mini-2024-07-18", new BigDecimal("0.00000015"), new BigDecimal("0.00000060"), + SpanCostCalculator::textGenerationCost), + o1_preview("o1-preview", new BigDecimal("0.000015"), new BigDecimal("0.000060"), + SpanCostCalculator::textGenerationCost), + o1_preview_2024_09_12("o1-preview-2024-09-12", new BigDecimal("0.000015"), new BigDecimal("0.000060"), + SpanCostCalculator::textGenerationCost), o1_mini("o1-mini", new BigDecimal("0.000003"), new BigDecimal("0.000012"), SpanCostCalculator::textGenerationCost), - o1_mini_2024_09_12("o1-mini-2024-09-12", new BigDecimal("0.000003"), new BigDecimal("0.000012"), SpanCostCalculator::textGenerationCost), - gpt_4o_realtime_preview("gpt-4o-realtime-preview", new BigDecimal("0.000005"), new BigDecimal("0.000020"), SpanCostCalculator::textGenerationCost), - gpt_4o_realtime_preview_2024_10_01("gpt-4o-realtime-preview-2024-10-01", new BigDecimal("0.000005"), new BigDecimal("0.000020"), + o1_mini_2024_09_12("o1-mini-2024-09-12", new BigDecimal("0.000003"), new BigDecimal("0.000012"), + SpanCostCalculator::textGenerationCost), + gpt_4o_realtime_preview("gpt-4o-realtime-preview", new BigDecimal("0.000005"), new BigDecimal("0.000020"), + SpanCostCalculator::textGenerationCost), + gpt_4o_realtime_preview_2024_10_01("gpt-4o-realtime-preview-2024-10-01", new BigDecimal("0.000005"), + new BigDecimal("0.000020"), + SpanCostCalculator::textGenerationCost), + chatgpt_4o_latest("chatgpt-4o-latest", new BigDecimal("0.000005"), new BigDecimal("0.000015"), + SpanCostCalculator::textGenerationCost), + gpt_4_turbo("gpt-4-turbo", new BigDecimal("0.000010"), new BigDecimal("0.000030"), + SpanCostCalculator::textGenerationCost), + gpt_4_turbo_2024_04_09("gpt-4-turbo-2024-04-09", new BigDecimal("0.000010"), new BigDecimal("0.000030"), SpanCostCalculator::textGenerationCost), - chatgpt_4o_latest("chatgpt-4o-latest", new BigDecimal("0.000005"), new BigDecimal("0.000015"), SpanCostCalculator::textGenerationCost), - gpt_4_turbo("gpt-4-turbo", new BigDecimal("0.000010"), new BigDecimal("0.000030"), SpanCostCalculator::textGenerationCost), - gpt_4_turbo_2024_04_09("gpt-4-turbo-2024-04-09", new BigDecimal("0.000010"), new BigDecimal("0.000030"), SpanCostCalculator::textGenerationCost), gpt_4("gpt-4", new BigDecimal("0.000030"), new BigDecimal("0.000060"), SpanCostCalculator::textGenerationCost), - gpt_4_32k("gpt-4-32k", new BigDecimal("0.000060"), new BigDecimal("0.000120"), SpanCostCalculator::textGenerationCost), - gpt_4_0125_preview("gpt-4-0125-preview", new BigDecimal("0.000010"), new BigDecimal("0.000030"), SpanCostCalculator::textGenerationCost), - gpt_4_1106_preview("gpt-4-1106-preview", new BigDecimal("0.000010"), new BigDecimal("0.000030"), SpanCostCalculator::textGenerationCost), - gpt_4_vision_preview("gpt-4-vision-preview", new BigDecimal("0.000010"), new BigDecimal("0.000030"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo("gpt-3.5-turbo", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo_0125("gpt-3.5-turbo-0125", new BigDecimal("0.00000050"), new BigDecimal("0.0000015"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo_instruct("gpt-3.5-turbo-instruct", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo_1106("gpt-3.5-turbo-1106", new BigDecimal("0.000001"), new BigDecimal("0.000002"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo_0613("gpt-3.5-turbo-0613", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo_16k_0613("gpt-3.5-turbo-16k-0613", new BigDecimal("0.000003"), new BigDecimal("0.000004"), SpanCostCalculator::textGenerationCost), - gpt_3_5_turbo_0301("gpt-3.5-turbo-0301", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), SpanCostCalculator::textGenerationCost), - davinci_002("davinci-002", new BigDecimal("0.000005"), new BigDecimal("0.000002"), SpanCostCalculator::textGenerationCost), - babbage_002("babbage-002", new BigDecimal("0.0000004"), new BigDecimal("0.0000004"), SpanCostCalculator::textGenerationCost), + gpt_4_32k("gpt-4-32k", new BigDecimal("0.000060"), new BigDecimal("0.000120"), + SpanCostCalculator::textGenerationCost), + gpt_4_0125_preview("gpt-4-0125-preview", new BigDecimal("0.000010"), new BigDecimal("0.000030"), + SpanCostCalculator::textGenerationCost), + gpt_4_1106_preview("gpt-4-1106-preview", new BigDecimal("0.000010"), new BigDecimal("0.000030"), + SpanCostCalculator::textGenerationCost), + gpt_4_vision_preview("gpt-4-vision-preview", new BigDecimal("0.000010"), new BigDecimal("0.000030"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo("gpt-3.5-turbo", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo_0125("gpt-3.5-turbo-0125", new BigDecimal("0.00000050"), new BigDecimal("0.0000015"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo_instruct("gpt-3.5-turbo-instruct", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo_1106("gpt-3.5-turbo-1106", new BigDecimal("0.000001"), new BigDecimal("0.000002"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo_0613("gpt-3.5-turbo-0613", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo_16k_0613("gpt-3.5-turbo-16k-0613", new BigDecimal("0.000003"), new BigDecimal("0.000004"), + SpanCostCalculator::textGenerationCost), + gpt_3_5_turbo_0301("gpt-3.5-turbo-0301", new BigDecimal("0.0000015"), new BigDecimal("0.000002"), + SpanCostCalculator::textGenerationCost), + davinci_002("davinci-002", new BigDecimal("0.000005"), new BigDecimal("0.000002"), + SpanCostCalculator::textGenerationCost), + babbage_002("babbage-002", new BigDecimal("0.0000004"), new BigDecimal("0.0000004"), + SpanCostCalculator::textGenerationCost), DEFAULT("", new BigDecimal("0"), new BigDecimal("0"), SpanCostCalculator::defaultCost); private final String name; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/SpanCostCalculator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/SpanCostCalculator.java index d1fd69c3c8..3ac04c5644 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/SpanCostCalculator.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/cost/SpanCostCalculator.java @@ -9,7 +9,8 @@ class SpanCostCalculator { public static BigDecimal textGenerationCost(ModelPrice modelPrice, Map usage) { return modelPrice.getInputPrice().multiply(BigDecimal.valueOf(usage.getOrDefault("prompt_tokens", 0))) - .add(modelPrice.getOutputPrice().multiply(BigDecimal.valueOf(usage.getOrDefault("completion_tokens", 0)))); + .add(modelPrice.getOutputPrice() + .multiply(BigDecimal.valueOf(usage.getOrDefault("completion_tokens", 0)))); } public static BigDecimal defaultCost(ModelPrice modelPrice, Map usage) { diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ProjectResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ProjectResourceClient.java index 866126ffdc..a87f6c39f8 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ProjectResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ProjectResourceClient.java @@ -55,4 +55,23 @@ public Project getProject(UUID projectId, String apiKey, String workspaceName) { } } + public Project getByName(String projectName, String apiKey, String workspaceName) { + + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .queryParam("name", projectName) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_OK); + + return response.readEntity(Project.ProjectPage.class) + .content() + .stream() + .findFirst() + .orElseThrow(); + } + } + } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java index 456f2a09a6..9b609e87d8 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java @@ -1,16 +1,21 @@ package com.comet.opik.api.resources.utils.resources; +import com.comet.opik.api.FeedbackScore; import com.comet.opik.api.FeedbackScoreBatch; import com.comet.opik.api.FeedbackScoreBatchItem; import com.comet.opik.api.Trace; import com.comet.opik.api.TraceBatch; +import com.comet.opik.api.TracesDelete; +import com.comet.opik.api.resources.utils.TestUtils; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; import lombok.RequiredArgsConstructor; import org.apache.http.HttpStatus; import ru.vyarus.dropwizard.guice.test.ClientSupport; import java.util.List; +import java.util.UUID; import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; import static org.assertj.core.api.Assertions.assertThat; @@ -23,18 +28,27 @@ public class TraceResourceClient { private final ClientSupport client; private final String baseURI; - public void createTrace(Trace trace, String apiKey, String workspaceName) { + public UUID createTrace(Trace trace, String apiKey, String workspaceName) { try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) .request() + .accept(MediaType.APPLICATION_JSON_TYPE) .header(HttpHeaders.AUTHORIZATION, apiKey) .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(trace))) { assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_CREATED); + + var actualId = TestUtils.getIdFromLocation(response.getLocation()); + + if (trace.id() != null) { + assertThat(actualId).isEqualTo(trace.id()); + } + + return actualId; } } - public void feedbackScore(List score, String apiKey, String workspaceName) { + public void feedbackScores(List score, String apiKey, String workspaceName) { try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) .path("feedback-scores") @@ -47,6 +61,20 @@ public void feedbackScore(List score, String apiKey, Str } } + public void feedbackScore(UUID entityId, FeedbackScore score, String workspaceName, String apiKey) { + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .path(entityId.toString()) + .path("feedback-scores") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .put(Entity.json(score))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + public void batchCreateTraces(List traces, String apiKey, String workspaceName) { try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) .path("batch") @@ -55,8 +83,47 @@ public void batchCreateTraces(List traces, String apiKey, String workspac .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(TraceBatch.builder().traces(traces).build()))) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + public Trace getById(UUID id, String workspaceName, String apiKey) { + var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .path(id.toString()) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK); + return response.readEntity(Trace.class); + } + + public void deleteTrace(UUID id, String workspaceName, String apiKey) { + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .path(id.toString()) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .delete()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + public void deleteTraces(TracesDelete request, String workspaceName, String apiKey) { + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .path("delete") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(request))) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); assertThat(actualResponse.hasEntity()).isFalse(); } } + } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResourceTest.java index 25002b558d..5869b0ac95 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackScoreResourceTest.java @@ -129,6 +129,7 @@ class GetFeedbackScoreNames { void getFeedbackScoreNames__whenGetFeedbackScoreNames__thenReturnFeedbackScoreNames( boolean userProjectId, boolean withExperimentsOnly) { + // given var apiKey = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); @@ -225,7 +226,7 @@ private List> createMultiValueScores(List m .build()) .toList(); - traceResourceClient.feedbackScore(scores, apiKey, workspaceName); + traceResourceClient.feedbackScores(scores, apiKey, workspaceName); return scores; }).toList(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java index a8d30bee30..152d3a5867 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java @@ -2,10 +2,10 @@ import com.comet.opik.api.DataPoint; import com.comet.opik.api.TimeInterval; +import com.comet.opik.api.Trace; import com.comet.opik.api.metrics.MetricType; import com.comet.opik.api.metrics.ProjectMetricRequest; import com.comet.opik.api.metrics.ProjectMetricResponse; -import com.comet.opik.api.Trace; import com.comet.opik.api.resources.utils.AuthTestUtils; import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; import com.comet.opik.api.resources.utils.ClientSupportUtils; @@ -350,7 +350,8 @@ private void createTraces(String projectName, Instant marker, int count) { .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) .startTime(marker.plus(i, ChronoUnit.SECONDS)) - .build()).toList(); + .build()) + .toList(); traceResourceClient.batchCreateTraces(traces, API_KEY, WORKSPACE_NAME); } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 142aadad58..cdb2715a8b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -25,8 +25,10 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; -import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.api.resources.utils.resources.ExperimentResourceClient; +import com.comet.opik.api.resources.utils.resources.ProjectResourceClient; +import com.comet.opik.api.resources.utils.resources.TraceResourceClient; import com.comet.opik.domain.FeedbackScoreMapper; import com.comet.opik.domain.SpanType; import com.comet.opik.infrastructure.auth.RequestContext; @@ -43,6 +45,7 @@ import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; +import org.apache.http.HttpStatus; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -142,6 +145,9 @@ class TracesResourceTest { private String baseURI; private ClientSupport client; + private ProjectResourceClient projectResourceClient; + private ExperimentResourceClient experimentResourceClient; + private TraceResourceClient traceResourceClient; @BeforeAll void setUpAll(ClientSupport client, Jdbi jdbi) throws SQLException { @@ -159,6 +165,10 @@ void setUpAll(ClientSupport client, Jdbi jdbi) throws SQLException { ClientSupportUtils.config(client); mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); + + this.projectResourceClient = new ProjectResourceClient(this.client, baseURI, factory); + this.experimentResourceClient = new ExperimentResourceClient(this.client, baseURI, factory); + this.traceResourceClient = new TraceResourceClient(this.client, baseURI); } private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { @@ -171,31 +181,11 @@ void tearDownAll() { } private UUID getProjectId(String projectName, String workspaceName, String apiKey) { - return client.target("%s/v1/private/projects".formatted(baseURI)) - .queryParam("name", projectName) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .get() - .readEntity(Project.ProjectPage.class) - .content() - .stream() - .findFirst() - .orElseThrow() - .id(); + return projectResourceClient.getByName(projectName, apiKey, workspaceName).id(); } private UUID createProject(String projectName, String workspaceName, String apiKey) { - try (Response response = client.target("%s/v1/private/projects".formatted(baseURI)) - .queryParam("name", projectName) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(Project.builder().name(projectName).build()))) { - - assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(201); - return TestUtils.getIdFromLocation(response.getLocation()); - } + return projectResourceClient.createProject(projectName, apiKey, workspaceName); } @Nested @@ -749,10 +739,10 @@ void feedbackBatch__whenSessionTokenIsPresent__thenReturnProperResponse(String s private void assertExpectedResponseWithoutABody(boolean expected, Response actualResponse) { if (expected) { - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); assertThat(actualResponse.hasEntity()).isFalse(); } else { - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_UNAUTHORIZED); assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) .isEqualTo(UNAUTHORIZED_RESPONSE); } @@ -790,7 +780,8 @@ void findWithUsage() { .feedbackScores(null) .build()) .toList(); - batchCreateTracesAndAssert(traces, API_KEY, TEST_WORKSPACE); + + traceResourceClient.batchCreateTraces(traces, API_KEY, TEST_WORKSPACE); var traceIdToSpansMap = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -827,7 +818,8 @@ void findWithoutUsage() { .feedbackScores(null) .build()) .toList(); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -855,7 +847,8 @@ void findWithImageTruncation(JsonNode original, JsonNode expected, boolean trunc .metadata(original) .build()) .toList(); - batchCreateTracesAndAssert(traces, API_KEY, TEST_WORKSPACE); + + traceResourceClient.batchCreateTraces(traces, API_KEY, TEST_WORKSPACE); var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .queryParam("page", 1) @@ -3270,42 +3263,15 @@ void getTrace__whenTraceDoesNotExist__thenReturnNotFound() { } private UUID create(Trace trace, String apiKey, String workspaceName) { - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .request() - .accept(MediaType.APPLICATION_JSON_TYPE) - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(trace))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - assertThat(actualResponse.hasEntity()).isFalse(); - - var actualId = TestUtils.getIdFromLocation(actualResponse.getLocation()); - - if (trace.id() != null) { - assertThat(actualId).isEqualTo(trace.id()); - } - return actualId; - } + return traceResourceClient.createTrace(trace, apiKey, workspaceName); } private void create(UUID entityId, FeedbackScore score, String workspaceName, String apiKey) { - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .path(entityId.toString()) - .path("feedback-scores") - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .put(Entity.json(score))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); - assertThat(actualResponse.hasEntity()).isFalse(); - } + traceResourceClient.feedbackScore(entityId, score, workspaceName, apiKey); } private Trace getAndAssert(Trace expectedTrace, UUID projectId, String apiKey, String workspaceName) { - var actualResponse = getById(expectedTrace.id(), workspaceName, apiKey); - var actualTrace = actualResponse.readEntity(Trace.class); + var actualTrace = traceResourceClient.getById(expectedTrace.id(), workspaceName, apiKey); assertThat(actualTrace) .usingRecursiveComparison() @@ -3326,7 +3292,7 @@ private void getAndAssertTraceNotFound(UUID id, String apiKey, String testWorksp .header(WORKSPACE_HEADER, testWorkspace) .get(); - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(404); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NOT_FOUND); assertThat(actualResponse.hasEntity()).isTrue(); assertThat(actualResponse.readEntity(ErrorMessage.class).errors()) .allMatch(error -> Pattern.matches("Trace not found", error)); @@ -3347,17 +3313,7 @@ void createTrace() { .usage(null) .feedbackScores(null) .build(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)).request() - .accept(MediaType.APPLICATION_JSON_TYPE) - .header(HttpHeaders.AUTHORIZATION, API_KEY) - .header(WORKSPACE_HEADER, TEST_WORKSPACE) - .post(Entity.json(trace))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - assertThat(actualResponse.hasEntity()).isFalse(); - var actualId = TestUtils.getIdFromLocation(actualResponse.getLocation()); - assertThat(actualId).isEqualTo(id); - } + traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); @@ -3408,30 +3364,16 @@ void createWithMissingId() { @DisplayName("when project doesn't exist, then accept and create project") void create__whenProjectDoesNotExist__thenAcceptAndCreateProject() { - var workspaceName = generator.generate().toString(); var projectName = RandomStringUtils.randomAlphanumeric(10); var trace = factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) .build(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)).request() - .accept(MediaType.APPLICATION_JSON_TYPE) - .header(HttpHeaders.AUTHORIZATION, API_KEY) - .header(WORKSPACE_HEADER, TEST_WORKSPACE) - .post(Entity.json(trace))) { - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - } + traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); - var actualResponse = client.target("%s/v1/private/projects".formatted(baseURI)) - .queryParam("workspace_name", workspaceName) - .queryParam("name", projectName) - .request() - .header(HttpHeaders.AUTHORIZATION, API_KEY) - .header(WORKSPACE_HEADER, TEST_WORKSPACE) - .get(); + Project project = projectResourceClient.getByName(projectName, API_KEY, TEST_WORKSPACE); - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); - assertThat(actualResponse.readEntity(Project.ProjectPage.class).size()).isEqualTo(1); + assertThat(project).isNotNull(); } @Test @@ -3445,21 +3387,12 @@ void create__whenProjectNameIsNull__thenAcceptAndUseDefaultProject() { .projectName(null) .build(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)).request() - .accept(MediaType.APPLICATION_JSON_TYPE) - .header(HttpHeaders.AUTHORIZATION, API_KEY) - .header(WORKSPACE_HEADER, TEST_WORKSPACE) - .post(Entity.json(trace))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - } + traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); + var actualEntity = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); UUID projectId = getProjectId(DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); - var actualEntity = actualResponse.readEntity(Trace.class); assertThat(actualEntity.projectId()).isEqualTo(projectId); } @@ -3511,7 +3444,7 @@ void batch__whenCreateTraces__thenReturnNoContent() { .build()) .toList(); - batchCreateTracesAndAssert(expectedTraces, API_KEY, TEST_WORKSPACE); + traceResourceClient.batchCreateTraces(expectedTraces, API_KEY, TEST_WORKSPACE); getAndAssertPage(TEST_WORKSPACE, projectName, List.of(), List.of(), expectedTraces.reversed(), List.of(), API_KEY); @@ -3535,7 +3468,7 @@ void batch__whenTraceProjectNameIsNull__thenUserDefaultProjectAndReturnNoContent .build()) .toList(); - batchCreateTracesAndAssert(expectedTraces, apiKey, workspaceName); + traceResourceClient.batchCreateTraces(expectedTraces, apiKey, workspaceName); getAndAssertPage(workspaceName, DEFAULT_PROJECT, List.of(), List.of(), expectedTraces.reversed(), List.of(), apiKey); @@ -3617,21 +3550,7 @@ void batch__whenSendingMultipleTracesWithNoId__thenReturnNoContent() { List expectedTraces = List.of(newTrace, expectedTrace); - batchCreateTracesAndAssert(expectedTraces, API_KEY, TEST_WORKSPACE); - } - } - - private void batchCreateTracesAndAssert(List traces, String apiKey, String workspaceName) { - - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .path("batch") - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(TraceBatch.builder().traces(traces).build()))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); - assertThat(actualResponse.hasEntity()).isFalse(); + traceResourceClient.batchCreateTraces(expectedTraces, API_KEY, TEST_WORKSPACE); } } @@ -3667,7 +3586,8 @@ void delete() { .projectName(projectName) .usage(null) .build()); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -3684,7 +3604,8 @@ void delete() { .map(item -> FeedbackScoreMapper.INSTANCE.toFeedbackScoreBatchItem( trace.id(), projectName, item))) .toList(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(traceScores).build(), workspaceName, apiKey); + + traceResourceClient.feedbackScores(traceScores, apiKey, workspaceName); var spanScores = spans.stream() .flatMap(span -> span.feedbackScores().stream() @@ -3696,7 +3617,7 @@ void delete() { getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, spans.reversed(), List.of(), apiKey); - deleteAndAssert(traces.getFirst().id(), workspaceName, apiKey); + traceResourceClient.deleteTrace(traces.getFirst().id(), workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, List.of(), List.of(), apiKey); @@ -3715,7 +3636,8 @@ void deleteWithoutSpansScores() { .projectName(projectName) .usage(null) .build()); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -3733,12 +3655,13 @@ void deleteWithoutSpansScores() { .map(item -> FeedbackScoreMapper.INSTANCE.toFeedbackScoreBatchItem( trace.id(), projectName, item))) .toList(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(traceScores).build(), workspaceName, apiKey); + + traceResourceClient.feedbackScores(traceScores, apiKey, workspaceName); getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, spans.reversed(), List.of(), apiKey); - deleteAndAssert(traces.getFirst().id(), workspaceName, apiKey); + traceResourceClient.deleteTrace(traces.getFirst().id(), workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, List.of(), List.of(), apiKey); @@ -3758,7 +3681,8 @@ void deleteWithoutScores() { .usage(null) .feedbackScores(null) .build()); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -3774,7 +3698,7 @@ void deleteWithoutScores() { getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, spans.reversed(), List.of(), apiKey); - deleteAndAssert(traces.getFirst().id(), workspaceName, apiKey); + traceResourceClient.deleteTrace(traces.getFirst().id(), workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, List.of(), List.of(), apiKey); @@ -3794,11 +3718,12 @@ void deleteWithoutSpans() { .usage(null) .feedbackScores(null) .build()); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); - deleteAndAssert(traces.getFirst().id(), workspaceName, apiKey); + traceResourceClient.deleteTrace(traces.getFirst().id(), workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); } @@ -3813,7 +3738,7 @@ void delete__whenTraceDoesNotExist__thenNoContent() { var id = generator.generate(); - deleteAndAssert(id, workspaceName, apiKey); + traceResourceClient.deleteTrace(id, workspaceName, apiKey); getAndAssertTraceNotFound(id, apiKey, workspaceName); } @@ -3838,7 +3763,8 @@ void deleteTraces() { .usage(null) .build()) .toList(); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -3855,7 +3781,8 @@ void deleteTraces() { .map(item -> FeedbackScoreMapper.INSTANCE.toFeedbackScoreBatchItem( trace.id(), projectName, item))) .toList(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(traceScores).build(), workspaceName, apiKey); + + traceResourceClient.feedbackScores(traceScores, apiKey, workspaceName); var spanScores = spans.stream() .flatMap(span -> span.feedbackScores().stream() @@ -3870,7 +3797,8 @@ void deleteTraces() { var request = TracesDelete.builder() .ids(traces.stream().map(Trace::id).collect(Collectors.toUnmodifiableSet())) .build(); - deleteAndAssert(request, workspaceName, apiKey); + + traceResourceClient.deleteTraces(request, workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, List.of(), List.of(), apiKey); @@ -3891,7 +3819,7 @@ void deleteTracesWithoutSpansScores() { .usage(null) .build()) .toList(); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -3909,7 +3837,8 @@ void deleteTracesWithoutSpansScores() { .map(item -> FeedbackScoreMapper.INSTANCE.toFeedbackScoreBatchItem( trace.id(), projectName, item))) .toList(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(traceScores).build(), workspaceName, apiKey); + + traceResourceClient.feedbackScores(traceScores, apiKey, workspaceName); getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, spans.reversed(), List.of(), apiKey); @@ -3917,7 +3846,8 @@ void deleteTracesWithoutSpansScores() { var request = TracesDelete.builder() .ids(traces.stream().map(Trace::id).collect(Collectors.toUnmodifiableSet())) .build(); - deleteAndAssert(request, workspaceName, apiKey); + + traceResourceClient.deleteTraces(request, workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, List.of(), List.of(), apiKey); @@ -3939,7 +3869,8 @@ void deleteTracesWithoutScores() { .feedbackScores(null) .build()) .toList(); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); var spans = traces.stream() .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() @@ -3958,7 +3889,8 @@ void deleteTracesWithoutScores() { var request = TracesDelete.builder() .ids(traces.stream().map(Trace::id).collect(Collectors.toUnmodifiableSet())) .build(); - deleteAndAssert(request, workspaceName, apiKey); + + traceResourceClient.deleteTraces(request, workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); getAndAssertPageSpans(workspaceName, projectName, List.of(), spans, List.of(), List.of(), apiKey); @@ -3980,14 +3912,16 @@ void deleteTracesWithoutSpans() { .feedbackScores(null) .build()) .toList(); - batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); var request = TracesDelete.builder() .ids(traces.stream().map(Trace::id).collect(Collectors.toUnmodifiableSet())) .build(); - deleteAndAssert(request, workspaceName, apiKey); + + traceResourceClient.deleteTraces(request, workspaceName, apiKey); getAndAssertPage(workspaceName, projectName, List.of(), traces, List.of(), List.of(), apiKey); } @@ -4000,7 +3934,7 @@ void deleteTracesWithoutTraces() { mockTargetWorkspace(apiKey, workspaceName, workspaceId); var request = factory.manufacturePojo(TracesDelete.class); - deleteAndAssert(request, workspaceName, apiKey); + traceResourceClient.deleteTraces(request, workspaceName, apiKey); } } @@ -4063,9 +3997,8 @@ void when__traceDoesNotExist__thenReturnCreateIt() { runPatchAndAssertStatus(id, traceUpdate, API_KEY, TEST_WORKSPACE); - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); + var actualEntity = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); - var actualEntity = actualResponse.readEntity(Trace.class); assertThat(actualEntity.id()).isEqualTo(id); assertThat(actualEntity.input()).isEqualTo(traceUpdate.input()); @@ -4099,9 +4032,8 @@ void when__traceUpdateAndInsertAreProcessedOutOfOther__thenReturnTrace() { create(newTrace, API_KEY, TEST_WORKSPACE); - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); + var actualEntity = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); - var actualEntity = actualResponse.readEntity(Trace.class); assertThat(actualEntity.id()).isEqualTo(id); assertThat(actualEntity.input()).isEqualTo(traceUpdate.input()); @@ -4156,9 +4088,8 @@ void when__multipleTraceUpdateAndInsertAreProcessedOutOfOtherAndConcurrent__then var created = Instant.now(); - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); + var actualEntity = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); - var actualEntity = actualResponse.readEntity(Trace.class); assertThat(actualEntity.id()).isEqualTo(id); assertThat(actualEntity.endTime()).isEqualTo(traceUpdate3.endTime()); @@ -4199,8 +4130,7 @@ void update() { runPatchAndAssertStatus(id, traceUpdate, API_KEY, TEST_WORKSPACE); - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); - var actualEntity = actualResponse.readEntity(Trace.class); + var actualEntity = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); assertThat(actualEntity.id()).isEqualTo(id); assertThat(actualEntity.input()).isEqualTo(traceUpdate.input()); @@ -4386,44 +4316,6 @@ void update__whenUpdatingUsingProjectId__thenAcceptUpdate() { } - private Response getById(UUID id, String workspaceName, String apiKey) { - var response = client.target(URL_TEMPLATE.formatted(baseURI)) - .path(id.toString()) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .get(); - - assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(200); - return response; - } - - private void deleteAndAssert(UUID id, String workspaceName, String apiKey) { - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .path(id.toString()) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .delete()) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); - assertThat(actualResponse.hasEntity()).isFalse(); - } - } - - private void deleteAndAssert(TracesDelete request, String workspaceName, String apiKey) { - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .path("delete") - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(request))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); - assertThat(actualResponse.hasEntity()).isFalse(); - } - } - @Nested @DisplayName("Feedback:") @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -4618,11 +4510,7 @@ void deleteFeedback() { assertThat(actualResponse.hasEntity()).isFalse(); } - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); - - var actualEntity = actualResponse.readEntity(Trace.class); + var actualEntity = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); assertThat(actualEntity.feedbackScores()).isNull(); } @@ -4723,8 +4611,7 @@ void feedback() { .projectName(trace1.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - var feedbackScoreBatch = FeedbackScoreBatch.builder().scores(List.of(score1, score2, score3)).build(); - createAndAssertForTrace(feedbackScoreBatch, TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(List.of(score1, score2, score3), API_KEY, TEST_WORKSPACE); var projectId1 = getProjectId(trace1.projectName(), TEST_WORKSPACE, API_KEY); var projectId2 = getProjectId(trace2.projectName(), TEST_WORKSPACE, API_KEY); @@ -4780,8 +4667,7 @@ void feedback__whenWorkspaceIsSpecified__thenReturnNoContent() { .projectName(expectedTrace1.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - var feedbackScoreBatch = FeedbackScoreBatch.builder().scores(List.of(score1, score2, score3)).build(); - createAndAssertForTrace(feedbackScoreBatch, workspaceName, apiKey); + traceResourceClient.feedbackScores(List.of(score1, score2, score3), apiKey, workspaceName); var projectId1 = getProjectId(DEFAULT_PROJECT, workspaceName, apiKey); var projectId2 = getProjectId(projectName, workspaceName, apiKey); @@ -4834,8 +4720,7 @@ void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { .value(factory.manufacturePojo(BigDecimal.class)) .reason(null) .build(); - createAndAssertForTrace( - FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(List.of(score), API_KEY, TEST_WORKSPACE); var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); trace = trace.toBuilder() @@ -4865,8 +4750,7 @@ void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { .projectName(expectedTrace.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - createAndAssertForTrace( - FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(List.of(score), API_KEY, TEST_WORKSPACE); expectedTrace = expectedTrace.toBuilder() .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score))) @@ -4897,12 +4781,10 @@ void feedback__whenOverridingFeedbackValue__thenReturnNoContent() { .id(id) .projectName(trace.projectName()) .build(); - createAndAssertForTrace( - FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(List.of(score), API_KEY, TEST_WORKSPACE); var newScore = score.toBuilder().value(factory.manufacturePojo(BigDecimal.class)).build(); - createAndAssertForTrace( - FeedbackScoreBatch.builder().scores(List.of(newScore)).build(), TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(List.of(newScore), API_KEY, TEST_WORKSPACE); var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); trace = trace.toBuilder() @@ -4920,8 +4802,7 @@ void feedback__whenTraceDoesNotExist__thenReturnNoContentAndCreateScore() { .projectName(DEFAULT_PROJECT) .build(); - createAndAssertForTrace( - FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(List.of(score), API_KEY, TEST_WORKSPACE); } @Test @@ -4938,7 +4819,7 @@ void feedback__whenFeedbackSpanBatchHasMaxSize__thenReturnNoContentAndCreateScor .id(id) .build()) .toList(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(scores).build(), TEST_WORKSPACE, API_KEY); + traceResourceClient.feedbackScores(scores, API_KEY, TEST_WORKSPACE); } @Test @@ -4964,16 +4845,8 @@ void feedback__whenFeedbackTraceIdIsNotValid__thenReturn400() { } } - private void createAndAssertForTrace(FeedbackScoreBatch request, String workspaceName, String apiKey) { - createAndAssert(URL_TEMPLATE.formatted(baseURI), request, workspaceName, apiKey); - } - private void createAndAssertForSpan(FeedbackScoreBatch request, String workspaceName, String apiKey) { - createAndAssert(URL_TEMPLATE_SPANS.formatted(baseURI), request, workspaceName, apiKey); - } - - private void createAndAssert(String path, FeedbackScoreBatch request, String workspaceName, String apiKey) { - try (var actualResponse = client.target(path) + try (var actualResponse = client.target(URL_TEMPLATE_SPANS.formatted(baseURI)) .path("feedback-scores") .request() .header(HttpHeaders.AUTHORIZATION, apiKey)