diff --git a/packages/connect-node/src/node-universal-handler.ts b/packages/connect-node/src/node-universal-handler.ts index 80c59c25c..49da585a2 100644 --- a/packages/connect-node/src/node-universal-handler.ts +++ b/packages/connect-node/src/node-universal-handler.ts @@ -127,7 +127,7 @@ export function universalRequestFromNodeRequest( header: nodeHeaderToWebHeader(nodeRequest.headers), body, signal: abortController.signal, - values: contextValues, + contextValues: contextValues, }; } diff --git a/packages/connect-web-bench/README.md b/packages/connect-web-bench/README.md index f9e44ee0a..50e64a63f 100644 --- a/packages/connect-web-bench/README.md +++ b/packages/connect-web-bench/README.md @@ -10,5 +10,5 @@ it like a web server would usually do. | code generator | bundle size | minified | compressed | |----------------|-------------------:|-----------------------:|---------------------:| -| connect | 113,658 b | 49,964 b | 13,486 b | +| connect | 114,538 b | 50,308 b | 13,557 b | | grpc-web | 414,071 b | 300,352 b | 53,255 b | diff --git a/packages/connect-web/src/connect-transport.ts b/packages/connect-web/src/connect-transport.ts index faf3f9fdd..0f4a222ef 100644 --- a/packages/connect-web/src/connect-transport.ts +++ b/packages/connect-web/src/connect-transport.ts @@ -30,8 +30,9 @@ import type { Transport, UnaryRequest, UnaryResponse, + ContextValues, } from "@connectrpc/connect"; -import { appendHeaders } from "@connectrpc/connect"; +import { appendHeaders, createContextValues } from "@connectrpc/connect"; import { createClientMethodSerializers, createEnvelopeReadableStream, @@ -142,6 +143,7 @@ export function createConnectTransport( timeoutMs: number | undefined, header: HeadersInit | undefined, message: PartialMessage, + contextValues?: ContextValues, ): Promise> { const { serialize, parse } = createClientMethodSerializers( method, @@ -176,6 +178,7 @@ export function createConnectTransport( timeoutMs, header, ), + contextValues: contextValues ?? createContextValues(), message, }, next: async (req: UnaryRequest): Promise> => { @@ -242,6 +245,7 @@ export function createConnectTransport( timeoutMs: number | undefined, header: HeadersInit | undefined, input: AsyncIterable>, + contextValues?: ContextValues, ): Promise> { const { serialize, parse } = createClientMethodSerializers( method, @@ -320,6 +324,7 @@ export function createConnectTransport( timeoutMs, header, ), + contextValues: contextValues ?? createContextValues(), message: input, }, next: async (req) => { diff --git a/packages/connect-web/src/grpc-web-transport.ts b/packages/connect-web/src/grpc-web-transport.ts index b4711680d..05c0c0e65 100644 --- a/packages/connect-web/src/grpc-web-transport.ts +++ b/packages/connect-web/src/grpc-web-transport.ts @@ -29,7 +29,9 @@ import type { Transport, UnaryRequest, UnaryResponse, + ContextValues, } from "@connectrpc/connect"; +import { createContextValues } from "@connectrpc/connect"; import { createClientMethodSerializers, createEnvelopeReadableStream, @@ -137,6 +139,7 @@ export function createGrpcWebTransport( timeoutMs: number | undefined, header: Headers, message: PartialMessage, + contextValues?: ContextValues, ): Promise> { const { serialize, parse } = createClientMethodSerializers( method, @@ -166,6 +169,7 @@ export function createGrpcWebTransport( mode: "cors", }, header: requestHeader(useBinaryFormat, timeoutMs, header), + contextValues: contextValues ?? createContextValues(), message, }, next: async (req: UnaryRequest): Promise> => { @@ -233,6 +237,7 @@ export function createGrpcWebTransport( timeoutMs: number | undefined, header: HeadersInit | undefined, input: AsyncIterable>, + contextValues?: ContextValues, ): Promise> { const { serialize, parse } = createClientMethodSerializers( method, @@ -323,6 +328,7 @@ export function createGrpcWebTransport( mode: "cors", }, header: requestHeader(useBinaryFormat, timeoutMs, header), + contextValues: contextValues ?? createContextValues(), message: input, }, next: async (req) => { diff --git a/packages/connect/src/call-options.ts b/packages/connect/src/call-options.ts index 874564f7e..ec2504d44 100644 --- a/packages/connect/src/call-options.ts +++ b/packages/connect/src/call-options.ts @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +import type { ContextValues } from "./context-values.js"; + /** * Options for a call. Every client should accept CallOptions as optional * argument in its RPC methods. @@ -44,4 +46,9 @@ export interface CallOptions { * Called when response trailers are received. */ onTrailer?(trailers: Headers): void; + + /** + * ContextValues to pass to the interceptors. + */ + contextValues?: ContextValues; } diff --git a/packages/connect/src/callback-client.ts b/packages/connect/src/callback-client.ts index 4c94ca7d7..9826aab91 100644 --- a/packages/connect/src/callback-client.ts +++ b/packages/connect/src/callback-client.ts @@ -99,6 +99,7 @@ function createUnaryFn, O extends Message>( options.timeoutMs, options.headers, requestMessage, + options.contextValues, ) .then( (response) => { @@ -146,6 +147,7 @@ function createServerStreamingFn, O extends Message>( options.timeoutMs, options.headers, createAsyncIterable([input]), + options.contextValues, ); options.onHeader?.(response.header); for await (const message of response.message) { diff --git a/packages/connect/src/implementation.ts b/packages/connect/src/implementation.ts index 1364f4c88..97ea6d638 100644 --- a/packages/connect/src/implementation.ts +++ b/packages/connect/src/implementation.ts @@ -140,7 +140,7 @@ interface HandlerContextInit { requestHeader?: HeadersInit; responseHeader?: HeadersInit; responseTrailer?: HeadersInit; - values?: ContextValues; + contextValues?: ContextValues; } interface HandlerContextController extends HandlerContext { @@ -181,7 +181,7 @@ export function createHandlerContext( deadline.cleanup(); abortController.abort(reason); }, - values: init.values ?? createContextValues(), + values: init.contextValues ?? createContextValues(), }; } diff --git a/packages/connect/src/interceptor.ts b/packages/connect/src/interceptor.ts index f84cafd91..0d30f69ad 100644 --- a/packages/connect/src/interceptor.ts +++ b/packages/connect/src/interceptor.ts @@ -18,6 +18,7 @@ import type { MethodInfo, ServiceType, } from "@bufbuild/protobuf"; +import type { ContextValues } from "./context-values.js"; /** * An interceptor can add logic to clients, similar to the decorators @@ -164,6 +165,11 @@ interface RequestCommon, O extends Message> { * Headers that will be sent along with the request. */ readonly header: Headers; + + /** + * The context values for the current call. + */ + readonly contextValues: ContextValues; } interface ResponseCommon, O extends Message> { diff --git a/packages/connect/src/promise-client.spec.ts b/packages/connect/src/promise-client.spec.ts index 7cace483b..d613fb3d3 100644 --- a/packages/connect/src/promise-client.spec.ts +++ b/packages/connect/src/promise-client.spec.ts @@ -17,16 +17,24 @@ import { createBiDiStreamingFn, createClientStreamingFn, createServerStreamingFn, + createUnaryFn, } from "./promise-client.js"; import { createAsyncIterable } from "./protocol/async-iterable.js"; import { createRouterTransport } from "./router-transport.js"; import type { HandlerContext } from "./implementation"; import { ConnectError } from "./connect-error.js"; import { Code } from "./code.js"; +import { createContextKey, createContextValues } from "./context-values.js"; const TestService = { typeName: "handwritten.TestService", methods: { + unaryMethod: { + name: "Unary", + I: Int32Value, + O: StringValue, + kind: MethodKind.Unary, + }, clientStream: { name: "ClientStream", I: Int32Value, @@ -48,6 +56,54 @@ const TestService = { }, } as const; +const kString = createContextKey("foo"); + +describe("createUnaryFn()", function () { + it("passes the context values to interceptors", async () => { + const input = new Int32Value({ value: 1 }); + + const output = new StringValue({ value: "yield 1" }); + let interceptorCalled = false; + + const transport = createRouterTransport( + ({ service }) => { + service(TestService, { + unaryMethod: ( + // eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock + _input: Int32Value, + // eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock + _context: HandlerContext, + ) => Promise.resolve(output), + }); + }, + { + transport: { + interceptors: [ + (next) => { + return (req) => { + interceptorCalled = true; + expect(req.contextValues.get(kString)).toBe("bar"); + return next(req); + }; + }, + ], + }, + }, + ); + const fn = createUnaryFn( + transport, + TestService, + TestService.methods.unaryMethod, + ); + const res = await fn(input, { + contextValues: createContextValues().set(kString, "bar"), + }); + expect(res).toBeInstanceOf(StringValue); + expect(res.value).toEqual(output.value); + expect(interceptorCalled).toBe(true); + }); +}); + describe("createClientStreamingFn()", function () { it("works as expected on the happy path", async () => { const input = new Int32Value({ value: 1 }); @@ -78,6 +134,56 @@ describe("createClientStreamingFn()", function () { expect(res).toBeInstanceOf(StringValue); expect(res.value).toEqual(output.value); }); + it("passes the context values to interceptors", async () => { + const input = new Int32Value({ value: 1 }); + + const output = new StringValue({ value: "yield 1" }); + const kString = createContextKey("foo"); + let interceptorCalled = false; + + const transport = createRouterTransport( + ({ service }) => { + service(TestService, { + clientStream: ( + // eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock + _input: AsyncIterable, + // eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock + _context: HandlerContext, + ) => Promise.resolve(output), + }); + }, + { + transport: { + interceptors: [ + (next) => { + return (req) => { + interceptorCalled = true; + expect(req.contextValues.get(kString)).toBe("bar"); + return next(req); + }; + }, + ], + }, + }, + ); + const fn = createClientStreamingFn( + transport, + TestService, + TestService.methods.clientStream, + ); + const res = await fn( + // eslint-disable-next-line @typescript-eslint/require-await + (async function* () { + yield input; + })(), + { + contextValues: createContextValues().set(kString, "bar"), + }, + ); + expect(res).toBeInstanceOf(StringValue); + expect(res.value).toEqual(output.value); + expect(interceptorCalled).toBe(true); + }); it("closes the request iterable when response is received", async () => { const output = new StringValue({ value: "yield 1" }); const transport = createRouterTransport(({ service }) => { @@ -182,6 +288,51 @@ describe("createServerStreamingFn()", function () { } expect(receivedMessages).toEqual(output); }); + it("passes the context values to interceptors", async () => { + const output = [ + new StringValue({ value: "input1" }), + new StringValue({ value: "input2" }), + new StringValue({ value: "input3" }), + ]; + let interceptorCalled = false; + const transport = createRouterTransport( + ({ service }) => { + service(TestService, { + // eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock + serverStream: (_input: Int32Value, _context: HandlerContext) => + createAsyncIterable(output), + }); + }, + { + transport: { + interceptors: [ + (next) => { + return (req) => { + interceptorCalled = true; + expect(req.contextValues.get(kString)).toBe("bar"); + return next(req); + }; + }, + ], + }, + }, + ); + + const fn = createServerStreamingFn( + transport, + TestService, + TestService.methods.serverStream, + ); + const receivedMessages: StringValue[] = []; + const input = new Int32Value({ value: 123 }); + for await (const res of fn(input, { + contextValues: createContextValues().set(kString, "bar"), + })) { + receivedMessages.push(res); + } + expect(receivedMessages).toEqual(output); + expect(interceptorCalled).toBeTrue(); + }); it("doesn't support throw/return on the returned response", function () { const fn = createServerStreamingFn( createRouterTransport(({ service }) => { @@ -232,6 +383,57 @@ describe("createBiDiStreamingFn()", () => { expect(index).toBe(3); expect(bidiIndex).toBe(3); }); + it("passes the context values to interceptors", async () => { + const values = [123, 456, 789]; + + const input = createAsyncIterable( + values.map((value) => new Int32Value({ value })), + ); + let interceptorCalled = false; + let bidiIndex = 0; + const transport = createRouterTransport( + ({ service }) => { + service(TestService, { + bidiStream: async function* (input: AsyncIterable) { + for await (const thing of input) { + expect(thing.value).toBe(values[bidiIndex]); + bidiIndex += 1; + yield new StringValue({ value: thing.value.toString() }); + } + }, + }); + }, + { + transport: { + interceptors: [ + (next) => { + return (req) => { + interceptorCalled = true; + expect(req.contextValues.get(kString)).toBe("bar"); + return next(req); + }; + }, + ], + }, + }, + ); + const fn = createBiDiStreamingFn( + transport, + TestService, + TestService.methods.bidiStream, + ); + + let index = 0; + for await (const res of fn(input, { + contextValues: createContextValues().set(kString, "bar"), + })) { + expect(res).toEqual(new StringValue({ value: values[index].toString() })); + index += 1; + } + expect(index).toBe(3); + expect(bidiIndex).toBe(3); + expect(interceptorCalled).toBeTrue(); + }); it("closes the request iterable when response is received", async () => { const values = [123, 456, 789]; diff --git a/packages/connect/src/promise-client.ts b/packages/connect/src/promise-client.ts index b04d56575..8ac8c706c 100644 --- a/packages/connect/src/promise-client.ts +++ b/packages/connect/src/promise-client.ts @@ -77,7 +77,7 @@ type UnaryFn, O extends Message> = ( options?: CallOptions, ) => Promise; -function createUnaryFn, O extends Message>( +export function createUnaryFn, O extends Message>( transport: Transport, service: ServiceType, method: MethodInfo, @@ -90,6 +90,7 @@ function createUnaryFn, O extends Message>( options?.timeoutMs, options?.headers, input, + options?.contextValues, ); options?.onHeader?.(response.header); options?.onTrailer?.(response.trailer); @@ -123,6 +124,7 @@ export function createServerStreamingFn< options?.timeoutMs, options?.headers, createAsyncIterable([input]), + options?.contextValues, ), options, ); @@ -157,6 +159,7 @@ export function createClientStreamingFn< options?.timeoutMs, options?.headers, request, + options?.contextValues, ); options?.onHeader?.(response.header); let singleMessage: O | undefined; @@ -203,6 +206,7 @@ export function createBiDiStreamingFn< options?.timeoutMs, options?.headers, request, + options?.contextValues, ), options, ); diff --git a/packages/connect/src/protocol-connect/handler-factory.ts b/packages/connect/src/protocol-connect/handler-factory.ts index 4c245933d..e73bb4ddd 100644 --- a/packages/connect/src/protocol-connect/handler-factory.ts +++ b/packages/connect/src/protocol-connect/handler-factory.ts @@ -195,7 +195,7 @@ function createUnaryHandler, O extends Message>( ? contentTypeUnaryProto : contentTypeUnaryJson, }, - values: req.values, + contextValues: req.contextValues, }); const compression = compressionNegotiate( opt.acceptCompression, @@ -380,7 +380,7 @@ function createStreamHandler, O extends Message>( ? contentTypeStreamProto : contentTypeStreamJson, }, - values: req.values, + contextValues: req.contextValues, }); const compression = compressionNegotiate( opt.acceptCompression, diff --git a/packages/connect/src/protocol-connect/transport.ts b/packages/connect/src/protocol-connect/transport.ts index 238675e60..aac92d4e5 100644 --- a/packages/connect/src/protocol-connect/transport.ts +++ b/packages/connect/src/protocol-connect/transport.ts @@ -53,6 +53,8 @@ import { createMethodUrl } from "../protocol/create-method-url.js"; import { runUnaryCall, runStreamingCall } from "../protocol/run-call.js"; import { createMethodSerializationLookup } from "../protocol/serialization.js"; import type { Transport } from "../transport.js"; +import type { ContextValues } from "../context-values.js"; +import { createContextValues } from "../context-values.js"; /** * Create a Transport for the Connect protocol. @@ -69,6 +71,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, message: PartialMessage, + contextValues?: ContextValues, ): Promise> { const serialization = createMethodSerializationLookup( method, @@ -100,6 +103,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.acceptCompression, opt.sendCompression, ), + contextValues: contextValues ?? createContextValues(), message, }, next: async (req: UnaryRequest): Promise> => { @@ -188,6 +192,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, input: AsyncIterable>, + contextValues?: ContextValues, ): Promise> { const serialization = createMethodSerializationLookup( method, @@ -226,6 +231,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.acceptCompression, opt.sendCompression, ), + contextValues: contextValues ?? createContextValues(), message: input, }, next: async (req: StreamRequest) => { diff --git a/packages/connect/src/protocol-grpc-web/handler-factory.ts b/packages/connect/src/protocol-grpc-web/handler-factory.ts index ea923d5da..3d821af6e 100644 --- a/packages/connect/src/protocol-grpc-web/handler-factory.ts +++ b/packages/connect/src/protocol-grpc-web/handler-factory.ts @@ -137,7 +137,7 @@ function createHandler, O extends Message>( responseTrailer: { [headerGrpcStatus]: grpcStatusOk, }, - values: req.values, + contextValues: req.contextValues, }); const compression = compressionNegotiate( opt.acceptCompression, diff --git a/packages/connect/src/protocol-grpc-web/transport.ts b/packages/connect/src/protocol-grpc-web/transport.ts index f4cdd0a9c..21063e5db 100644 --- a/packages/connect/src/protocol-grpc-web/transport.ts +++ b/packages/connect/src/protocol-grpc-web/transport.ts @@ -47,6 +47,8 @@ import { runUnaryCall, runStreamingCall } from "../protocol/run-call.js"; import { createMethodSerializationLookup } from "../protocol/serialization.js"; import type { CommonTransportOptions } from "../protocol/transport-options.js"; import type { Transport } from "../transport.js"; +import { createContextValues } from "../context-values.js"; +import type { ContextValues } from "../context-values.js"; /** * Create a Transport for the gRPC-web protocol. @@ -63,6 +65,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, message: PartialMessage, + contextValues?: ContextValues, ): Promise> { const serialization = createMethodSerializationLookup( method, @@ -93,6 +96,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.acceptCompression, opt.sendCompression, ), + contextValues: contextValues ?? createContextValues(), message, }, next: async (req: UnaryRequest): Promise> => { @@ -192,6 +196,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, input: AsyncIterable>, + contextValues?: ContextValues, ): Promise> { const serialization = createMethodSerializationLookup( method, @@ -226,6 +231,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.acceptCompression, opt.sendCompression, ), + contextValues: contextValues ?? createContextValues(), message: input, }, next: async (req: StreamRequest) => { diff --git a/packages/connect/src/protocol-grpc/handler-factory.ts b/packages/connect/src/protocol-grpc/handler-factory.ts index 3c0953efd..44e4ed0c3 100644 --- a/packages/connect/src/protocol-grpc/handler-factory.ts +++ b/packages/connect/src/protocol-grpc/handler-factory.ts @@ -128,7 +128,7 @@ function createHandler, O extends Message>( responseTrailer: { [headerGrpcStatus]: grpcStatusOk, }, - values: req.values, + contextValues: req.contextValues, }); const compression = compressionNegotiate( opt.acceptCompression, diff --git a/packages/connect/src/protocol-grpc/transport.ts b/packages/connect/src/protocol-grpc/transport.ts index 2aaca18d1..3dc87b0e9 100644 --- a/packages/connect/src/protocol-grpc/transport.ts +++ b/packages/connect/src/protocol-grpc/transport.ts @@ -46,6 +46,8 @@ import { runUnaryCall, runStreamingCall } from "../protocol/run-call.js"; import { createMethodSerializationLookup } from "../protocol/serialization.js"; import type { CommonTransportOptions } from "../protocol/transport-options.js"; import type { Transport } from "../transport.js"; +import { createContextValues } from "../context-values.js"; +import type { ContextValues } from "../context-values.js"; /** * Create a Transport for the gRPC protocol. @@ -62,6 +64,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, message: PartialMessage, + contextValues?: ContextValues, ): Promise> { const serialization = createMethodSerializationLookup( method, @@ -92,6 +95,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.acceptCompression, opt.sendCompression, ), + contextValues: contextValues ?? createContextValues(), message, }, next: async (req: UnaryRequest): Promise> => { @@ -168,6 +172,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, input: AsyncIterable>, + contextValues?: ContextValues, ): Promise> { const serialization = createMethodSerializationLookup( method, @@ -198,6 +203,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.acceptCompression, opt.sendCompression, ), + contextValues: contextValues ?? createContextValues(), message: input, }, next: async (req: StreamRequest) => { diff --git a/packages/connect/src/protocol/run-call.spec.ts b/packages/connect/src/protocol/run-call.spec.ts index 20ade9aed..016082a3f 100644 --- a/packages/connect/src/protocol/run-call.spec.ts +++ b/packages/connect/src/protocol/run-call.spec.ts @@ -22,6 +22,7 @@ import type { UnaryResponse, } from "../interceptor.js"; import { createAsyncIterable } from "./async-iterable.js"; +import { createContextValues } from "../context-values.js"; const TestService = { typeName: "TestService", @@ -51,6 +52,7 @@ describe("runUnaryCall()", function () { init: {}, header: new Headers(), message: { value: 123 }, + contextValues: createContextValues(), }; } @@ -133,6 +135,7 @@ describe("runStreamingCall()", function () { init: {}, header: new Headers(), message: createAsyncIterable([{ value: 1 }, { value: 2 }, { value: 3 }]), + contextValues: createContextValues(), }; } diff --git a/packages/connect/src/protocol/universal.ts b/packages/connect/src/protocol/universal.ts index c92b70252..3ef9c239a 100644 --- a/packages/connect/src/protocol/universal.ts +++ b/packages/connect/src/protocol/universal.ts @@ -66,7 +66,7 @@ export interface UniversalServerRequest { */ body: AsyncIterable | JsonValue; signal: AbortSignal; - values?: ContextValues; + contextValues?: ContextValues; } /** diff --git a/packages/connect/src/transport.ts b/packages/connect/src/transport.ts index be73d1b4f..357b33c4f 100644 --- a/packages/connect/src/transport.ts +++ b/packages/connect/src/transport.ts @@ -20,6 +20,7 @@ import type { ServiceType, } from "@bufbuild/protobuf"; import type { StreamResponse, UnaryResponse } from "./interceptor.js"; +import type { ContextValues } from "./context-values.js"; /** * Transport represents the underlying transport for a client. @@ -38,6 +39,7 @@ export interface Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, input: PartialMessage, + contextValues?: ContextValues, ): Promise>; /** @@ -51,5 +53,6 @@ export interface Transport { timeoutMs: number | undefined, header: HeadersInit | undefined, input: AsyncIterable>, + contextValues?: ContextValues, ): Promise>; }