Skip to content

Commit

Permalink
[OPIK-419] add last updated trace at to projects (#849)
Browse files Browse the repository at this point in the history
* OPIK-419 insert single and batch traces

* OPIK-419 null timestamp

* OPIK-419 create single and batch

* OPIK-419 coverage create single trace

* OPIK-419 add a couple of todos to not forget

* OPIK-419 refactor

* OPIK-419 refactor

* OPIK-419 cover batch

* OPIK-419 update trace failing test

* OPIK-419 update trace failing test green

* OPIK-419 remove comment

* OPIK-419 fix broken test

* OPIK-419 projectdao not public

* OPIK-419 pr comments

* OPIK-419 cover new functionality in service test

* OPIK-419 handle async update to last_trace_updated_at

* OPIK-419 default null
  • Loading branch information
idoberko2 authored Dec 11, 2024
1 parent 10b558e commit 2d07e60
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.comet.opik.api.events;

import com.comet.opik.infrastructure.events.BaseEvent;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.Accessors;

import java.util.Set;
import java.util.UUID;

@Getter
@Accessors(fluent = true)
public class TracesCreated extends BaseEvent {
private final @NonNull Set<UUID> projectIds;

public TracesCreated(@NonNull Set<UUID> projectIds, @NonNull String workspaceId, @NonNull String userName) {
super(workspaceId, userName);
this.projectIds = projectIds;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.comet.opik.api.events;

import com.comet.opik.infrastructure.events.BaseEvent;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.Accessors;

import java.util.Set;
import java.util.UUID;

@Getter
@Accessors(fluent = true)
public class TracesUpdated extends BaseEvent {
private final @NonNull Set<UUID> projectIds;

public TracesUpdated(@NonNull Set<UUID> projectIds, @NonNull String workspaceId, @NonNull String userName) {
super(workspaceId, userName);
this.projectIds = projectIds;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.comet.opik.api.resources.v1.events;

import com.comet.opik.api.ProjectIdLastUpdated;
import com.comet.opik.api.events.TracesCreated;
import com.comet.opik.api.events.TracesUpdated;
import com.comet.opik.domain.ProjectService;
import com.comet.opik.domain.TraceService;
import com.google.common.eventbus.EventBus;
import com.google.common.eventbus.Subscribe;
import jakarta.inject.Inject;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;
import ru.vyarus.dropwizard.guice.module.installer.feature.eager.EagerSingleton;

import java.util.Set;
import java.util.UUID;

@EagerSingleton
@Slf4j
public class ProjectEventListener {
private final ProjectService projectService;
private final TraceService traceService;

@Inject
public ProjectEventListener(EventBus eventBus, ProjectService projectService, TraceService traceService) {
this.projectService = projectService;
this.traceService = traceService;
eventBus.register(this);
}

@Subscribe
public void onTracesCreated(TracesCreated event) {
updateProjectsLastUpdatedTraceAt(event.workspaceId(), event.projectIds());
}

@Subscribe
public void onTracesUpdated(TracesUpdated event) {
updateProjectsLastUpdatedTraceAt(event.workspaceId(), event.projectIds());
}

private void updateProjectsLastUpdatedTraceAt(String workspaceId, Set<UUID> projectIds) {
log.info("Recording last traces for projects '{}'", projectIds);

traceService.getLastUpdatedTraceAt(projectIds, workspaceId)
.flatMap(lastTraceByProjectId -> Mono.fromRunnable(() -> projectService.recordLastUpdatedTrace(
workspaceId,
lastTraceByProjectId.entrySet().stream()
.map(entry -> new ProjectIdLastUpdated(entry.getKey(), entry.getValue())).toList())))
.block();

log.info("Recorded last traces for projects '{}'", projectIds);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.jdbi.v3.sqlobject.customizer.BindList;
import org.jdbi.v3.sqlobject.customizer.BindMethods;
import org.jdbi.v3.sqlobject.customizer.Define;
import org.jdbi.v3.sqlobject.statement.SqlBatch;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
import org.jdbi.v3.sqlobject.statement.SqlUpdate;
import org.jdbi.v3.stringtemplate4.UseStringTemplateEngine;
Expand Down Expand Up @@ -85,4 +86,11 @@ default Optional<Project> fetch(UUID id, String workspaceId) {

@SqlQuery("SELECT * FROM projects WHERE workspace_id = :workspaceId AND name IN (<names>)")
List<Project> findByNames(@Bind("workspaceId") String workspaceId, @BindList("names") Collection<String> names);

@SqlBatch("UPDATE projects SET last_updated_trace_at = :lastUpdatedAt " +
"WHERE workspace_id = :workspace_id" +
" AND id = :id" +
" AND (last_updated_trace_at IS NULL OR last_updated_trace_at < :lastUpdatedAt)")
int[] recordLastUpdatedTrace(@Bind("workspace_id") String workspaceId,
@BindMethods Collection<ProjectIdLastUpdated> lastUpdatedTraces);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.sql.SQLIntegrityConstraintViolationException;
import java.time.Instant;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -68,11 +69,15 @@ public interface ProjectService {

Page<Project> find(int page, int size, ProjectCriteria criteria, List<SortingField> sortingFields);

List<Project> findByIds(String workspaceId, Set<UUID> ids);

List<Project> findByNames(String workspaceId, List<String> names);

Project getOrCreate(String workspaceId, String projectName, String userName);

Project retrieveByName(String projectName);

void recordLastUpdatedTrace(String workspaceId, Collection<ProjectIdLastUpdated> lastUpdatedTraces);
}

@Slf4j
Expand Down Expand Up @@ -277,6 +282,16 @@ public Page<Project> find(int page, int size, @NonNull ProjectCriteria criteria,
sortingFactory.getSortableFields());
}

@Override
public List<Project> findByIds(String workspaceId, Set<UUID> ids) {
if (ids.isEmpty()) {
log.info("ids list is empty, returning");
return List.of();
}

return template.inTransaction(READ_ONLY, handle -> handle.attach(ProjectDAO.class).findByIds(ids, workspaceId));
}

private Page<Project> findWithLastTraceSorting(int page, int size, @NonNull ProjectCriteria criteria,
@NonNull SortingField sortingField) {
String workspaceId = requestContext.get().getWorkspaceId();
Expand Down Expand Up @@ -392,4 +407,10 @@ public Project retrieveByName(@NonNull String projectName) {
});
}

@Override
public void recordLastUpdatedTrace(String workspaceId, Collection<ProjectIdLastUpdated> lastUpdatedTraces) {
template.inTransaction(WRITE,
handle -> handle.attach(ProjectDAO.class).recordLastUpdatedTrace(workspaceId, lastUpdatedTraces));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.error.IdentifierMismatchException;
import com.comet.opik.api.events.TracesCreated;
import com.comet.opik.api.events.TracesUpdated;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.infrastructure.lock.LockService;
import com.comet.opik.utils.AsyncUtils;
import com.comet.opik.utils.WorkspaceUtils;
import com.google.common.base.Preconditions;
import com.google.common.eventbus.EventBus;
import com.google.inject.ImplementedBy;
import io.opentelemetry.instrumentation.annotations.WithSpan;
import jakarta.inject.Inject;
Expand Down Expand Up @@ -68,6 +71,8 @@ public interface TraceService {
Mono<ProjectStats> getStats(TraceSearchCriteria searchCriteria);

Mono<Long> getDailyCreatedCount();

Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(Set<UUID> projectIds, String workspaceId);
}

@Slf4j
Expand All @@ -85,6 +90,7 @@ class TraceServiceImpl implements TraceService {
private final @NonNull ProjectService projectService;
private final @NonNull IdGenerator idGenerator;
private final @NonNull LockService lockService;
private final @NonNull EventBus eventBus;

@Override
@WithSpan
Expand All @@ -93,12 +99,16 @@ public Mono<UUID> create(@NonNull Trace trace) {
String projectName = WorkspaceUtils.getProjectName(trace.projectName());
UUID id = trace.id() == null ? idGenerator.generateId() : trace.id();

return IdGenerator
return Mono.deferContextual(ctx -> IdGenerator
.validateVersionAsync(id, TRACE_KEY)
.then(Mono.defer(() -> getOrCreateProject(projectName)))
.flatMap(project -> lockService.executeWithLock(
new LockService.Lock(id, TRACE_KEY),
Mono.defer(() -> insertTrace(trace, project, id))));
Mono.defer(() -> insertTrace(trace, project, id)))
.doOnSuccess(__ -> eventBus.post(new TracesCreated(
Set.of(project.id()),
ctx.get(RequestContext.WORKSPACE_ID),
ctx.get(RequestContext.USER_NAME))))));
}

@WithSpan
Expand All @@ -113,14 +123,20 @@ public Mono<Long> create(TraceBatch batch) {
.distinct()
.toList();

Mono<List<Trace>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::getOrCreateProject)
.collectList()
.map(projects -> bindTraceToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());
return Mono.deferContextual(ctx -> {
Mono<List<Trace>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::getOrCreateProject)
.collectList()
.map(projects -> bindTraceToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());

return resolveProjects
.flatMap(traces -> template.nonTransaction(connection -> dao.batchInsert(traces, connection)));
return resolveProjects
.flatMap(traces -> template.nonTransaction(connection -> dao.batchInsert(traces, connection))
.doOnSuccess(__ -> eventBus.post(new TracesCreated(
traces.stream().map(Trace::projectId).collect(Collectors.toUnmodifiableSet()),
ctx.get(RequestContext.WORKSPACE_ID),
ctx.get(RequestContext.USER_NAME)))));
});
}

private List<Trace> bindTraceToProjectAndId(TraceBatch batch, List<Project> projects) {
Expand Down Expand Up @@ -223,7 +239,7 @@ public Mono<Void> update(@NonNull TraceUpdate traceUpdate, @NonNull UUID id) {

var projectName = WorkspaceUtils.getProjectName(traceUpdate.projectName());

return getProjectById(traceUpdate)
return Mono.deferContextual(ctx -> getProjectById(traceUpdate)
.switchIfEmpty(Mono.defer(() -> getOrCreateProject(projectName)))
.subscribeOn(Schedulers.boundedElastic())
.flatMap(project -> lockService.executeWithLock(
Expand All @@ -232,8 +248,12 @@ public Mono<Void> update(@NonNull TraceUpdate traceUpdate, @NonNull UUID id) {
.flatMap(trace -> updateOrFail(traceUpdate, id, trace, project).thenReturn(id))
.switchIfEmpty(Mono.defer(() -> insertUpdate(project, traceUpdate, id))
.thenReturn(id))
.onErrorResume(this::handleDBError))))
.then();
.onErrorResume(this::handleDBError)
.doOnSuccess(__ -> eventBus.post(new TracesUpdated(
Set.of(project.id()),
ctx.get(RequestContext.WORKSPACE_ID),
ctx.get(RequestContext.USER_NAME)))))))
.then());
}

private Mono<Void> insertUpdate(Project project, TraceUpdate traceUpdate, UUID id) {
Expand Down Expand Up @@ -373,4 +393,9 @@ public Mono<Long> getDailyCreatedCount() {
return dao.getDailyTraces();
}

@Override
public Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(Set<UUID> projectIds, String workspaceId) {
return template
.nonTransaction(connection -> dao.getLastUpdatedTraceAt(projectIds, workspaceId, connection));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--liquibase formatted sql
--changeset idoberko2:add_projects_last_updated_trace_at

ALTER TABLE projects ADD COLUMN last_updated_trace_at TIMESTAMP(6) DEFAULT NULL;

--rollback ALTER TABLE projects DROP COLUMN last_updated_trace_at;
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.comet.opik;

import java.time.Instant;

public class TestComparators {
public static int compareMicroNanoTime(Instant i1, Instant i2) {
// Calculate the difference in nanoseconds
long nanoDifference = Math.abs(i1.getNano() - i2.getNano());
if (nanoDifference < 1_000) {
return 0; // Consider equal if within a microsecond
}
return i1.compareTo(i2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.api.resources.utils.TestUtils;
import jakarta.ws.rs.HttpMethod;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
Expand Down Expand Up @@ -126,4 +128,16 @@ public void deleteTraces(BatchDelete request, String workspaceName, String apiKe
}
}

public void updateTrace(UUID id, TraceUpdate traceUpdate, String apiKey, String workspaceName) {
try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI))
.path(id.toString())
.request()
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(WORKSPACE_HEADER, workspaceName)
.method(HttpMethod.PATCH, Entity.json(traceUpdate))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204);
assertThat(actualResponse.hasEntity()).isFalse();
}
}
}
Loading

0 comments on commit 2d07e60

Please sign in to comment.