Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-419] add last updated trace at to projects #849

Merged
merged 17 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.comet.opik.api.events;

import com.comet.opik.infrastructure.events.BaseEvent;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.Accessors;

import java.util.Set;
import java.util.UUID;

@Getter
@Accessors(fluent = true)
public class TracesCreated extends BaseEvent {
private final @NonNull Set<UUID> projectIds;

public TracesCreated(@NonNull Set<UUID> projectIds, @NonNull String workspaceId, @NonNull String userName) {
super(workspaceId, userName);
this.projectIds = projectIds;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.comet.opik.api.events;

import com.comet.opik.infrastructure.events.BaseEvent;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.Accessors;

import java.util.Set;
import java.util.UUID;

@Getter
@Accessors(fluent = true)
public class TracesUpdated extends BaseEvent {
private final @NonNull Set<UUID> projectIds;

public TracesUpdated(@NonNull Set<UUID> projectIds, @NonNull String workspaceId, @NonNull String userName) {
super(workspaceId, userName);
this.projectIds = projectIds;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.comet.opik.api.resources.v1.events;

import com.comet.opik.api.ProjectIdLastUpdated;
import com.comet.opik.api.events.TracesCreated;
import com.comet.opik.api.events.TracesUpdated;
import com.comet.opik.domain.ProjectService;
import com.comet.opik.domain.TraceService;
import com.google.common.eventbus.EventBus;
import com.google.common.eventbus.Subscribe;
import jakarta.inject.Inject;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;
import ru.vyarus.dropwizard.guice.module.installer.feature.eager.EagerSingleton;

import java.util.Set;
import java.util.UUID;

@EagerSingleton
@Slf4j
public class ProjectEventListener {
private final ProjectService projectService;
private final TraceService traceService;

@Inject
public ProjectEventListener(EventBus eventBus, ProjectService projectService, TraceService traceService) {
this.projectService = projectService;
this.traceService = traceService;
eventBus.register(this);
}

@Subscribe
public void onTracesCreated(TracesCreated event) {
updateProjectsLastUpdatedTraceAt(event.workspaceId(), event.projectIds());
}

@Subscribe
public void onTracesUpdated(TracesUpdated event) {
updateProjectsLastUpdatedTraceAt(event.workspaceId(), event.projectIds());
}

private void updateProjectsLastUpdatedTraceAt(String workspaceId, Set<UUID> projectIds) {
log.info("Recording last traces for projects '{}'", projectIds);

traceService.getLastUpdatedTraceAt(projectIds, workspaceId)
.flatMap(lastTraceByProjectId -> Mono.fromRunnable(() -> projectService.recordLastUpdatedTrace(
workspaceId,
lastTraceByProjectId.entrySet().stream()
.map(entry -> new ProjectIdLastUpdated(entry.getKey(), entry.getValue())).toList())))
.block();

log.info("Recorded last traces for projects '{}'", projectIds);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.jdbi.v3.sqlobject.customizer.BindList;
import org.jdbi.v3.sqlobject.customizer.BindMethods;
import org.jdbi.v3.sqlobject.customizer.Define;
import org.jdbi.v3.sqlobject.statement.SqlBatch;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
import org.jdbi.v3.sqlobject.statement.SqlUpdate;
import org.jdbi.v3.stringtemplate4.UseStringTemplateEngine;
Expand Down Expand Up @@ -85,4 +86,11 @@ default Optional<Project> fetch(UUID id, String workspaceId) {

@SqlQuery("SELECT * FROM projects WHERE workspace_id = :workspaceId AND name IN (<names>)")
List<Project> findByNames(@Bind("workspaceId") String workspaceId, @BindList("names") Collection<String> names);

@SqlBatch("UPDATE projects SET last_updated_trace_at = :lastUpdatedAt " +
"WHERE workspace_id = :workspace_id" +
" AND id = :id" +
" AND (last_updated_trace_at IS NULL OR last_updated_trace_at < :lastUpdatedAt)")
int[] recordLastUpdatedTrace(@Bind("workspace_id") String workspaceId,
@BindMethods Collection<ProjectIdLastUpdated> lastUpdatedTraces);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.sql.SQLIntegrityConstraintViolationException;
import java.time.Instant;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -68,11 +69,15 @@ public interface ProjectService {

Page<Project> find(int page, int size, ProjectCriteria criteria, List<SortingField> sortingFields);

List<Project> findByIds(String workspaceId, Set<UUID> ids);

List<Project> findByNames(String workspaceId, List<String> names);

Project getOrCreate(String workspaceId, String projectName, String userName);

Project retrieveByName(String projectName);

void recordLastUpdatedTrace(String workspaceId, Collection<ProjectIdLastUpdated> lastUpdatedTraces);
}

@Slf4j
Expand Down Expand Up @@ -277,6 +282,16 @@ public Page<Project> find(int page, int size, @NonNull ProjectCriteria criteria,
sortingFactory.getSortableFields());
}

@Override
public List<Project> findByIds(String workspaceId, Set<UUID> ids) {
if (ids.isEmpty()) {
log.info("ids list is empty, returning");
return List.of();
}

return template.inTransaction(READ_ONLY, handle -> handle.attach(ProjectDAO.class).findByIds(ids, workspaceId));
}

private Page<Project> findWithLastTraceSorting(int page, int size, @NonNull ProjectCriteria criteria,
@NonNull SortingField sortingField) {
String workspaceId = requestContext.get().getWorkspaceId();
Expand Down Expand Up @@ -392,4 +407,10 @@ public Project retrieveByName(@NonNull String projectName) {
});
}

@Override
public void recordLastUpdatedTrace(String workspaceId, Collection<ProjectIdLastUpdated> lastUpdatedTraces) {
template.inTransaction(WRITE,
handle -> handle.attach(ProjectDAO.class).recordLastUpdatedTrace(workspaceId, lastUpdatedTraces));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.error.IdentifierMismatchException;
import com.comet.opik.api.events.TracesCreated;
import com.comet.opik.api.events.TracesUpdated;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.infrastructure.lock.LockService;
import com.comet.opik.utils.AsyncUtils;
import com.comet.opik.utils.WorkspaceUtils;
import com.google.common.base.Preconditions;
import com.google.common.eventbus.EventBus;
import com.google.inject.ImplementedBy;
import io.opentelemetry.instrumentation.annotations.WithSpan;
import jakarta.inject.Inject;
Expand Down Expand Up @@ -68,6 +71,8 @@ public interface TraceService {
Mono<ProjectStats> getStats(TraceSearchCriteria searchCriteria);

Mono<Long> getDailyCreatedCount();

Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(Set<UUID> projectIds, String workspaceId);
}

@Slf4j
Expand All @@ -85,6 +90,7 @@ class TraceServiceImpl implements TraceService {
private final @NonNull ProjectService projectService;
private final @NonNull IdGenerator idGenerator;
private final @NonNull LockService lockService;
private final @NonNull EventBus eventBus;

@Override
@WithSpan
Expand All @@ -93,12 +99,16 @@ public Mono<UUID> create(@NonNull Trace trace) {
String projectName = WorkspaceUtils.getProjectName(trace.projectName());
UUID id = trace.id() == null ? idGenerator.generateId() : trace.id();

return IdGenerator
return Mono.deferContextual(ctx -> IdGenerator
.validateVersionAsync(id, TRACE_KEY)
.then(Mono.defer(() -> getOrCreateProject(projectName)))
.flatMap(project -> lockService.executeWithLock(
new LockService.Lock(id, TRACE_KEY),
Mono.defer(() -> insertTrace(trace, project, id))));
Mono.defer(() -> insertTrace(trace, project, id)))
.doOnSuccess(__ -> eventBus.post(new TracesCreated(
Set.of(project.id()),
ctx.get(RequestContext.WORKSPACE_ID),
ctx.get(RequestContext.USER_NAME))))));
}

@WithSpan
Expand All @@ -113,14 +123,20 @@ public Mono<Long> create(TraceBatch batch) {
.distinct()
.toList();

Mono<List<Trace>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::getOrCreateProject)
.collectList()
.map(projects -> bindTraceToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());
return Mono.deferContextual(ctx -> {
Mono<List<Trace>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::getOrCreateProject)
.collectList()
.map(projects -> bindTraceToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());

return resolveProjects
.flatMap(traces -> template.nonTransaction(connection -> dao.batchInsert(traces, connection)));
return resolveProjects
.flatMap(traces -> template.nonTransaction(connection -> dao.batchInsert(traces, connection))
.doOnSuccess(__ -> eventBus.post(new TracesCreated(
traces.stream().map(Trace::projectId).collect(Collectors.toUnmodifiableSet()),
ctx.get(RequestContext.WORKSPACE_ID),
ctx.get(RequestContext.USER_NAME)))));
});
}

private List<Trace> bindTraceToProjectAndId(TraceBatch batch, List<Project> projects) {
Expand Down Expand Up @@ -223,7 +239,7 @@ public Mono<Void> update(@NonNull TraceUpdate traceUpdate, @NonNull UUID id) {

var projectName = WorkspaceUtils.getProjectName(traceUpdate.projectName());

return getProjectById(traceUpdate)
return Mono.deferContextual(ctx -> getProjectById(traceUpdate)
.switchIfEmpty(Mono.defer(() -> getOrCreateProject(projectName)))
.subscribeOn(Schedulers.boundedElastic())
.flatMap(project -> lockService.executeWithLock(
Expand All @@ -232,8 +248,12 @@ public Mono<Void> update(@NonNull TraceUpdate traceUpdate, @NonNull UUID id) {
.flatMap(trace -> updateOrFail(traceUpdate, id, trace, project).thenReturn(id))
.switchIfEmpty(Mono.defer(() -> insertUpdate(project, traceUpdate, id))
.thenReturn(id))
.onErrorResume(this::handleDBError))))
.then();
.onErrorResume(this::handleDBError)
.doOnSuccess(__ -> eventBus.post(new TracesUpdated(
Set.of(project.id()),
ctx.get(RequestContext.WORKSPACE_ID),
ctx.get(RequestContext.USER_NAME)))))))
.then());
}

private Mono<Void> insertUpdate(Project project, TraceUpdate traceUpdate, UUID id) {
Expand Down Expand Up @@ -373,4 +393,9 @@ public Mono<Long> getDailyCreatedCount() {
return dao.getDailyTraces();
}

@Override
public Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(Set<UUID> projectIds, String workspaceId) {
return template
.nonTransaction(connection -> dao.getLastUpdatedTraceAt(projectIds, workspaceId, connection));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--liquibase formatted sql
--changeset idoberko2:add_projects_last_updated_trace_at

ALTER TABLE projects ADD COLUMN last_updated_trace_at TIMESTAMP(6);
idoberko2 marked this conversation as resolved.
Show resolved Hide resolved

--rollback ALTER TABLE projects DROP COLUMN last_updated_trace_at;
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.comet.opik;

import java.time.Instant;

public class TestComparators {
public static int compareMicroNanoTime(Instant i1, Instant i2) {
// Calculate the difference in nanoseconds
long nanoDifference = Math.abs(i1.getNano() - i2.getNano());
if (nanoDifference < 1_000) {
return 0; // Consider equal if within a microsecond
}
return i1.compareTo(i2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.api.resources.utils.TestUtils;
import jakarta.ws.rs.HttpMethod;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
Expand Down Expand Up @@ -126,4 +128,16 @@ public void deleteTraces(BatchDelete request, String workspaceName, String apiKe
}
}

public void updateTrace(UUID id, TraceUpdate traceUpdate, String apiKey, String workspaceName) {
try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI))
.path(id.toString())
.request()
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(WORKSPACE_HEADER, workspaceName)
.method(HttpMethod.PATCH, Entity.json(traceUpdate))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204);
assertThat(actualResponse.hasEntity()).isFalse();
}
}
}
Loading
Loading