Skip to content

Commit

Permalink
syncer: support TLS for region syncer (tikv#1728) (tikv#1739)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <rleungx@gmail.com>
  • Loading branch information
rleungx authored and nolouch committed Nov 12, 2019
1 parent 0e3a054 commit 49a1bc0
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 79 deletions.
46 changes: 3 additions & 43 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@ package pd

import (
"context"
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"
"strings"
"sync"
"time"

opentracing "github.com/opentracing/opentracing-go"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/pdpb"
log "github.com/pingcap/log"
"github.com/pingcap/log"
"github.com/pingcap/pd/pkg/grpcutil"
"github.com/pkg/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// Client is a PD (Placement Driver) client.
Expand Down Expand Up @@ -272,43 +268,7 @@ func (c *client) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) {
return conn, nil
}

opt := grpc.WithInsecure()
if len(c.security.CAPath) != 0 {

certificates := []tls.Certificate{}
if len(c.security.CertPath) != 0 && len(c.security.KeyPath) != 0 {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(c.security.CertPath, c.security.KeyPath)
if err != nil {
return nil, errors.Errorf("could not load client key pair: %s", err)
}
certificates = append(certificates, certificate)
}

// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(c.security.CAPath)
if err != nil {
return nil, errors.Errorf("could not read ca certificate: %s", err)
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return nil, errors.New("failed to append ca certs")
}

creds := credentials.NewTLS(&tls.Config{
Certificates: certificates,
RootCAs: certPool,
})

opt = grpc.WithTransportCredentials(creds)
}
u, err := url.Parse(addr)
if err != nil {
return nil, errors.WithStack(err)
}
cc, err := grpc.Dial(u.Host, opt)
cc, err := grpcutil.GetClientConn(addr, c.security.CAPath, c.security.CertPath, c.security.KeyPath)
if err != nil {
return nil, errors.WithStack(err)
}
Expand Down
69 changes: 69 additions & 0 deletions pkg/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package grpcutil

import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"

"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// GetClientConn returns a gRPC client connection.
func GetClientConn(addr string, caPath string, certPath string, keyPath string) (*grpc.ClientConn, error) {
opt := grpc.WithInsecure()
if len(caPath) != 0 {
certificates := []tls.Certificate{}
if len(certPath) != 0 && len(keyPath) != 0 {
// Load the client certificates from disk
certificate, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, errors.Errorf("could not load client key pair: %s", err)
}
certificates = append(certificates, certificate)
}

// Create a certificate pool from the certificate authority
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, errors.Errorf("could not read ca certificate: %s", err)
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return nil, errors.New("failed to append ca certs")
}

creds := credentials.NewTLS(&tls.Config{
Certificates: certificates,
RootCAs: certPool,
})

opt = grpc.WithTransportCredentials(creds)
}
u, err := url.Parse(addr)
if err != nil {
return nil, errors.WithStack(err)
}
cc, err := grpc.Dial(u.Host, opt)
if err != nil {
return nil, errors.WithStack(err)
}
return cc, nil
}
18 changes: 13 additions & 5 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,15 +776,23 @@ type SecurityConfig struct {
KeyPath string `toml:"key-path" json:"key-path"`
}

// ConvertToMap is used to convert SecurityConfig to a map.
func (s *SecurityConfig) ConvertToMap() map[string]string {
return map[string]string{
"caPath": s.CAPath,
"certPath": s.CertPath,
"keyPath": s.KeyPath}
}

