Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interop: Custom creds for stress test client #6809

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions interop/stress/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"math/rand"
"net"
"os"
"strconv"
"strings"
"sync"
Expand All @@ -34,27 +35,37 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/google"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/interop"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"

_ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath.

testgrpc "google.golang.org/grpc/interop/grpc_testing"
metricspb "google.golang.org/grpc/interop/stress/grpc_testing"
)

const (
googleDefaultCredsName = "google_default_credentials"
computeEngineCredsName = "compute_engine_channel_creds"
)

var (
serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use")

totalNumCalls int64
logger = grpclog.Component("stress")
Expand All @@ -71,12 +82,13 @@ func parseTestCases(testCaseString string) []testCaseWithWeight {
testCaseStrings := strings.Split(testCaseString, ",")
testCases := make([]testCaseWithWeight, len(testCaseStrings))
for i, str := range testCaseStrings {
testCase := strings.Split(str, ":")
if len(testCase) != 2 {
testCaseNameAndWeight := strings.Split(str, ":")
if len(testCaseNameAndWeight) != 2 {
panic(fmt.Sprintf("invalid test case with weight: %s", str))
}
// Check if test case is supported.
switch testCase[0] {
testCaseName := strings.ToLower(testCaseNameAndWeight[0])
switch testCaseName {
case
"empty_unary",
"large_unary",
Expand All @@ -90,10 +102,10 @@ func parseTestCases(testCaseString string) []testCaseWithWeight {
"status_code_and_message",
"custom_metadata":
default:
panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0]))
}
testCases[i].name = testCase[0]
w, err := strconv.Atoi(testCase[1])
testCases[i].name = testCaseName
w, err := strconv.Atoi(testCaseNameAndWeight[1])
if err != nil {
panic(fmt.Sprintf("%v", err))
}
Expand Down Expand Up @@ -263,6 +275,7 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
logger.Infof("use_tls: %t", *useTLS)
logger.Infof("use_test_ca: %t", *testCA)
logger.Infof("server_host_override: %s", *tlsServerName)
logger.Infof("custom_credentials_type: %s", *customCredentialsType)

logger.Infoln("addresses:")
for i, addr := range addresses {
Expand All @@ -276,7 +289,15 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {

func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
var opts []grpc.DialOption
if useTLS {
if *customCredentialsType != "" {
if *customCredentialsType == googleDefaultCredsName {
opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
} else if *customCredentialsType == computeEngineCredsName {
opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
} else {
logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType)
}
Comment on lines +292 to +299
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider:

switch *customCredentialsType {
	case googleDefaultCredsName:
	case computeEngineCredsName:
	case "":
		if useTLS {
		}
	default:
		// unknown custom creds error
}

Also if we can change the spec here, consider custom_creds_type=="tls" means to use tls instead of having a separate flag for it. Having two flags for one setting is problematic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how the flags work in the interop test client across languages as well as the java stress test one. Makes sense to revisit, but I think that would be a broad change outside the scope of this PR.

} else if useTLS {
var sn string
if tlsServerName != "" {
sn = tlsServerName
Expand All @@ -303,6 +324,7 @@ func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.C

func main() {
flag.Parse()
resolver.SetDefaultScheme("dns")
addresses := strings.Split(*serverAddresses, ",")
tests := parseTestCases(*testCases)
logParameterInfo(addresses, tests)
Expand Down Expand Up @@ -337,6 +359,6 @@ func main() {
close(stop)
}
wg.Wait()
logger.Infof("Total calls made: %v", totalNumCalls)
fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls)
logger.Infof(" ===== ALL DONE ===== ")
}
Loading