From ccbf8c399102bd8934a64f07391018cc39c5ea73 Mon Sep 17 00:00:00 2001 From: Andres Cruz Date: Tue, 3 Sep 2024 11:48:13 +0200 Subject: [PATCH] OPIK-42 Experiment compare for single supplied experiment --- .../resources/v1/priv/DatasetsResource.java | 7 +- .../com/comet/opik/domain/DatasetItemDAO.java | 275 +++++++++++++++--- .../comet/opik/domain/DatasetItemService.java | 13 +- .../v1/priv/DatasetsResourceTest.java | 206 ++++++++++++- 4 files changed, 446 insertions(+), 55 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index 5f1284a173..e90b80455f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -347,12 +347,13 @@ public Response findDatasetItemsWithExperimentItems( .entityType(FeedbackScoreDAO.EntityType.TRACE) .build(); - log.info("Finding dataset items with experiment items by '{}'", datasetItemSearchCriteria); + log.info("Finding dataset items with experiment items by '{}', page '{}', size '{}'", + datasetItemSearchCriteria, page, size); var datasetItemPage = itemService.getItems(page, size, datasetItemSearchCriteria) .contextWrite(ctx -> setRequestContext(ctx, requestContext)) .block(); - log.info("Found dataset items with experiment items by '{}', count '{}'", - datasetItemSearchCriteria, datasetItemPage.content().size()); + log.info("Found dataset items with experiment items by '{}', count '{}', page '{}', size '{}'", + datasetItemSearchCriteria, datasetItemPage.content().size(), page, size); return Response.ok(datasetItemPage).build(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java index 9913f5bba5..824a0deb21 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java @@ -10,6 +10,7 @@ import com.comet.opik.utils.AsyncUtils; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.base.Preconditions; import com.google.inject.ImplementedBy; import io.r2dbc.spi.Result; import io.r2dbc.spi.Statement; @@ -51,6 +52,9 @@ public interface DatasetItemDAO { Mono getItems(DatasetItemSearchCriteria datasetItemSearchCriteria, int page, int size); + Mono getItemsFromSingleExperiment( + DatasetItemSearchCriteria datasetItemSearchCriteria, int page, int size); + Mono get(UUID id); Flux getItems(UUID datasetId, int limit, UUID lastRetrievedId); @@ -68,7 +72,7 @@ class DatasetItemDAOImpl implements DatasetItemDAO { * This query is used to insert/update a dataset item into the database. * 1. The query uses a multiIf function to determine the value of the dataset_id field and validate if it matches with the previous value. * 2. The query uses a multiIf function to determine the value of the created_at field and validate if it matches with the previous value to avoid duplication of rows. - * */ + */ private static final String INSERT_DATASET_ITEM = """ INSERT INTO dataset_items ( id, @@ -90,7 +94,7 @@ INSERT INTO dataset_items ( LENGTH(CAST(old.dataset_id AS Nullable(String))) > 0 AND notEquals(old.dataset_id, new.dataset_id), leftPad('', 40, '*'), LENGTH(CAST(old.dataset_id AS Nullable(String))) > 0, old.dataset_id, new.dataset_id - ) as dataset_id, + ) AS dataset_id, new.source, new.trace_id, new.span_id, @@ -100,16 +104,16 @@ INSERT INTO dataset_items ( multiIf( notEquals(old.created_at, toDateTime64('1970-01-01 00:00:00.000', 9)), old.created_at, new.created_at - ) as created_at, + ) AS created_at, multiIf( LENGTH(old.workspace_id) > 0 AND notEquals(old.workspace_id, new.workspace_id), CAST(leftPad('', 40, '*') AS FixedString(19)), LENGTH(old.workspace_id) > 0, old.workspace_id, new.workspace_id - ) as workspace_id, + ) AS workspace_id, if( LENGTH(old.created_by) > 0, old.created_by, new.created_by - ) as created_by, + ) AS created_by, new.last_updated_by FROM ( SELECT @@ -141,7 +145,7 @@ LEFT JOIN ( private static final String SELECT_DATASET_ITEM = """ SELECT *, - null as experiment_items_array + null AS experiment_items_array FROM dataset_items WHERE id = :id AND workspace_id = :workspace_id @@ -153,7 +157,7 @@ LEFT JOIN ( private static final String SELECT_DATASET_ITEMS_STREAM = """ SELECT *, - null as experiment_items_array + null AS experiment_items_array FROM dataset_items WHERE dataset_id = :datasetId AND workspace_id = :workspace_id @@ -174,7 +178,7 @@ LEFT JOIN ( private static final String SELECT_DATASET_ITEMS = """ SELECT *, - null as experiment_items_array + null AS experiment_items_array FROM dataset_items WHERE dataset_id = :datasetId AND workspace_id = :workspace_id @@ -186,7 +190,25 @@ LEFT JOIN ( private static final String SELECT_DATASET_ITEMS_COUNT = """ SELECT - count(id) as count + count(id) AS count + FROM ( + SELECT + id + FROM dataset_items + WHERE dataset_id = :datasetId + AND workspace_id = :workspace_id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS lastRows + ; + """; + + /** + * Counts dataset items only if there's a matching experiment item. + */ + private static final String SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_FROM_SINGLE_EXPERIMENT_COUNT = """ + SELECT + COUNT(DISTINCT di.id) AS count FROM ( SELECT id @@ -195,7 +217,18 @@ LEFT JOIN ( AND workspace_id = :workspace_id ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id - ) as lastRows + ) AS di + INNER JOIN ( + SELECT + dataset_item_id + FROM experiment_items + WHERE experiment_id = :experimentId + AND workspace_id = :workspace_id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS ei ON di.id = ei.dataset_item_id + GROUP BY + di.id ; """; @@ -240,30 +273,30 @@ LEFT JOIN ( */ private static final String SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS = """ SELECT - di.id as id, - di.input as input, - di.expected_output as expected_output, - di.metadata as metadata, - di.trace_id as trace_id, - di.span_id as span_id, - di.source as source, - di.created_at as created_at, - di.last_updated_at as last_updated_at, - di.created_by as created_by, - di.last_updated_by as last_updated_by, + di.id AS id, + di.input AS input, + di.expected_output AS expected_output, + di.metadata AS metadata, + di.trace_id AS trace_id, + di.span_id AS span_id, + di.source AS source, + di.created_at AS created_at, + di.last_updated_at AS last_updated_at, + di.created_by AS created_by, + di.last_updated_by AS last_updated_by, groupArray(tuple( ei.id, ei.experiment_id, ei.dataset_item_id, ei.trace_id, - t.input, - t.output, - t.feedback_scores_array, + tfs.input, + tfs.output, + tfs.feedback_scores_array, ei.created_at, ei.last_updated_at, ei.created_by, ei.last_updated_by - )) as experiment_items_array + )) AS experiment_items_array FROM ( SELECT * @@ -272,21 +305,21 @@ LEFT JOIN ( AND workspace_id = :workspace_id ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id - ) as di + ) AS di LEFT JOIN ( SELECT * FROM experiment_items - WHERE experiment_id in :experiment_ids + WHERE experiment_id in :experimentIds AND workspace_id = :workspace_id ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id - ) as ei ON di.id = ei.dataset_item_id + ) AS ei ON di.id = ei.dataset_item_id LEFT JOIN ( SELECT - id, - input, - output, + t.id, + t.input, + t.output, groupArray(tuple( fs.entity_id, fs.name, @@ -294,25 +327,144 @@ LEFT JOIN ( fs.value, fs.reason, fs.source - )) as feedback_scores_array - FROM traces + )) AS feedback_scores_array + FROM ( + SELECT + id, + input, + output + FROM traces + WHERE workspace_id = :workspace_id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS t LEFT JOIN ( SELECT - * + entity_id, + name, + category_name, + value, + reason, + source FROM feedback_scores - WHERE entity_type = :entity_type + WHERE entity_type = :entityType AND workspace_id = :workspace_id ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name - ) as fs ON id = fs.entity_id + ) AS fs ON t.id = fs.entity_id GROUP BY - id, - input, - output, - last_updated_at + t.id, + t.input, + t.output + ) AS tfs ON ei.trace_id = tfs.id + GROUP BY + di.id, + di.input, + di.expected_output, + di.metadata, + di.trace_id, + di.span_id, + di.source, + di.created_at, + di.last_updated_at, + di.created_by, + di.last_updated_by + ORDER BY di.id DESC, di.last_updated_at DESC + LIMIT :limit OFFSET :offset + ; + """; + + /** + * Same relationships as the query above, but with two logical changes: + * - Only accepts a single experiment id. + * - Only returns dataset items if there are matching experiment items for the given experiment id. + */ + private static final String SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_FROM_SINGLE_EXPERIMENT = """ + SELECT + di.id AS id, + di.input AS input, + di.expected_output AS expected_output, + di.metadata AS metadata, + di.trace_id AS trace_id, + di.span_id AS span_id, + di.source AS source, + di.created_at AS created_at, + di.last_updated_at AS last_updated_at, + di.created_by AS created_by, + di.last_updated_by AS last_updated_by, + groupArray(tuple( + ei.id, + ei.experiment_id, + ei.dataset_item_id, + ei.trace_id, + tfs.input, + tfs.output, + tfs.feedback_scores_array, + ei.created_at, + ei.last_updated_at, + ei.created_by, + ei.last_updated_by + )) AS experiment_items_array + FROM ( + SELECT + * + FROM dataset_items + WHERE dataset_id = :datasetId + AND workspace_id = :workspace_id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS di + INNER JOIN ( + SELECT + * + FROM experiment_items + WHERE experiment_id = :experimentId + AND workspace_id = :workspace_id ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id - ) as t ON ei.trace_id = t.id + ) AS ei ON di.id = ei.dataset_item_id + LEFT JOIN ( + SELECT + t.id, + t.input, + t.output, + groupArray(tuple( + fs.entity_id, + fs.name, + fs.category_name, + fs.value, + fs.reason, + fs.source + )) AS feedback_scores_array + FROM ( + SELECT + id, + input, + output + FROM traces + WHERE workspace_id = :workspace_id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS t + LEFT JOIN ( + SELECT + entity_id, + name, + category_name, + value, + reason, + source + FROM feedback_scores + WHERE entity_type = :entityType + AND workspace_id = :workspace_id + ORDER BY entity_id DESC, last_updated_at DESC + LIMIT 1 BY entity_id, name + ) AS fs ON t.id = fs.entity_id + GROUP BY + t.id, + t.input, + t.output + ) AS tfs ON ei.trace_id = tfs.id GROUP BY di.id, di.input, @@ -573,7 +725,8 @@ public Mono getItems(@NonNull UUID datasetId, int page, int siz @Override public Mono getItems(@NonNull DatasetItemSearchCriteria datasetItemSearchCriteria, int page, int size) { - + log.info("Finding dataset items with experiment items by '{}', page '{}', size '{}'", + datasetItemSearchCriteria, page, size); return makeMonoContextAware( (userName, workspaceName, workspaceId) -> asyncTemplate @@ -591,9 +744,9 @@ public Mono getItems(@NonNull DatasetItemSearchCriteria dataset SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS) .bind("datasetId", datasetItemSearchCriteria.datasetId()) - .bind("experiment_ids", + .bind("experimentIds", datasetItemSearchCriteria.experimentIds()) - .bind("entity_type", + .bind("entityType", datasetItemSearchCriteria.entityType() .getType()) .bind("workspace_id", workspaceId) @@ -605,4 +758,38 @@ public Mono getItems(@NonNull DatasetItemSearchCriteria dataset .flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count)))))); } + + @Override + public Mono getItemsFromSingleExperiment( + @NonNull DatasetItemSearchCriteria datasetItemSearchCriteria, int page, int size) { + Preconditions.checkArgument(datasetItemSearchCriteria.experimentIds().size() == 1); + log.info("Finding dataset items with experiment items from single experiment by '{}', page '{}', size '{}'", + datasetItemSearchCriteria, page, size); + var experimentId = datasetItemSearchCriteria.experimentIds().stream().toList().getFirst(); + return makeMonoContextAware((userName, workspaceName, workspaceId) -> asyncTemplate.nonTransaction( + connection -> Flux + .from(connection + .createStatement( + SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_FROM_SINGLE_EXPERIMENT_COUNT) + .bind("datasetId", datasetItemSearchCriteria.datasetId()) + .bind("experimentId", experimentId) + .bind("workspace_id", workspaceId) + .execute()) + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class))) + .reduce(0L, Long::sum) + .flatMap(count -> Flux + .from(connection + .createStatement( + SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_FROM_SINGLE_EXPERIMENT) + .bind("datasetId", datasetItemSearchCriteria.datasetId()) + .bind("experimentId", experimentId) + .bind("entityType", datasetItemSearchCriteria.entityType().getType()) + .bind("workspace_id", workspaceId) + .bind("limit", size) + .bind("offset", (page - 1) * size) + .execute()) + .flatMap(this::mapItem) + .collectList() + .flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count)))))); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java index 46047d7545..b9c7072ac7 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java @@ -16,6 +16,7 @@ import jakarta.ws.rs.core.Response; import lombok.NonNull; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -46,6 +47,7 @@ public interface DatasetItemService { @Singleton @RequiredArgsConstructor(onConstructor_ = @Inject) +@Slf4j class DatasetItemServiceImpl implements DatasetItemService { private final @NonNull DatasetItemDAO dao; @@ -55,7 +57,6 @@ class DatasetItemServiceImpl implements DatasetItemService { @Override public Mono save(@NonNull DatasetItemBatch batch) { - if (batch.datasetId() == null && batch.datasetName() == null) { return Mono.error(failWithError("dataset_id or dataset_name must be provided")); } @@ -215,6 +216,14 @@ public Mono getItems(@NonNull UUID datasetId, int page, int siz @Override public Mono getItems( int page, int size, @NonNull DatasetItemSearchCriteria datasetItemSearchCriteria) { - return dao.getItems(datasetItemSearchCriteria, page, size); + if (datasetItemSearchCriteria.experimentIds().size() == 1) { + log.info("Finding dataset items with experiment items from single experiment by '{}', page '{}', size '{}'", + datasetItemSearchCriteria, page, size); + return dao.getItemsFromSingleExperiment(datasetItemSearchCriteria, page, size); + } else { + log.info("Finding dataset items with experiment items by '{}', page '{}', size '{}'", + datasetItemSearchCriteria, page, size); + return dao.getItems(datasetItemSearchCriteria, page, size); + } } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index 0140293c6b..1eee857098 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -3076,9 +3076,9 @@ class FindDatasetItemsWithExperimentItems { @Test void find() { - String workspaceName = UUID.randomUUID().toString(); - String apiKey = UUID.randomUUID().toString(); - String workspaceId = UUID.randomUUID().toString(); + var workspaceName = UUID.randomUUID().toString(); + var apiKey = UUID.randomUUID().toString(); + var workspaceId = UUID.randomUUID().toString(); mockTargetWorkspace(apiKey, workspaceName, workspaceId); @@ -3129,7 +3129,7 @@ void find() { createAndAssert(traceMissingFields, workspaceName, apiKey); // Creating the dataset - Dataset dataset = factory.manufacturePojo(Dataset.class); + var dataset = factory.manufacturePojo(Dataset.class); var datasetId = createAndAssert(dataset, apiKey, workspaceName); // Creating 5 dataset items for the dataset above @@ -3160,7 +3160,7 @@ void find() { .build())) .collect(Collectors.groupingBy(ExperimentItem::datasetItemId)); - // Dataset items 2 covers the case of experiments items related to a trace without input, output and scores. + // Dataset item 2 covers the case of experiments items related to a trace without input, output and scores. // It also has 2 experiment items per each of the 5 experiments. datasetItemIdToExperimentItemMap.put(expectedDatasetItems.get(2).id(), experimentIds.stream() .flatMap(experimentId -> IntStream.range(0, 2) @@ -3174,7 +3174,7 @@ void find() { .build())) .toList()); - // Dataset items 3 covers the case of experiments items related to an un-existing trace id. + // Dataset item 3 covers the case of experiments items related to an un-existing trace id. // It also has 2 experiment items per each of the 5 experiments. datasetItemIdToExperimentItemMap.put(expectedDatasetItems.get(3).id(), experimentIds.stream() .flatMap(experimentId -> IntStream.range(0, 2) @@ -3187,6 +3187,8 @@ void find() { .build())) .toList()); + // Dataset item 4 covers the case of not matching experiment items. + // When storing the experiment items in batch, adding some more unrelated random ones var experimentItemsBatch = factory.manufacturePojo(ExperimentItemsBatch.class); experimentItemsBatch = experimentItemsBatch.toBuilder() @@ -3274,6 +3276,198 @@ void find() { } } + @Test + void findFromSingleExperiment() { + var workspaceName = UUID.randomUUID().toString(); + var apiKey = UUID.randomUUID().toString(); + var workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + // Creating two traces with input, output and scores + var trace1 = factory.manufacturePojo(Trace.class); + createAndAssert(trace1, workspaceName, apiKey); + + var trace2 = factory.manufacturePojo(Trace.class); + createAndAssert(trace2, workspaceName, apiKey); + var traces = List.of(trace1, trace2); + + // Creating 5 scores peach each of the two traces above + var scores1 = PodamFactoryUtils.manufacturePojoList(factory, FeedbackScoreBatchItem.class) + .stream() + .map(feedbackScoreBatchItem -> feedbackScoreBatchItem.toBuilder() + .id(trace1.id()) + .projectName(trace1.projectName()) + .value(factory.manufacturePojo(BigDecimal.class)) + .build()) + .toList(); + + var scores2 = PodamFactoryUtils.manufacturePojoList(factory, FeedbackScoreBatchItem.class) + .stream() + .map(feedbackScoreBatchItem -> feedbackScoreBatchItem.toBuilder() + .id(trace2.id()) + .projectName(trace2.projectName()) + .value(factory.manufacturePojo(BigDecimal.class)) + .build()) + .toList(); + + var traceIdToScoresMap = Stream.concat(scores1.stream(), scores2.stream()) + .collect(Collectors.groupingBy(FeedbackScoreBatchItem::id)); + + // When storing the scores in batch, adding some more unrelated random ones + var feedbackScoreBatch = factory.manufacturePojo(FeedbackScoreBatch.class); + feedbackScoreBatch = feedbackScoreBatch.toBuilder() + .scores(Stream.concat(feedbackScoreBatch.scores().stream(), + traceIdToScoresMap.values().stream().flatMap(List::stream)).toList()) + .build(); + + createScoreAndAssert(feedbackScoreBatch, apiKey, workspaceName); + + // Creating a trace without input, output and scores + var traceMissingFields = factory.manufacturePojo(Trace.class).toBuilder() + .input(null) + .output(null) + .build(); + createAndAssert(traceMissingFields, workspaceName, apiKey); + + // Creating the dataset + var dataset = factory.manufacturePojo(Dataset.class); + var datasetId = createAndAssert(dataset, apiKey, workspaceName); + + // Creating 5 dataset items for the dataset above + var datasetItemBatch = factory.manufacturePojo(DatasetItemBatch.class).toBuilder() + .datasetId(datasetId) + .build(); + + putAndAssert(datasetItemBatch, workspaceName, apiKey); + + // Creating 5 different experiment ids + var expectedDatasetItems = datasetItemBatch.items().subList(0, 4).reversed(); + var experimentIds = IntStream.range(0, 5).mapToObj(__ -> GENERATOR.generate()).toList(); + + // Dataset items 0 and 1 cover the general case. + // Per each dataset item there are 10 experiment items, so 2 experiment items per each of the 5 experiments. + // The first 5 experiment items are related to trace 1, the other 5 to trace 2. + var datasetItemIdToExperimentItemMap = expectedDatasetItems.subList(0, 2).stream() + .flatMap(datasetItem -> IntStream.range(0, 10) + .mapToObj(i -> factory.manufacturePojo(ExperimentItem.class).toBuilder() + .experimentId(experimentIds.get(i / 2)) + .datasetItemId(datasetItem.id()) + .traceId(traces.get(i / 5).id()) + .input(traces.get(i / 5).input()) + .output(traces.get(i / 5).output()) + .feedbackScores(traceIdToScoresMap.get(traces.get(i / 5).id()).stream() + .map(FeedbackScoreMapper.INSTANCE::toFeedbackScore) + .toList()) + .build())) + .collect(Collectors.groupingBy(ExperimentItem::datasetItemId)); + + // Dataset item 2 covers the case of experiments items related to a trace without input, output and scores. + // It also has 2 experiment items per each of the 5 experiments. + datasetItemIdToExperimentItemMap.put(expectedDatasetItems.get(2).id(), experimentIds.stream() + .flatMap(experimentId -> IntStream.range(0, 2) + .mapToObj(i -> factory.manufacturePojo(ExperimentItem.class).toBuilder() + .experimentId(experimentId) + .datasetItemId(expectedDatasetItems.get(2).id()) + .traceId(traceMissingFields.id()) + .input(traceMissingFields.input()) + .output(traceMissingFields.output()) + .feedbackScores(null) + .build())) + .toList()); + + // Dataset item 3 covers the case of experiments items related to an un-existing trace id. + // It also has 2 experiment items per each of the 5 experiments. + datasetItemIdToExperimentItemMap.put(expectedDatasetItems.get(3).id(), experimentIds.stream() + .flatMap(experimentId -> IntStream.range(0, 2) + .mapToObj(i -> factory.manufacturePojo(ExperimentItem.class).toBuilder() + .experimentId(experimentId) + .datasetItemId(expectedDatasetItems.get(3).id()) + .input(null) + .output(null) + .feedbackScores(null) + .build())) + .toList()); + + // Dataset item 4 covers the case of not matching experiment items. + + // When storing the experiment items in batch, adding some more unrelated random ones + var experimentItemsBatch = factory.manufacturePojo(ExperimentItemsBatch.class); + experimentItemsBatch = experimentItemsBatch.toBuilder() + .experimentItems(Stream.concat(experimentItemsBatch.experimentItems().stream(), + datasetItemIdToExperimentItemMap.values().stream().flatMap(Collection::stream)) + .collect(Collectors.toUnmodifiableSet())) + .build(); + createAndAssert(experimentItemsBatch, apiKey, workspaceName); + + var page = 1; + var pageSize = 5; + // Filtering by experiments 1. + var experimentIdsQueryParm = JsonUtils.writeValueAsString(List.of(experimentIds.get(1))); + + try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path(datasetId.toString()) + .path(DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_PATH) + .queryParam("page", page) + .queryParam("size", pageSize) + .queryParam("experiment_ids", experimentIdsQueryParm) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + var actualPage = actualResponse.readEntity(DatasetItemPage.class); + + assertThat(actualPage.page()).isEqualTo(page); + assertThat(actualPage.size()).isEqualTo(expectedDatasetItems.size()); + assertThat(actualPage.total()).isEqualTo(expectedDatasetItems.size()); + + var actualDatasetItems = actualPage.content(); + + assertThat(actualDatasetItems) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_DATA_ITEM) + .containsExactlyElementsOf(expectedDatasetItems); + + for (var i = 0; i < actualDatasetItems.size(); i++) { + var actualDatasetItem = actualDatasetItems.get(i); + var expectedDatasetItem = expectedDatasetItems.get(i); + + // Filtering by those related to experiments 1 and 3 + var experimentItems = datasetItemIdToExperimentItemMap.get(expectedDatasetItem.id()); + var expectedExperimentItems = List.of(experimentItems.get(2), experimentItems.get(3)).reversed(); + + assertThat(actualDatasetItem.experimentItems()) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_LIST) + .containsExactlyElementsOf(expectedExperimentItems); + + for (var j = 0; j < actualDatasetItem.experimentItems().size(); j++) { + var actualExperimentItem = actualDatasetItem.experimentItems().get(j); + var expectedExperimentItem = expectedExperimentItems.get(j); + + assertThat(actualExperimentItem.feedbackScores()) + .usingRecursiveComparison() + .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .ignoringCollectionOrder() + .isEqualTo(expectedExperimentItem.feedbackScores()); + + assertThat(actualExperimentItem.createdAt()) + .isAfter(expectedExperimentItem.createdAt()); + assertThat(actualExperimentItem.lastUpdatedAt()) + .isAfter(expectedExperimentItem.lastUpdatedAt()); + + assertThat(actualExperimentItem.createdBy()) + .isEqualTo(USER); + assertThat(actualExperimentItem.lastUpdatedBy()) + .isEqualTo(USER); + } + + assertThat(actualDatasetItem.createdAt()).isAfter(expectedDatasetItem.createdAt()); + assertThat(actualDatasetItem.lastUpdatedAt()).isAfter(expectedDatasetItem.lastUpdatedAt()); + } + } + } + @ParameterizedTest @ValueSource(strings = {"[wrong_payload]", "[0191377d-06ee-7026-8f63-cc5309d1f54b]"}) void findInvalidExperimentIds(String experimentIds) {