// ToTLSConfig generatres tls config.
func (s SecurityConfig) ToTLSConfig() (*tls.Config, error) {
if len(s.CertPath) == 0 && len(s.KeyPath) == 0 {
func ToTLSConfig(config map[string]string) (*tls.Config, error) {
if len(config["certPath"]) == 0 && len(config["keyPath"]) == 0 {
return nil, nil
}
tlsInfo := transport.TLSInfo{
CertFile: s.CertPath,
KeyFile: s.KeyPath,
TrustedCAFile: s.CAPath,
CertFile: config["certPath"],
KeyFile: config["keyPath"],
TrustedCAFile: config["caPath"],
}
tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion server/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type testConfigSuite struct{}

func (s *testConfigSuite) TestTLS(c *C) {
cfg := NewConfig()
tls, err := cfg.Security.ToTLSConfig()
tls, err := ToTLSConfig(cfg.Security.ConvertToMap())
c.Assert(err, IsNil)
c.Assert(tls, IsNil)
}
Expand Down
2 changes: 1 addition & 1 deletion server/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func PrepareJoinCluster(cfg *Config) error {
}

// Below are cases without data directory.
tlsConfig, err := cfg.Security.ToTLSConfig()
tlsConfig, err := ToTLSConfig(cfg.Security.ConvertToMap())
if err != nil {
return err
}
Expand Down
15 changes: 5 additions & 10 deletions server/region_syncer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ package syncer

import (
"context"
"net/url"
"time"

"github.com/pingcap/kvproto/pkg/pdpb"
log "github.com/pingcap/log"
"github.com/pingcap/log"
"github.com/pingcap/pd/pkg/grpcutil"
"github.com/pingcap/pd/server/core"
"github.com/pkg/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -51,14 +51,9 @@ func (s *RegionSyncer) reset() {
func (s *RegionSyncer) establish(addr string) (ClientStream, error) {
s.reset()

u, err := url.Parse(addr)
cc, err := grpcutil.GetClientConn(addr, s.securityConfig["caPath"], s.securityConfig["certPath"], s.securityConfig["keyPath"])
if err != nil {
return nil, err
}

cc, err := grpc.Dial(u.Host, grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(msgSize)))
if err != nil {
return nil, err
return nil, errors.WithStack(err)
}

ctx, cancel := context.WithCancel(s.server.Context())
Expand Down
31 changes: 17 additions & 14 deletions server/region_syncer/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"github.com/juju/ratelimit"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/pdpb"
log "github.com/pingcap/log"
"github.com/pingcap/log"
"github.com/pingcap/pd/server/core"
"github.com/pkg/errors"
"go.uber.org/zap"
Expand Down Expand Up @@ -59,19 +59,21 @@ type Server interface {
GetStorage() *core.KV
Name() string
GetMetaRegions() []*metapb.Region
GetSecurityConfig() map[string]string
}

// RegionSyncer is used to sync the region information without raft.
type RegionSyncer struct {
sync.RWMutex
streams map[string]ServerStream
ctx context.Context
cancel context.CancelFunc
server Server
closed chan struct{}
wg sync.WaitGroup
history *historyBuffer
limit *ratelimit.Bucket
streams map[string]ServerStream
ctx context.Context
cancel context.CancelFunc
server Server
closed chan struct{}
wg sync.WaitGroup
history *historyBuffer
limit *ratelimit.Bucket
securityConfig map[string]string
}

// NewRegionSyncer returns a region syncer.
Expand All @@ -81,11 +83,12 @@ type RegionSyncer struct {
// no longer etcd but go-leveldb.
func NewRegionSyncer(s Server) *RegionSyncer {
return &RegionSyncer{
streams: make(map[string]ServerStream),
server: s,
closed: make(chan struct{}),
history: newHistoryBuffer(defaultHistoryBufferSize, s.GetStorage().GetRegionKV()),
limit: ratelimit.NewBucketWithRate(defaultBucketRate, defaultBucketCapacity),
streams: make(map[string]ServerStream),
server: s,
closed: make(chan struct{}),
history: newHistoryBuffer(defaultHistoryBufferSize, s.GetStorage().GetRegionKV()),
limit: ratelimit.NewBucketWithRate(defaultBucketRate, defaultBucketCapacity),
securityConfig: s.GetSecurityConfig(),
}
}

Expand Down
8 changes: 4 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (s *Server) startEtcd(ctx context.Context) error {
if err != nil {
return errors.WithStack(err)
}
tlsConfig, err := s.cfg.Security.ToTLSConfig()
tlsConfig, err := ToTLSConfig(s.cfg.Security.ConvertToMap())
if err != nil {
return err
}
Expand Down Expand Up @@ -727,9 +727,9 @@ func (s *Server) GetClusterVersion() semver.Version {
return s.scheduleOpt.loadClusterVersion()
}

// GetSecurityConfig get the security config.
func (s *Server) GetSecurityConfig() *SecurityConfig {
return &s.cfg.Security
// GetSecurityConfig get paths of the security config.
func (s *Server) GetSecurityConfig() map[string]string {
return s.cfg.Security.ConvertToMap()
}

// IsNamespaceExist returns whether the namespace exists.
Expand Down
2 changes: 1 addition & 1 deletion server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func subTimeByWallClock(after time.Time, before time.Time) time.Duration {

// InitHTTPClient initials a http client.
func InitHTTPClient(svr *Server) error {
tlsConfig, err := svr.GetSecurityConfig().ToTLSConfig()
tlsConfig, err := ToTLSConfig(svr.GetSecurityConfig())
if err != nil {
return err
}
Expand Down

0 comments on commit 49a1bc0

Please sign in to comment.