diff --git a/integration/lower-case-svc-methods/lower-case-svc-methods-test.ts b/integration/lower-case-svc-methods/lower-case-svc-methods-test.ts new file mode 100644 index 000000000..7a05c4c47 --- /dev/null +++ b/integration/lower-case-svc-methods/lower-case-svc-methods-test.ts @@ -0,0 +1,33 @@ +import { MathServiceClientImpl } from './math'; + +function getRpc() { + return { + request: jest.fn(() => ({then: () => null})) + }; +} + +function getContext() { + const dataLoaderReturnValue = {load: jest.fn()}; + return { + dataLoaderReturnValue, + getDataLoader: jest.fn(() => dataLoaderReturnValue) + } +} + +describe('lower-case-svc-methods', () => { + it('lower-caseifies normal functions', () => { + const rpc = getRpc(), ctx = getContext(); + const client = new MathServiceClientImpl(rpc); + client.absoluteValue(ctx, {num: -1}); + + expect(rpc.request).toBeCalledWith(ctx, 'MathService', 'AbsoluteValue', expect.any(Buffer)); + }); + it('lower-caseifies batch functions', () => { + const rpc = getRpc(), ctx = getContext(); + const client = new MathServiceClientImpl(rpc); + client.getDouble(ctx, -1); + + expect(ctx.getDataLoader).toBeCalledWith('MathService.BatchDouble', expect.any(Function)); + expect(ctx.dataLoaderReturnValue.load).toBeCalledWith(-1); + }); +}); \ No newline at end of file diff --git a/integration/lower-case-svc-methods/math.bin b/integration/lower-case-svc-methods/math.bin new file mode 100644 index 000000000..0b8f2e6c0 Binary files /dev/null and b/integration/lower-case-svc-methods/math.bin differ diff --git a/integration/lower-case-svc-methods/math.proto b/integration/lower-case-svc-methods/math.proto new file mode 100644 index 000000000..2c45d34d5 --- /dev/null +++ b/integration/lower-case-svc-methods/math.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +message NumPair { + double num1 = 1; + double num2 = 2; +} + +message NumSingle { + double num = 1; +} + +message Numbers { + repeated double num = 1; +} + +service MathService { + rpc Add(NumPair) returns (NumSingle); + rpc AbsoluteValue(NumSingle) returns (NumSingle); + rpc BatchDouble(Numbers) returns (Numbers); +} \ No newline at end of file diff --git a/integration/lower-case-svc-methods/math.ts b/integration/lower-case-svc-methods/math.ts new file mode 100644 index 000000000..a82e96343 --- /dev/null +++ b/integration/lower-case-svc-methods/math.ts @@ -0,0 +1,296 @@ +/* eslint-disable */ +import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import * as Long from 'long'; +import * as DataLoader from 'dataloader'; +import * as hash from 'object-hash'; + +export const protobufPackage = ''; + +export interface NumPair { + num1: number; + num2: number; +} + +export interface NumSingle { + num: number; +} + +export interface Numbers { + num: number[]; +} + +const baseNumPair: object = { num1: 0, num2: 0 }; + +export const NumPair = { + encode(message: NumPair, writer: Writer = Writer.create()): Writer { + if (message.num1 !== 0) { + writer.uint32(9).double(message.num1); + } + if (message.num2 !== 0) { + writer.uint32(17).double(message.num2); + } + return writer; + }, + + decode(input: Reader | Uint8Array, length?: number): NumPair { + const reader = input instanceof Reader ? input : new Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = { ...baseNumPair } as NumPair; + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + message.num1 = reader.double(); + break; + case 2: + message.num2 = reader.double(); + break; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }, + + fromJSON(object: any): NumPair { + const message = { ...baseNumPair } as NumPair; + if (object.num1 !== undefined && object.num1 !== null) { + message.num1 = Number(object.num1); + } else { + message.num1 = 0; + } + if (object.num2 !== undefined && object.num2 !== null) { + message.num2 = Number(object.num2); + } else { + message.num2 = 0; + } + return message; + }, + + toJSON(message: NumPair): unknown { + const obj: any = {}; + message.num1 !== undefined && (obj.num1 = message.num1); + message.num2 !== undefined && (obj.num2 = message.num2); + return obj; + }, + + fromPartial(object: DeepPartial): NumPair { + const message = { ...baseNumPair } as NumPair; + if (object.num1 !== undefined && object.num1 !== null) { + message.num1 = object.num1; + } else { + message.num1 = 0; + } + if (object.num2 !== undefined && object.num2 !== null) { + message.num2 = object.num2; + } else { + message.num2 = 0; + } + return message; + }, +}; + +const baseNumSingle: object = { num: 0 }; + +export const NumSingle = { + encode(message: NumSingle, writer: Writer = Writer.create()): Writer { + if (message.num !== 0) { + writer.uint32(9).double(message.num); + } + return writer; + }, + + decode(input: Reader | Uint8Array, length?: number): NumSingle { + const reader = input instanceof Reader ? input : new Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = { ...baseNumSingle } as NumSingle; + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + message.num = reader.double(); + break; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }, + + fromJSON(object: any): NumSingle { + const message = { ...baseNumSingle } as NumSingle; + if (object.num !== undefined && object.num !== null) { + message.num = Number(object.num); + } else { + message.num = 0; + } + return message; + }, + + toJSON(message: NumSingle): unknown { + const obj: any = {}; + message.num !== undefined && (obj.num = message.num); + return obj; + }, + + fromPartial(object: DeepPartial): NumSingle { + const message = { ...baseNumSingle } as NumSingle; + if (object.num !== undefined && object.num !== null) { + message.num = object.num; + } else { + message.num = 0; + } + return message; + }, +}; + +const baseNumbers: object = { num: 0 }; + +export const Numbers = { + encode(message: Numbers, writer: Writer = Writer.create()): Writer { + writer.uint32(10).fork(); + for (const v of message.num) { + writer.double(v); + } + writer.ldelim(); + return writer; + }, + + decode(input: Reader | Uint8Array, length?: number): Numbers { + const reader = input instanceof Reader ? input : new Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = { ...baseNumbers } as Numbers; + message.num = []; + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if ((tag & 7) === 2) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.num.push(reader.double()); + } + } else { + message.num.push(reader.double()); + } + break; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }, + + fromJSON(object: any): Numbers { + const message = { ...baseNumbers } as Numbers; + message.num = []; + if (object.num !== undefined && object.num !== null) { + for (const e of object.num) { + message.num.push(Number(e)); + } + } + return message; + }, + + toJSON(message: Numbers): unknown { + const obj: any = {}; + if (message.num) { + obj.num = message.num.map((e) => e); + } else { + obj.num = []; + } + return obj; + }, + + fromPartial(object: DeepPartial): Numbers { + const message = { ...baseNumbers } as Numbers; + message.num = []; + if (object.num !== undefined && object.num !== null) { + for (const e of object.num) { + message.num.push(e); + } + } + return message; + }, +}; + +export interface MathService { + add(ctx: Context, request: NumPair): Promise; + absoluteValue(ctx: Context, request: NumSingle): Promise; + batchDouble(ctx: Context, request: Numbers): Promise; + getDouble(ctx: Context, nu: number): Promise; +} + +export class MathServiceClientImpl implements MathService { + private readonly rpc: Rpc; + constructor(rpc: Rpc) { + this.rpc = rpc; + this.add = this.add.bind(this); + this.absoluteValue = this.absoluteValue.bind(this); + this.batchDouble = this.batchDouble.bind(this); + } + add(ctx: Context, request: NumPair): Promise { + const data = NumPair.encode(request).finish(); + const promise = this.rpc.request(ctx, 'MathService', 'Add', data); + return promise.then((data) => NumSingle.decode(new Reader(data))); + } + + absoluteValue(ctx: Context, request: NumSingle): Promise { + const data = NumSingle.encode(request).finish(); + const promise = this.rpc.request(ctx, 'MathService', 'AbsoluteValue', data); + return promise.then((data) => NumSingle.decode(new Reader(data))); + } + + getDouble(ctx: Context, nu: number): Promise { + const dl = ctx.getDataLoader('MathService.BatchDouble', () => { + return new DataLoader( + (num) => { + const request = { num }; + return this.batchDouble(ctx, request).then((res) => res.num); + }, + { cacheKeyFn: hash, ...ctx.rpcDataLoaderOptions } + ); + }); + return dl.load(nu); + } + + batchDouble(ctx: Context, request: Numbers): Promise { + const data = Numbers.encode(request).finish(); + const promise = this.rpc.request(ctx, 'MathService', 'BatchDouble', data); + return promise.then((data) => Numbers.decode(new Reader(data))); + } +} + +interface Rpc { + request(ctx: Context, service: string, method: string, data: Uint8Array): Promise; +} + +export interface DataLoaderOptions { + cache?: boolean; +} + +export interface DataLoaders { + rpcDataLoaderOptions?: DataLoaderOptions; + getDataLoader(identifier: string, constructorFn: () => T): T; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; +export type DeepPartial = T extends Builtin + ? T + : T extends Array + ? Array> + : T extends ReadonlyArray + ? ReadonlyArray> + : T extends {} + ? { [K in keyof T]?: DeepPartial } + : Partial; + +// If you get a compile-error about 'Constructor and ... have no overlap', +// add '--ts_proto_opt=esModuleInterop=true' as a flag when calling 'protoc'. +if (util.Long !== Long) { + util.Long = Long as any; + configure(); +} diff --git a/integration/lower-case-svc-methods/parameters.txt b/integration/lower-case-svc-methods/parameters.txt new file mode 100644 index 000000000..4d8fd099d --- /dev/null +++ b/integration/lower-case-svc-methods/parameters.txt @@ -0,0 +1 @@ +lowerCaseServiceMethods=true,context=true diff --git a/src/generate-grpc-js.ts b/src/generate-grpc-js.ts index 765f50efa..e26d1a882 100644 --- a/src/generate-grpc-js.ts +++ b/src/generate-grpc-js.ts @@ -4,7 +4,7 @@ import { camelCase } from './case'; import { Context } from './context'; import SourceInfo, { Fields } from './sourceInfo'; import { messageToTypeName, wrapperTypeName } from './types'; -import { maybeAddComment, maybePrefixPackage } from './utils'; +import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, maybePrefixPackage } from './utils'; import { generateDecoder, generateEncoder } from './encode'; const CallOptions = imp('CallOptions@@grpc/grpc-js'); @@ -60,6 +60,8 @@ function generateServiceDefinition( `); for (const [index, methodDesc] of serviceDesc.method.entries()) { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); + const inputType = messageToTypeName(ctx, methodDesc.inputType); const outputType = messageToTypeName(ctx, methodDesc.outputType); @@ -73,7 +75,7 @@ function generateServiceDefinition( const outputDecoder = generateDecoder(ctx, methodDesc.outputType); chunks.push(code` - ${camelCase(methodDesc.name)}: { + ${methodDesc.formattedName}: { path: '/${maybePrefixPackage(fileDesc, serviceDesc.name)}/${methodDesc.name}', requestStream: ${methodDesc.clientStreaming}, responseStream: ${methodDesc.serverStreaming}, @@ -98,6 +100,8 @@ function generateServerStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S chunks.push(code`export interface ${def(`${serviceDesc.name}Server`)} extends ${UntypedServiceImplementation} {`); for (const [index, methodDesc] of serviceDesc.method.entries()) { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); + const inputType = messageToTypeName(ctx, methodDesc.inputType); const outputType = messageToTypeName(ctx, methodDesc.outputType); @@ -113,7 +117,7 @@ function generateServerStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S : handleUnaryCall; chunks.push(code` - ${camelCase(methodDesc.name)}: ${callType}<${inputType}, ${outputType}>; + ${methodDesc.formattedName}: ${callType}<${inputType}, ${outputType}>; `); } @@ -128,6 +132,8 @@ function generateClientStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S chunks.push(code`export interface ${def(`${serviceDesc.name}Client`)} extends ${Client} {`); for (const [index, methodDesc] of serviceDesc.method.entries()) { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); + const inputType = messageToTypeName(ctx, methodDesc.inputType); const outputType = messageToTypeName(ctx, methodDesc.outputType); @@ -140,11 +146,11 @@ function generateClientStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S if (methodDesc.serverStreaming) { // bidi streaming chunks.push(code` - ${camelCase(methodDesc.name)}(): ${ClientDuplexStream}<${inputType}, ${outputType}>; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}(): ${ClientDuplexStream}<${inputType}, ${outputType}>; + ${methodDesc.formattedName}( options: Partial<${CallOptions}>, ): ${ClientDuplexStream}<${inputType}, ${outputType}>; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( metadata: ${Metadata}, options?: Partial<${CallOptions}>, ): ${ClientDuplexStream}<${inputType}, ${outputType}>; @@ -152,18 +158,18 @@ function generateClientStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S } else { // client streaming chunks.push(code` - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( callback: ${responseCallback}, ): ${ClientWritableStream}<${inputType}>; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( metadata: ${Metadata}, callback: ${responseCallback}, ): ${ClientWritableStream}<${inputType}>; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( options: Partial<${CallOptions}>, callback: ${responseCallback}, ): ${ClientWritableStream}<${inputType}>; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( metadata: ${Metadata}, options: Partial<${CallOptions}>, callback: ${responseCallback}, @@ -174,11 +180,11 @@ function generateClientStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S if (methodDesc.serverStreaming) { // server streaming chunks.push(code` - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( request: ${inputType}, options?: Partial<${CallOptions}>, ): ${ClientReadableStream}<${outputType}>; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( request: ${inputType}, metadata?: ${Metadata}, options?: Partial<${CallOptions}>, @@ -187,16 +193,16 @@ function generateClientStub(ctx: Context, sourceInfo: SourceInfo, serviceDesc: S } else { // unary chunks.push(code` - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( request: ${inputType}, callback: ${responseCallback}, ): ${ClientUnaryCall}; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( request: ${inputType}, metadata: ${Metadata}, callback: ${responseCallback}, ): ${ClientUnaryCall}; - ${camelCase(methodDesc.name)}( + ${methodDesc.formattedName}( request: ${inputType}, metadata: ${Metadata}, options: Partial<${CallOptions}>, diff --git a/src/generate-grpc-web.ts b/src/generate-grpc-web.ts index 4909f2fac..fb5bae222 100644 --- a/src/generate-grpc-web.ts +++ b/src/generate-grpc-web.ts @@ -2,7 +2,7 @@ import { MethodDescriptorProto, FileDescriptorProto, ServiceDescriptorProto } fr import { requestType, responseObservable, responsePromise, responseType } from './types'; import { Code, code, imp, joinCode } from 'ts-poet'; import { Context } from './context'; -import { maybePrefixPackage } from './utils'; +import { assertInstanceOf, FormattedMethodDescriptor, maybePrefixPackage } from './utils'; const grpc = imp('grpc@@improbable-eng/grpc-web'); const share = imp('share@rxjs/operators'); @@ -32,7 +32,8 @@ export function generateGrpcClientImpl( chunks.push(code`this.rpc = rpc;`); // Bind each FooService method to the FooServiceImpl class for (const methodDesc of serviceDesc.method) { - chunks.push(code`this.${methodDesc.name} = this.${methodDesc.name}.bind(this);`); + assertInstanceOf(methodDesc, FormattedMethodDescriptor); + chunks.push(code`this.${methodDesc.formattedName} = this.${methodDesc.formattedName}.bind(this);`); } chunks.push(code`}\n`); @@ -47,6 +48,7 @@ export function generateGrpcClientImpl( /** Creates the RPC methods that client code actually calls. */ function generateRpcMethod(ctx: Context, serviceDesc: ServiceDescriptorProto, methodDesc: MethodDescriptorProto) { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); const { options, utils } = ctx; const inputType = requestType(ctx, methodDesc); const partialInputType = code`${utils.DeepPartial}<${inputType}>`; @@ -56,7 +58,7 @@ function generateRpcMethod(ctx: Context, serviceDesc: ServiceDescriptorProto, me : responsePromise(ctx, methodDesc); const method = methodDesc.serverStreaming ? 'invoke' : 'unary'; return code` - ${methodDesc.name}( + ${methodDesc.formattedName}( request: ${partialInputType}, metadata?: grpc.Metadata, ): ${returns} { diff --git a/src/generate-nestjs.ts b/src/generate-nestjs.ts index cfdfca22f..08450b703 100644 --- a/src/generate-nestjs.ts +++ b/src/generate-nestjs.ts @@ -10,7 +10,7 @@ import { } from './types'; import SourceInfo, { Fields } from './sourceInfo'; import { contextTypeVar } from './main'; -import { maybeAddComment, singular } from './utils'; +import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, singular } from './utils'; import { camelCase } from './case'; import { Context } from './context'; @@ -32,6 +32,7 @@ export function generateNestjsServiceController( `); serviceDesc.method.forEach((methodDesc, index) => { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); const info = sourceInfo.lookup(Fields.service.method, index); maybeAddComment(info, chunks, serviceDesc.options?.deprecated); @@ -64,18 +65,16 @@ export function generateNestjsServiceController( `; } - const name = options.lowerCaseServiceMethods ? camelCase(methodDesc.name) : methodDesc.name; chunks.push(code` - ${name}(${joinCode(params, { on: ', ' })}): ${returns}; + ${methodDesc.formattedName}(${joinCode(params, { on: ', ' })}): ${returns}; `); if (options.context) { const batchMethod = detectBatchMethod(ctx, fileDesc, serviceDesc, methodDesc); if (batchMethod) { - const name = batchMethod.methodDesc.name.replace('Batch', 'Get'); const maybeCtx = options.context ? 'ctx: Context,' : ''; chunks.push(code` - ${name}( + ${batchMethod.singleMethodName}( ${maybeCtx} ${singular(batchMethod.inputFieldName)}: ${batchMethod.inputType}, ): Promise<${batchMethod.outputType}>; @@ -104,10 +103,7 @@ export function generateNestjsServiceClient( `); serviceDesc.method.forEach((methodDesc, index) => { - if (options.lowerCaseServiceMethods) { - methodDesc.name = camelCase(methodDesc.name); - } - + assertInstanceOf(methodDesc, FormattedMethodDescriptor); const params: Code[] = []; if (options.context) { params.push(code`ctx: Context`); @@ -128,7 +124,7 @@ export function generateNestjsServiceClient( const info = sourceInfo.lookup(Fields.service.method, index); maybeAddComment(info, chunks, methodDesc.options?.deprecated); chunks.push(code` - ${methodDesc.name}( + ${methodDesc.formattedName}( ${joinCode(params, { on: ',' })} ): ${returns}; `); @@ -136,10 +132,9 @@ export function generateNestjsServiceClient( if (options.context) { const batchMethod = detectBatchMethod(ctx, fileDesc, serviceDesc, methodDesc); if (batchMethod) { - const name = batchMethod.methodDesc.name.replace('Batch', 'Get'); const maybeContext = options.context ? `ctx: Context,` : ''; chunks.push(code` - ${name}( + ${batchMethod.singleMethodName}( ${maybeContext} ${singular(batchMethod.inputFieldName)} ): Promise<${batchMethod.inputType}>; @@ -159,12 +154,18 @@ export function generateNestjsGrpcServiceMethodsDecorator(ctx: Context, serviceD const grpcMethods = serviceDesc.method .filter((m) => !m.clientStreaming) - .map((m) => (options.lowerCaseServiceMethods ? camelCase(m.name) : m.name)) + .map((m) => { + assertInstanceOf(m, FormattedMethodDescriptor); + return m.formattedName; + }) .map((n) => `"${n}"`); const grpcStreamMethods = serviceDesc.method .filter((m) => m.clientStreaming) - .map((m) => (options.lowerCaseServiceMethods ? camelCase(m.name) : m.name)) + .map((m) => { + assertInstanceOf(m, FormattedMethodDescriptor); + return m.formattedName; + }) .map((n) => `"${n}"`); return code` diff --git a/src/generate-services.ts b/src/generate-services.ts index 00fec5c02..f175ca177 100644 --- a/src/generate-services.ts +++ b/src/generate-services.ts @@ -8,7 +8,7 @@ import { responsePromise, responseType, } from './types'; -import { maybeAddComment, maybePrefixPackage, singular } from './utils'; +import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, maybePrefixPackage, singular } from './utils'; import SourceInfo, { Fields } from './sourceInfo'; import { camelCase } from './case'; import { contextTypeVar } from './main'; @@ -42,8 +42,7 @@ export function generateService( chunks.push(code`export interface ${serviceDesc.name}${maybeTypeVar} {`); serviceDesc.method.forEach((methodDesc, index) => { - const name = options.lowerCaseServiceMethods ? camelCase(methodDesc.name) : methodDesc.name; - + assertInstanceOf(methodDesc, FormattedMethodDescriptor); const info = sourceInfo.lookup(Fields.service.method, index); maybeAddComment(info, chunks, methodDesc.options?.deprecated); @@ -81,14 +80,13 @@ export function generateService( returnType = responsePromise(ctx, methodDesc); } - chunks.push(code`${name}(${joinCode(params, { on: ',' })}): ${returnType};`); + chunks.push(code`${methodDesc.formattedName}(${joinCode(params, { on: ',' })}): ${returnType};`); // If this is a batch method, auto-generate the singular version of it if (options.context) { const batchMethod = detectBatchMethod(ctx, fileDesc, serviceDesc, methodDesc); if (batchMethod) { - const name = batchMethod.methodDesc.name.replace('Batch', 'Get'); - chunks.push(code`${name}( + chunks.push(code`${batchMethod.singleMethodName}( ctx: Context, ${singular(batchMethod.inputFieldName)}: ${batchMethod.inputType}, ): Promise<${batchMethod.outputType}>;`); @@ -107,6 +105,7 @@ function generateRegularRpcMethod( serviceDesc: ServiceDescriptorProto, methodDesc: MethodDescriptorProto ): Code { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); const { options } = ctx; const Reader = imp('Reader@protobufjs/minimal'); const inputType = requestType(ctx, methodDesc); @@ -116,7 +115,7 @@ function generateRegularRpcMethod( const maybeCtx = options.context ? 'ctx,' : ''; return code` - ${methodDesc.name}( + ${methodDesc.formattedName}( ${joinCode(params, { on: ',' })} ): ${responsePromise(ctx, methodDesc)} { const data = ${inputType}.encode(request).finish(); @@ -152,7 +151,8 @@ export function generateServiceClientImpl( chunks.push(code`this.rpc = rpc;`); // Bind each FooService method to the FooServiceImpl class for (const methodDesc of serviceDesc.method) { - chunks.push(code`this.${methodDesc.name} = this.${methodDesc.name}.bind(this);`); + assertInstanceOf(methodDesc, FormattedMethodDescriptor); + chunks.push(code`this.${methodDesc.formattedName} = this.${methodDesc.formattedName}.bind(this);`); } chunks.push(code`}`); @@ -189,6 +189,7 @@ function generateBatchingRpcMethod(ctx: Context, batchMethod: BatchMethod): Code mapType, uniqueIdentifier, } = batchMethod; + assertInstanceOf(methodDesc, FormattedMethodDescriptor); // Create the `(keys) => ...` lambda we'll pass to the DataLoader constructor const lambda: Code[] = []; @@ -199,14 +200,14 @@ function generateBatchingRpcMethod(ctx: Context, batchMethod: BatchMethod): Code if (mapType) { // If the return type is a map, lookup each key in the result lambda.push(code` - return this.${methodDesc.name}(ctx, request).then(res => { + return this.${methodDesc.formattedName}(ctx, request).then(res => { return ${inputFieldName}.map(key => res.${outputFieldName}[key]) }); `); } else { // Otherwise assume they come back in order lambda.push(code` - return this.${methodDesc.name}(ctx, request).then(res => res.${outputFieldName}) + return this.${methodDesc.formattedName}(ctx, request).then(res => res.${outputFieldName}) `); } lambda.push(code`}`); @@ -234,6 +235,7 @@ function generateCachingRpcMethod( serviceDesc: ServiceDescriptorProto, methodDesc: MethodDescriptorProto ): Code { + assertInstanceOf(methodDesc, FormattedMethodDescriptor); const inputType = requestType(ctx, methodDesc); const outputType = responseType(ctx, methodDesc); const uniqueIdentifier = `${maybePrefixPackage(fileDesc, serviceDesc.name)}.${methodDesc.name}`; @@ -251,7 +253,7 @@ function generateCachingRpcMethod( `; return code` - ${methodDesc.name}( + ${methodDesc.formattedName}( ctx: Context, request: ${inputType}, ): Promise<${outputType}> { diff --git a/src/main.ts b/src/main.ts index b5f9bbcb8..b0c4e43d3 100644 --- a/src/main.ts +++ b/src/main.ts @@ -4,6 +4,7 @@ import { FieldDescriptorProto, FileDescriptorProto, FieldDescriptorProto_Type, + MethodDescriptorProto, } from 'ts-proto-descriptors'; import { basicLongWireType, @@ -33,7 +34,7 @@ import { valueTypeName, } from './types'; import SourceInfo, { Fields } from './sourceInfo'; -import { maybeAddComment, maybePrefixPackage } from './utils'; +import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, maybePrefixPackage } from './utils'; import { camelToSnake, capitalize, maybeSnakeToCamel } from './case'; import { generateNestjsGrpcServiceMethodsDecorator, @@ -88,6 +89,13 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri const headerComment = sourceInfo.lookup(Fields.file.syntax, undefined); maybeAddComment(headerComment, chunks, fileDesc.options?.deprecated); + // Apply formatting to methods here, so they propagate globally + for (let svc of fileDesc.service) { + for (let i = 0; i < svc.method.length; i++) { + svc.method[i] = new FormattedMethodDescriptor(svc.method[i], options); + } + } + // first make all the type declarations visit( fileDesc, @@ -225,6 +233,16 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri }) ); + // Finally, reset method definitions to their original state (unformatted) + // This is mainly so that the `meta-typings` tests pass + for (let svc of fileDesc.service) { + for (let i = 0; i < svc.method.length; i++) { + const methodInfo = svc.method[i]; + assertInstanceOf(methodInfo, FormattedMethodDescriptor); + svc.method[i] = methodInfo.getSource(); + } + } + return [moduleName, joinCode(chunks, { on: '\n\n' })]; } diff --git a/src/types.ts b/src/types.ts index ff2c8e27b..5b882a7ed 100644 --- a/src/types.ts +++ b/src/types.ts @@ -12,7 +12,7 @@ import { import { code, Code, imp, Import } from 'ts-poet'; import { DateOption, EnvOption, LongOption, OneofOption, Options } from './options'; import { visit } from './visit'; -import { fail, maybePrefixPackage } from './utils'; +import { fail, FormattedMethodDescriptor, maybePrefixPackage } from './utils'; import SourceInfo from './sourceInfo'; import { camelCase } from './case'; import { Context } from './context'; @@ -595,9 +595,9 @@ export function detectBatchMethod( } const uniqueIdentifier = `${maybePrefixPackage(fileDesc, serviceDesc.name)}.${methodDesc.name}`; return { - methodDesc, + methodDesc: methodDesc, uniqueIdentifier, - singleMethodName, + singleMethodName: FormattedMethodDescriptor.formatName(singleMethodName, ctx.options), inputFieldName, inputType, outputFieldName, diff --git a/src/utils.ts b/src/utils.ts index 3afbec3a2..87b875075 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,7 +1,9 @@ import { code, Code } from 'ts-poet'; -import { CodeGeneratorRequest, FileDescriptorProto } from 'ts-proto-descriptors'; +import { CodeGeneratorRequest, FileDescriptorProto, MethodDescriptorProto, MethodOptions } from 'ts-proto-descriptors'; import ReadStream = NodeJS.ReadStream; import { SourceDescription } from './sourceInfo'; +import { Options } from './options'; +import { camelCase } from './case'; export function protoFilesToGenerate(request: CodeGeneratorRequest): FileDescriptorProto[] { return request.protoFile.filter((f) => request.fileToGenerate.includes(f.name)); @@ -97,3 +99,71 @@ export function maybePrefixPackage(fileDesc: FileDescriptorProto, rest: string): const prefix = fileDesc.package === '' ? '' : `${fileDesc.package}.`; return `${prefix}${rest}`; } + +/** + * Asserts that an object is an instance of a certain class + * @param obj The object to check + * @param constructor The constructor of the class to check + */ +export function assertInstanceOf(obj: unknown, constructor: { new (...args: any[]): T }): asserts obj is T { + if (!(obj instanceof constructor)) { + throw new Error(`Expected instance of ${constructor.name}`); + } +} + +/** + * A MethodDescriptorProto subclass that adds formatted properties + */ +export class FormattedMethodDescriptor implements MethodDescriptorProto { + public name: string; + public inputType: string; + public outputType: string; + public options: MethodOptions | undefined; + public clientStreaming: boolean; + public serverStreaming: boolean; + + private original: MethodDescriptorProto; + private ctxOptions: Options; + /** + * The name of this method with formatting applied according to the `Options` object passed to the constructor. + * Automatically updates to any changes to the `Options` or `name` of this object + */ + public get formattedName() { + return FormattedMethodDescriptor.formatName(this.name, this.ctxOptions); + } + + constructor(src: MethodDescriptorProto, options: Options) { + this.ctxOptions = options; + this.original = src; + this.name = src.name; + this.inputType = src.inputType; + this.outputType = src.outputType; + this.options = src.options; + this.clientStreaming = src.clientStreaming; + this.serverStreaming = src.serverStreaming; + } + + /** + * Retrieve the source `MethodDescriptorProto` used to construct this object + * @returns The source `MethodDescriptorProto` used to construct this object + */ + public getSource(): MethodDescriptorProto { + return this.original; + } + + /** + * Applies formatting rules to a gRPC method name. + * @param methodName The original method name + * @param options The options object containing rules to apply + * @returns The formatted method name + */ + public static formatName(methodName: string, options: Options) { + let result = methodName; + + if (options.lowerCaseServiceMethods || options.outputServices === 'grpc-js') { + result = camelCase(result); + } + + return result; + } +}