diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 4f20b0b35b4..d20a6916f37 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -55,23 +55,15 @@ var PDCli pd.Client func requirePDClient(cmd *cobra.Command, _ []string) error { var ( - caPath string - err error + tlsConfig *tls.Config + err error ) - caPath, err = cmd.Flags().GetString("cacert") - if err == nil && len(caPath) != 0 { - var certPath, keyPath string - certPath, err = cmd.Flags().GetString("cert") - if err != nil { - return err - } - keyPath, err = cmd.Flags().GetString("key") - if err != nil { - return err - } - return initNewPDClientWithTLS(cmd, caPath, certPath, keyPath) + tlsConfig, err = parseTLSConfig(cmd) + if err != nil { + return err } - return initNewPDClient(cmd) + + return initNewPDClient(cmd, pd.WithTLSConfig(tlsConfig)) } // shouldInitPDClient checks whether we should create a new PD client according to the cluster information. @@ -111,44 +103,36 @@ func initNewPDClient(cmd *cobra.Command, opts ...pd.ClientOption) error { return nil } -func initNewPDClientWithTLS(cmd *cobra.Command, caPath, certPath, keyPath string) error { - tlsConfig, err := initTLSConfig(caPath, certPath, keyPath) - if err != nil { - return err - } - initNewPDClient(cmd, pd.WithTLSConfig(tlsConfig)) - return nil -} - // TODO: replace dialClient with the PD HTTP client completely. var dialClient = &http.Client{ Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID), } -// RequireHTTPSClient creates a HTTPS client if the related flags are set -func RequireHTTPSClient(cmd *cobra.Command, args []string) error { +func parseTLSConfig(cmd *cobra.Command) (*tls.Config, error) { caPath, err := cmd.Flags().GetString("cacert") - if err == nil && len(caPath) != 0 { - certPath, err := cmd.Flags().GetString("cert") - if err != nil { - return err - } - keyPath, err := cmd.Flags().GetString("key") - if err != nil { - return err - } - err = initHTTPSClient(caPath, certPath, keyPath) - if err != nil { - cmd.Println(err) - return err - } + if err != nil || len(caPath) == 0 { + return nil, err + } + certPath, err := cmd.Flags().GetString("cert") + if err != nil { + return nil, err + } + keyPath, err := cmd.Flags().GetString("key") + if err != nil { + return nil, err } - return nil -} - -func initHTTPSClient(caPath, certPath, keyPath string) error { tlsConfig, err := initTLSConfig(caPath, certPath, keyPath) if err != nil { + return nil, err + } + + return tlsConfig, nil +} + +// RequireHTTPSClient creates a HTTPS client if the related flags are set +func RequireHTTPSClient(cmd *cobra.Command, _ []string) error { + tlsConfig, err := parseTLSConfig(cmd) + if err != nil || tlsConfig == nil { return err } dialClient = &http.Client{ diff --git a/tools/pd-ctl/pdctl/command/global_test.go b/tools/pd-ctl/pdctl/command/global_test.go new file mode 100644 index 00000000000..86eb4366d04 --- /dev/null +++ b/tools/pd-ctl/pdctl/command/global_test.go @@ -0,0 +1,58 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package command + +import ( + "os" + "os/exec" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestParseTLSConfig(t *testing.T) { + re := require.New(t) + + rootCmd := &cobra.Command{ + Use: "pd-ctl", + Short: "Placement Driver control", + SilenceErrors: true, + } + certPath := "../../tests/cert" + rootCmd.Flags().String("cacert", certPath+"/ca.pem", "path of file that contains list of trusted SSL CAs") + rootCmd.Flags().String("cert", certPath+"/client.pem", "path of file that contains X509 certificate in PEM format") + rootCmd.Flags().String("key", certPath+"/client-key.pem", "path of file that contains X509 key in PEM format") + + // generate certs + if err := os.Mkdir(certPath, 0755); err != nil { + t.Fatal(err) + } + certScript := "../../tests/cert_opt.sh" + if err := exec.Command(certScript, "generate", certPath).Run(); err != nil { + t.Fatal(err) + } + defer func() { + if err := exec.Command(certScript, "cleanup", certPath).Run(); err != nil { + t.Fatal(err) + } + if err := os.RemoveAll(certPath); err != nil { + t.Fatal(err) + } + }() + + tlsConfig, err := parseTLSConfig(rootCmd) + re.NoError(err) + re.NotNil(tlsConfig) +} diff --git a/tools/pd-ctl/pdctl/ctl.go b/tools/pd-ctl/pdctl/ctl.go index 5790911d79f..8b0c62a920f 100644 --- a/tools/pd-ctl/pdctl/ctl.go +++ b/tools/pd-ctl/pdctl/ctl.go @@ -30,6 +30,7 @@ import ( func init() { cobra.EnablePrefixMatching = true + cobra.EnableTraverseRunHooks = true } // GetRootCmd is exposed for integration tests. But it can be embedded into another suite, too. diff --git a/tools/pd-ctl/tests/health/health_test.go b/tools/pd-ctl/tests/health/health_test.go index 9150a56c91b..f1d3c7cfbf1 100644 --- a/tools/pd-ctl/tests/health/health_test.go +++ b/tools/pd-ctl/tests/health/health_test.go @@ -17,14 +17,21 @@ package health_test import ( "context" "encoding/json" + "os" + "os/exec" + "path/filepath" + "strings" "testing" "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/utils/grpcutil" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/cluster" + "github.com/tikv/pd/server/config" pdTests "github.com/tikv/pd/tests" ctl "github.com/tikv/pd/tools/pd-ctl/pdctl" "github.com/tikv/pd/tools/pd-ctl/tests" + "go.etcd.io/etcd/pkg/transport" ) func TestHealth(t *testing.T) { @@ -68,3 +75,80 @@ func TestHealth(t *testing.T) { re.NoError(json.Unmarshal(output, &h)) re.Equal(healths, h) } + +func TestHealthTLS(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + certPath := "../cert" + certScript := "../cert_opt.sh" + // generate certs + if err := os.Mkdir(certPath, 0755); err != nil { + t.Fatal(err) + } + if err := exec.Command(certScript, "generate", certPath).Run(); err != nil { + t.Fatal(err) + } + defer func() { + if err := exec.Command(certScript, "cleanup", certPath).Run(); err != nil { + t.Fatal(err) + } + if err := os.RemoveAll(certPath); err != nil { + t.Fatal(err) + } + }() + + tlsInfo := transport.TLSInfo{ + KeyFile: filepath.Join(certPath, "pd-server-key.pem"), + CertFile: filepath.Join(certPath, "pd-server.pem"), + TrustedCAFile: filepath.Join(certPath, "ca.pem"), + } + tc, err := pdTests.NewTestCluster(ctx, 1, func(conf *config.Config, _ string) { + conf.Security.TLSConfig = grpcutil.TLSConfig{ + KeyPath: tlsInfo.KeyFile, + CertPath: tlsInfo.CertFile, + CAPath: tlsInfo.TrustedCAFile, + } + conf.AdvertiseClientUrls = strings.ReplaceAll(conf.AdvertiseClientUrls, "http", "https") + conf.ClientUrls = strings.ReplaceAll(conf.ClientUrls, "http", "https") + conf.AdvertisePeerUrls = strings.ReplaceAll(conf.AdvertisePeerUrls, "http", "https") + conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https") + conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https") + }) + re.NoError(err) + defer tc.Destroy() + err = tc.RunInitialServers() + re.NoError(err) + tc.WaitLeader() + cmd := ctl.GetRootCmd() + + client := tc.GetEtcdClient() + members, err := cluster.GetMembers(client) + re.NoError(err) + healthMembers := cluster.CheckHealth(tc.GetHTTPClient(), members) + healths := []api.Health{} + for _, member := range members { + h := api.Health{ + Name: member.Name, + MemberID: member.MemberId, + ClientUrls: member.ClientUrls, + Health: false, + } + if _, ok := healthMembers[member.GetMemberId()]; ok { + h.Health = true + } + healths = append(healths, h) + } + + pdAddr := tc.GetConfig().GetClientURL() + pdAddr = strings.ReplaceAll(pdAddr, "http", "https") + args := []string{"-u", pdAddr, "health", + "--cacert=../cert/ca.pem", + "--cert=../cert/client.pem", + "--key=../cert/client-key.pem"} + output, err := tests.ExecuteCommand(cmd, args...) + re.NoError(err) + h := make([]api.Health, len(healths)) + re.NoError(json.Unmarshal(output, &h)) + re.Equal(healths, h) +}