diff --git a/__tests__/forbiddenFlags.test.ts b/__tests__/forbiddenFlags.test.ts new file mode 100644 index 00000000..b582263c --- /dev/null +++ b/__tests__/forbiddenFlags.test.ts @@ -0,0 +1,162 @@ +import { + makeWorkerUtils, + runTaskListOnce, + Task, + WorkerSharedOptions, +} from "../src/index"; +import { + ESCAPED_GRAPHILE_WORKER_SCHEMA, + reset, + TEST_CONNECTION_STRING, + withPgClient, + withPgPool, +} from "./helpers"; + +const options: WorkerSharedOptions = {}; + +test("supports the flags API", () => + withPgClient(async (pgClient) => { + await reset(pgClient, options); + + // Schedule a job + const utils = await makeWorkerUtils({ + connectionString: TEST_CONNECTION_STRING, + }); + await utils.addJob("job1", { a: 1 }, { flags: ["a", "b"] }); + await utils.release(); + + // Assert that it has an entry in jobs / job_queues + const { rows: jobs } = await pgClient.query( + `select * from ${ESCAPED_GRAPHILE_WORKER_SCHEMA}.jobs`, + ); + expect(jobs).toHaveLength(1); + expect(jobs[0]).toHaveProperty("flags"); + expect(jobs[0].flags).toHaveLength(2); + + const task: Task = jest.fn(); + const taskList = { task }; + await runTaskListOnce(options, taskList, pgClient); + })); + +test("get_job skips forbidden flags with string[] arg", () => + withPgPool(async (pgPool) => { + await reset(pgPool, options); + + const badFlag = "d"; + + const shouldRun = jest.fn(); + const shouldSkip = jest.fn(); + + const job: Task = async (_payload, helpers) => { + const flags = helpers.job.flags || []; + + if (flags.includes(badFlag)) { + shouldSkip(); + } else { + shouldRun(); + } + }; + + // Schedule a job + const utils = await makeWorkerUtils({ pgPool }); + + await utils.addJob("flag-test", { a: 1 }, { flags: ["a", "b"] }); + await utils.addJob("flag-test", { a: 1 }, { flags: ["c", "d"] }); + await utils.release(); + + // Assert that it has an entry in jobs / job_queues + const pgClient = await pgPool.connect(); + + await runTaskListOnce( + { forbiddenFlags: ["d"] }, + { "flag-test": job }, + pgClient, + ); + + await pgClient.release(); + + expect(shouldRun).toHaveBeenCalled(); + expect(shouldSkip).not.toHaveBeenCalled(); + })); + +test("get_job skips forbidden flags with () => string[] arg", () => + withPgPool(async (pgPool) => { + await reset(pgPool, options); + + const badFlag = "d"; + + const shouldRun = jest.fn(); + const shouldSkip = jest.fn(); + + const job: Task = async (_payload, helpers) => { + const flags = helpers.job.flags || []; + + if (flags.includes(badFlag)) { + shouldSkip(); + } else { + shouldRun(); + } + }; + + // Schedule a job + const utils = await makeWorkerUtils({ pgPool }); + + await utils.addJob("flag-test", { a: 1 }, { flags: ["a", "b"] }); + await utils.addJob("flag-test", { a: 1 }, { flags: ["c", "d"] }); + await utils.release(); + + // Assert that it has an entry in jobs / job_queues + const pgClient = await pgPool.connect(); + + await runTaskListOnce( + { forbiddenFlags: () => ["d"] }, + { "flag-test": job }, + pgClient, + ); + + await pgClient.release(); + + expect(shouldRun).toHaveBeenCalled(); + expect(shouldSkip).not.toHaveBeenCalled(); + })); + +test("get_job skips forbidden flags with () => Promise arg", () => + withPgPool(async (pgPool) => { + await reset(pgPool, options); + + const badFlag = "d"; + + const shouldRun = jest.fn(); + const shouldSkip = jest.fn(); + + const job: Task = async (_payload, helpers) => { + const flags = helpers.job.flags || []; + + if (flags.includes(badFlag)) { + shouldSkip(); + } else { + shouldRun(); + } + }; + + // Schedule a job + const utils = await makeWorkerUtils({ pgPool }); + + await utils.addJob("flag-test", { a: 1 }, { flags: ["a", "b"] }); + await utils.addJob("flag-test", { a: 1 }, { flags: ["c", "d"] }); + await utils.release(); + + // Assert that it has an entry in jobs / job_queues + const pgClient = await pgPool.connect(); + + await runTaskListOnce( + { forbiddenFlags: async () => ["d"] }, + { "flag-test": job }, + pgClient, + ); + + await pgClient.release(); + + expect(shouldRun).toHaveBeenCalled(); + expect(shouldSkip).not.toHaveBeenCalled(); + })); diff --git a/__tests__/helpers.ts b/__tests__/helpers.ts index 536a22b0..bd360b0b 100644 --- a/__tests__/helpers.ts +++ b/__tests__/helpers.ts @@ -109,6 +109,7 @@ export function makeMockJob(taskIdentifier: string): Job { locked_at: null, locked_by: null, key: null, + flags: null, }; } diff --git a/__tests__/migrate.test.ts b/__tests__/migrate.test.ts index a7f5f35f..c1d8bc0e 100644 --- a/__tests__/migrate.test.ts +++ b/__tests__/migrate.test.ts @@ -33,7 +33,7 @@ test("migration installs schema; second migration does no harm", async () => { const { rows: migrationRows } = await pgClient.query( `select * from ${ESCAPED_GRAPHILE_WORKER_SCHEMA}.migrations`, ); - expect(migrationRows).toHaveLength(4); + expect(migrationRows).toHaveLength(5); const migration = migrationRows[0]; expect(migration.id).toEqual(1); diff --git a/sql/000005.sql b/sql/000005.sql new file mode 100644 index 00000000..ad3ca099 --- /dev/null +++ b/sql/000005.sql @@ -0,0 +1,178 @@ +alter table :GRAPHILE_WORKER_SCHEMA.jobs add column revision int default 0 not null; +alter table :GRAPHILE_WORKER_SCHEMA.jobs add column flags text[] default null; + +drop function :GRAPHILE_WORKER_SCHEMA.add_job; +create function :GRAPHILE_WORKER_SCHEMA.add_job( + identifier text, + payload json = null, + queue_name text = null, + run_at timestamptz = null, + max_attempts int = null, + job_key text = null, + priority int = null, + flags text[] = null +) returns :GRAPHILE_WORKER_SCHEMA.jobs as $$ +declare + v_job :GRAPHILE_WORKER_SCHEMA.jobs; +begin + -- Apply rationality checks + if length(identifier) > 128 then + raise exception 'Task identifier is too long (max length: 128).' using errcode = 'GWBID'; + end if; + if queue_name is not null and length(queue_name) > 128 then + raise exception 'Job queue name is too long (max length: 128).' using errcode = 'GWBQN'; + end if; + if job_key is not null and length(job_key) > 512 then + raise exception 'Job key is too long (max length: 512).' using errcode = 'GWBJK'; + end if; + if max_attempts < 1 then + raise exception 'Job maximum attempts must be at least 1' using errcode = 'GWBMA'; + end if; + + if job_key is not null then + -- Upsert job + insert into :GRAPHILE_WORKER_SCHEMA.jobs ( + task_identifier, + payload, + queue_name, + run_at, + max_attempts, + key, + priority, + flags + ) + values( + identifier, + coalesce(payload, '{}'::json), + queue_name, + coalesce(run_at, now()), + coalesce(max_attempts, 25), + job_key, + coalesce(priority, 0), + flags + ) + on conflict (key) do update set + task_identifier=excluded.task_identifier, + payload=excluded.payload, + queue_name=excluded.queue_name, + max_attempts=excluded.max_attempts, + run_at=excluded.run_at, + priority=excluded.priority, + revision=jobs.revision + 1, + flags=excluded.flags, + -- always reset error/retry state + attempts=0, + last_error=null + where jobs.locked_at is null + returning * + into v_job; + + -- If upsert succeeded (insert or update), return early + if not (v_job is null) then + return v_job; + end if; + + -- Upsert failed -> there must be an existing job that is locked. Remove + -- existing key to allow a new one to be inserted, and prevent any + -- subsequent retries by bumping attempts to the max allowed. + update :GRAPHILE_WORKER_SCHEMA.jobs + set + key = null, + attempts = jobs.max_attempts + where key = job_key; + end if; + + -- insert the new job. Assume no conflicts due to the update above + insert into :GRAPHILE_WORKER_SCHEMA.jobs( + task_identifier, + payload, + queue_name, + run_at, + max_attempts, + key, + priority, + flags + ) + values( + identifier, + coalesce(payload, '{}'::json), + queue_name, + coalesce(run_at, now()), + coalesce(max_attempts, 25), + job_key, + coalesce(priority, 0), + flags + ) + returning * + into v_job; + + return v_job; +end; +$$ language plpgsql volatile; + +drop function :GRAPHILE_WORKER_SCHEMA.get_job; + +create function :GRAPHILE_WORKER_SCHEMA.get_job( + worker_id text, + task_identifiers text[] = null, + job_expiry interval = interval '4 hours', + forbidden_flags text[] = null +) returns :GRAPHILE_WORKER_SCHEMA.jobs as $$ +declare + v_job_id bigint; + v_queue_name text; + v_row :GRAPHILE_WORKER_SCHEMA.jobs; + v_now timestamptz = now(); +begin + if worker_id is null or length(worker_id) < 10 then + raise exception 'invalid worker id'; + end if; + + select jobs.queue_name, jobs.id into v_queue_name, v_job_id + from :GRAPHILE_WORKER_SCHEMA.jobs + where (jobs.locked_at is null or jobs.locked_at < (v_now - job_expiry)) + and ( + jobs.queue_name is null + or + exists ( + select 1 + from :GRAPHILE_WORKER_SCHEMA.job_queues + where job_queues.queue_name = jobs.queue_name + and (job_queues.locked_at is null or job_queues.locked_at < (v_now - job_expiry)) + for update + skip locked + ) + ) + and run_at <= v_now + and attempts < max_attempts + and (task_identifiers is null or task_identifier = any(task_identifiers)) + and (forbidden_flags is null or not (forbidden_flags && flags)) + order by priority asc, run_at asc, id asc + limit 1 + for update + skip locked; + + if v_job_id is null then + return null; + end if; + + if v_queue_name is not null then + update :GRAPHILE_WORKER_SCHEMA.job_queues + set + locked_by = worker_id, + locked_at = v_now + where job_queues.queue_name = v_queue_name; + end if; + + update :GRAPHILE_WORKER_SCHEMA.jobs + set + attempts = attempts + 1, + locked_by = worker_id, + locked_at = v_now + where id = v_job_id + returning * into v_row; + + return v_row; +end; +$$ language plpgsql volatile; + diff --git a/src/helpers.ts b/src/helpers.ts index 8dd664e1..94efd39e 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -27,7 +27,8 @@ export function makeAddJob( run_at => $4::timestamptz, max_attempts => $5::int, job_key => $6::text, - priority => $7::int + priority => $7::int, + flags => $8::text[] ); `, [ @@ -38,6 +39,7 @@ export function makeAddJob( spec.maxAttempts || null, spec.jobKey || null, spec.priority || null, + spec.flags || null, ], ); const job: Job = rows[0]; diff --git a/src/interfaces.ts b/src/interfaces.ts index 6387c114..f1257b5f 100644 --- a/src/interfaces.ts +++ b/src/interfaces.ts @@ -175,6 +175,7 @@ export interface Job { key: string | null; locked_at: Date | null; locked_by: string | null; + flags: string[] | null; } export interface Worker { @@ -223,8 +224,15 @@ export interface TaskSpec { * Unique identifier for the job, can be used to update or remove it later if needed. (Default: null) */ jobKey?: string; + + /** + * Flags for the job, can be used to dynamically filter which jobs can and cannot run at runtime + */ + flags?: string[]; } +export type ForbiddenFlagsFn = () => string[] | Promise; + /** * These options are common Graphile Worker pools, workers, and utils. */ @@ -260,6 +268,14 @@ export interface SharedOptions { * example if you wish to use Graphile Worker with pgBouncer or similar. */ noPreparedStatements?: boolean; + + /** + * An array of strings or function returning an array of strings or promise resolving to + * an array of strings that represent flags + * + * Graphile worker will skip the execution of any jobs that contain these flags + */ + forbiddenFlags?: null | string[] | ForbiddenFlagsFn; } /** diff --git a/src/worker.ts b/src/worker.ts index b01e2912..54538368 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -23,6 +23,7 @@ export function makeNewWorker( workerId = `worker-${randomBytes(9).toString("hex")}`, pollInterval = defaults.pollInterval, noPreparedStatements, + forbiddenFlags, } = options; const { workerSchema, @@ -77,6 +78,20 @@ export function makeNewWorker( const supportedTaskNames = Object.keys(tasks); assert(supportedTaskNames.length, "No runnable tasks!"); + let flagsToSkip: null | string[] = null; + + if (Array.isArray(forbiddenFlags)) { + flagsToSkip = forbiddenFlags; + } else if (typeof forbiddenFlags === "function") { + const forbiddenFlagsResult = forbiddenFlags(); + + if (Array.isArray(forbiddenFlagsResult)) { + flagsToSkip = forbiddenFlagsResult; + } + + flagsToSkip = await forbiddenFlagsResult; + } + const { rows: [jobRow], } = await withPgClient((client) => @@ -84,8 +99,8 @@ export function makeNewWorker( text: // TODO: breaking change; change this to more optimal: // `SELECT id, queue_name, task_identifier, payload FROM ${escapedWorkerSchema}.get_job($1, $2);`, - `SELECT * FROM ${escapedWorkerSchema}.get_job($1, $2);`, - values: [workerId, supportedTaskNames], + `SELECT * FROM ${escapedWorkerSchema}.get_job($1, $2, forbidden_flags => $3);`, + values: [workerId, supportedTaskNames, flagsToSkip], name: noPreparedStatements ? undefined : `get_job/${workerSchema}`, }), );