diff --git a/bazel/external/apache_thrift.BUILD b/bazel/external/apache_thrift.BUILD new file mode 100644 index 000000000000..8b296fc00672 --- /dev/null +++ b/bazel/external/apache_thrift.BUILD @@ -0,0 +1,21 @@ +# The apache-thrift distribution does not keep the thrift files in a directory with the +# expected package name (it uses src/Thrift.py vs src/thrift/Thrift.py), so we provide a +# genrule to copy src/**/*.py to thrift/**/*.py. +src_files = glob(["src/**/*.py"]) + +genrule( + name = "thrift_files", + srcs = src_files, + outs = [f.replace("src/", "thrift/") for f in src_files], + cmd = '\n'.join( + ['mkdir -p $$(dirname $(location %s)) && cp $(location %s) $(location :%s)' % (f, f, f.replace('src/', 'thrift/')) for f in src_files] + ), + visibility = ["//visibility:private"], +) + +py_library( + name = "apache_thrift", + srcs = [":thrift_files"], + visibility = ["//visibility:public"], + deps = ["@six_archive//:six"], +) diff --git a/bazel/external/twitter_common_finagle_thrift.BUILD b/bazel/external/twitter_common_finagle_thrift.BUILD new file mode 100644 index 000000000000..1ca6af126c59 --- /dev/null +++ b/bazel/external/twitter_common_finagle_thrift.BUILD @@ -0,0 +1,7 @@ +py_library( + name = "twitter_common_finagle_thrift", + srcs = glob([ + "gen/**/*.py", + ]), + visibility = ["//visibility:public"], +) diff --git a/bazel/external/twitter_common_lang.BUILD b/bazel/external/twitter_common_lang.BUILD new file mode 100644 index 000000000000..f4300b37b05d --- /dev/null +++ b/bazel/external/twitter_common_lang.BUILD @@ -0,0 +1,7 @@ +py_library( + name = "twitter_common_lang", + srcs = glob([ + "twitter/**/*.py", + ]), + visibility = ["//visibility:public"], +) diff --git a/bazel/external/twitter_common_rpc.BUILD b/bazel/external/twitter_common_rpc.BUILD new file mode 100644 index 000000000000..7a13ec511a66 --- /dev/null +++ b/bazel/external/twitter_common_rpc.BUILD @@ -0,0 +1,11 @@ +py_library( + name = "twitter_common_rpc", + srcs = glob([ + "twitter/**/*.py", + ]), + visibility = ["//visibility:public"], + deps = [ + "@com_github_twitter_common_lang//:twitter_common_lang", + "@com_github_twitter_common_finagle_thrift//:twitter_common_finagle_thrift" + ], +) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index d267eef32793..6063eb7c42af 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -141,6 +141,22 @@ def _python_deps(): name = "jinja2", actual = "@com_github_pallets_jinja//:jinja2", ) + _repository_impl( + name = "com_github_apache_thrift", + build_file = "@envoy//bazel/external:apache_thrift.BUILD", + ) + _repository_impl( + name = "com_github_twitter_common_lang", + build_file = "@envoy//bazel/external:twitter_common_lang.BUILD", + ) + _repository_impl( + name = "com_github_twitter_common_rpc", + build_file = "@envoy//bazel/external:twitter_common_rpc.BUILD", + ) + _repository_impl( + name = "com_github_twitter_common_finagle_thrift", + build_file = "@envoy//bazel/external:twitter_common_finagle_thrift.BUILD", + ) # Bazel native C++ dependencies. For the depedencies that doesn't provide autoconf/automake builds. def _cc_deps(): diff --git a/bazel/repository_locations.bzl b/bazel/repository_locations.bzl index 43279152c4bc..539c12893a07 100644 --- a/bazel/repository_locations.bzl +++ b/bazel/repository_locations.bzl @@ -8,6 +8,11 @@ REPOSITORY_LOCATIONS = dict( commit = "92020a042c0cd46979db9f6f0cb32783dc07765e", # 2018-06-08 remote = "https://github.com/abseil/abseil-cpp", ), + com_github_apache_thrift = dict( + sha256 = "7d59ac4fdcb2c58037ebd4a9da5f9a49e3e034bf75b3f26d9fe48ba3d8806e6b", + urls = ["https://files.pythonhosted.org/packages/c6/b4/510617906f8e0c5660e7d96fbc5585113f83ad547a3989b80297ac72a74c/thrift-0.11.0.tar.gz"], # 0.11.0 + strip_prefix = "thrift-0.11.0", + ), com_github_bombela_backward = dict( commit = "44ae9609e860e3428cd057f7052e505b4819eb84", # 2018-02-06 remote = "https://github.com/bombela/backward-cpp", @@ -80,6 +85,21 @@ REPOSITORY_LOCATIONS = dict( commit = "f54b0e47a08782a6131cc3d60f94d038fa6e0a51", # v1.1.0 remote = "https://github.com/tencent/rapidjson", ), + com_github_twitter_common_lang = dict( + sha256 = "56d1d266fd4767941d11c27061a57bc1266a3342e551bde3780f9e9eb5ad0ed1", + urls = ["https://files.pythonhosted.org/packages/08/bc/d6409a813a9dccd4920a6262eb6e5889e90381453a5f58938ba4cf1d9420/twitter.common.lang-0.3.9.tar.gz"], # 0.3.9 + strip_prefix = "twitter.common.lang-0.3.9/src", + ), + com_github_twitter_common_rpc = dict( + sha256 = "0792b63fb2fb32d970c2e9a409d3d00633190a22eb185145fe3d9067fdaa4514", + urls = ["https://files.pythonhosted.org/packages/be/97/f5f701b703d0f25fbf148992cd58d55b4d08d3db785aad209255ee67e2d0/twitter.common.rpc-0.3.9.tar.gz"], # 0.3.9 + strip_prefix = "twitter.common.rpc-0.3.9/src", + ), + com_github_twitter_common_finagle_thrift = dict( + sha256 = "1e3a57d11f94f58745e6b83348ecd4fa74194618704f45444a15bc391fde497a", + urls = ["https://files.pythonhosted.org/packages/f9/e7/4f80d582578f8489226370762d2cf6bc9381175d1929eba1754e03f70708/twitter.common.finagle-thrift-0.3.9.tar.gz"], # 0.3.9 + strip_prefix = "twitter.common.finagle-thrift-0.3.9/src", + ), com_google_googletest = dict( commit = "43863938377a9ea1399c0596269e0890b5c5515a", remote = "https://github.com/google/googletest", diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol.cc b/source/extensions/filters/network/thrift_proxy/binary_protocol.cc index bf9e05dd1235..c885cbc33edc 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol.cc +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol.cc @@ -97,7 +97,11 @@ bool BinaryProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& n if (buffer.length() < 3) { return false; } - field_id = BufferHelper::peekI16(buffer, 1); + int16_t id = BufferHelper::peekI16(buffer, 1); + if (id < 0) { + throw EnvoyException(fmt::format("invalid binary protocol field id {}", id)); + } + field_id = id; buffer.drain(3); } diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol.cc b/source/extensions/filters/network/thrift_proxy/compact_protocol.cc index 527be64b7631..8892c84c2196 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol.cc +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol.cc @@ -140,7 +140,7 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& return false; } - if (id <= 0 || id > INT16_MAX) { + if (id < 0 || id > INT16_MAX) { throw EnvoyException(fmt::format("invalid compact protocol field id {}", id)); } @@ -390,7 +390,7 @@ bool CompactProtocolImpl::readString(Buffer::Instance& buffer, std::string& valu } int len_size; - int32_t str_len = BufferHelper::peekZigZagI32(buffer, 0, len_size); + int32_t str_len = BufferHelper::peekVarIntI32(buffer, 0, len_size); if (len_size < 0) { return false; } diff --git a/test/extensions/extensions_build_system.bzl b/test/extensions/extensions_build_system.bzl index 1ea8f3630c1e..113c47f8374e 100644 --- a/test/extensions/extensions_build_system.bzl +++ b/test/extensions/extensions_build_system.bzl @@ -1,4 +1,4 @@ -load("//bazel:envoy_build_system.bzl", "envoy_cc_test", "envoy_cc_mock") +load("//bazel:envoy_build_system.bzl", "envoy_cc_test", "envoy_cc_test_library", "envoy_cc_mock") load("@envoy_build_config//:extensions_build_config.bzl", "EXTENSIONS") # All extension tests should use this version of envoy_cc_test(). It allows compiling out @@ -12,10 +12,18 @@ def envoy_extension_cc_test(name, envoy_cc_test(name, **kwargs) +def envoy_extension_cc_test_library(name, + extension_name, + **kwargs): + if not extension_name in EXTENSIONS: + return + + envoy_cc_test_library(name, **kwargs) + def envoy_extension_cc_mock(name, extension_name, **kwargs): if not extension_name in EXTENSIONS: return - envoy_cc_mock(name, **kwargs) \ No newline at end of file + envoy_cc_mock(name, **kwargs) diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index da7a3d6c933b..616b15fd1225 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -2,30 +2,32 @@ licenses(["notice"]) # Apache 2 load( "//bazel:envoy_build_system.bzl", - "envoy_cc_mock", - "envoy_cc_test_library", "envoy_package", ) load( "//test/extensions:extensions_build_system.bzl", + "envoy_extension_cc_mock", "envoy_extension_cc_test", + "envoy_extension_cc_test_library", ) envoy_package() -envoy_cc_mock( +envoy_extension_cc_mock( name = "mocks", srcs = ["mocks.cc"], hdrs = ["mocks.h"], + extension_name = "envoy.filters.network.thrift_proxy", deps = [ "//source/extensions/filters/network/thrift_proxy:transport_lib", "//test/test_common:printers_lib", ], ) -envoy_cc_test_library( +envoy_extension_cc_test_library( name = "utility_lib", hdrs = ["utility.h"], + extension_name = "envoy.filters.network.thrift_proxy", deps = [ "//source/common/buffer:buffer_lib", "//source/common/common:byte_order_lib", @@ -87,6 +89,7 @@ envoy_extension_cc_test( extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", + ":utility_lib", "//source/extensions/filters/network/thrift_proxy:decoder_lib", "//test/test_common:printers_lib", "//test/test_common:utility_lib", @@ -131,3 +134,21 @@ envoy_extension_cc_test( "//test/test_common:utility_lib", ], ) + +envoy_extension_cc_test( + name = "filter_integration_test", + srcs = ["filter_integration_test.cc"], + data = [ + "//test/extensions/filters/network/thrift_proxy/driver:generate_fixture", + ], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + "//source/extensions/filters/network/tcp_proxy:config", + "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy:filter_lib", + "//test/integration:integration_lib", + "//test/test_common:environment_lib", + "//test/test_common:network_utility_lib", + "//test/test_common:printers_lib", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc b/test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc index c4f2c6f44339..b32438229f3c 100644 --- a/test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc +++ b/test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc @@ -224,7 +224,7 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { EXPECT_EQ(field_id, 1); } - // Non-terminal field + // Non-stop field { Buffer::OwnedImpl buffer; std::string name = "-"; @@ -241,6 +241,24 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { EXPECT_EQ(field_id, 99); EXPECT_EQ(buffer.length(), 0); } + + // field id < 0 + { + Buffer::OwnedImpl buffer; + std::string name = "-"; + FieldType field_type = FieldType::String; + int16_t field_id = 1; + + addInt8(buffer, FieldType::I32); + addInt16(buffer, -1); + + EXPECT_THROW_WITH_MESSAGE(proto.readFieldBegin(buffer, name, field_type, field_id), + EnvoyException, "invalid binary protocol field id -1"); + EXPECT_EQ(name, "-"); + EXPECT_EQ(field_type, FieldType::String); + EXPECT_EQ(field_id, 1); + EXPECT_EQ(buffer.length(), 3); + } } TEST(BinaryProtocolTest, ReadFieldEnd) { diff --git a/test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc b/test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc index 7bc0fe7ab8a2..87eaeb0b1fb4 100644 --- a/test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc +++ b/test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc @@ -336,7 +336,7 @@ TEST(CompactProtocolTest, ReadFieldBegin) { EXPECT_EQ(buffer.length(), 6); } - // Long-form field header, field id out of range + // Long-form field header, field id > 32767 { Buffer::OwnedImpl buffer; std::string name = "-"; @@ -344,16 +344,34 @@ TEST(CompactProtocolTest, ReadFieldBegin) { int16_t field_id = 1; addInt8(buffer, 0x05); - addSeq(buffer, {0xFE, 0xFF, 0x7F}); // zigzag(0x1FFFFE) = 0xFFFFF + addSeq(buffer, {0x80, 0x80, 0x04}); // zigzag(0x10000) = 0x8000 EXPECT_THROW_WITH_MESSAGE(proto.readFieldBegin(buffer, name, field_type, field_id), - EnvoyException, "invalid compact protocol field id 1048575"); + EnvoyException, "invalid compact protocol field id 32768"); EXPECT_EQ(name, "-"); EXPECT_EQ(field_type, FieldType::String); EXPECT_EQ(field_id, 1); EXPECT_EQ(buffer.length(), 4); } + // Long-form field header, field id < 0 + { + Buffer::OwnedImpl buffer; + std::string name = "-"; + FieldType field_type = FieldType::String; + int16_t field_id = 1; + + addInt8(buffer, 0x05); + addSeq(buffer, {0x01}); // zigzag(1) = -1 + + EXPECT_THROW_WITH_MESSAGE(proto.readFieldBegin(buffer, name, field_type, field_id), + EnvoyException, "invalid compact protocol field id -1"); + EXPECT_EQ(name, "-"); + EXPECT_EQ(field_type, FieldType::String); + EXPECT_EQ(field_id, 1); + EXPECT_EQ(buffer.length(), 2); + } + // Unknown compact protocol field type { Buffer::OwnedImpl buffer; @@ -962,7 +980,7 @@ TEST(CompactProtocolTest, ReadString) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x8); // zigzag(8) = 4 + addInt8(buffer, 0x4); EXPECT_FALSE(proto.readString(buffer, value)); EXPECT_EQ(value, "-"); @@ -974,12 +992,12 @@ TEST(CompactProtocolTest, ReadString) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x01); // zigzag(1) = -1 + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // -1 EXPECT_THROW_WITH_MESSAGE(proto.readString(buffer, value), EnvoyException, "negative compact protocol string/binary length -1"); EXPECT_EQ(value, "-"); - EXPECT_EQ(buffer.length(), 1); + EXPECT_EQ(buffer.length(), 5); } // empty string @@ -999,7 +1017,7 @@ TEST(CompactProtocolTest, ReadString) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x0C); // zigzag(0x0C) = 0x06 + addInt8(buffer, 0x06); addString(buffer, "string"); EXPECT_TRUE(proto.readString(buffer, value)); @@ -1015,7 +1033,7 @@ TEST(CompactProtocolTest, ReadBinary) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x0C); // zigzag(0x0C) = 0x06 + addInt8(buffer, 0x06); addString(buffer, "string"); EXPECT_TRUE(proto.readBinary(buffer, value)); diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 751da9c8e943..8d19f4e2b45b 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -3,6 +3,7 @@ #include "extensions/filters/network/thrift_proxy/decoder.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" @@ -21,6 +22,7 @@ using testing::Return; using testing::ReturnRef; using testing::SetArgReferee; using testing::StrictMock; +using testing::TestParamInfo; using testing::TestWithParam; using testing::Values; using testing::_; @@ -108,29 +110,44 @@ ExpectationSet expectContainerEnd(NiceMock& proto, FieldType field class DecoderStateMachineNonValueTest : public TestWithParam {}; +static std::string protoStateParamToString(const TestParamInfo& params) { + return ProtocolStateNameValues::name(params.param); +} + INSTANTIATE_TEST_CASE_P(NonValueProtocolStates, DecoderStateMachineNonValueTest, Values(ProtocolState::MessageBegin, ProtocolState::MessageEnd, ProtocolState::StructBegin, ProtocolState::StructEnd, ProtocolState::FieldBegin, ProtocolState::FieldEnd, ProtocolState::MapBegin, ProtocolState::MapEnd, ProtocolState::ListBegin, ProtocolState::ListEnd, - ProtocolState::SetBegin, ProtocolState::SetEnd)); + ProtocolState::SetBegin, ProtocolState::SetEnd), + protoStateParamToString); class DecoderStateMachineValueTest : public TestWithParam {}; INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, DecoderStateMachineValueTest, Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, - FieldType::I32, FieldType::I64, FieldType::String)); + FieldType::I32, FieldType::I64, FieldType::String), + fieldTypeParamToString); class DecoderStateMachineNestingTest : public TestWithParam> {}; +static std::string nestedFieldTypesParamToString( + const TestParamInfo>& params) { + FieldType outer_field_type, inner_type, value_type; + std::tie(outer_field_type, inner_type, value_type) = params.param; + return fmt::format("{}Of{}Of{}", fieldTypeToString(outer_field_type), + fieldTypeToString(inner_type), fieldTypeToString(value_type)); +} + INSTANTIATE_TEST_CASE_P( NestedTypes, DecoderStateMachineNestingTest, Combine(Values(FieldType::Struct, FieldType::List, FieldType::Map, FieldType::Set), Values(FieldType::Struct, FieldType::List, FieldType::Map, FieldType::Set), Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, - FieldType::I32, FieldType::I64, FieldType::String))); + FieldType::I32, FieldType::I64, FieldType::String)), + nestedFieldTypesParamToString); TEST_P(DecoderStateMachineNonValueTest, NoData) { ProtocolState state = GetParam(); diff --git a/test/extensions/filters/network/thrift_proxy/driver/BUILD b/test/extensions/filters/network/thrift_proxy/driver/BUILD new file mode 100644 index 000000000000..beed670c9e20 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +filegroup( + name = "generate_fixture", + srcs = ["generate_fixture.sh"], + data = [ + ":client", + ":server", + ], +) + +py_binary( + name = "client", + srcs = ["client.py"], + deps = [ + "//test/extensions/filters/network/thrift_proxy/driver/fbthrift:fbthrift_lib", + "//test/extensions/filters/network/thrift_proxy/driver/finagle:finagle_lib", + "//test/extensions/filters/network/thrift_proxy/driver/generated/example:example_lib", + "@com_github_twitter_common_rpc//:twitter_common_rpc", + ], +) + +py_binary( + name = "server", + srcs = ["server.py"], + deps = [ + "//test/extensions/filters/network/thrift_proxy/driver/fbthrift:fbthrift_lib", + "//test/extensions/filters/network/thrift_proxy/driver/finagle:finagle_lib", + "//test/extensions/filters/network/thrift_proxy/driver/generated/example:example_lib", + "@com_github_twitter_common_rpc//:twitter_common_rpc", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/README.md b/test/extensions/filters/network/thrift_proxy/driver/README.md new file mode 100644 index 000000000000..251d1542abdc --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/README.md @@ -0,0 +1,33 @@ +Thrift Integration Test Driver +============================== + +The code in this package provides `client.py` and `server.py` which +can be used as a thrift client and server pair. Both scripts support +all the Thrift transport and protocol variations that Envoy's Thrift +proxy supports (or will eventually support): + +Transports: framed, unframed, header +Protocols: binary, compact, json, ttwitter (e.g., finagle-thrift) + +The client script can be configured to write its request and the +server's response to a file. The server script can be configured to +return successful responses, IDL-defined exceptions, or server +(application) exceptions. + +Envoy's thrift_proxy integration tests use the `generate_fixtures.sh` +script to create request and response files for various combinations +of transport, protocol, service multiplexing. In addition, the +integration tests generate IDL and application exception responses. +The generated data is used with the Envoy's integration test +infrastructure to simulate downstream and upstream connections. +Generated files are used instead of running the client and server +scripts directly to eliminate the need to select a Thrift upstream +server port (or determine its self-selected port). + +Regenerating example.thrift +--------------------------- + +Install the Apache thrift library (from source or a package) so that +the `thrift` command is available. The `generate_bindings.sh` script +will regenerate the Python bindings which are checked into the +repository. diff --git a/test/extensions/filters/network/thrift_proxy/driver/client.py b/test/extensions/filters/network/thrift_proxy/driver/client.py new file mode 100755 index 000000000000..bbc1293cee55 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/client.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python + +import argparse +import io +import sys + +from generated.example import Example +from generated.example.ttypes import ( + Param, TheWorks, AppException +) + +from thrift import Thrift +from thrift.protocol import ( + TBinaryProtocol, TCompactProtocol, TJSONProtocol, TMultiplexedProtocol +) +from thrift.transport import TSocket +from thrift.transport import TTransport +from fbthrift import THeaderTransport +from twitter.common.rpc.finagle.protocol import TFinagleProtocol + + +class TRecordingTransport(TTransport.TTransportBase): + def __init__(self, underlying, writehandle, readhandle): + self._underlying = underlying + self._whandle = writehandle + self._rhandle = readhandle + + def isOpen(self): + return self._underlying.isOpen() + + def open(self): + if not self._underlying.isOpen(): + self._underlying.open() + + def close(self): + self._underlying.close() + self._whandle.close() + self._rhandle.close() + + def read(self, sz): + buf = self._underlying.read(sz) + if len(buf) != 0: + self._rhandle.write(buf) + return buf + + def write(self, buf): + if len(buf) != 0: + self._whandle.write(buf) + self._underlying.write(buf) + + def flush(self): + self._underlying.flush() + self._whandle.flush() + self._rhandle.flush() + + +def main(cfg, reqhandle, resphandle): + if cfg.unix: + if cfg.addr == "": + sys.exit("invalid unix domain socket: {}".format(cfg.addr)) + socket = TSocket.TSocket(unix_socket=cfg.addr) + else: + try: + (host, port) = cfg.addr.rsplit(":", 1) + if host == "": + host = "localhost" + socket = TSocket.TSocket(host=host, port=int(port)) + except ValueError: + sys.exit("invalid address: {}".format(cfg.addr)) + + transport = TRecordingTransport(socket, reqhandle, resphandle) + + if cfg.transport == "framed": + transport = TTransport.TFramedTransport(transport) + elif cfg.transport == "unframed": + transport = TTransport.TBufferedTransport(transport) + elif cfg.transport == "header": + transport = THeaderTransport.THeaderTransport( + transport, + client_type=THeaderTransport.CLIENT_TYPE.HEADER, + ) + else: + sys.exit("unknown transport {0}".format(cfg.transport)) + + transport.open() + + if cfg.protocol == "binary": + protocol = TBinaryProtocol.TBinaryProtocol(transport) + elif cfg.protocol == "compact": + protocol = TCompactProtocol.TCompactProtocol(transport) + elif cfg.protocol == "json": + protocol = TJSONProtocol.TJSONProtocol(transport) + elif cfg.protocol == "finagle": + protocol = TFinagleProtocol(transport, client_id="thrift-playground") + else: + sys.exit("unknown protocol {0}".format(cfg.protocol)) + + if cfg.service is not None: + protocol = TMultiplexedProtocol.TMultiplexedProtocol(protocol, cfg.service) + + client = Example.Client(protocol) + + try: + if cfg.method == "ping": + client.ping() + print("client: pinged") + elif cfg.method == "poke": + client.poke() + print("client: poked") + elif cfg.method == "add": + if len(cfg.params) != 2: + sys.exit("add takes 2 arguments, got: {0}".format(cfg.params)) + + a = int(cfg.params[0]) + b = int(cfg.params[1]) + v = client.add(a, b) + print("client: added {0} + {1} = {2}".format(a, b, v)) + elif cfg.method == "execute": + param = Param( + return_fields=cfg.params, + the_works=TheWorks( + field_1=True, + field_2=0x7f, + field_3=0x7fff, + field_4=0x7fffffff, + field_5=0x7fffffffffffffff, + field_6=-1.5, + field_7=u"string is UTF-8: \U0001f60e", + field_8=b"binary is bytes: \x80\x7f\x00\x01", + field_9={1: "one", 2: "two", 3: "three"}, + field_10=[1, 2, 4, 8], + field_11=set(["a", "b", "c"]), + field_12=False, + ) + ) + + try: + result = client.execute(param) + print("client: executed {0}: {1}".format(param, result)) + except AppException as e: + print("client: execute failed with IDL Exception: {0}".format(e.why)) + else: + sys.exit("unknown method {0}".format(cfg.method)) + except Thrift.TApplicationException as e: + print("client exception: {0}: {1}".format(e.type, e.message)) + + if cfg.request is None: + req = "".join( [ "%02X " % ord( x ) for x in reqhandle.getvalue() ] ).strip() + print("request: {}".format(req)) + if cfg.response is None: + resp = "".join( [ "%02X " % ord( x ) for x in resphandle.getvalue() ] ).strip() + print("response: {}".format(resp)) + + transport.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Thrift client tool.", + ) + parser.add_argument( + "method", + metavar="METHOD", + help="Name of the service method to invoke.", + ) + parser.add_argument( + "params", + metavar="PARAMS", + nargs="*", + help="Method parameters", + ) + parser.add_argument( + "-a", + "--addr", + metavar="ADDR", + dest="addr", + required=True, + help="Target address for requests in the form host:port. The host is optional. If --unix" + + " is set, the address is the socket name.", + ) + parser.add_argument( + "-m", + "--multiplex", + metavar="SERVICE", + dest="service", + help="Enable service multiplexing and set the service name.", + ) + parser.add_argument( + "-p", + "--protocol", + dest="protocol", + default="binary", + choices=["binary", "compact", "json", "finagle"], + help="selects a protocol.", + ) + parser.add_argument( + "--request", + metavar="FILE", + dest="request", + help="Writes the Thrift request to a file.", + ) + parser.add_argument( + "--response", + metavar="FILE", + dest="response", + help="Writes the Thrift response to a file.", + ) + parser.add_argument( + "-t", + "--transport", + dest="transport", + default="framed", + choices=["framed", "unframed", "header"], + help="selects a transport.", + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + action="store_true", + ) + cfg = parser.parse_args() + + reqhandle = io.BytesIO() + resphandle = io.BytesIO() + if cfg.request is not None: + try: + reqhandle = io.open(cfg.request, "wb") + except IOError as e: + sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) + if cfg.response is not None: + try: + resphandle = io.open(cfg.response, "wb") + except IOError as e: + sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) + try: + main(cfg, reqhandle, resphandle) + except Thrift.TException as tx: + sys.exit("Unhandled Thrift Exception: {0}".format(tx.message)) diff --git a/test/extensions/filters/network/thrift_proxy/driver/example.thrift b/test/extensions/filters/network/thrift_proxy/driver/example.thrift new file mode 100644 index 000000000000..eda22a8d7d1e --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/example.thrift @@ -0,0 +1,40 @@ +// TheWorks contains one instance of each type of field. Envoy does not +// concern itself with the optionality of fields, so we leave it +// defaulted. +struct TheWorks { + 1: bool field_1, + 2: i8 field_2, + 3: i16 field_3, + 4: i32 field_4, + 5: i64 field_5, + 6: double field_6, + 7: string field_7, + 8: binary field_8, + 9: map field_9, + 10: list field_10, + 11: set field_11, + 12: bool field_12, +} + +struct Param { + 1: list return_fields, + 2: TheWorks the_works, +} + +struct Result { + 1: TheWorks the_works, +} + +exception AppException { + 1: string why, +} + +service Example { + void ping(), + + oneway void poke(), + + i32 add(1:i32 a, 2:i32 b), + + Result execute(1:Param input) throws (1:AppException appex), +} diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/BUILD b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/BUILD new file mode 100644 index 000000000000..a1b33006f10f --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/BUILD @@ -0,0 +1,16 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +py_library( + name = "fbthrift_lib", + srcs = [ + "THeaderTransport.py", + "__init__.py", + ], + deps = [ + "@com_github_apache_thrift//:apache_thrift", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py new file mode 100644 index 000000000000..cba5ec0651d9 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py @@ -0,0 +1,662 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# INFO:(zuercher): Adapted from +# https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/transport/THeaderTransport.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +if sys.version_info[0] >= 3: + from http import server + BaseHTTPServer = server + xrange = range + from io import BytesIO as StringIO + PY3 = True +else: + import BaseHTTPServer + from cStringIO import StringIO + PY3 = False + +from struct import pack, unpack +import zlib + +from thrift.Thrift import TApplicationException +from thrift.transport.TTransport import TTransportException, TTransportBase, CReadableTransport + +# INFO:(zuercher): Instead of importing these constants from TBinaryProtocol and TCompactProtocol +BINARY_PROTO_ID = 0x80 +COMPACT_PROTO_ID = 0x82 + + +# INFO:(zuercher): Copied from: +# https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/protocol/TCompactProtocol.py +def getVarint(n): + out = [] + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + if sys.version_info[0] >= 3: + return bytes(out) + else: + return b''.join(map(chr, out)) + + +# INFO:(zuercher): Copied from +# https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/protocol/TCompactProtocol.py +def readVarint(trans): + result = 0 + shift = 0 + while True: + x = trans.read(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 + + +# Import the snappy module if it is available +try: + import snappy +except ImportError: + # If snappy is not available, don't fail immediately. + # Only raise an error if we actually ever need to perform snappy + # compression. + class DummySnappy(object): + def compress(self, buf): + raise TTransportException(TTransportException.INVALID_TRANSFORM, + 'snappy module not available') + + def decompress(self, buf): + raise TTransportException(TTransportException.INVALID_TRANSFORM, + 'snappy module not available') + snappy = DummySnappy() # type: ignore + + +# Definitions from THeader.h + + +class CLIENT_TYPE: + HEADER = 0 + FRAMED_DEPRECATED = 1 + UNFRAMED_DEPRECATED = 2 + HTTP_SERVER = 3 + HTTP_CLIENT = 4 + FRAMED_COMPACT = 5 + HEADER_SASL = 6 + HTTP_GET = 7 + UNKNOWN = 8 + UNFRAMED_COMPACT_DEPRECATED = 9 + + +class HEADER_FLAG: + SUPPORT_OUT_OF_ORDER = 0x01 + DUPLEX_REVERSE = 0x08 + SASL = 0x10 + + +class TRANSFORM: + NONE = 0x00 + ZLIB = 0x01 + HMAC = 0x02 + SNAPPY = 0x03 + QLZ = 0x04 + ZSTD = 0x05 + + +class INFO: + NORMAL = 1 + PERSISTENT = 2 + + +T_BINARY_PROTOCOL = 0 +T_COMPACT_PROTOCOL = 2 +HEADER_MAGIC = 0x0FFF0000 +PACKED_HEADER_MAGIC = pack(b'!H', HEADER_MAGIC >> 16) +HEADER_MASK = 0xFFFF0000 +FLAGS_MASK = 0x0000FFFF +HTTP_SERVER_MAGIC = 0x504F5354 # POST +HTTP_CLIENT_MAGIC = 0x48545450 # HTTP +HTTP_GET_CLIENT_MAGIC = 0x47455420 # GET +HTTP_HEAD_CLIENT_MAGIC = 0x48454144 # HEAD +BIG_FRAME_MAGIC = 0x42494746 # BIGF +MAX_FRAME_SIZE = 0x3FFFFFFF +MAX_BIG_FRAME_SIZE = 2 ** 61 - 1 + + +class THeaderTransport(TTransportBase, CReadableTransport): + """Transport that sends headers. Also understands framed/unframed/HTTP + transports and will do the right thing""" + + __max_frame_size = MAX_FRAME_SIZE + + # Defaults to current user, but there is also a setter below. + __identity = None + IDENTITY_HEADER = "identity" + ID_VERSION_HEADER = "id_version" + ID_VERSION = "1" + + def __init__(self, trans, client_types=None, client_type=None): + self.__trans = trans + self.__rbuf = StringIO() + self.__rbuf_frame = False + self.__wbuf = StringIO() + self.seq_id = 0 + self.__flags = 0 + self.__read_transforms = [] + self.__write_transforms = [] + self.__supported_client_types = set(client_types or + (CLIENT_TYPE.HEADER,)) + self.__proto_id = T_COMPACT_PROTOCOL # default to compact like c++ + self.__client_type = client_type or CLIENT_TYPE.HEADER + self.__read_headers = {} + self.__read_persistent_headers = {} + self.__write_headers = {} + self.__write_persistent_headers = {} + + self.__supported_client_types.add(self.__client_type) + + # If we support unframed binary / framed binary also support compact + if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types: + self.__supported_client_types.add( + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED) + if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types: + self.__supported_client_types.add( + CLIENT_TYPE.FRAMED_COMPACT) + + def set_header_flag(self, flag): + self.__flags |= flag + + def clear_header_flag(self, flag): + self.__flags &= ~ flag + + def header_flags(self): + return self.__flags + + def set_max_frame_size(self, size): + if size > MAX_BIG_FRAME_SIZE: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "Cannot set max frame size > %s" % + MAX_BIG_FRAME_SIZE) + if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER: + raise TTransportException( + TTransportException.INVALID_FRAME_SIZE, + "Cannot set max frame size > %s for clients other than HEADER" + % MAX_FRAME_SIZE) + self.__max_frame_size = size + + def get_peer_identity(self): + if self.IDENTITY_HEADER in self.__read_headers: + if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION: + return self.__read_headers[self.IDENTITY_HEADER] + return None + + def set_identity(self, identity): + self.__identity = identity + + def get_protocol_id(self): + return self.__proto_id + + def set_protocol_id(self, proto_id): + self.__proto_id = proto_id + + def set_header(self, str_key, str_value): + self.__write_headers[str_key] = str_value + + def get_write_headers(self): + return self.__write_headers + + def get_headers(self): + return self.__read_headers + + def clear_headers(self): + self.__write_headers.clear() + + def set_persistent_header(self, str_key, str_value): + self.__write_persistent_headers[str_key] = str_value + + def get_write_persistent_headers(self): + return self.__write_persistent_headers + + def clear_persistent_headers(self): + self.__write_persistent_headers.clear() + + def add_transform(self, trans_id): + self.__write_transforms.append(trans_id) + + def _reset_protocol(self): + # HTTP calls that are one way need to flush here. + if self.__client_type == CLIENT_TYPE.HTTP_SERVER: + self.flush() + # set to anything except unframed + self.__client_type = CLIENT_TYPE.UNKNOWN + # Read header bytes to check which protocol to decode + self.readFrame(0) + + def getTransport(self): + return self.__trans + + def isOpen(self): + return self.getTransport().isOpen() + + def open(self): + return self.getTransport().open() + + def close(self): + return self.getTransport().close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) == sz: + return ret + + if self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): + return ret + self.getTransport().readAll(sz - len(ret)) + + self.readFrame(sz - len(ret)) + return ret + self.__rbuf.read(sz - len(ret)) + + readAll = read # TTransportBase.readAll does a needless copy here. + + def readFrame(self, req_sz): + self.__rbuf_frame = True + word1 = self.getTransport().readAll(4) + sz = unpack('!I', word1)[0] + proto_id = word1[0] if PY3 else ord(word1[0]) + if proto_id == BINARY_PROTO_ID: + # unframed + self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED + self.__proto_id = T_BINARY_PROTOCOL + if req_sz <= 4: # check for reads < 0. + self.__rbuf = StringIO(word1) + else: + self.__rbuf = StringIO(word1 + self.getTransport().read( + req_sz - 4)) + elif proto_id == COMPACT_PROTO_ID: + self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED + self.__proto_id = T_COMPACT_PROTOCOL + if req_sz <= 4: # check for reads < 0. + self.__rbuf = StringIO(word1) + else: + self.__rbuf = StringIO(word1 + self.getTransport().read( + req_sz - 4)) + elif sz == HTTP_SERVER_MAGIC: + self.__client_type = CLIENT_TYPE.HTTP_SERVER + mf = self.getTransport().handle.makefile('rb', -1) + + self.handler = RequestHandler(mf, + 'client_address:port', '') + self.header = self.handler.wfile + self.__rbuf = StringIO(self.handler.data) + else: + if sz == BIG_FRAME_MAGIC: + sz = unpack('!Q', self.getTransport().readAll(8))[0] + # could be header format or framed. Check next two bytes. + magic = self.getTransport().readAll(2) + proto_id = magic[0] if PY3 else ord(magic[0]) + if proto_id == COMPACT_PROTO_ID: + self.__client_type = CLIENT_TYPE.FRAMED_COMPACT + self.__proto_id = T_COMPACT_PROTOCOL + _frame_size_check(sz, self.__max_frame_size, header=False) + self.__rbuf = StringIO(magic + self.getTransport().readAll( + sz - 2)) + elif proto_id == BINARY_PROTO_ID: + self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED + self.__proto_id = T_BINARY_PROTOCOL + _frame_size_check(sz, self.__max_frame_size, header=False) + self.__rbuf = StringIO(magic + self.getTransport().readAll( + sz - 2)) + elif magic == PACKED_HEADER_MAGIC: + self.__client_type = CLIENT_TYPE.HEADER + _frame_size_check(sz, self.__max_frame_size) + # flags(2), seq_id(4), header_size(2) + n_header_meta = self.getTransport().readAll(8) + self.__flags, self.seq_id, header_size = unpack('!HIH', + n_header_meta) + data = StringIO() + data.write(magic) + data.write(n_header_meta) + data.write(self.getTransport().readAll(sz - 10)) + data.seek(10) + self.read_header_format(sz - 10, header_size, data) + else: + self.__client_type = CLIENT_TYPE.UNKNOWN + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Could not detect client transport type") + + if self.__client_type not in self.__supported_client_types: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Client type {} not supported on server" + .format(self.__client_type)) + + def read_header_format(self, sz, header_size, data): + # clear out any previous transforms + self.__read_transforms = [] + + header_size = header_size * 4 + if header_size > sz: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "Header size is larger than frame") + end_header = header_size + data.tell() + + self.__proto_id = readVarint(data) + num_headers = readVarint(data) + + if self.__proto_id == 1 and self.__client_type != \ + CLIENT_TYPE.HTTP_SERVER: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Trying to recv JSON encoding over binary") + + # Read the headers. Data for each header varies. + for _ in range(0, num_headers): + trans_id = readVarint(data) + if trans_id == TRANSFORM.ZLIB: + self.__read_transforms.insert(0, trans_id) + elif trans_id == TRANSFORM.SNAPPY: + self.__read_transforms.insert(0, trans_id) + elif trans_id == TRANSFORM.HMAC: + raise TApplicationException( + TApplicationException.INVALID_TRANSFORM, + "Hmac transform is no longer supported: %i" % trans_id) + else: + # TApplicationException will be sent back to client + raise TApplicationException( + TApplicationException.INVALID_TRANSFORM, + "Unknown transform in client request: %i" % trans_id) + + # Clear out previous info headers. + self.__read_headers.clear() + + # Read the info headers. + while data.tell() < end_header: + info_id = readVarint(data) + if info_id == INFO.NORMAL: + _read_info_headers( + data, end_header, self.__read_headers) + elif info_id == INFO.PERSISTENT: + _read_info_headers( + data, end_header, self.__read_persistent_headers) + else: + break # Unknown header. Stop info processing. + + if self.__read_persistent_headers: + self.__read_headers.update(self.__read_persistent_headers) + + # Skip the rest of the header + data.seek(end_header) + + payload = data.read(sz - header_size) + + # Read the data section. + self.__rbuf = StringIO(self.untransform(payload)) + + def write(self, buf): + self.__wbuf.write(buf) + + def transform(self, buf): + for trans_id in self.__write_transforms: + if trans_id == TRANSFORM.ZLIB: + buf = zlib.compress(buf) + elif trans_id == TRANSFORM.SNAPPY: + buf = snappy.compress(buf) + else: + raise TTransportException(TTransportException.INVALID_TRANSFORM, + "Unknown transform during send") + return buf + + def untransform(self, buf): + for trans_id in self.__read_transforms: + if trans_id == TRANSFORM.ZLIB: + buf = zlib.decompress(buf) + elif trans_id == TRANSFORM.SNAPPY: + buf = snappy.decompress(buf) + if trans_id not in self.__write_transforms: + self.__write_transforms.append(trans_id) + return buf + + def flush(self): + self.flushImpl(False) + + def onewayFlush(self): + self.flushImpl(True) + + def _flushHeaderMessage(self, buf, wout, wsz): + """Write a message for CLIENT_TYPE.HEADER + + @param buf(StringIO): Buffer to write message to + @param wout(str): Payload + @param wsz(int): Payload length + """ + transform_data = StringIO() + # For now, all transforms don't require data. + num_transforms = len(self.__write_transforms) + for trans_id in self.__write_transforms: + transform_data.write(getVarint(trans_id)) + + # Add in special flags. + if self.__identity: + self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION + self.__write_headers[self.IDENTITY_HEADER] = self.__identity + + info_data = StringIO() + + # Write persistent kv-headers + _flush_info_headers(info_data, + self.get_write_persistent_headers(), + INFO.PERSISTENT) + + # Write non-persistent kv-headers + _flush_info_headers(info_data, + self.__write_headers, + INFO.NORMAL) + + header_data = StringIO() + header_data.write(getVarint(self.__proto_id)) + header_data.write(getVarint(num_transforms)) + + header_size = transform_data.tell() + header_data.tell() + \ + info_data.tell() + + padding_size = 4 - (header_size % 4) + header_size = header_size + padding_size + + # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2) + wsz += header_size + 10 + if wsz > MAX_FRAME_SIZE: + buf.write(pack("!I", BIG_FRAME_MAGIC)) + buf.write(pack("!Q", wsz)) + else: + buf.write(pack("!I", wsz)) + buf.write(pack("!HH", HEADER_MAGIC >> 16, self.__flags)) + buf.write(pack("!I", self.seq_id)) + buf.write(pack("!H", header_size // 4)) + + buf.write(header_data.getvalue()) + buf.write(transform_data.getvalue()) + buf.write(info_data.getvalue()) + + # Pad out the header with 0x00 + for _ in range(0, padding_size, 1): + buf.write(pack("!c", b'\0')) + + # Send data section + buf.write(wout) + + def flushImpl(self, oneway): + wout = self.__wbuf.getvalue() + wout = self.transform(wout) + wsz = len(wout) + + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf.seek(0) + self.__wbuf.truncate() + + if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Trying to send JSON encoding over binary") + + buf = StringIO() + if self.__client_type == CLIENT_TYPE.HEADER: + self._flushHeaderMessage(buf, wout, wsz) + elif self.__client_type in (CLIENT_TYPE.FRAMED_DEPRECATED, + CLIENT_TYPE.FRAMED_COMPACT): + buf.write(pack("!i", wsz)) + buf.write(wout) + elif self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): + buf.write(wout) + elif self.__client_type == CLIENT_TYPE.HTTP_SERVER: + # Reset the client type if we sent something - + # oneway calls via HTTP expect a status response otherwise + buf.write(self.header.getvalue()) + buf.write(wout) + self.__client_type == CLIENT_TYPE.HEADER + elif self.__client_type == CLIENT_TYPE.UNKNOWN: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Unknown client type") + + # We don't include the framing bytes as part of the frame size check + frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12) + _frame_size_check(frame_size, + self.__max_frame_size, + header=self.__client_type == CLIENT_TYPE.HEADER) + self.getTransport().write(buf.getvalue()) + if oneway: + self.getTransport().onewayFlush() + else: + self.getTransport().flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + if not self.__rbuf_frame: + self.readFrame(0) + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastproto doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + + # On unframed clients, there is a chance there is something left + # in rbuf, and the read pointer is not advanced by fastproto + # so seek to the end to be safe + self.__rbuf.seek(0, 2) + while len(prefix) < reqlen: + prefix += self.read(reqlen) + self.__rbuf = StringIO(prefix) + return self.__rbuf + + +def _serialize_string(str_): + if PY3 and not isinstance(str_, bytes): + str_ = str_.encode() + return getVarint(len(str_)) + str_ + + +def _flush_info_headers(info_data, write_headers, type): + if (len(write_headers) > 0): + info_data.write(getVarint(type)) + info_data.write(getVarint(len(write_headers))) + write_headers_iter = write_headers.items() + for str_key, str_value in write_headers_iter: + info_data.write(_serialize_string(str_key)) + info_data.write(_serialize_string(str_value)) + write_headers.clear() + + +def _read_string(bufio, buflimit): + str_sz = readVarint(bufio) + if str_sz + bufio.tell() > buflimit: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "String read too big") + return bufio.read(str_sz) + + +def _read_info_headers(data, end_header, read_headers): + num_keys = readVarint(data) + for _ in xrange(num_keys): + str_key = _read_string(data, end_header) + str_value = _read_string(data, end_header) + read_headers[str_key] = str_value + + +def _frame_size_check(sz, set_max_size, header=True): + if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE): + raise TTransportException( + TTransportException.INVALID_FRAME_SIZE, + "%s transport frame was too large" % 'Header' if header else 'Framed' + ) + + +class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): + + # Same as superclass function, but append 'POST' because we + # stripped it in the calling function. Would be nice if + # we had an ungetch instead + def handle_one_request(self): + self.raw_requestline = self.rfile.readline() + if not self.raw_requestline: + self.close_connection = 1 + return + self.raw_requestline = "POST" + self.raw_requestline + if not self.parse_request(): + # An error code has been sent, just exit + return + mname = 'do_' + self.command + if not hasattr(self, mname): + self.send_error(501, "Unsupported method (%r)" % self.command) + return + method = getattr(self, mname) + method() + + def setup(self): + self.rfile = self.request + self.wfile = StringIO() # New output buffer + + def finish(self): + if not self.rfile.closed: + self.rfile.close() + # leave wfile open for reading. + + def do_POST(self): + if int(self.headers['Content-Length']) > 0: + self.data = self.rfile.read(int(self.headers['Content-Length'])) + else: + self.data = "" + + # Prepare a response header, to be sent later. + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() + +# INFO:(zuercher): Added to simplify usage +class THeaderTransportFactory: + def getTransport(self, trans): + return THeaderTransport(trans, client_type=CLIENT_TYPE.HEADER) diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/BUILD b/test/extensions/filters/network/thrift_proxy/driver/finagle/BUILD new file mode 100644 index 000000000000..71fa29d64063 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/BUILD @@ -0,0 +1,19 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +py_library( + name = "finagle_lib", + srcs = [ + "TFinagleServerProcessor.py", + "TFinagleServerProtocol.py", + "__init__.py", + ], + deps = [ + "@com_github_apache_thrift//:apache_thrift", + "@com_github_twitter_common_finagle_thrift//:twitter_common_finagle_thrift", + "@com_github_twitter_common_rpc//:twitter_common_rpc", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py new file mode 100644 index 000000000000..3b207152ea21 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py @@ -0,0 +1,56 @@ +import logging + +from thrift.Thrift import TProcessor, TMessageType, TException +from thrift.protocol import TProtocolDecorator +from gen.twitter.finagle.thrift.ttypes import (ConnectionOptions, UpgradeReply) + +# Matches twitter/common/rpc/finagle/protocol.py +UPGRADE_METHOD = "__can__finagle__trace__v3__" + +# Twitter's TFinagleProcessor only works for the client side of an RPC. +class TFinagleServerProcessor(TProcessor): + def __init__(self, underlying): + self._underlying = underlying + + def process(self, iprot, oprot): + try: + if iprot.upgraded() is not None: + return self._underlying.process(iprot, oprot) + except AttributeError as e: + logging.exception("underlying protocol object is not a TFinagleServerProtocol", e) + return self._underlying.process(iprot, oprot) + + (name, ttype, seqid) = iprot.readMessageBegin() + if ttype != TMessageType.CALL and ttype != TMessageType.ONEWAY: + raise TException("TFinagle protocol only supports CALL & ONEWAY") + + # Check if this is an upgrade request. + if name == UPGRADE_METHOD: + connection_options = ConnectionOptions() + connection_options.read(iprot) + iprot.readMessageEnd() + + oprot.writeMessageBegin(UPGRADE_METHOD, TMessageType.REPLY, seqid) + upgrade_reply = UpgradeReply() + upgrade_reply.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + iprot.set_upgraded(True) + oprot.set_upgraded(True) + return True + + # Not upgraded. Replay the message begin to the underlying processor. + iprot.set_upgraded(False) + oprot.set_upgraded(False) + msg = (name, ttype, seqid) + return self._underlying.process(StoredMessageProtocol(iprot, msg), oprot) + + +class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): + def __init__(self, protocol, messageBegin): + TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) + self.messageBegin = messageBegin + + def readMessageBegin(self): + return self.messageBegin diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py new file mode 100644 index 000000000000..dcdad5122e0e --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py @@ -0,0 +1,33 @@ +from thrift.protocol import TBinaryProtocol +from gen.twitter.finagle.thrift.ttypes import (RequestHeader, ResponseHeader) + + +class TFinagleServerProtocolFactory(object): + def getProtocol(self, trans): + return TFinagleServerProtocol(trans) + + +class TFinagleServerProtocol(TBinaryProtocol.TBinaryProtocol): + def __init__(self, *args, **kw): + self._last_request = None + self._upgraded = None + TBinaryProtocol.TBinaryProtocol.__init__(self, *args, **kw) + + def upgraded(self): + return self._upgraded + + def set_upgraded(self, upgraded): + self._upgraded = upgraded + + def writeMessageBegin(self, *args, **kwargs): + if self._upgraded: + header = ResponseHeader() # .. TODO set some fields + header.write(self) + return TBinaryProtocol.TBinaryProtocol.writeMessageBegin(self, *args, **kwargs) + + def readMessageBegin(self, *args, **kwargs): + if self._upgraded: + header = RequestHeader() + header.read(self) + self._last_request = header + return TBinaryProtocol.TBinaryProtocol.readMessageBegin(self, *args, **kwargs) diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/extensions/filters/network/thrift_proxy/driver/generate_bindings.sh b/test/extensions/filters/network/thrift_proxy/driver/generate_bindings.sh new file mode 100755 index 000000000000..6b65871512c0 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generate_bindings.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# Generates the thrift bindings for example.thrift. Requires that +# apache-thrift's thrift generator is installed and on the path. + +DIR=$(cd `dirname $0` && pwd) +cd "${DIR}" + +thrift --gen py --out ./generated example.thrift diff --git a/test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh b/test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh new file mode 100755 index 000000000000..be83b3eb6599 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +# Generates request and response fixtures for integration tests. + +# Usage: generate_fixture.sh [multiplex-service] -- method [param...] + +set -e + +function usage() { + echo "Usage: $0 [multiplex-service] -- method [param...]" + echo "where mode is success, exception, or idl-exception" + exit 1 +} + +FIXTURE_DIR="${TEST_TMPDIR}" +mkdir -p "${FIXTURE_DIR}" + +DRIVER_DIR="${TEST_RUNDIR}/test/extensions/filters/network/thrift_proxy/driver" + +if [[ -z "${TEST_UDSDIR}" ]]; then + TEST_UDSDIR=`mktemp -d /tmp/envoy_test_thrift.XXXXXX` +fi + +MODE="$1" +TRANSPORT="$2" +PROTOCOL="$3" +MULTIPLEX="$4" +if ! shift 4; then + usage +fi + +if [[ -z "${MODE}" || -z "${TRANSPORT}" || -z "${PROTOCOL}" || -z "${MULTIPLEX}" ]]; then + usage +fi + +if [[ "${MULTIPLEX}" != "--" ]]; then + if [[ "$1" != "--" ]]; then + echo "expected -- after multiplex service name" + exit 1 + fi + shift +else + MULTIPLEX="" +fi + +METHOD="$1" +if [[ "${METHOD}" == "" ]]; then + usage +fi +shift + +SOCKET="${TEST_UDSDIR}/fixture.sock" +rm -f "${SOCKET}" + +SERVICE_FLAGS=("--addr" "${SOCKET}" + "--unix" + "--response" "${MODE}" + "--transport" "${TRANSPORT}" + "--protocol" "${PROTOCOL}") + +if [[ -n "$MULTIPLEX" ]]; then + SERVICE_FLAGS[9]="--multiplex" + SERVICE_FLAGS[10]="${MULTIPLEX}" + + REQUEST_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MULTIPLEX}-${MODE}.request" + RESPONSE_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MULTIPLEX}-${MODE}.response" +else + REQUEST_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MODE}.request" + RESPONSE_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MODE}.response" +fi + +# start server +"${DRIVER_DIR}/server" "${SERVICE_FLAGS[@]}" & +SERVER_PID="$!" + +trap "kill ${SERVER_PID}" EXIT; + +while [[ ! -a "${SOCKET}" ]]; do + sleep 0.1 + + if ! kill -0 "${SERVER_PID}"; then + echo "server failed to start" + exit 1 + fi +done + +"${DRIVER_DIR}/client" "${SERVICE_FLAGS[@]}" \ + --request "${REQUEST_FILE}" \ + --response "${RESPONSE_FILE}" \ + "${METHOD}" "$@" diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/generated/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/BUILD b/test/extensions/filters/network/thrift_proxy/driver/generated/example/BUILD new file mode 100644 index 000000000000..6c9595737b16 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/BUILD @@ -0,0 +1,18 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +py_library( + name = "example_lib", + srcs = [ + "Example.py", + "__init__.py", + "constants.py", + "ttypes.py", + ], + deps = [ + "@com_github_apache_thrift//:apache_thrift", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example-remote b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example-remote new file mode 100755 index 000000000000..11d032908d65 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example-remote @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +import sys +import pprint +if sys.version_info[0] > 2: + from urllib.parse import urlparse +else: + from urlparse import urlparse +from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient +from thrift.protocol.TBinaryProtocol import TBinaryProtocol + +from example import Example +from example.ttypes import * + +if len(sys.argv) <= 1 or sys.argv[1] == '--help': + print('') + print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]') + print('') + print('Functions:') + print(' void ping()') + print(' void poke()') + print(' i32 add(i32 a, i32 b)') + print(' Result execute(Param input)') + print('') + sys.exit(0) + +pp = pprint.PrettyPrinter(indent=2) +host = 'localhost' +port = 9090 +uri = '' +framed = False +ssl = False +validate = True +ca_certs = None +keyfile = None +certfile = None +http = False +argi = 1 + +if sys.argv[argi] == '-h': + parts = sys.argv[argi + 1].split(':') + host = parts[0] + if len(parts) > 1: + port = int(parts[1]) + argi += 2 + +if sys.argv[argi] == '-u': + url = urlparse(sys.argv[argi + 1]) + parts = url[1].split(':') + host = parts[0] + if len(parts) > 1: + port = int(parts[1]) + else: + port = 80 + uri = url[2] + if url[4]: + uri += '?%s' % url[4] + http = True + argi += 2 + +if sys.argv[argi] == '-f' or sys.argv[argi] == '-framed': + framed = True + argi += 1 + +if sys.argv[argi] == '-s' or sys.argv[argi] == '-ssl': + ssl = True + argi += 1 + +if sys.argv[argi] == '-novalidate': + validate = False + argi += 1 + +if sys.argv[argi] == '-ca_certs': + ca_certs = sys.argv[argi+1] + argi += 2 + +if sys.argv[argi] == '-keyfile': + keyfile = sys.argv[argi+1] + argi += 2 + +if sys.argv[argi] == '-certfile': + certfile = sys.argv[argi+1] + argi += 2 + +cmd = sys.argv[argi] +args = sys.argv[argi + 1:] + +if http: + transport = THttpClient.THttpClient(host, port, uri) +else: + if ssl: + socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile) + else: + socket = TSocket.TSocket(host, port) + if framed: + transport = TTransport.TFramedTransport(socket) + else: + transport = TTransport.TBufferedTransport(socket) +protocol = TBinaryProtocol(transport) +client = Example.Client(protocol) +transport.open() + +if cmd == 'ping': + if len(args) != 0: + print('ping requires 0 args') + sys.exit(1) + pp.pprint(client.ping()) + +elif cmd == 'poke': + if len(args) != 0: + print('poke requires 0 args') + sys.exit(1) + pp.pprint(client.poke()) + +elif cmd == 'add': + if len(args) != 2: + print('add requires 2 args') + sys.exit(1) + pp.pprint(client.add(eval(args[0]), eval(args[1]),)) + +elif cmd == 'execute': + if len(args) != 1: + print('execute requires 1 args') + sys.exit(1) + pp.pprint(client.execute(eval(args[0]),)) + +else: + print('Unrecognized method %s' % cmd) + sys.exit(1) + +transport.close() diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example.py new file mode 100644 index 000000000000..325cbff2bae3 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example.py @@ -0,0 +1,660 @@ +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException +from thrift.protocol.TProtocol import TProtocolException +from thrift.TRecursive import fix_spec + +import sys +import logging +from .ttypes import * +from thrift.Thrift import TProcessor +from thrift.transport import TTransport +all_structs = [] + + +class Iface(object): + def ping(self): + pass + + def poke(self): + pass + + def add(self, a, b): + """ + Parameters: + - a + - b + """ + pass + + def execute(self, input): + """ + Parameters: + - input + """ + pass + + +class Client(Iface): + def __init__(self, iprot, oprot=None): + self._iprot = self._oprot = iprot + if oprot is not None: + self._oprot = oprot + self._seqid = 0 + + def ping(self): + self.send_ping() + self.recv_ping() + + def send_ping(self): + self._oprot.writeMessageBegin('ping', TMessageType.CALL, self._seqid) + args = ping_args() + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def recv_ping(self): + iprot = self._iprot + (fname, mtype, rseqid) = iprot.readMessageBegin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(iprot) + iprot.readMessageEnd() + raise x + result = ping_result() + result.read(iprot) + iprot.readMessageEnd() + return + + def poke(self): + self.send_poke() + + def send_poke(self): + self._oprot.writeMessageBegin('poke', TMessageType.ONEWAY, self._seqid) + args = poke_args() + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def add(self, a, b): + """ + Parameters: + - a + - b + """ + self.send_add(a, b) + return self.recv_add() + + def send_add(self, a, b): + self._oprot.writeMessageBegin('add', TMessageType.CALL, self._seqid) + args = add_args() + args.a = a + args.b = b + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def recv_add(self): + iprot = self._iprot + (fname, mtype, rseqid) = iprot.readMessageBegin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(iprot) + iprot.readMessageEnd() + raise x + result = add_result() + result.read(iprot) + iprot.readMessageEnd() + if result.success is not None: + return result.success + raise TApplicationException(TApplicationException.MISSING_RESULT, "add failed: unknown result") + + def execute(self, input): + """ + Parameters: + - input + """ + self.send_execute(input) + return self.recv_execute() + + def send_execute(self, input): + self._oprot.writeMessageBegin('execute', TMessageType.CALL, self._seqid) + args = execute_args() + args.input = input + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def recv_execute(self): + iprot = self._iprot + (fname, mtype, rseqid) = iprot.readMessageBegin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(iprot) + iprot.readMessageEnd() + raise x + result = execute_result() + result.read(iprot) + iprot.readMessageEnd() + if result.success is not None: + return result.success + if result.appex is not None: + raise result.appex + raise TApplicationException(TApplicationException.MISSING_RESULT, "execute failed: unknown result") + + +class Processor(Iface, TProcessor): + def __init__(self, handler): + self._handler = handler + self._processMap = {} + self._processMap["ping"] = Processor.process_ping + self._processMap["poke"] = Processor.process_poke + self._processMap["add"] = Processor.process_add + self._processMap["execute"] = Processor.process_execute + + def process(self, iprot, oprot): + (name, type, seqid) = iprot.readMessageBegin() + if name not in self._processMap: + iprot.skip(TType.STRUCT) + iprot.readMessageEnd() + x = TApplicationException(TApplicationException.UNKNOWN_METHOD, 'Unknown function %s' % (name)) + oprot.writeMessageBegin(name, TMessageType.EXCEPTION, seqid) + x.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + return + else: + self._processMap[name](self, seqid, iprot, oprot) + return True + + def process_ping(self, seqid, iprot, oprot): + args = ping_args() + args.read(iprot) + iprot.readMessageEnd() + result = ping_result() + try: + self._handler.ping() + msg_type = TMessageType.REPLY + except TTransport.TTransportException: + raise + except TApplicationException as ex: + logging.exception('TApplication exception in handler') + msg_type = TMessageType.EXCEPTION + result = ex + except Exception: + logging.exception('Unexpected exception in handler') + msg_type = TMessageType.EXCEPTION + result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') + oprot.writeMessageBegin("ping", msg_type, seqid) + result.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + def process_poke(self, seqid, iprot, oprot): + args = poke_args() + args.read(iprot) + iprot.readMessageEnd() + try: + self._handler.poke() + except TTransport.TTransportException: + raise + except Exception: + logging.exception('Exception in oneway handler') + + def process_add(self, seqid, iprot, oprot): + args = add_args() + args.read(iprot) + iprot.readMessageEnd() + result = add_result() + try: + result.success = self._handler.add(args.a, args.b) + msg_type = TMessageType.REPLY + except TTransport.TTransportException: + raise + except TApplicationException as ex: + logging.exception('TApplication exception in handler') + msg_type = TMessageType.EXCEPTION + result = ex + except Exception: + logging.exception('Unexpected exception in handler') + msg_type = TMessageType.EXCEPTION + result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') + oprot.writeMessageBegin("add", msg_type, seqid) + result.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + def process_execute(self, seqid, iprot, oprot): + args = execute_args() + args.read(iprot) + iprot.readMessageEnd() + result = execute_result() + try: + result.success = self._handler.execute(args.input) + msg_type = TMessageType.REPLY + except TTransport.TTransportException: + raise + except AppException as appex: + msg_type = TMessageType.REPLY + result.appex = appex + except TApplicationException as ex: + logging.exception('TApplication exception in handler') + msg_type = TMessageType.EXCEPTION + result = ex + except Exception: + logging.exception('Unexpected exception in handler') + msg_type = TMessageType.EXCEPTION + result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') + oprot.writeMessageBegin("execute", msg_type, seqid) + result.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + +# HELPER FUNCTIONS AND STRUCTURES + + +class ping_args(object): + + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('ping_args') + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(ping_args) +ping_args.thrift_spec = ( +) + + +class ping_result(object): + + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('ping_result') + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(ping_result) +ping_result.thrift_spec = ( +) + + +class poke_args(object): + + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('poke_args') + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(poke_args) +poke_args.thrift_spec = ( +) + + +class add_args(object): + """ + Attributes: + - a + - b + """ + + + def __init__(self, a=None, b=None,): + self.a = a + self.b = b + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.I32: + self.a = iprot.readI32() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.b = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('add_args') + if self.a is not None: + oprot.writeFieldBegin('a', TType.I32, 1) + oprot.writeI32(self.a) + oprot.writeFieldEnd() + if self.b is not None: + oprot.writeFieldBegin('b', TType.I32, 2) + oprot.writeI32(self.b) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(add_args) +add_args.thrift_spec = ( + None, # 0 + (1, TType.I32, 'a', None, None, ), # 1 + (2, TType.I32, 'b', None, None, ), # 2 +) + + +class add_result(object): + """ + Attributes: + - success + """ + + + def __init__(self, success=None,): + self.success = success + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 0: + if ftype == TType.I32: + self.success = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('add_result') + if self.success is not None: + oprot.writeFieldBegin('success', TType.I32, 0) + oprot.writeI32(self.success) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(add_result) +add_result.thrift_spec = ( + (0, TType.I32, 'success', None, None, ), # 0 +) + + +class execute_args(object): + """ + Attributes: + - input + """ + + + def __init__(self, input=None,): + self.input = input + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRUCT: + self.input = Param() + self.input.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('execute_args') + if self.input is not None: + oprot.writeFieldBegin('input', TType.STRUCT, 1) + self.input.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(execute_args) +execute_args.thrift_spec = ( + None, # 0 + (1, TType.STRUCT, 'input', [Param, None], None, ), # 1 +) + + +class execute_result(object): + """ + Attributes: + - success + - appex + """ + + + def __init__(self, success=None, appex=None,): + self.success = success + self.appex = appex + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 0: + if ftype == TType.STRUCT: + self.success = Result() + self.success.read(iprot) + else: + iprot.skip(ftype) + elif fid == 1: + if ftype == TType.STRUCT: + self.appex = AppException() + self.appex.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('execute_result') + if self.success is not None: + oprot.writeFieldBegin('success', TType.STRUCT, 0) + self.success.write(oprot) + oprot.writeFieldEnd() + if self.appex is not None: + oprot.writeFieldBegin('appex', TType.STRUCT, 1) + self.appex.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(execute_result) +execute_result.thrift_spec = ( + (0, TType.STRUCT, 'success', [Result, None], None, ), # 0 + (1, TType.STRUCT, 'appex', [AppException, None], None, ), # 1 +) +fix_spec(all_structs) +del all_structs + diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/__init__.py new file mode 100644 index 000000000000..a53ccc6084ee --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/__init__.py @@ -0,0 +1 @@ +__all__ = ['ttypes', 'constants', 'Example'] diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/constants.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/constants.py new file mode 100644 index 000000000000..0c217ceda691 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/constants.py @@ -0,0 +1,14 @@ +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException +from thrift.protocol.TProtocol import TProtocolException +from thrift.TRecursive import fix_spec + +import sys +from .ttypes import * diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/ttypes.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/ttypes.py new file mode 100644 index 000000000000..89aa4a9f6233 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/ttypes.py @@ -0,0 +1,445 @@ +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException +from thrift.protocol.TProtocol import TProtocolException +from thrift.TRecursive import fix_spec + +import sys + +from thrift.transport import TTransport +all_structs = [] + + +class TheWorks(object): + """ + Attributes: + - field_1 + - field_2 + - field_3 + - field_4 + - field_5 + - field_6 + - field_7 + - field_8 + - field_9 + - field_10 + - field_11 + - field_12 + """ + + + def __init__(self, field_1=None, field_2=None, field_3=None, field_4=None, field_5=None, field_6=None, field_7=None, field_8=None, field_9=None, field_10=None, field_11=None, field_12=None,): + self.field_1 = field_1 + self.field_2 = field_2 + self.field_3 = field_3 + self.field_4 = field_4 + self.field_5 = field_5 + self.field_6 = field_6 + self.field_7 = field_7 + self.field_8 = field_8 + self.field_9 = field_9 + self.field_10 = field_10 + self.field_11 = field_11 + self.field_12 = field_12 + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.BOOL: + self.field_1 = iprot.readBool() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.BYTE: + self.field_2 = iprot.readByte() + else: + iprot.skip(ftype) + elif fid == 3: + if ftype == TType.I16: + self.field_3 = iprot.readI16() + else: + iprot.skip(ftype) + elif fid == 4: + if ftype == TType.I32: + self.field_4 = iprot.readI32() + else: + iprot.skip(ftype) + elif fid == 5: + if ftype == TType.I64: + self.field_5 = iprot.readI64() + else: + iprot.skip(ftype) + elif fid == 6: + if ftype == TType.DOUBLE: + self.field_6 = iprot.readDouble() + else: + iprot.skip(ftype) + elif fid == 7: + if ftype == TType.STRING: + self.field_7 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + else: + iprot.skip(ftype) + elif fid == 8: + if ftype == TType.STRING: + self.field_8 = iprot.readBinary() + else: + iprot.skip(ftype) + elif fid == 9: + if ftype == TType.MAP: + self.field_9 = {} + (_ktype1, _vtype2, _size0) = iprot.readMapBegin() + for _i4 in range(_size0): + _key5 = iprot.readI32() + _val6 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + self.field_9[_key5] = _val6 + iprot.readMapEnd() + else: + iprot.skip(ftype) + elif fid == 10: + if ftype == TType.LIST: + self.field_10 = [] + (_etype10, _size7) = iprot.readListBegin() + for _i11 in range(_size7): + _elem12 = iprot.readI32() + self.field_10.append(_elem12) + iprot.readListEnd() + else: + iprot.skip(ftype) + elif fid == 11: + if ftype == TType.SET: + self.field_11 = set() + (_etype16, _size13) = iprot.readSetBegin() + for _i17 in range(_size13): + _elem18 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + self.field_11.add(_elem18) + iprot.readSetEnd() + else: + iprot.skip(ftype) + elif fid == 12: + if ftype == TType.BOOL: + self.field_12 = iprot.readBool() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('TheWorks') + if self.field_1 is not None: + oprot.writeFieldBegin('field_1', TType.BOOL, 1) + oprot.writeBool(self.field_1) + oprot.writeFieldEnd() + if self.field_2 is not None: + oprot.writeFieldBegin('field_2', TType.BYTE, 2) + oprot.writeByte(self.field_2) + oprot.writeFieldEnd() + if self.field_3 is not None: + oprot.writeFieldBegin('field_3', TType.I16, 3) + oprot.writeI16(self.field_3) + oprot.writeFieldEnd() + if self.field_4 is not None: + oprot.writeFieldBegin('field_4', TType.I32, 4) + oprot.writeI32(self.field_4) + oprot.writeFieldEnd() + if self.field_5 is not None: + oprot.writeFieldBegin('field_5', TType.I64, 5) + oprot.writeI64(self.field_5) + oprot.writeFieldEnd() + if self.field_6 is not None: + oprot.writeFieldBegin('field_6', TType.DOUBLE, 6) + oprot.writeDouble(self.field_6) + oprot.writeFieldEnd() + if self.field_7 is not None: + oprot.writeFieldBegin('field_7', TType.STRING, 7) + oprot.writeString(self.field_7.encode('utf-8') if sys.version_info[0] == 2 else self.field_7) + oprot.writeFieldEnd() + if self.field_8 is not None: + oprot.writeFieldBegin('field_8', TType.STRING, 8) + oprot.writeBinary(self.field_8) + oprot.writeFieldEnd() + if self.field_9 is not None: + oprot.writeFieldBegin('field_9', TType.MAP, 9) + oprot.writeMapBegin(TType.I32, TType.STRING, len(self.field_9)) + for kiter19, viter20 in self.field_9.items(): + oprot.writeI32(kiter19) + oprot.writeString(viter20.encode('utf-8') if sys.version_info[0] == 2 else viter20) + oprot.writeMapEnd() + oprot.writeFieldEnd() + if self.field_10 is not None: + oprot.writeFieldBegin('field_10', TType.LIST, 10) + oprot.writeListBegin(TType.I32, len(self.field_10)) + for iter21 in self.field_10: + oprot.writeI32(iter21) + oprot.writeListEnd() + oprot.writeFieldEnd() + if self.field_11 is not None: + oprot.writeFieldBegin('field_11', TType.SET, 11) + oprot.writeSetBegin(TType.STRING, len(self.field_11)) + for iter22 in self.field_11: + oprot.writeString(iter22.encode('utf-8') if sys.version_info[0] == 2 else iter22) + oprot.writeSetEnd() + oprot.writeFieldEnd() + if self.field_12 is not None: + oprot.writeFieldBegin('field_12', TType.BOOL, 12) + oprot.writeBool(self.field_12) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + +class Param(object): + """ + Attributes: + - return_fields + - the_works + """ + + + def __init__(self, return_fields=None, the_works=None,): + self.return_fields = return_fields + self.the_works = the_works + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.LIST: + self.return_fields = [] + (_etype26, _size23) = iprot.readListBegin() + for _i27 in range(_size23): + _elem28 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + self.return_fields.append(_elem28) + iprot.readListEnd() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.STRUCT: + self.the_works = TheWorks() + self.the_works.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('Param') + if self.return_fields is not None: + oprot.writeFieldBegin('return_fields', TType.LIST, 1) + oprot.writeListBegin(TType.STRING, len(self.return_fields)) + for iter29 in self.return_fields: + oprot.writeString(iter29.encode('utf-8') if sys.version_info[0] == 2 else iter29) + oprot.writeListEnd() + oprot.writeFieldEnd() + if self.the_works is not None: + oprot.writeFieldBegin('the_works', TType.STRUCT, 2) + self.the_works.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + +class Result(object): + """ + Attributes: + - the_works + """ + + + def __init__(self, the_works=None,): + self.the_works = the_works + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRUCT: + self.the_works = TheWorks() + self.the_works.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('Result') + if self.the_works is not None: + oprot.writeFieldBegin('the_works', TType.STRUCT, 1) + self.the_works.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + +class AppException(TException): + """ + Attributes: + - why + """ + + + def __init__(self, why=None,): + self.why = why + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.why = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('AppException') + if self.why is not None: + oprot.writeFieldBegin('why', TType.STRING, 1) + oprot.writeString(self.why.encode('utf-8') if sys.version_info[0] == 2 else self.why) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __str__(self): + return repr(self) + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(TheWorks) +TheWorks.thrift_spec = ( + None, # 0 + (1, TType.BOOL, 'field_1', None, None, ), # 1 + (2, TType.BYTE, 'field_2', None, None, ), # 2 + (3, TType.I16, 'field_3', None, None, ), # 3 + (4, TType.I32, 'field_4', None, None, ), # 4 + (5, TType.I64, 'field_5', None, None, ), # 5 + (6, TType.DOUBLE, 'field_6', None, None, ), # 6 + (7, TType.STRING, 'field_7', 'UTF8', None, ), # 7 + (8, TType.STRING, 'field_8', 'BINARY', None, ), # 8 + (9, TType.MAP, 'field_9', (TType.I32, None, TType.STRING, 'UTF8', False), None, ), # 9 + (10, TType.LIST, 'field_10', (TType.I32, None, False), None, ), # 10 + (11, TType.SET, 'field_11', (TType.STRING, 'UTF8', False), None, ), # 11 + (12, TType.BOOL, 'field_12', None, None, ), # 12 +) +all_structs.append(Param) +Param.thrift_spec = ( + None, # 0 + (1, TType.LIST, 'return_fields', (TType.STRING, 'UTF8', False), None, ), # 1 + (2, TType.STRUCT, 'the_works', [TheWorks, None], None, ), # 2 +) +all_structs.append(Result) +Result.thrift_spec = ( + None, # 0 + (1, TType.STRUCT, 'the_works', [TheWorks, None], None, ), # 1 +) +all_structs.append(AppException) +AppException.thrift_spec = ( + None, # 0 + (1, TType.STRING, 'why', 'UTF8', None, ), # 1 +) +fix_spec(all_structs) +del all_structs diff --git a/test/extensions/filters/network/thrift_proxy/driver/server.py b/test/extensions/filters/network/thrift_proxy/driver/server.py new file mode 100755 index 000000000000..094a8d2338bf --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/server.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python + +import argparse +import logging +import sys + +from generated.example import Example +from generated.example.ttypes import ( + Result, TheWorks, AppException +) + +from thrift import Thrift, TMultiplexedProcessor +from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol +from thrift.server import TServer +from thrift.transport import TSocket +from thrift.transport import TTransport +from fbthrift import THeaderTransport +from finagle import TFinagleServerProcessor, TFinagleServerProtocol + + +class SuccessHandler: + def ping(self): + print("server: ping()") + + def poke(self): + print("server: poke()") + + def add(self, a, b): + result = a + b + print("server: add({0}, {1}) = {2}".format(a, b, result)) + return result + + def execute(self, param): + print("server: execute({0})".format(param)) + if "all" in param.return_fields: + return Result(param.the_works) + elif "none" in param.return_fields: + return Result(TheWorks()) + the_works = TheWorks() + for field, value in vars(param.the_works).items(): + if field in param.return_fields: + setattr(the_works, field, value) + return Result(the_works) + + +class IDLExceptionHandler: + def ping(self): + print("server: ping()") + + def poke(self): + print("server: poke()") + + def add(self, a, b): + result = a + b + print("server: add({0}, {1}) = {2}".format(a, b, result)) + return result + + def execute(self, param): + print("server: app error: execute failed") + raise AppException("execute failed") + + +class ExceptionHandler: + def ping(self): + print("server: ping failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for ping", + ) + + def poke(self): + print("server: poke failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for poke", + ) + + def add(self, a, b): + print("server: add failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for add", + ) + + def execute(self, param): + print("server: execute failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for execute", + ) + + +def main(cfg): + if cfg.unix: + if cfg.addr == "": + sys.exit("invalid listener unix domain socket: {}".format(cfg.addr)) + else: + try: + (host, port) = cfg.addr.rsplit(":", 1) + port = int(port) + except ValueError: + sys.exit("invalid listener address: {}".format(cfg.addr)) + + if cfg.response == "success": + handler = SuccessHandler() + elif cfg.response == "idl-exception": + handler = IDLExceptionHandler() + elif cfg.response == "exception": + # squelch traceback for the exception we throw + logging.getLogger().setLevel(logging.CRITICAL) + handler = ExceptionHandler() + else: + sys.exit("unknown server response mode {0}".format(cfg.response)) + + processor = Example.Processor(handler) + if cfg.service is not None: + # wrap processor with multiplexor + multi = TMultiplexedProcessor.TMultiplexedProcessor() + multi.registerProcessor(cfg.service, processor) + processor = multi + + if cfg.protocol == "finagle": + # wrap processor with finagle request/response header handler + processor = TFinagleServerProcessor.TFinagleServerProcessor(processor) + + if cfg.unix: + transport = TSocket.TServerSocket(unix_socket=cfg.addr) + else: + transport = TSocket.TServerSocket(host=host, port=port) + + if cfg.transport == "framed": + transport_factory = TTransport.TFramedTransportFactory() + elif cfg.transport == "unframed": + transport_factory = TTransport.TBufferedTransportFactory() + elif cfg.transport == "header": + transport_factory = THeaderTransport.THeaderTransportFactory() + else: + sys.exit("unknown transport {0}".format(cfg.transport)) + + if cfg.protocol == "binary": + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() + elif cfg.protocol == "compact": + protocol_factory = TCompactProtocol.TCompactProtocolFactory() + elif cfg.protocol == "json": + protocol_factory = TJSONProtocol.TJSONProtocolFactory() + elif cfg.protocol == "finagle": + protocol_factory = TFinagleServerProtocol.TFinagleServerProtocolFactory() + else: + sys.exit("unknown protocol {0}".format(cfg.protocol)) + + print("Thrift Server listening on {0} for {1} {2} requests".format( + cfg.addr, cfg.transport, cfg.protocol)) + if cfg.service is not None: + print("Thrift Server service name {0}".format(cfg.service)) + if cfg.response == "idl-exception": + print("Thrift Server will throw IDL exceptions when defined") + elif cfg.response == "exception": + print("Thrift Server will throw Thrift exceptions for all messages") + + server = TServer.TSimpleServer(processor, transport, transport_factory, protocol_factory) + try: + server.serve() + except KeyboardInterrupt: + print + + +if __name__ == "__main__": + logging.basicConfig() + parser = argparse.ArgumentParser(description="Thrift server to match client.py.") + parser.add_argument( + "-a", + "--addr", + metavar="ADDR", + dest="addr", + default=":0", + help="Listener address for server in the form host:port. The host is optional. If --unix" + + " is set, the address is the socket name.", + ) + parser.add_argument( + "-m", + "--multiplex", + metavar="SERVICE", + dest="service", + help="Enable service multiplexing and set the service name.", + ) + parser.add_argument( + "-p", + "--protocol", + help="Selects a protocol.", + dest="protocol", + default="binary", + choices=["binary", "compact", "json", "finagle"], + ) + parser.add_argument( + "-r", + "--response", + dest="response", + default="success", + choices=["success", "idl-exception", "exception"], + help="Controls how the server responds to requests", + ) + parser.add_argument( + "-t", + "--transport", + help="Selects a transport.", + dest="transport", + default="framed", + choices=["framed", "unframed", "header"], + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + action="store_true", + ) + cfg = parser.parse_args() + + try: + main(cfg) + except Thrift.TException as tx: + sys.exit("Thrift exception: {0}".format(tx.message)) diff --git a/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc b/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc new file mode 100644 index 000000000000..f4ca2a32f850 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc @@ -0,0 +1,266 @@ +#include + +#include + +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +#include "test/integration/integration.h" +#include "test/test_common/environment.h" +#include "test/test_common/network_utility.h" + +#include "gtest/gtest.h" + +using testing::Combine; +using testing::TestParamInfo; +using testing::TestWithParam; +using testing::Values; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +std::string thrift_config; + +enum class CallResult { + Success, + IDLException, + Exception, +}; + +class ThriftFilterIntegrationTest + : public BaseIntegrationTest, + public TestWithParam> { +public: + ThriftFilterIntegrationTest() + : BaseIntegrationTest(Network::Address::IpVersion::v4, thrift_config) {} + + static void SetUpTestCase() { + thrift_config = ConfigHelper::BASE_CONFIG + R"EOF( + filter_chains: + filters: + - name: envoy.filters.network.thrift_proxy + config: + stat_prefix: thrift_stats + - name: envoy.tcp_proxy + config: + stat_prefix: tcp_stats + cluster: cluster_0 + )EOF"; + } + + void initializeCall(CallResult result) { + std::tie(transport_, protocol_, multiplexed_) = GetParam(); + + std::string result_mode; + switch (result) { + case CallResult::Success: + result_mode = "success"; + break; + case CallResult::IDLException: + result_mode = "idl-exception"; + break; + case CallResult::Exception: + result_mode = "exception"; + break; + default: + NOT_REACHED; + } + + preparePayloads(result_mode, "execute"); + ASSERT(request_bytes_.length() > 0); + ASSERT(response_bytes_.length() > 0); + + BaseIntegrationTest::initialize(); + } + + void initializeOneway() { + std::tie(transport_, protocol_, multiplexed_) = GetParam(); + + preparePayloads("success", "poke"); + ASSERT(request_bytes_.length() > 0); + ASSERT(response_bytes_.length() == 0); + + BaseIntegrationTest::initialize(); + } + + void preparePayloads(std::string result_mode, std::string method) { + std::vector args = { + TestEnvironment::runfilesPath( + "test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh"), + result_mode, + transport_, + protocol_, + }; + + if (multiplexed_) { + args.push_back("svcname"); + } + args.push_back("--"); + args.push_back(method); + + TestEnvironment::exec(args); + + std::stringstream file_base; + file_base << "{{ test_tmpdir }}/" << transport_ << "-" << protocol_ << "-"; + if (multiplexed_) { + file_base << "svcname-"; + } + file_base << result_mode; + + readAll(file_base.str() + ".request", request_bytes_); + readAll(file_base.str() + ".response", response_bytes_); + } + + void TearDown() override { + test_server_.reset(); + fake_upstreams_.clear(); + } + +protected: + void readAll(std::string file, Buffer::OwnedImpl& buffer) { + file = TestEnvironment::substitute(file, version_); + + std::ifstream is(file, std::ios::binary | std::ios::ate); + RELEASE_ASSERT(!is.fail()); + + std::ifstream::pos_type len = is.tellg(); + if (len > 0) { + std::vector bytes(len, 0); + is.seekg(0, std::ios::beg); + RELEASE_ASSERT(!is.fail()); + + is.read(bytes.data(), len); + RELEASE_ASSERT(!is.fail()); + + buffer.add(bytes.data(), len); + } + } + + std::string transport_; + std::string protocol_; + bool multiplexed_; + + std::string result_; + + Buffer::OwnedImpl request_bytes_; + Buffer::OwnedImpl response_bytes_; +}; + +static std::string +paramToString(const TestParamInfo>& params) { + std::string transport, protocol; + bool multiplexed; + std::tie(transport, protocol, multiplexed) = params.param; + transport = StringUtil::toUpper(absl::string_view(transport).substr(0, 1)) + transport.substr(1); + protocol = StringUtil::toUpper(absl::string_view(protocol).substr(0, 1)) + protocol.substr(1); + if (multiplexed) { + return fmt::format("{}{}Multiplexed", transport, protocol); + } + return fmt::format("{}{}", transport, protocol); +} + +INSTANTIATE_TEST_CASE_P( + TransportAndProtocol, ThriftFilterIntegrationTest, + Combine(Values(TransportNames::get().FRAMED, TransportNames::get().UNFRAMED), + Values(ProtocolNames::get().BINARY, ProtocolNames::get().COMPACT), Values(false, true)), + paramToString); + +TEST_P(ThriftFilterIntegrationTest, Success) { + initializeCall(CallResult::Success); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(TestUtility::bufferToString(request_bytes_)); + + FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + Buffer::OwnedImpl upstream_request( + fake_upstream_connection->waitForData(request_bytes_.length())); + EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + + fake_upstream_connection->write(TestUtility::bufferToString(response_bytes_)); + + tcp_client->waitForData(TestUtility::bufferToString(response_bytes_)); + tcp_client->close(); + fake_upstream_connection->waitForDisconnect(); + + EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); + EXPECT_EQ(1U, counter->value()); + counter = test_server_->counter("thrift.thrift_stats.response_success"); + EXPECT_EQ(1U, counter->value()); +} + +TEST_P(ThriftFilterIntegrationTest, IDLException) { + initializeCall(CallResult::IDLException); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(TestUtility::bufferToString(request_bytes_)); + + FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + Buffer::OwnedImpl upstream_request( + fake_upstream_connection->waitForData(request_bytes_.length())); + EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + + fake_upstream_connection->write(TestUtility::bufferToString(response_bytes_)); + + tcp_client->waitForData(TestUtility::bufferToString(response_bytes_)); + tcp_client->close(); + fake_upstream_connection->waitForDisconnect(); + + EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); + EXPECT_EQ(1U, counter->value()); + counter = test_server_->counter("thrift.thrift_stats.response_error"); + EXPECT_EQ(1U, counter->value()); +} + +TEST_P(ThriftFilterIntegrationTest, Exception) { + initializeCall(CallResult::Exception); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(TestUtility::bufferToString(request_bytes_)); + + FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + Buffer::OwnedImpl upstream_request( + fake_upstream_connection->waitForData(request_bytes_.length())); + EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + + fake_upstream_connection->write(TestUtility::bufferToString(response_bytes_)); + + tcp_client->waitForData(TestUtility::bufferToString(response_bytes_)); + tcp_client->close(); + fake_upstream_connection->waitForDisconnect(); + + EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); + EXPECT_EQ(1U, counter->value()); + counter = test_server_->counter("thrift.thrift_stats.response_exception"); + EXPECT_EQ(1U, counter->value()); +} + +TEST_P(ThriftFilterIntegrationTest, Oneway) { + initializeOneway(); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(TestUtility::bufferToString(request_bytes_)); + + FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + Buffer::OwnedImpl upstream_request( + fake_upstream_connection->waitForData(request_bytes_.length())); + EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + + tcp_client->close(); + fake_upstream_connection->waitForDisconnect(); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_oneway"); + EXPECT_EQ(1U, counter->value()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/utility.h b/test/extensions/filters/network/thrift_proxy/utility.h index 5b85d4f9533a..8058f9b33799 100644 --- a/test/extensions/filters/network/thrift_proxy/utility.h +++ b/test/extensions/filters/network/thrift_proxy/utility.h @@ -7,6 +7,10 @@ #include "extensions/filters/network/thrift_proxy/protocol.h" +#include "gtest/gtest.h" + +using testing::TestParamInfo; + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -52,6 +56,43 @@ inline std::string bufferToString(Buffer::Instance& buffer) { return std::string(data, buffer.length()); } +inline std::string fieldTypeToString(const FieldType& field_type) { + switch (field_type) { + case FieldType::Stop: + return "Stop"; + case FieldType::Void: + return "Void"; + case FieldType::Bool: + return "Bool"; + case FieldType::Byte: + return "Byte"; + case FieldType::Double: + return "Double"; + case FieldType::I16: + return "I16"; + case FieldType::I32: + return "I32"; + case FieldType::I64: + return "I64"; + case FieldType::String: + return "String"; + case FieldType::Struct: + return "Struct"; + case FieldType::Map: + return "Map"; + case FieldType::Set: + return "Set"; + case FieldType::List: + return "List"; + default: + return "UnknownFieldType"; + } +} + +inline std::string fieldTypeParamToString(const TestParamInfo& params) { + return fieldTypeToString(params.param); +} + } // namespace } // namespace ThriftProxy } // namespace NetworkFilters