Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: do not allow defining new actions from within other actions/flows #725

Merged
merged 9 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/errors/no_new_actions_at_runtime.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# No new actions at runtime error

Defining new actions at runtime is not allowed.

✅ DO:

```ts
const prompt = defineDotprompt({...})

const flow = defineFlow({...}, async (input) => {
await prompt.generate(...);
})
```

❌ DON'T:

```ts
const flow = defineFlow({...}, async (input) => {
const prompt = defineDotprompt({...})
prompt.generate(...);
})
```
34 changes: 31 additions & 3 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ import { JSONSchema7 } from 'json-schema';
import { AsyncLocalStorage } from 'node:async_hooks';
import { performance } from 'node:perf_hooks';
import * as z from 'zod';
import { ActionType, lookupPlugin, registerAction } from './registry.js';
import {
ActionType,
initializeAllPlugins,
lookupPlugin,
registerAction,
} from './registry.js';
import { parseSchema } from './schema.js';
import * as telemetry from './telemetry.js';
import {
Expand Down Expand Up @@ -216,9 +221,16 @@ export function defineAction<
},
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
const act = action(config, (i: I): Promise<z.infer<O>> => {
if (isInRuntimeContext()) {
throw new Error(
'Cannot define new actions at runtime.\n' +
'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md'
);
}
const act = action(config, async (i: I): Promise<z.infer<O>> => {
setCustomMetadataAttributes({ subtype: config.actionType });
return fn(i);
await initializeAllPlugins();
return await runInActionRuntimeContext(() => fn(i));
});
act.__action.actionType = config.actionType;
registerAction(config.actionType, act);
Expand Down Expand Up @@ -252,3 +264,19 @@ export function getStreamingCallback<S>(): StreamingCallback<S> | undefined {
}
return cb;
}

const runtimeCtxAls = new AsyncLocalStorage<any>();

/**
* Checks whether the caller is currently in the runtime context of an action.
*/
export function isInRuntimeContext() {
return !!runtimeCtxAls.getStore();
}

/**
* Execute the provided function in the action runtime context.
*/
export function runInActionRuntimeContext<R>(fn: () => R) {
return runtimeCtxAls.run('runtime', fn);
}
8 changes: 7 additions & 1 deletion js/core/src/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { z } from 'zod';
import { Action } from './action.js';
import { Action, isInRuntimeContext } from './action.js';
import { FlowStateStore } from './flowTypes.js';
import { LoggerConfig, TelemetryConfig } from './telemetryTypes.js';
import { TraceStore } from './tracing.js';
Expand Down Expand Up @@ -60,6 +60,12 @@ export function genkitPlugin<T extends PluginInit>(
pluginName: string,
initFn: T
): Plugin<Parameters<T>> {
if (isInRuntimeContext()) {
throw new Error(
'Cannot define new plugins at runtime.\n' +
'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md'
);
}
return (...args: Parameters<T>) => ({
name: pluginName,
initializer: async () => {
Expand Down
16 changes: 14 additions & 2 deletions js/core/src/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,19 @@ type ActionsRecord = Record<string, Action<z.ZodTypeAny, z.ZodTypeAny>>;
* Returns all actions in the registry.
*/
export async function listActions(): Promise<ActionsRecord> {
await initializeAllPlugins();
return Object.assign({}, actionsById());
}

let allPluginsInitialized = false;
export async function initializeAllPlugins() {
if (allPluginsInitialized) {
return;
}
for (const pluginName of Object.keys(pluginsByName())) {
await initializePlugin(pluginName);
}
return Object.assign({}, actionsById());
allPluginsInitialized = true;
}

/**
Expand Down Expand Up @@ -195,14 +204,17 @@ export async function lookupFlowStateStore(
* Registers a flow state store for the given environment.
*/
export function registerPluginProvider(name: string, provider: PluginProvider) {
allPluginsInitialized = false;
let cached;
let isInitialized = false;
pluginsByName()[name] = {
name: provider.name,
initializer: () => {
if (cached) {
if (isInitialized) {
return cached;
}
cached = provider.initializer();
isInitialized = true;
return cached;
},
};
Expand Down
2 changes: 2 additions & 0 deletions js/flow/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
StreamingCallback,
} from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { initializeAllPlugins } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import {
newTrace,
Expand Down Expand Up @@ -391,6 +392,7 @@ export class Flow<
labels: Record<string, string> | undefined
) {
const startTimeMs = performance.now();
await initializeAllPlugins();
await runWithActiveContext(ctx, async () => {
let traceContext;
if (ctx.state.traceContext) {
Expand Down
3 changes: 2 additions & 1 deletion js/flow/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

import { runInActionRuntimeContext } from '@genkit-ai/core';
import { AsyncLocalStorage } from 'node:async_hooks';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
Expand Down Expand Up @@ -45,7 +46,7 @@ export function runWithActiveContext<R>(
ctx: Context<z.ZodTypeAny, z.ZodTypeAny, z.ZodTypeAny>,
fn: () => R
) {
return ctxAsyncLocalStorage.run(ctx, fn);
return ctxAsyncLocalStorage.run(ctx, () => runInActionRuntimeContext(fn));
}

/**
Expand Down
Loading