diff --git a/pw_web/webconsole/common/logService.ts b/pw_web/webconsole/common/logService.ts index 2d49380af..106520c90 100644 --- a/pw_web/webconsole/common/logService.ts +++ b/pw_web/webconsole/common/logService.ts @@ -12,29 +12,14 @@ // License for the specific language governing permissions and limitations under // the License. -import {Device, pw_rpc} from "pigweedjs"; -type Client = pw_rpc.Client; - -function createDefaultRPCLogService(client: Client) { - const logService = client.channel()! - .methodStub('pw.log.Logs.Listen'); - - return logService; -} +import {Device} from "pigweedjs"; export async function listenToDefaultLogService( device: Device, onFrame: (frame: Uint8Array) => void) { - const client = device.client; - // @ts-ignore - const logService: pw_rpc.ServerStreamingMethodStub = (createDefaultRPCLogService(client))!; - const request = new logService.method.responseType(); - // @ts-ignore - const call = logService.invoke(request, (msg) => { - // @ts-ignore - msg.getEntriesList().forEach(entry => onFrame(entry.getMessage())); - }); - + const call = device.rpcs.pw.log.Logs.Listen((msg: any) => { + msg.getEntriesList().forEach((entry: any) => onFrame(entry.getMessage())); + }) return () => { call.cancel(); }; diff --git a/ts/device/index.ts b/ts/device/index.ts index cb68c7f71..325007039 100644 --- a/ts/device/index.ts +++ b/ts/device/index.ts @@ -14,10 +14,20 @@ import objectPath from 'object-path'; import {Decoder, Encoder} from 'pigweedjs/pw_hdlc'; -import {Client, Channel, ServiceClient, UnaryMethodStub, MethodStub} from 'pigweedjs/pw_rpc'; +import { + Client, + Channel, + ServiceClient, + UnaryMethodStub, + MethodStub, + ServerStreamingMethodStub +} from 'pigweedjs/pw_rpc'; import {WebSerialTransport} from '../transport/web_serial_transport'; import {ProtoCollection} from 'pigweedjs/pw_protobuf_compiler'; +function protoFieldToMethodName(string) { + return string.split("_").map(titleCase).join(""); +} function titleCase(string) { return string.charAt(0).toUpperCase() + string.slice(1); } @@ -85,7 +95,9 @@ export class Device { let methodMap = {}; let methodKeys = Array.from(service.methodsByName.keys()); methodKeys - .filter((method: any) => service.methodsByName.get(method) instanceof UnaryMethodStub) + .filter((method: any) => + service.methodsByName.get(method) instanceof UnaryMethodStub + || service.methodsByName.get(method) instanceof ServerStreamingMethodStub) .forEach(key => { let fn = this.createMethodWrapper( service.methodsByName.get(key), @@ -97,9 +109,32 @@ export class Device { return methodMap; } - private createMethodWrapper(realMethod: MethodStub, methodName: string, fullMethodPath: string) { - const requestType = realMethod.method.descriptor.getInputType().replace(/^\./, ''); - const requestProtoDescriptor = this.protoCollection.getDescriptorProto(requestType); + private createMethodWrapper( + realMethod: MethodStub, + methodName: string, + fullMethodPath: string) { + if (realMethod instanceof UnaryMethodStub) { + return this.createUnaryMethodWrapper( + realMethod, + methodName, + fullMethodPath); + } + else if (realMethod instanceof ServerStreamingMethodStub) { + return this.createServerStreamingMethodWrapper( + realMethod, + methodName, + fullMethodPath); + } + } + + private createUnaryMethodWrapper( + realMethod: UnaryMethodStub, + methodName: string, + fullMethodPath: string) { + const requestType = + realMethod.method.descriptor.getInputType().replace(/^\./, ''); + const requestProtoDescriptor = + this.protoCollection.getDescriptorProto(requestType); const requestFields = requestProtoDescriptor.getFieldList(); const functionArguments = requestFields .map(field => field.getName()) @@ -116,12 +151,46 @@ export class Device { let fn = new Function(...functionArguments).bind((args) => { const request = new realMethod.method.requestType(); requestFields.forEach((field, index) => { - console.log("setting", `set${titleCase(field.getName())}`, args[index]); request[`set${titleCase(field.getName())}`](args[index]); }) - if (realMethod instanceof UnaryMethodStub) { - return realMethod.call(request); - } + return realMethod.call(request); + }); + return fn; + } + + private createServerStreamingMethodWrapper( + realMethod: ServerStreamingMethodStub, + methodName: string, + fullMethodPath: string) { + const requestType = realMethod.method.descriptor.getInputType().replace(/^\./, ''); + const requestProtoDescriptor = + this.protoCollection.getDescriptorProto(requestType); + const requestFields = requestProtoDescriptor.getFieldList(); + const functionArguments = requestFields + .map(field => field.getName()) + .concat( + [ + 'onNext', + 'onComplete', + 'onError', + 'return this(arguments);' + ] + ); + + // We store field names so REPL can show hints in autocomplete using these. + this.nameToMethodArgumentsMap[fullMethodPath] = requestFields + .map(field => field.getName()); + + // We create a new JS function dynamically here that takes + // proto message fields as arguments and calls the actual RPC method. + let fn = new Function(...functionArguments).bind((args) => { + const request = new realMethod.method.requestType(); + requestFields.forEach((field, index) => { + request[`set${protoFieldToMethodName(field.getName())}`](args[index]); + }) + const callbacks = Array.from(args).slice(requestFields.length); + // @ts-ignore + return realMethod.invoke(request, callbacks[0], callbacks[1], callbacks[2]); }); return fn; } diff --git a/ts/device/index_test.ts b/ts/device/index_test.ts index 1218dbbc9..3a383bfc3 100644 --- a/ts/device/index_test.ts +++ b/ts/device/index_test.ts @@ -18,10 +18,60 @@ import {Device} from "./" import {ProtoCollection} from 'pigweedjs/protos/collection'; import {WebSerialTransport} from '../transport/web_serial_transport'; import {Serial} from 'pigweedjs/types/serial'; +import {Message} from 'google-protobuf'; +import {RpcPacket, PacketType} from 'pigweedjs/protos/pw_rpc/internal/packet_pb'; +import {Method, ServerStreamingMethodStub} from 'pigweedjs/pw_rpc'; +import {Status} from 'pigweedjs/pw_status'; +import { + Response, +} from 'pigweedjs/protos/pw_rpc/ts/test_pb'; describe('WebSerialTransport', () => { let device: Device; let serialMock: SerialMock; + + function newResponse(payload = '._.'): Message { + const response = new Response(); + response.setPayload(payload); + return response; + } + + function generateResponsePacket( + channelId: number, + method: Method, + status: Status, + response?: Message + ) { + const packet = new RpcPacket(); + packet.setType(PacketType.RESPONSE); + packet.setChannelId(channelId); + packet.setServiceId(method.service.id); + packet.setMethodId(method.id); + packet.setStatus(status); + if (response === undefined) { + packet.setPayload(new Uint8Array()); + } else { + packet.setPayload(response.serializeBinary()); + } + return packet.serializeBinary(); + } + + function generateStreamingPacket( + channelId: number, + method: Method, + response: Message, + status: Status = Status.OK + ) { + const packet = new RpcPacket(); + packet.setType(PacketType.SERVER_STREAM); + packet.setChannelId(channelId); + packet.setServiceId(method.service.id); + packet.setMethodId(method.id); + packet.setPayload(response.serializeBinary()); + packet.setStatus(status); + return packet.serializeBinary(); + } + beforeEach(() => { serialMock = new SerialMock(); device = new Device(new ProtoCollection(), new WebSerialTransport(serialMock as Serial)); @@ -45,10 +95,32 @@ describe('WebSerialTransport', () => { 71, 139, 109, 127, 108, 165, 126]); await device.connect(); - console.log(device.rpcs.pw.rpc.EchoService.Echo); serialMock.dataFromDevice(helloResponse); const [status, response] = await device.rpcs.pw.rpc.EchoService.Echo("hello"); expect(response.getMsg()).toBe("hello"); expect(status).toBe(0); }); + + it('server streaming rpc sends response', async () => { + await device.connect(); + const response1 = newResponse('!!!'); + const response2 = newResponse('?'); + const serverStreaming = device.client + .channel() + ?.methodStub( + 'pw.rpc.test1.TheTestService.SomeServerStreaming' + )! as ServerStreamingMethodStub; + const onNext = jest.fn(); + const onCompleted = jest.fn(); + const onError = jest.fn(); + + device.rpcs.pw.rpc.test1.TheTestService.SomeServerStreaming(4, onNext, onCompleted, onError); + device.client.processPacket(generateStreamingPacket(1, serverStreaming.method, response1)); + device.client.processPacket(generateStreamingPacket(1, serverStreaming.method, response2)); + device.client.processPacket(generateResponsePacket(1, serverStreaming.method, Status.ABORTED)); + + expect(onNext).toBeCalledWith(response1); + expect(onNext).toBeCalledWith(response2); + expect(onCompleted).toBeCalledWith(Status.ABORTED); + }); });