diff --git a/internal-packages/run-engine/src/engine/systems/checkpointSystem.ts b/internal-packages/run-engine/src/engine/systems/checkpointSystem.ts index c6586643ba..91a0ad0c9f 100644 --- a/internal-packages/run-engine/src/engine/systems/checkpointSystem.ts +++ b/internal-packages/run-engine/src/engine/systems/checkpointSystem.ts @@ -283,6 +283,7 @@ export class CheckpointSystem { environmentType: snapshot.environmentType, projectId: snapshot.projectId, organizationId: snapshot.organizationId, + batchId: snapshot.batchId ?? undefined, completedWaitpoints: snapshot.completedWaitpoints, workerId, runnerId, diff --git a/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts b/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts index c49484b3ca..912c9e2d25 100644 --- a/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts +++ b/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts @@ -374,6 +374,7 @@ export class DequeueSystem { projectId: snapshot.projectId, organizationId: snapshot.organizationId, checkpointId: snapshot.checkpointId ?? undefined, + batchId: snapshot.batchId ?? undefined, completedWaitpoints: snapshot.completedWaitpoints, workerId, runnerId, diff --git a/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts b/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts index 6d5c9028f6..9827d7ec1d 100644 --- a/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts +++ b/internal-packages/run-engine/src/engine/systems/runAttemptSystem.ts @@ -218,6 +218,8 @@ export class RunAttemptSystem { environmentType: latestSnapshot.environmentType, projectId: latestSnapshot.projectId, organizationId: latestSnapshot.organizationId, + batchId: latestSnapshot.batchId ?? undefined, + completedWaitpoints: latestSnapshot.completedWaitpoints, workerId, runnerId, }); diff --git a/internal-packages/run-engine/src/engine/tests/checkpoints.test.ts b/internal-packages/run-engine/src/engine/tests/checkpoints.test.ts index cb1a56d29a..01d46986f6 100644 --- a/internal-packages/run-engine/src/engine/tests/checkpoints.test.ts +++ b/internal-packages/run-engine/src/engine/tests/checkpoints.test.ts @@ -1,4 +1,3 @@ -//todo checkpoint tests import { containerTest, assertNonNullable } from "@internal/testcontainers"; import { trace } from "@internal/tracing"; import { expect } from "vitest"; @@ -6,6 +5,7 @@ import { RunEngine } from "../index.js"; import { setTimeout } from "node:timers/promises"; import { EventBusEventArgs } from "../eventBus.js"; import { setupAuthenticatedEnvironment, setupBackgroundWorker } from "./setup.js"; +import { generateFriendlyId } from "@trigger.dev/core/v3/isomorphic"; vi.setConfig({ testTimeout: 60_000 }); @@ -983,4 +983,406 @@ describe("RunEngine checkpoints", () => { } } ); + + containerTest("batchTriggerAndWait resume after checkpoint", async ({ prisma, redisOptions }) => { + //create environment + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + + const engine = new RunEngine({ + prisma, + worker: { + redis: redisOptions, + workers: 1, + tasksPerWorker: 10, + pollIntervalMs: 20, + }, + queue: { + redis: redisOptions, + }, + runLock: { + redis: redisOptions, + }, + machines: { + defaultMachine: "small-1x", + machines: { + "small-1x": { + name: "small-1x" as const, + cpu: 0.5, + memory: 0.5, + centsPerMs: 0.0001, + }, + }, + baseCostInCents: 0.0001, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + try { + const parentTask = "parent-task"; + const childTask = "child-task"; + + //create background worker + await setupBackgroundWorker(engine, authenticatedEnvironment, [parentTask, childTask]); + + //create a batch + const batch = await prisma.batchTaskRun.create({ + data: { + friendlyId: generateFriendlyId("batch"), + runtimeEnvironmentId: authenticatedEnvironment.id, + }, + }); + + //trigger the run + const parentRun = await engine.trigger( + { + number: 1, + friendlyId: "run_p1234", + environment: authenticatedEnvironment, + taskIdentifier: parentTask, + payload: "{}", + payloadType: "application/json", + context: {}, + traceContext: {}, + traceId: "t12345", + spanId: "s12345", + masterQueue: "main", + queue: `task/${parentTask}`, + isTest: false, + tags: [], + }, + prisma + ); + + //dequeue parent + const dequeued = await engine.dequeueFromMasterQueue({ + consumerId: "test_12345", + masterQueue: parentRun.masterQueue, + maxRunCount: 10, + }); + + //create an attempt + const initialExecutionData = await engine.getRunExecutionData({ runId: parentRun.id }); + assertNonNullable(initialExecutionData); + const attemptResult = await engine.startRunAttempt({ + runId: parentRun.id, + snapshotId: initialExecutionData.snapshot.id, + }); + + //block using the batch + await engine.blockRunWithCreatedBatch({ + runId: parentRun.id, + batchId: batch.id, + environmentId: authenticatedEnvironment.id, + projectId: authenticatedEnvironment.projectId, + organizationId: authenticatedEnvironment.organizationId, + }); + + const afterBlockedByBatch = await engine.getRunExecutionData({ runId: parentRun.id }); + assertNonNullable(afterBlockedByBatch); + expect(afterBlockedByBatch.snapshot.executionStatus).toBe("EXECUTING_WITH_WAITPOINTS"); + + const child1 = await engine.trigger( + { + number: 1, + friendlyId: "run_c1234", + environment: authenticatedEnvironment, + taskIdentifier: childTask, + payload: "{}", + payloadType: "application/json", + context: {}, + traceContext: {}, + traceId: "t12345", + spanId: "s12345", + masterQueue: "main", + queue: `task/${childTask}`, + isTest: false, + tags: [], + resumeParentOnCompletion: true, + parentTaskRunId: parentRun.id, + batch: { id: batch.id, index: 0 }, + }, + prisma + ); + + const parentAfterChild1 = await engine.getRunExecutionData({ runId: parentRun.id }); + assertNonNullable(parentAfterChild1); + expect(parentAfterChild1.snapshot.executionStatus).toBe("EXECUTING_WITH_WAITPOINTS"); + + const child2 = await engine.trigger( + { + number: 2, + friendlyId: "run_c12345", + environment: authenticatedEnvironment, + taskIdentifier: childTask, + payload: "{}", + payloadType: "application/json", + context: {}, + traceContext: {}, + traceId: "t123456", + spanId: "s123456", + masterQueue: "main", + queue: `task/${childTask}`, + isTest: false, + tags: [], + resumeParentOnCompletion: true, + parentTaskRunId: parentRun.id, + batch: { id: batch.id, index: 1 }, + }, + prisma + ); + + const parentAfterChild2 = await engine.getRunExecutionData({ runId: parentRun.id }); + assertNonNullable(parentAfterChild2); + expect(parentAfterChild2.snapshot.executionStatus).toBe("EXECUTING_WITH_WAITPOINTS"); + + //check the waitpoint blocking the parent run + const runWaitpoints = await prisma.taskRunWaitpoint.findMany({ + where: { + taskRunId: parentRun.id, + }, + include: { + waitpoint: true, + }, + orderBy: { + createdAt: "asc", + }, + }); + expect(runWaitpoints.length).toBe(3); + const child1Waitpoint = runWaitpoints.find( + (w) => w.waitpoint.completedByTaskRunId === child1.id + ); + expect(child1Waitpoint?.waitpoint.type).toBe("RUN"); + expect(child1Waitpoint?.waitpoint.completedByTaskRunId).toBe(child1.id); + expect(child1Waitpoint?.batchId).toBe(batch.id); + expect(child1Waitpoint?.batchIndex).toBe(0); + const child2Waitpoint = runWaitpoints.find( + (w) => w.waitpoint.completedByTaskRunId === child2.id + ); + expect(child2Waitpoint?.waitpoint.type).toBe("RUN"); + expect(child2Waitpoint?.waitpoint.completedByTaskRunId).toBe(child2.id); + expect(child2Waitpoint?.batchId).toBe(batch.id); + expect(child2Waitpoint?.batchIndex).toBe(1); + const batchWaitpoint = runWaitpoints.find((w) => w.waitpoint.type === "BATCH"); + expect(batchWaitpoint?.waitpoint.type).toBe("BATCH"); + expect(batchWaitpoint?.waitpoint.completedByBatchId).toBe(batch.id); + + await engine.unblockRunForCreatedBatch({ + runId: parentRun.id, + batchId: batch.id, + environmentId: authenticatedEnvironment.id, + projectId: authenticatedEnvironment.projectId, + }); + + // Create a checkpoint + const checkpointResult = await engine.createCheckpoint({ + runId: parentRun.id, + snapshotId: parentAfterChild2.snapshot.id, + checkpoint: { + type: "DOCKER", + reason: "TEST_CHECKPOINT", + location: "test-location", + imageRef: "test-image-ref", + }, + }); + + expect(checkpointResult.ok).toBe(true); + + const snapshot = checkpointResult.ok ? checkpointResult.snapshot : null; + + assertNonNullable(snapshot); + + const checkpointRun = checkpointResult.ok ? checkpointResult.run : null; + assertNonNullable(checkpointRun); + + // Verify checkpoint creation + expect(snapshot.executionStatus).toBe("SUSPENDED"); + expect(checkpointRun.status).toBe("WAITING_TO_RESUME"); + + // Get execution data to verify state + const executionData = await engine.getRunExecutionData({ runId: parentRun.id }); + assertNonNullable(executionData); + expect(executionData.snapshot.executionStatus).toBe("SUSPENDED"); + expect(executionData.checkpoint).toBeDefined(); + expect(executionData.checkpoint?.type).toBe("DOCKER"); + expect(executionData.checkpoint?.reason).toBe("TEST_CHECKPOINT"); + + //dequeue and start the 1st child + const dequeuedChild = await engine.dequeueFromMasterQueue({ + consumerId: "test_12345", + masterQueue: child1.masterQueue, + maxRunCount: 1, + }); + + expect(dequeuedChild.length).toBe(1); + + const childAttempt1 = await engine.startRunAttempt({ + runId: dequeuedChild[0].run.id, + snapshotId: dequeuedChild[0].snapshot.id, + }); + + // complete the 1st child + await engine.completeRunAttempt({ + runId: childAttempt1.run.id, + snapshotId: childAttempt1.snapshot.id, + completion: { + id: child1.id, + ok: true, + output: '{"foo":"bar"}', + outputType: "application/json", + }, + }); + + //child snapshot + const childExecutionDataAfter = await engine.getRunExecutionData({ + runId: childAttempt1.run.id, + }); + assertNonNullable(childExecutionDataAfter); + expect(childExecutionDataAfter.snapshot.executionStatus).toBe("FINISHED"); + + const child1WaitpointAfter = await prisma.waitpoint.findFirst({ + where: { + id: child1Waitpoint?.waitpointId, + }, + }); + expect(child1WaitpointAfter?.completedAt).not.toBeNull(); + expect(child1WaitpointAfter?.status).toBe("COMPLETED"); + expect(child1WaitpointAfter?.output).toBe('{"foo":"bar"}'); + + await setTimeout(500); + + const runWaitpointsAfterFirstChild = await prisma.taskRunWaitpoint.findMany({ + where: { + taskRunId: parentRun.id, + }, + include: { + waitpoint: true, + }, + }); + expect(runWaitpointsAfterFirstChild.length).toBe(3); + + //parent snapshot + const parentExecutionDataAfterFirstChildComplete = await engine.getRunExecutionData({ + runId: parentRun.id, + }); + assertNonNullable(parentExecutionDataAfterFirstChildComplete); + expect(parentExecutionDataAfterFirstChildComplete.snapshot.executionStatus).toBe("SUSPENDED"); + expect(parentExecutionDataAfterFirstChildComplete.batch?.id).toBe(batch.id); + expect(parentExecutionDataAfterFirstChildComplete.completedWaitpoints.length).toBe(0); + + expect(await engine.runQueue.lengthOfEnvQueue(authenticatedEnvironment)).toBe(1); + + //dequeue and start the 2nd child + const dequeuedChild2 = await engine.dequeueFromMasterQueue({ + consumerId: "test_12345", + masterQueue: child2.masterQueue, + maxRunCount: 1, + }); + + expect(dequeuedChild2.length).toBe(1); + + const childAttempt2 = await engine.startRunAttempt({ + runId: child2.id, + snapshotId: dequeuedChild2[0].snapshot.id, + }); + await engine.completeRunAttempt({ + runId: child2.id, + snapshotId: childAttempt2.snapshot.id, + completion: { + id: child2.id, + ok: true, + output: '{"baz":"qux"}', + outputType: "application/json", + }, + }); + + //child snapshot + const child2ExecutionDataAfter = await engine.getRunExecutionData({ runId: child1.id }); + assertNonNullable(child2ExecutionDataAfter); + expect(child2ExecutionDataAfter.snapshot.executionStatus).toBe("FINISHED"); + + const child2WaitpointAfter = await prisma.waitpoint.findFirst({ + where: { + id: child2Waitpoint?.waitpointId, + }, + }); + expect(child2WaitpointAfter?.completedAt).not.toBeNull(); + expect(child2WaitpointAfter?.status).toBe("COMPLETED"); + expect(child2WaitpointAfter?.output).toBe('{"baz":"qux"}'); + + await setTimeout(500); + + const runWaitpointsAfterSecondChild = await prisma.taskRunWaitpoint.findMany({ + where: { + taskRunId: parentRun.id, + }, + include: { + waitpoint: true, + }, + }); + expect(runWaitpointsAfterSecondChild.length).toBe(0); + + //parent snapshot + const parentExecutionDataAfterSecondChildComplete = await engine.getRunExecutionData({ + runId: parentRun.id, + }); + assertNonNullable(parentExecutionDataAfterSecondChildComplete); + expect(parentExecutionDataAfterSecondChildComplete.snapshot.executionStatus).toBe("QUEUED"); + expect(parentExecutionDataAfterSecondChildComplete.batch?.id).toBe(batch.id); + expect(parentExecutionDataAfterSecondChildComplete.completedWaitpoints.length).toBe(3); + + // Dequeue the run + const dequeuedParentAfterCheckpoint = await engine.dequeueFromMasterQueue({ + consumerId: "test_12345", + masterQueue: parentRun.masterQueue, + maxRunCount: 10, + }); + + expect(dequeuedParentAfterCheckpoint.length).toBe(1); + expect(dequeuedParentAfterCheckpoint[0].run.id).toBe(parentRun.id); + expect(dequeuedParentAfterCheckpoint[0].snapshot.executionStatus).toBe("PENDING_EXECUTING"); + + // Create an attempt + const parentResumed = await engine.continueRunExecution({ + runId: dequeuedParentAfterCheckpoint[0].run.id, + snapshotId: dequeuedParentAfterCheckpoint[0].snapshot.id, + }); + + expect(parentResumed.snapshot.executionStatus).toBe("EXECUTING"); + + const execution = await engine.getRunExecutionData({ runId: parentRun.id }); + expect(execution?.snapshot.executionStatus).toBe("EXECUTING"); + expect(execution?.batch?.id).toBe(batch.id); + expect(execution?.completedWaitpoints.length).toBe(3); + + const completedWaitpoint0 = execution?.completedWaitpoints.find((w) => w.index === 0); + assertNonNullable(completedWaitpoint0); + expect(completedWaitpoint0.id).toBe(child1Waitpoint!.waitpointId); + expect(completedWaitpoint0.completedByTaskRun?.id).toBe(child1.id); + expect(completedWaitpoint0.completedByTaskRun?.batch?.id).toBe(batch.id); + expect(completedWaitpoint0.output).toBe('{"foo":"bar"}'); + expect(completedWaitpoint0.index).toBe(0); + + const completedWaitpoint1 = execution?.completedWaitpoints.find((w) => w.index === 1); + assertNonNullable(completedWaitpoint1); + expect(completedWaitpoint1.id).toBe(child2Waitpoint!.waitpointId); + expect(completedWaitpoint1.completedByTaskRun?.id).toBe(child2.id); + expect(completedWaitpoint1.completedByTaskRun?.batch?.id).toBe(batch.id); + expect(completedWaitpoint1.index).toBe(1); + expect(completedWaitpoint1.output).toBe('{"baz":"qux"}'); + + const batchWaitpointAfter = execution?.completedWaitpoints.find((w) => w.type === "BATCH"); + expect(batchWaitpointAfter?.id).toBe(batchWaitpoint?.waitpointId); + expect(batchWaitpointAfter?.completedByBatch?.id).toBe(batch.id); + expect(batchWaitpointAfter?.index).toBeUndefined(); + + const batchAfter = await prisma.batchTaskRun.findUnique({ + where: { + id: batch.id, + }, + }); + expect(batchAfter?.status === "COMPLETED"); + } finally { + engine.quit(); + } + }); });