diff --git a/examples/multiplexer/multiplexed_client.py b/examples/multiplexer/multiplexed_client.py index e4612af..fc27f33 100644 --- a/examples/multiplexer/multiplexed_client.py +++ b/examples/multiplexer/multiplexed_client.py @@ -2,18 +2,31 @@ import thriftpy from thriftpy.rpc import client_context +from thriftpy.protocol import ( + TBinaryProtocolFactory, + TMultiplexingProtocolFactory + ) dd_thrift = thriftpy.load("dingdong.thrift", module_name="dd_thrift") pp_thrift = thriftpy.load("pingpong.thrift", module_name="pp_thrift") +DD_SERVICE_NAME = "dd_thrift" +PP_SERVICE_NAME = "pp_thrift" + + def main(): - with client_context(dd_thrift.DingService, '127.0.0.1', 9090) as c: + binary_factory = TBinaryProtocolFactory() + dd_factory = TMultiplexingProtocolFactory(binary_factory, DD_SERVICE_NAME) + with client_context(dd_thrift.DingService, '127.0.0.1', 9090, + proto_factory=dd_factory) as c: # ring that doorbell dong = c.ding() print(dong) - with client_context(pp_thrift.PingService, '127.0.0.1', 9090) as c: + pp_factory = TMultiplexingProtocolFactory(binary_factory, PP_SERVICE_NAME) + with client_context(pp_thrift.PingService, '127.0.0.1', 9090, + proto_factory=pp_factory) as c: # play table tennis like a champ pong = c.ping() print(pong) diff --git a/examples/multiplexer/multiplexed_server.py b/examples/multiplexer/multiplexed_server.py index 90f051b..485571a 100644 --- a/examples/multiplexer/multiplexed_server.py +++ b/examples/multiplexer/multiplexed_server.py @@ -10,6 +10,8 @@ dd_thrift = thriftpy.load("dingdong.thrift", module_name="dd_thrift") pp_thrift = thriftpy.load("pingpong.thrift", module_name="pp_thrift") +DD_SERVICE_NAME = "dd_thrift" +PP_SERVICE_NAME = "pp_thrift" class DingDispatcher(object): def ding(self): @@ -28,8 +30,8 @@ def main(): pp_proc = TProcessor(pp_thrift.PingService, PingDispatcher()) mux_proc = TMultiplexingProcessor() - mux_proc.register_processor(dd_proc) - mux_proc.register_processor(pp_proc) + mux_proc.register_processor(DD_SERVICE_NAME, dd_proc) + mux_proc.register_processor(PP_SERVICE_NAME, pp_proc) server = TThreadedServer(mux_proc, TServerSocket(), iprot_factory=TBinaryProtocolFactory(), diff --git a/tests/test_multiplexed.py b/tests/test_multiplexed.py index ad53ea2..6757a23 100644 --- a/tests/test_multiplexed.py +++ b/tests/test_multiplexed.py @@ -9,7 +9,10 @@ import pytest import thriftpy -from thriftpy.protocol import TBinaryProtocolFactory +from thriftpy.protocol import ( + TBinaryProtocolFactory, + TMultiplexingProtocolFactory + ) from thriftpy.rpc import client_context from thriftpy.server import TThreadedServer from thriftpy.thrift import TProcessor, TMultiplexingProcessor @@ -37,8 +40,8 @@ def server(request): p2 = TProcessor(mux.ThingTwoService, DispatcherTwo()) mux_proc = TMultiplexingProcessor() - mux_proc.register_processor(p1) - mux_proc.register_processor(p2) + mux_proc.register_processor("ThingOneService", p1) + mux_proc.register_processor("ThingTwoService", p2) _server = TThreadedServer(mux_proc, TServerSocket(unix_socket=sock_path), iprot_factory=TBinaryProtocolFactory(), @@ -58,13 +61,21 @@ def fin(): def client_one(timeout=3000): + binary_factory = TBinaryProtocolFactory() + multiplexing_factory = TMultiplexingProtocolFactory(binary_factory, + "ThingOneService") return client_context(mux.ThingOneService, unix_socket=sock_path, - timeout=timeout) + timeout=timeout, + proto_factory=multiplexing_factory) def client_two(timeout=3000): + binary_factory = TBinaryProtocolFactory() + multiplexing_factory = TMultiplexingProtocolFactory(binary_factory, + "ThingTwoService") return client_context(mux.ThingTwoService, unix_socket=sock_path, - timeout=timeout) + timeout=timeout, + proto_factory=multiplexing_factory) def test_multiplexed_server(server): diff --git a/thriftpy/protocol/__init__.py b/thriftpy/protocol/__init__.py index 7277adc..37c5d1e 100644 --- a/thriftpy/protocol/__init__.py +++ b/thriftpy/protocol/__init__.py @@ -4,6 +4,7 @@ from .binary import TBinaryProtocol, TBinaryProtocolFactory from .json import TJSONProtocol, TJSONProtocolFactory +from .multiplex import TMultiplexingProtocol, TMultiplexingProtocolFactory from thriftpy._compat import PYPY, CYTHON if not PYPY: @@ -19,4 +20,5 @@ __all__ = ['TBinaryProtocol', 'TBinaryProtocolFactory', 'TCyBinaryProtocol', 'TCyBinaryProtocolFactory', - 'TJSONProtocol', 'TJSONProtocolFactory'] + 'TJSONProtocol', 'TJSONProtocolFactory', + 'TMultiplexingProtocol', 'TMultiplexingProtocolFactory'] diff --git a/thriftpy/protocol/multiplex.py b/thriftpy/protocol/multiplex.py new file mode 100644 index 0000000..4a7cce7 --- /dev/null +++ b/thriftpy/protocol/multiplex.py @@ -0,0 +1,37 @@ +from thriftpy.thrift import TMultiplexingProcessor + + +class TMultiplexingProtocol(object): + + """ + + Multiplex protocol + + for writing message begin, it prepend the service name to the api + for other functions, it simply delegate to the original protocol + + """ + + def __init__(self, proto, service_name): + self.service_name = service_name + self.proto = proto + + def __getattr__(self, name): + return getattr(self.proto, name) + + def write_message_begin(self, name, ttype, seqid): + self.proto.write_message_begin( + self.service_name + TMultiplexingProcessor.SEPARATOR + name, + ttype, seqid) + + +class TMultiplexingProtocolFactory(object): + + def __init__(self, proto_factory, service_name): + self.proto_factory = proto_factory + self.service_name = service_name + + def get_protocol(self, trans): + proto = self.proto_factory.get_protocol(trans) + multi_proto = TMultiplexingProtocol(proto, self.service_name) + return multi_proto diff --git a/thriftpy/thrift.py b/thriftpy/thrift.py index 9b64ed9..b570595 100644 --- a/thriftpy/thrift.py +++ b/thriftpy/thrift.py @@ -10,7 +10,6 @@ from __future__ import absolute_import import functools -import inspect from ._compat import init_func_generator, with_metaclass @@ -248,41 +247,34 @@ def process(self, iprot, oprot): class TMultiplexingProcessor(TProcessor): - processors = {} - service_map = {} + SEPARATOR = ":" def __init__(self): + self.processors = {} pass - def register_processor(self, processor): - service = processor._service - module = inspect.getmodule(processor) - name = '{0}:{1}'.format(module.__name__, service.__name__) - if name in self.processors: + def register_processor(self, service_name, processor): + + if service_name in self.processors: raise TApplicationException( type=TApplicationException.INTERNAL_ERROR, - message='processor for `{0}` already registered'.format(name)) - - for srv in service.thrift_services: - if srv in self.service_map: - raise TApplicationException( - type=TApplicationException.INTERNAL_ERROR, - message='cannot multiplex processor for `{0}`; ' - '`{1}` is already a registered method for `{2}`' - .format(name, srv, self.service_map[srv])) - self.service_map[srv] = name + message='processor for `{0}` already registered' + .format(service_name)) - self.processors[name] = processor + self.processors[service_name] = processor def process_in(self, iprot): api, type, seqid = iprot.read_message_begin() - if api not in self.service_map: + + service_name, api = api.split(TMultiplexingProcessor.SEPARATOR) + + if service_name not in self.processors: iprot.skip(TType.STRUCT) iprot.read_message_end() e = TApplicationException(TApplicationException.UNKNOWN_METHOD) return api, seqid, e, None # noqa - proc = self.processors[self.service_map[api]] + proc = self.processors[service_name] args = getattr(proc._service, api + "_args")() args.read(iprot) iprot.read_message_end()