Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix config/output type URLs #1002

Merged
merged 13 commits into from
Oct 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import contextlib
import inspect
import pathlib
import sys
import weakref
from contextvars import ContextVar
from typing import Any, Callable, Dict, Generator, List, Optional

import grpc
from google.protobuf import any_pb2
from google.protobuf import any_pb2, descriptor_pool, message_factory

from ni_measurement_plugin_sdk_service._internal.parameter import decoder, encoder
from ni_measurement_plugin_sdk_service._internal.parameter.metadata import (
Expand Down Expand Up @@ -181,6 +182,8 @@ def __init__(
self._measure_function = measure_function
self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle
self._service_info = service_info
self._configuration_parameters_message_type = service_info.service_class + ".Configurations"
self._outputs_message_type = service_info.service_class + ".Outputs"

def GetMetadata( # noqa: N802 - function name should be lowercase
self, request: v1_measurement_service_pb2.GetMetadataRequest, context: grpc.ServicerContext
Expand All @@ -191,8 +194,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase
)

measurement_signature = v1_measurement_service_pb2.MeasurementSignature(
configuration_parameters_message_type="ni.measurementlink.measurement.v1.MeasurementConfigurations",
outputs_message_type="ni.measurementlink.measurement.v1.MeasurementOutputs",
configuration_parameters_message_type=self._configuration_parameters_message_type,
outputs_message_type=self._outputs_message_type,
)

for field_number, configuration_metadata in self._configuration_metadata.items():
Expand All @@ -206,7 +209,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase
measurement_signature.configuration_parameters.append(configuration_parameter)

measurement_signature.configuration_defaults.value = encoder.serialize_default_values(
self._configuration_metadata, self._service_info.service_class + ".Configurations"
self._configuration_metadata, self._configuration_parameters_message_type
)

for field_number, output_metadata in self._output_metadata.items():
Expand Down Expand Up @@ -236,10 +239,11 @@ def Measure( # noqa: N802 - function name should be lowercase
self, request: v1_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext
) -> v1_measurement_service_pb2.MeasureResponse:
"""RPC API that executes the registered measurement method."""
self._validate_parameters(request)
mapping_by_id = decoder.deserialize_parameters(
self._configuration_metadata,
request.configuration_parameters.value,
self._service_info.service_class + ".Configurations",
self._configuration_parameters_message_type,
)
mapping_by_variable_name = _get_mapping_by_parameter_name(
mapping_by_id, self._measure_function
Expand Down Expand Up @@ -271,10 +275,21 @@ def _serialize_response(
outputs: Any,
) -> v1_measurement_service_pb2.MeasureResponse:
return v1_measurement_service_pb2.MeasureResponse(
outputs=_serialize_outputs(
self._output_metadata, outputs, self._service_info.service_class + ".Outputs"
)
outputs=_serialize_outputs(self._output_metadata, outputs, self._outputs_message_type)
)

def _validate_parameters(self, request: v1_measurement_service_pb2.MeasureRequest) -> None:
pool = descriptor_pool.Default()
configuration_proto = pool.FindMessageTypeByName(
self._configuration_parameters_message_type
)
configuration_message = message_factory.GetMessageClass(configuration_proto)()
if not request.configuration_parameters.Is(configuration_message.DESCRIPTOR):
expected = "type.googleapis.com/" + configuration_message.DESCRIPTOR.full_name
actual = request.configuration_parameters.type_url
sys.stderr.write(f"Note: Wrong message type. Expected {expected!r} but got {actual!r}")
elif not request.configuration_parameters.Unpack(configuration_message):
sys.stderr.write("Warning: Unpack failed")


class MeasurementServiceServicerV2(v2_measurement_service_pb2_grpc.MeasurementServiceServicer):
Expand All @@ -297,6 +312,8 @@ def __init__(
self._measure_function = measure_function
self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle
self._service_info = service_info
self._configuration_parameters_message_type = service_info.service_class + ".Configurations"
self._outputs_message_type = service_info.service_class + ".Outputs"

def GetMetadata( # noqa: N802 - function name should be lowercase
self, request: v2_measurement_service_pb2.GetMetadataRequest, context: grpc.ServicerContext
Expand All @@ -307,8 +324,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase
)

measurement_signature = v2_measurement_service_pb2.MeasurementSignature(
configuration_parameters_message_type="ni.measurementlink.measurement.v2.MeasurementConfigurations",
outputs_message_type="ni.measurementlink.measurement.v2.MeasurementOutputs",
configuration_parameters_message_type=self._configuration_parameters_message_type,
outputs_message_type=self._outputs_message_type,
)

for field_number, configuration_metadata in self._configuration_metadata.items():
Expand All @@ -323,7 +340,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase
measurement_signature.configuration_parameters.append(configuration_parameter)

measurement_signature.configuration_defaults.value = encoder.serialize_default_values(
self._configuration_metadata, self._service_info.service_class + ".Configurations"
self._configuration_metadata, self._configuration_parameters_message_type
)

for field_number, output_metadata in self._output_metadata.items():
Expand Down Expand Up @@ -355,10 +372,11 @@ def Measure( # noqa: N802 - function name should be lowercase
self, request: v2_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext
) -> Generator[v2_measurement_service_pb2.MeasureResponse, None, None]:
"""RPC API that executes the registered measurement method."""
self._validate_parameters(request)
mapping_by_id = decoder.deserialize_parameters(
self._configuration_metadata,
request.configuration_parameters.value,
self._service_info.service_class + ".Configurations",
self._configuration_parameters_message_type,
)
mapping_by_variable_name = _get_mapping_by_parameter_name(
mapping_by_id, self._measure_function
Expand Down Expand Up @@ -386,7 +404,18 @@ def Measure( # noqa: N802 - function name should be lowercase

def _serialize_response(self, outputs: Any) -> v2_measurement_service_pb2.MeasureResponse:
return v2_measurement_service_pb2.MeasureResponse(
outputs=_serialize_outputs(
self._output_metadata, outputs, self._service_info.service_class + ".Outputs"
)
outputs=_serialize_outputs(self._output_metadata, outputs, self._outputs_message_type)
)

def _validate_parameters(self, request: v2_measurement_service_pb2.MeasureRequest) -> None:
pool = descriptor_pool.Default()
configuration_proto = pool.FindMessageTypeByName(
self._configuration_parameters_message_type
)
configuration_message = message_factory.GetMessageClass(configuration_proto)()
if not request.configuration_parameters.Is(configuration_message.DESCRIPTOR):
expected = "type.googleapis.com/" + configuration_message.DESCRIPTOR.full_name
actual = request.configuration_parameters.type_url
sys.stderr.write(f"Note: Wrong message type. Expected {expected!r} but got {actual!r}")
elif not request.configuration_parameters.Unpack(configuration_message):
sys.stderr.write("Warning: Unpack failed")
Loading