Skip to content

Commit

Permalink
Add context values to clients
Browse files Browse the repository at this point in the history
  • Loading branch information
srikrsna-buf committed Oct 2, 2023
1 parent 05f34d1 commit fb45873
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 3 deletions.
2 changes: 1 addition & 1 deletion packages/connect-web-bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ it like a web server would usually do.

| code generator | bundle size | minified | compressed |
|----------------|-------------------:|-----------------------:|---------------------:|
| connect | 113,658 b | 49,964 b | 13,486 b |
| connect | 114,440 b | 50,266 b | 13,576 b |
| grpc-web | 414,071 b | 300,352 b | 53,255 b |
7 changes: 6 additions & 1 deletion packages/connect-web/src/connect-transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ import type {
Transport,
UnaryRequest,
UnaryResponse,
ContextValues,
} from "@connectrpc/connect";
import { appendHeaders } from "@connectrpc/connect";
import { appendHeaders, createContextValues } from "@connectrpc/connect";
import {
createClientMethodSerializers,
createEnvelopeReadableStream,
Expand Down Expand Up @@ -142,6 +143,7 @@ export function createConnectTransport(
timeoutMs: number | undefined,
header: HeadersInit | undefined,
message: PartialMessage<I>,
values?: ContextValues,
): Promise<UnaryResponse<I, O>> {
const { serialize, parse } = createClientMethodSerializers(
method,
Expand Down Expand Up @@ -176,6 +178,7 @@ export function createConnectTransport(
timeoutMs,
header,
),
values: values ?? createContextValues(),
message,
},
next: async (req: UnaryRequest<I, O>): Promise<UnaryResponse<I, O>> => {
Expand Down Expand Up @@ -242,6 +245,7 @@ export function createConnectTransport(
timeoutMs: number | undefined,
header: HeadersInit | undefined,
input: AsyncIterable<PartialMessage<I>>,
values?: ContextValues,
): Promise<StreamResponse<I, O>> {
const { serialize, parse } = createClientMethodSerializers(
method,
Expand Down Expand Up @@ -320,6 +324,7 @@ export function createConnectTransport(
timeoutMs,
header,
),
values: values ?? createContextValues(),
message: input,
},
next: async (req) => {
Expand Down
6 changes: 6 additions & 0 deletions packages/connect-web/src/grpc-web-transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import type {
Transport,
UnaryRequest,
UnaryResponse,
ContextValues,
} from "@connectrpc/connect";
import { createContextValues } from "@connectrpc/connect";
import {
createClientMethodSerializers,
createEnvelopeReadableStream,
Expand Down Expand Up @@ -137,6 +139,7 @@ export function createGrpcWebTransport(
timeoutMs: number | undefined,
header: Headers,
message: PartialMessage<I>,
values?: ContextValues,
): Promise<UnaryResponse<I, O>> {
const { serialize, parse } = createClientMethodSerializers(
method,
Expand Down Expand Up @@ -166,6 +169,7 @@ export function createGrpcWebTransport(
mode: "cors",
},
header: requestHeader(useBinaryFormat, timeoutMs, header),
values: values ?? createContextValues(),
message,
},
next: async (req: UnaryRequest<I, O>): Promise<UnaryResponse<I, O>> => {
Expand Down Expand Up @@ -233,6 +237,7 @@ export function createGrpcWebTransport(
timeoutMs: number | undefined,
header: HeadersInit | undefined,
input: AsyncIterable<PartialMessage<I>>,
values?: ContextValues,
): Promise<StreamResponse<I, O>> {
const { serialize, parse } = createClientMethodSerializers(
method,
Expand Down Expand Up @@ -323,6 +328,7 @@ export function createGrpcWebTransport(
mode: "cors",
},
header: requestHeader(useBinaryFormat, timeoutMs, header),
values: values ?? createContextValues(),
message: input,
},
next: async (req) => {
Expand Down
7 changes: 7 additions & 0 deletions packages/connect/src/call-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import type { ContextValues } from "./context-values";

/**
* Options for a call. Every client should accept CallOptions as optional
* argument in its RPC methods.
Expand Down Expand Up @@ -44,4 +46,9 @@ export interface CallOptions {
* Called when response trailers are received.
*/
onTrailer?(trailers: Headers): void;

/**
* ContextValues to pass to the interceptors.
*/
values?: ContextValues;
}
2 changes: 2 additions & 0 deletions packages/connect/src/callback-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ function createUnaryFn<I extends Message<I>, O extends Message<O>>(
options.timeoutMs,
options.headers,
requestMessage,
options.values,
)
.then(
(response) => {
Expand Down Expand Up @@ -146,6 +147,7 @@ function createServerStreamingFn<I extends Message<I>, O extends Message<O>>(
options.timeoutMs,
options.headers,
createAsyncIterable([input]),
options.values,
);
options.onHeader?.(response.header);
for await (const message of response.message) {
Expand Down
6 changes: 6 additions & 0 deletions packages/connect/src/interceptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type {
MethodInfo,
ServiceType,
} from "@bufbuild/protobuf";
import type { ContextValues } from "./context-values";

/**
* An interceptor can add logic to clients, similar to the decorators
Expand Down Expand Up @@ -164,6 +165,11 @@ interface RequestCommon<I extends Message<I>, O extends Message<O>> {
* Headers that will be sent along with the request.
*/
readonly header: Headers;

/**
* The context values for the current call.
*/
readonly values: ContextValues;
}

interface ResponseCommon<I extends Message<I>, O extends Message<O>> {
Expand Down
202 changes: 202 additions & 0 deletions packages/connect/src/promise-client.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@ import {
createBiDiStreamingFn,
createClientStreamingFn,
createServerStreamingFn,
createUnaryFn,
} from "./promise-client.js";
import { createAsyncIterable } from "./protocol/async-iterable.js";
import { createRouterTransport } from "./router-transport.js";
import type { HandlerContext } from "./implementation";
import { ConnectError } from "./connect-error.js";
import { Code } from "./code.js";
import { createContextKey, createContextValues } from "./context-values.js";

const TestService = {
typeName: "handwritten.TestService",
methods: {
unaryMethod: {
name: "Unary",
I: Int32Value,
O: StringValue,
kind: MethodKind.Unary,
},
clientStream: {
name: "ClientStream",
I: Int32Value,
Expand All @@ -48,6 +56,54 @@ const TestService = {
},
} as const;

const kString = createContextKey("foo");

describe("createUnaryFn()", function () {
it("passes the context values to interceptors", async () => {
const input = new Int32Value({ value: 1 });

const output = new StringValue({ value: "yield 1" });
let interceptorCalled = false;

const transport = createRouterTransport(
({ service }) => {
service(TestService, {
unaryMethod: (
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock
_input: Int32Value,
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock
_context: HandlerContext,
) => Promise.resolve(output),
});
},
{
transport: {
interceptors: [
(next) => {
return (req) => {
interceptorCalled = true;
expect(req.values.get(kString)).toBe("bar");
return next(req);
};
},
],
},
},
);
const fn = createUnaryFn(
transport,
TestService,
TestService.methods.unaryMethod,
);
const res = await fn(input, {
values: createContextValues().set(kString, "bar"),
});
expect(res).toBeInstanceOf(StringValue);
expect(res.value).toEqual(output.value);
expect(interceptorCalled).toBe(true);
});
});

describe("createClientStreamingFn()", function () {
it("works as expected on the happy path", async () => {
const input = new Int32Value({ value: 1 });
Expand Down Expand Up @@ -78,6 +134,56 @@ describe("createClientStreamingFn()", function () {
expect(res).toBeInstanceOf(StringValue);
expect(res.value).toEqual(output.value);
});
it("passes the context values to interceptors", async () => {
const input = new Int32Value({ value: 1 });

const output = new StringValue({ value: "yield 1" });
const kString = createContextKey("foo");
let interceptorCalled = false;

const transport = createRouterTransport(
({ service }) => {
service(TestService, {
clientStream: (
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock
_input: AsyncIterable<Int32Value>,
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock
_context: HandlerContext,
) => Promise.resolve(output),
});
},
{
transport: {
interceptors: [
(next) => {
return (req) => {
interceptorCalled = true;
expect(req.values.get(kString)).toBe("bar");
return next(req);
};
},
],
},
},
);
const fn = createClientStreamingFn(
transport,
TestService,
TestService.methods.clientStream,
);
const res = await fn(
// eslint-disable-next-line @typescript-eslint/require-await
(async function* () {
yield input;
})(),
{
values: createContextValues().set(kString, "bar"),
},
);
expect(res).toBeInstanceOf(StringValue);
expect(res.value).toEqual(output.value);
expect(interceptorCalled).toBe(true);
});
it("closes the request iterable when response is received", async () => {
const output = new StringValue({ value: "yield 1" });
const transport = createRouterTransport(({ service }) => {
Expand Down Expand Up @@ -182,6 +288,51 @@ describe("createServerStreamingFn()", function () {
}
expect(receivedMessages).toEqual(output);
});
it("passes the context values to interceptors", async () => {
const output = [
new StringValue({ value: "input1" }),
new StringValue({ value: "input2" }),
new StringValue({ value: "input3" }),
];
let interceptorCalled = false;
const transport = createRouterTransport(
({ service }) => {
service(TestService, {
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- arguments not used for mock
serverStream: (_input: Int32Value, _context: HandlerContext) =>
createAsyncIterable(output),
});
},
{
transport: {
interceptors: [
(next) => {
return (req) => {
interceptorCalled = true;
expect(req.values.get(kString)).toBe("bar");
return next(req);
};
},
],
},
},
);

const fn = createServerStreamingFn(
transport,
TestService,
TestService.methods.serverStream,
);
const receivedMessages: StringValue[] = [];
const input = new Int32Value({ value: 123 });
for await (const res of fn(input, {
values: createContextValues().set(kString, "bar"),
})) {
receivedMessages.push(res);
}
expect(receivedMessages).toEqual(output);
expect(interceptorCalled).toBeTrue();
});
it("doesn't support throw/return on the returned response", function () {
const fn = createServerStreamingFn(
createRouterTransport(({ service }) => {
Expand Down Expand Up @@ -232,6 +383,57 @@ describe("createBiDiStreamingFn()", () => {
expect(index).toBe(3);
expect(bidiIndex).toBe(3);
});
it("passes the context values to interceptors", async () => {
const values = [123, 456, 789];

const input = createAsyncIterable(
values.map((value) => new Int32Value({ value })),
);
let interceptorCalled = false;
let bidiIndex = 0;
const transport = createRouterTransport(
({ service }) => {
service(TestService, {
bidiStream: async function* (input: AsyncIterable<Int32Value>) {
for await (const thing of input) {
expect(thing.value).toBe(values[bidiIndex]);
bidiIndex += 1;
yield new StringValue({ value: thing.value.toString() });
}
},
});
},
{
transport: {
interceptors: [
(next) => {
return (req) => {
interceptorCalled = true;
expect(req.values.get(kString)).toBe("bar");
return next(req);
};
},
],
},
},
);
const fn = createBiDiStreamingFn(
transport,
TestService,
TestService.methods.bidiStream,
);

let index = 0;
for await (const res of fn(input, {
values: createContextValues().set(kString, "bar"),
})) {
expect(res).toEqual(new StringValue({ value: values[index].toString() }));
index += 1;
}
expect(index).toBe(3);
expect(bidiIndex).toBe(3);
expect(interceptorCalled).toBeTrue();
});
it("closes the request iterable when response is received", async () => {
const values = [123, 456, 789];

Expand Down
Loading

0 comments on commit fb45873

Please sign in to comment.