Skip to content

Commit

Permalink
fix(sdks/actor/runtime): fix throttling code missing final call
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanFlurry committed Dec 30, 2024
1 parent e4b2ce7 commit a729d28
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 46 deletions.
17 changes: 17 additions & 0 deletions sdks/actor/client/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { assertUnreachable } from "../common/utils.ts";

export type WebSocketMessage = string | Blob | ArrayBuffer | Uint8Array;

export function messageLength(message: WebSocketMessage): number {
if (message instanceof Blob) {
return message.size;
} else if (message instanceof ArrayBuffer) {
return message.byteLength;
} else if (message instanceof Uint8Array) {
return message.byteLength;
} else if (typeof message === "string") {
return message.length;
} else {
assertUnreachable(message);
}
}
122 changes: 76 additions & 46 deletions sdks/actor/runtime/src/actor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { assertExists } from "@std/assert/exists";
import { deadline } from "@std/async/deadline";
import { throttle } from "@std/async/unstable-throttle";
import type { Logger } from "@std/log/get-logger";
import { Hono, type Context as HonoContext } from "hono";
import { type Context as HonoContext, Hono } from "hono";
import { upgradeWebSocket } from "hono/deno";
import type { WSEvents } from "hono/ws";
import onChange from "on-change";
Expand Down Expand Up @@ -109,14 +109,11 @@ export abstract class Actor<
#connections = new Map<ConnectionId, Connection<this>>();
#eventSubscriptions = new Map<string, Set<Connection<this>>>();

#lastSaveTime = 0;
#pendingSaveTimeout?: number;

protected constructor(config?: Partial<ActorConfig>) {
this.#config = mergeActorConfig(config);

this.#saveStateThrottled = throttle(() => {
this.#saveStateInner().catch((error) => {
logger().error("failed to save state", { error });
});
}, this.#config.state.saveInterval);
}

// This is called by Rivet when the actor is exported as the default
Expand Down Expand Up @@ -166,15 +163,34 @@ export abstract class Actor<

#saveStateLock = new Lock<void>(void 0);

/** Throttled save state method. Used to write to KV at a reasonable cadence. */
#saveStateThrottled: () => void;

/** Promise used to wait for a save to complete. This is required since you cannot await `#saveStateThrottled`. */
#onStateSavedPromise?: PromiseWithResolvers<void>;

/** Saves the state to the database. You probably want to use #saveStateThrottled instead except for a few edge cases. */
/** Throttled save state method. Used to write to KV at a reasonable cadence. */
#saveStateThrottled() {
const now = Date.now();
const timeSinceLastSave = now - this.#lastSaveTime;
const saveInterval = this.#config.state.saveInterval;

// If we're within the throttle window and not already scheduled, schedule the next save.
if (timeSinceLastSave < saveInterval) {
if (this.#pendingSaveTimeout === undefined) {
this.#pendingSaveTimeout = setTimeout(() => {
this.#pendingSaveTimeout = undefined;
this.#saveStateInner();
}, saveInterval - timeSinceLastSave);
}
} else {
// If we're outside the throttle window, save immediately
this.#saveStateInner();
}
}

/** Saves the state to KV. You probably want to use #saveStateThrottled instead except for a few edge cases. */
async #saveStateInner() {
try {
this.#lastSaveTime = Date.now();

if (this.#stateChanged) {
// Use a lock in order to avoid race conditions with multiple
// parallel promises writing to KV. This should almost never happen
Expand Down Expand Up @@ -241,6 +257,8 @@ export abstract class Actor<
logger().error("error in `_onStateChange`", { error });
}
}

// State will be flushed at the end of the RPC
},
{
ignoreDetached: true,
Expand All @@ -261,11 +279,14 @@ export abstract class Actor<
// KEYS.STATE.INITIALIZED,
// KEYS.STATE.DATA,
//]);
const getStateBatch = Object.fromEntries(await this.#ctx.kv.getBatch([
KEYS.STATE.INITIALIZED,
KEYS.STATE.DATA,
]));
const initialized = getStateBatch[String(KEYS.STATE.INITIALIZED)] as boolean;
const getStateBatch = Object.fromEntries(
await this.#ctx.kv.getBatch([
KEYS.STATE.INITIALIZED,
KEYS.STATE.DATA,
]),
);
const initialized =
getStateBatch[String(KEYS.STATE.INITIALIZED)] as boolean;
const stateData = getStateBatch[String(KEYS.STATE.DATA)] as State;

if (!initialized) {
Expand Down Expand Up @@ -403,8 +424,9 @@ export abstract class Actor<
}

const protocolFormatRaw = c.req.query("format");
const { data: protocolFormat, success } =
ProtocolFormatSchema.safeParse(protocolFormatRaw);
const { data: protocolFormat, success } = ProtocolFormatSchema.safeParse(
protocolFormatRaw,
);
if (!success) {
logger().warn("invalid protocol format", {
protocolFormat: protocolFormatRaw,
Expand All @@ -425,8 +447,9 @@ export abstract class Actor<
// Parse and validate params
let params: ConnParams;
try {
params =
typeof paramsStr === "string" ? JSON.parse(paramsStr) : undefined;
params = typeof paramsStr === "string"
? JSON.parse(paramsStr)
: undefined;
} catch (error) {
logger().warn("malformed connection parameters", { error });
throw new errors.MalformedConnectionParameters(error);
Expand Down Expand Up @@ -530,14 +553,16 @@ export abstract class Actor<
const output = await this.#executeRpc(ctx, name, args);

conn._sendWebSocketMessage(
conn._serialize({
body: {
ro: {
i: id,
o: output,
conn._serialize(
{
body: {
ro: {
i: id,
o: output,
},
},
},
} satisfies wsToClient.ToClient),
} satisfies wsToClient.ToClient,
),
);
} else if ("sr" in message.body) {
// Subscription request
Expand Down Expand Up @@ -579,28 +604,32 @@ export abstract class Actor<
// Build response
if (rpcRequestId !== undefined) {
conn._sendWebSocketMessage(
conn._serialize({
body: {
re: {
i: rpcRequestId,
c: code,
m: message,
md: metadata,
conn._serialize(
{
body: {
re: {
i: rpcRequestId,
c: code,
m: message,
md: metadata,
},
},
},
} satisfies wsToClient.ToClient),
} satisfies wsToClient.ToClient,
),
);
} else {
conn._sendWebSocketMessage(
conn._serialize({
body: {
er: {
c: code,
m: message,
md: metadata,
conn._serialize(
{
body: {
er: {
c: code,
m: message,
md: metadata,
},
},
},
} satisfies wsToClient.ToClient),
} satisfies wsToClient.ToClient,
),
);
}
}
Expand Down Expand Up @@ -734,8 +763,9 @@ export abstract class Actor<
for (const connection of subscriptions) {
// Lazily serialize the appropriate format
if (!(connection._protocolFormat in serialized)) {
serialized[connection._protocolFormat] =
connection._serialize(toClient);
serialized[connection._protocolFormat] = connection._serialize(
toClient,
);
}

connection._sendWebSocketMessage(serialized[connection._protocolFormat]);
Expand Down

0 comments on commit a729d28

Please sign in to comment.