From 3e144cb900369171448baeb9bcc3bcbec91ea1e5 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sun, 8 Apr 2018 14:37:33 -0700 Subject: [PATCH] Teleport certificate authority rotation. This commit implements #1860 During the the rotation procedure issuing TLS and SSH certificate authorities are re-generated and all internal components of the cluster re-register to get new credentials. The rotation procedure is based on a distributed state machine algorithm - certificate authorities have explicit rotation state and all parts of the cluster sync local state machines by following transitions between phases. Operator can launch CA rotation in auto or manual modes. In manual mode operator moves cluster bewtween rotation states and watches the states of the components to sync. In auto mode state transitions are happening automatically on a specified schedule. The design documentation is embedded in the code: lib/auth/rotate.go --- Gopkg.lock | 242 +------- Gopkg.toml | 2 +- constants.go | 17 +- e | 2 +- integration/helpers.go | 157 ++++- integration/integration_test.go | 529 +++++++++++++++- lib/auth/apiserver.go | 46 +- lib/auth/auth.go | 73 ++- lib/auth/auth_test.go | 181 ++---- lib/auth/auth_with_roles.go | 69 ++- lib/auth/clt.go | 90 ++- lib/auth/helpers.go | 28 +- lib/auth/init.go | 581 +++++++---------- lib/auth/init_test.go | 58 +- lib/auth/methods.go | 142 ----- lib/auth/methods_test.go | 52 -- lib/auth/middleware.go | 14 +- lib/auth/password_test.go | 15 +- lib/auth/permissions.go | 73 ++- lib/auth/register.go | 32 +- lib/auth/rotate.go | 575 +++++++++++++++++ lib/auth/state.go | 246 ++++++++ lib/auth/tls_test.go | 409 ++++++++++++ lib/backend/backend.go | 5 + lib/backend/boltbk/boltbk.go | 46 ++ lib/backend/boltbk/boltbk_test.go | 4 + lib/backend/dir/impl.go | 62 +- lib/backend/dir/impl_test.go | 18 +- lib/backend/dynamo/dynamodbbk.go | 45 ++ lib/backend/dynamo/dynamodbbk_test.go | 4 + lib/backend/etcdbk/etcd.go | 20 + lib/backend/etcdbk/etcd_test.go | 4 + lib/backend/test/suite.go | 30 + lib/config/fileconf.go | 2 +- lib/defaults/defaults.go | 7 + lib/fixtures/fixtures.go | 10 +- lib/httplib/httplib.go | 4 +- lib/multiplexer/multiplexer.go | 5 +- lib/reversetunnel/remotesite.go | 114 ++-- lib/reversetunnel/srv.go | 14 +- lib/service/cfg.go | 15 + lib/service/cfg_test.go | 42 +- lib/service/connect.go | 434 +++++++++++++ lib/service/service.go | 584 +++++++++++++----- lib/service/signals.go | 61 +- lib/service/supervisor.go | 153 ++++- lib/services/authority.go | 256 +++++++- lib/services/local/configuration_test.go | 4 +- lib/services/local/presence_test.go | 4 +- lib/services/local/trust.go | 46 +- lib/services/local/users.go | 2 +- lib/services/parser.go | 41 +- lib/services/resource.go | 10 + lib/services/role.go | 2 +- lib/services/server.go | 7 +- lib/services/suite/suite.go | 24 +- lib/services/trust.go | 11 +- lib/srv/regular/sshserver.go | 36 +- lib/state/cachingaccesspoint.go | 11 +- lib/utils/copy.go | 22 + tool/tctl/common/auth_command.go | 41 +- tool/tctl/common/status_command.go | 118 ++++ tool/tctl/common/tctl.go | 8 +- tool/tctl/main.go | 1 + tool/teleport/common/teleport.go | 11 +- vendor/github.com/cenkalti/backoff/.gitignore | 22 - .../github.com/cenkalti/backoff/.travis.yml | 9 - vendor/github.com/cenkalti/backoff/LICENSE | 20 - vendor/github.com/cenkalti/backoff/README.md | 30 - vendor/github.com/cenkalti/backoff/backoff.go | 66 -- .../cenkalti/backoff/backoff_test.go | 27 - vendor/github.com/cenkalti/backoff/context.go | 60 -- .../cenkalti/backoff/context_test.go | 26 - .../cenkalti/backoff/example_test.go | 73 --- .../cenkalti/backoff/exponential.go | 156 ----- .../cenkalti/backoff/exponential_test.go | 108 ---- vendor/github.com/cenkalti/backoff/retry.go | 78 --- .../github.com/cenkalti/backoff/retry_test.go | 99 --- vendor/github.com/cenkalti/backoff/ticker.go | 81 --- .../cenkalti/backoff/ticker_test.go | 94 --- vendor/github.com/cenkalti/backoff/tries.go | 35 -- .../github.com/cenkalti/backoff/tries_test.go | 55 -- .../vulcand/predicate/builder/builder.go | 169 +++++ vendor/github.com/vulcand/predicate/lib.go | 8 + vendor/github.com/vulcand/predicate/parse.go | 26 + .../vulcand/predicate/parse_test.go | 43 +- .../github.com/vulcand/predicate/predicate.go | 18 + 87 files changed, 4843 insertions(+), 2431 deletions(-) delete mode 100644 lib/auth/methods_test.go create mode 100644 lib/auth/rotate.go create mode 100644 lib/auth/state.go create mode 100644 lib/service/connect.go create mode 100644 tool/tctl/common/status_command.go delete mode 100644 vendor/github.com/cenkalti/backoff/.gitignore delete mode 100644 vendor/github.com/cenkalti/backoff/.travis.yml delete mode 100644 vendor/github.com/cenkalti/backoff/LICENSE delete mode 100644 vendor/github.com/cenkalti/backoff/README.md delete mode 100644 vendor/github.com/cenkalti/backoff/backoff.go delete mode 100644 vendor/github.com/cenkalti/backoff/backoff_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/context.go delete mode 100644 vendor/github.com/cenkalti/backoff/context_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/example_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/exponential.go delete mode 100644 vendor/github.com/cenkalti/backoff/exponential_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/retry.go delete mode 100644 vendor/github.com/cenkalti/backoff/retry_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/ticker.go delete mode 100644 vendor/github.com/cenkalti/backoff/ticker_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/tries.go delete mode 100644 vendor/github.com/cenkalti/backoff/tries_test.go create mode 100644 vendor/github.com/vulcand/predicate/builder/builder.go diff --git a/Gopkg.lock b/Gopkg.lock index fc284566a2db8..061ed07991ec8 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -4,19 +4,13 @@ [[projects]] branch = "master" name = "github.com/Azure/go-ansiterm" - packages = [ - ".", - "winterm" - ] + packages = [".","winterm"] revision = "19f72df4d05d31cbe1c56bfc8045c96babff6c7e" [[projects]] branch = "master" name = "github.com/alecthomas/template" - packages = [ - ".", - "parse" - ] + packages = [".","parse"] revision = "a0175ee3bccc567396460bf5acd36800cb10c49c" [[projects]] @@ -27,39 +21,7 @@ [[projects]] name = "github.com/aws/aws-sdk-go" - packages = [ - "aws", - "aws/awserr", - "aws/awsutil", - "aws/client", - "aws/client/metadata", - "aws/corehandlers", - "aws/credentials", - "aws/credentials/ec2rolecreds", - "aws/credentials/endpointcreds", - "aws/credentials/stscreds", - "aws/defaults", - "aws/ec2metadata", - "aws/endpoints", - "aws/request", - "aws/session", - "aws/signer/v4", - "internal/shareddefaults", - "private/protocol", - "private/protocol/json/jsonutil", - "private/protocol/jsonrpc", - "private/protocol/query", - "private/protocol/query/queryutil", - "private/protocol/rest", - "private/protocol/restxml", - "private/protocol/xml/xmlutil", - "service/dynamodb", - "service/dynamodb/dynamodbattribute", - "service/s3", - "service/s3/s3iface", - "service/s3/s3manager", - "service/sts" - ] + packages = ["aws","aws/awserr","aws/awsutil","aws/client","aws/client/metadata","aws/corehandlers","aws/credentials","aws/credentials/ec2rolecreds","aws/credentials/endpointcreds","aws/credentials/stscreds","aws/defaults","aws/ec2metadata","aws/endpoints","aws/request","aws/session","aws/signer/v4","internal/shareddefaults","private/protocol","private/protocol/json/jsonutil","private/protocol/jsonrpc","private/protocol/query","private/protocol/query/queryutil","private/protocol/rest","private/protocol/restxml","private/protocol/xml/xmlutil","service/dynamodb","service/dynamodb/dynamodbattribute","service/s3","service/s3/s3iface","service/s3/s3manager","service/sts"] revision = "a201bf33b18ad4ab54344e4bc26b87eb6ad37b8e" version = "v1.12.25" @@ -81,19 +43,9 @@ [[projects]] name = "github.com/boombuler/barcode" - packages = [ - ".", - "qr", - "utils" - ] + packages = [".","qr","utils"] revision = "fe0f26ff6d26693948ee8189aa064ee8c54141fa" -[[projects]] - name = "github.com/cenkalti/backoff" - packages = ["."] - revision = "61153c768f31ee5f130071d08fc82b85208528de" - version = "v1.1.0" - [[projects]] name = "github.com/codahale/hdrhistogram" packages = ["."] @@ -101,28 +53,14 @@ [[projects]] name = "github.com/coreos/etcd" - packages = [ - "client", - "pkg/pathutil", - "pkg/srv", - "pkg/tlsutil", - "pkg/transport", - "pkg/types", - "version" - ] + packages = ["client","pkg/pathutil","pkg/srv","pkg/tlsutil","pkg/transport","pkg/types","version"] revision = "9d43462d174c664f5edf313dec0de31e1ef4ed47" version = "v3.2.6" [[projects]] branch = "master" name = "github.com/coreos/go-oidc" - packages = [ - "http", - "jose", - "key", - "oauth2", - "oidc" - ] + packages = ["http","jose","key","oauth2","oidc"] revision = "e51edf2e47e65e5708600d4da6fca1388ee437b4" source = "github.com/gravitational/go-oidc" @@ -134,11 +72,7 @@ [[projects]] name = "github.com/coreos/pkg" - packages = [ - "health", - "httputil", - "timeutil" - ] + packages = ["health","httputil","timeutil"] revision = "1914e367e85eaf0c25d495b48e060dfe6190f8d0" [[projects]] @@ -177,31 +111,18 @@ [[projects]] name = "github.com/golang/protobuf" - packages = [ - "jsonpb", - "proto", - "protoc-gen-go/descriptor", - "ptypes/any", - "ptypes/empty" - ] + packages = ["jsonpb","proto","protoc-gen-go/descriptor","ptypes/any","ptypes/empty"] revision = "8ee79997227bf9b34611aee7946ae64735e6fd93" [[projects]] name = "github.com/google/gops" - packages = [ - "agent", - "internal", - "signal" - ] + packages = ["agent","internal","signal"] revision = "fa6968806ca68b7db113256f300d82b5206a2c3c" version = "v0.3.1" [[projects]] name = "github.com/gravitational/configure" - packages = [ - "cstrings", - "jsonschema" - ] + packages = ["cstrings","jsonschema"] revision = "1db4b84fe9dbbbaf40827aa714dcca17b368de2c" [[projects]] @@ -218,20 +139,13 @@ [[projects]] name = "github.com/gravitational/license" - packages = [ - ".", - "constants" - ] + packages = [".","constants"] revision = "102213511ace56c97ccf1eef645835e16f84d130" version = "0.0.4" [[projects]] name = "github.com/gravitational/reporting" - packages = [ - ".", - "client", - "types" - ] + packages = [".","client","types"] revision = "3c4a4e96fb5896e14fe29da7fcce14b8d93f3965" version = "0.0.4" @@ -255,12 +169,7 @@ [[projects]] name = "github.com/grpc-ecosystem/grpc-gateway" - packages = [ - "runtime", - "runtime/internal", - "third_party/googleapis/google/api", - "utilities" - ] + packages = ["runtime","runtime/internal","third_party/googleapis/google/api","utilities"] revision = "a8f25bd1ab549f8b87afd48aa9181221e9d439bb" version = "v1.1.0" @@ -299,10 +208,7 @@ [[projects]] name = "github.com/mailgun/lemma" - packages = [ - "random", - "secret" - ] + packages = ["random","secret"] revision = "e8b0cd607f5855f9a4a33f8ae5d033178f559964" version = "0.0.2" @@ -319,10 +225,7 @@ [[projects]] branch = "alexander/copy" name = "github.com/mailgun/oxy" - packages = [ - "forward", - "utils" - ] + packages = ["forward","utils"] revision = "0c3e45a1f7b20e1f818612ad4f42abfc0923f3d5" [[projects]] @@ -343,11 +246,7 @@ [[projects]] branch = "master" name = "github.com/mdp/rsc" - packages = [ - "gf256", - "qr", - "qr/coding" - ] + packages = ["gf256","qr","qr/coding"] revision = "90f07065088deccf50b28eb37c93dad3078c0f3c" [[projects]] @@ -364,11 +263,7 @@ [[projects]] name = "github.com/pquerna/otp" - packages = [ - ".", - "hotp", - "totp" - ] + packages = [".","hotp","totp"] revision = "54653902c20e47f3417541d35435cb6d6162e28a" [[projects]] @@ -385,11 +280,7 @@ [[projects]] name = "github.com/prometheus/common" - packages = [ - "expfmt", - "internal/bitbucket.org/ww/goautoneg", - "model" - ] + packages = ["expfmt","internal/bitbucket.org/ww/goautoneg","model"] revision = "50022896a67062a54a54f268fb6fe4bf90b34859" [[projects]] @@ -399,20 +290,13 @@ [[projects]] name = "github.com/russellhaering/gosaml2" - packages = [ - ".", - "types" - ] + packages = [".","types"] revision = "8908227c114abe0b63b1f0606abae72d11bf632a" [[projects]] branch = "master" name = "github.com/russellhaering/goxmldsig" - packages = [ - ".", - "etreeutils", - "types" - ] + packages = [".","etreeutils","types"] revision = "605161228693b2efadce55323c9c661a40c5fbaa" [[projects]] @@ -422,10 +306,7 @@ [[projects]] name = "github.com/sirupsen/logrus" - packages = [ - ".", - "hooks/syslog" - ] + packages = [".","hooks/syslog"] revision = "8ab1e1b91d5f1a6124287906f8b0402844d3a2b3" source = "github.com/gravitational/logrus" version = "1.0.0" @@ -443,18 +324,14 @@ [[projects]] name = "github.com/vulcand/oxy" - packages = [ - "connlimit", - "ratelimit", - "utils" - ] + packages = ["connlimit","ratelimit","utils"] revision = "5725fecc9a4f3aa6fdc3ffd29cef771241809add" [[projects]] name = "github.com/vulcand/predicate" - packages = ["."] - revision = "939c094524d124c55fa8afe0e077701db4a865e2" - version = "v1.0.0" + packages = [".","builder"] + revision = "8fbfb3ab0e94276b6b58bec378600829adc7a203" + version = "v1.1.0" [[projects]] name = "github.com/xeipuuv/gojsonpointer" @@ -475,64 +352,23 @@ [[projects]] branch = "master" name = "golang.org/x/crypto" - packages = [ - "bcrypt", - "blowfish", - "curve25519", - "ed25519", - "ed25519/internal/edwards25519", - "internal/chacha20", - "nacl/secretbox", - "poly1305", - "salsa20/salsa", - "ssh", - "ssh/agent", - "ssh/terminal" - ] + packages = ["bcrypt","blowfish","curve25519","ed25519","ed25519/internal/edwards25519","internal/chacha20","nacl/secretbox","poly1305","salsa20/salsa","ssh","ssh/agent","ssh/terminal"] revision = "b2aa35443fbc700ab74c586ae79b81c171851023" [[projects]] name = "golang.org/x/net" - packages = [ - "context", - "http2", - "http2/hpack", - "idna", - "internal/timeseries", - "lex/httplex", - "trace", - "websocket" - ] + packages = ["context","http2","http2/hpack","idna","internal/timeseries","lex/httplex","trace","websocket"] revision = "48359f4f600b3a2d5cf657458e3f940021631a56" [[projects]] branch = "master" name = "golang.org/x/sys" - packages = [ - "unix", - "windows" - ] + packages = ["unix","windows"] revision = "1d206c9fa8975fb4cf00df1dc8bf3283dc24ba0e" [[projects]] name = "golang.org/x/text" - packages = [ - "encoding", - "encoding/internal", - "encoding/internal/identifier", - "encoding/unicode", - "internal/gen", - "internal/triegen", - "internal/ucd", - "internal/utf8internal", - "runes", - "secure/bidirule", - "transform", - "unicode/bidi", - "unicode/cldr", - "unicode/norm", - "unicode/rangetable" - ] + packages = ["encoding","encoding/internal","encoding/internal/identifier","encoding/unicode","internal/gen","internal/triegen","internal/ucd","internal/utf8internal","runes","secure/bidirule","transform","unicode/bidi","unicode/cldr","unicode/norm","unicode/rangetable"] revision = "19e51611da83d6be54ddafce4a4af510cb3e9ea4" [[projects]] @@ -543,23 +379,7 @@ [[projects]] name = "google.golang.org/grpc" - packages = [ - ".", - "codes", - "connectivity", - "credentials", - "grpclb/grpc_lb_v1", - "grpclog", - "internal", - "keepalive", - "metadata", - "naming", - "peer", - "stats", - "status", - "tap", - "transport" - ] + packages = [".","codes","connectivity","credentials","grpclb/grpc_lb_v1","grpclog","internal","keepalive","metadata","naming","peer","stats","status","tap","transport"] revision = "b3ddf786825de56a4178401b7e174ee332173b66" version = "v1.5.2" @@ -581,6 +401,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "43348de969afeba8e1ab0c22221ee1a15b09c40f24aa8b7b96945a0c6bc030b3" + inputs-digest = "600efc6f221f0a0c99270cf1caa4d7bf3e607974c5e52c8b7ba3210bee185ec9" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 3f2cb24b5f292..7cda049addfac 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -36,7 +36,7 @@ ignored = ["github.com/Sirupsen/logrus"] [[constraint]] name = "github.com/vulcand/predicate" - version = "v1.0.0" + version = "v1.1.0" [[constraint]] name = "github.com/docker/docker" diff --git a/constants.go b/constants.go index d687e19bc0fd9..a1c4010e62859 100644 --- a/constants.go +++ b/constants.go @@ -53,10 +53,10 @@ const ( const ( // ComponentAuthority is a TLS and an SSH certificate authority - ComponentAuthority = "authority" + ComponentAuthority = "ca" // ComponentProcess is a main control process - ComponentProcess = "process" + ComponentProcess = "proc" // ComponentReverseTunnelServer is reverse tunnel server // that together with agent establish a bi-directional SSH revers tunnel @@ -81,7 +81,7 @@ const ( ComponentProxy = "proxy" // ComponentDiagnostic is a diagnostic service - ComponentDiagnostic = "diagnostic" + ComponentDiagnostic = "diag" // ComponentClient is a client ComponentClient = "client" @@ -105,7 +105,7 @@ const ( ComponentRemoteSubsystem = "subsystem:remote" // ComponentAuditLog is audit log component - ComponentAuditLog = "auditlog" + ComponentAuditLog = "audit" // ComponentKeyAgent is an agent that has loaded the sessions keys and // certificates for a user connected to a proxy. @@ -234,6 +234,15 @@ const ( // Syslog is a mode for syslog logging Syslog = "syslog" + + // HumanDateFormat is a human readable date formatting + HumanDateFormat = "Jan _2 15:04 UTC" + + // HumanDateFormatSeconds is a human readable date formatting with seconds + HumanDateFormatSeconds = "Jan _2 15:04:05 UTC" + + // HumanDateFormatMilli is a human readable date formatting with milliseconds + HumanDateFormatMilli = "Jan _2 15:04:05.000 UTC" ) // Component generates "component:subcomponent1:subcomponent2" strings used diff --git a/e b/e index b1f3899427704..df8533f556558 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit b1f38994277044dd50aa04a3eaf5ab8bb5536e15 +Subproject commit df8533f556558db175c3ba2df57c390c08b74369 diff --git a/integration/helpers.go b/integration/helpers.go index 641b2167eda5d..6ff1f62443b5b 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -1,6 +1,7 @@ package integration import ( + "context" "crypto/rsa" "crypto/x509/pkix" "encoding/json" @@ -61,7 +62,7 @@ type TeleInstance struct { // Slice of TCP ports used by Teleport services Ports []int - // Hostname is the name of the host where i isnstance is running + // Hostname is the name of the host where instance is running Hostname string // Internal stuff... @@ -260,7 +261,7 @@ func (s *InstanceSecrets) AsSlice() []*InstanceSecrets { } func (s *InstanceSecrets) GetIdentity() *auth.Identity { - i, err := auth.ReadIdentityFromKeyPair(s.PrivKey, s.Cert, s.TLSCert, s.TLSCACert) + i, err := auth.ReadIdentityFromKeyPair(s.PrivKey, s.Cert, s.TLSCert, [][]byte{s.TLSCACert}) fatalIf(err) return i } @@ -317,16 +318,104 @@ func (i *TeleInstance) Create(trustedSecrets []*InstanceSecrets, enableSSH bool, return i.CreateEx(trustedSecrets, tconf) } -// CreateEx creates a new instance of Teleport which trusts a list of other clusters (other -// instances) -// -// Unlike Create() it allows for greater customization because it accepts -// a full Teleport config structure -func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *service.Config) error { +// UserCreds holds user client credentials +type UserCreds struct { + // Key is user client key and certificate + Key client.Key + // HostCA is a trusted host certificate authority + HostCA services.CertAuthority +} + +// SetupUserCreds sets up user credentials for client +func SetupUserCreds(tc *client.TeleportClient, proxyHost string, creds UserCreds) error { + _, err := tc.AddKey(proxyHost, &creds.Key) + if err != nil { + return trace.Wrap(err) + } + err = tc.AddTrustedCA(creds.HostCA) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// SetupUser sets up user in the cluster +func SetupUser(process *service.TeleportProcess, username string, roles []services.Role) error { + auth := process.GetAuthServer() + teleUser, err := services.NewUser(username) + if err != nil { + return trace.Wrap(err) + } + if len(roles) == 0 { + role := services.RoleForUser(teleUser) + role.SetLogins(services.Allow, []string{username}) + + // allow tests to forward agent, still needs to be passed in client + roleOptions := role.GetOptions() + roleOptions.Set(services.ForwardAgent, true) + role.SetOptions(roleOptions) + + err = auth.UpsertRole(role, backend.Forever) + if err != nil { + return trace.Wrap(err) + } + teleUser.AddRole(role.GetMetadata().Name) + roles = append(roles, role) + } else { + for _, role := range roles { + err := auth.UpsertRole(role, backend.Forever) + if err != nil { + return trace.Wrap(err) + } + teleUser.AddRole(role.GetName()) + } + } + err = auth.UpsertUser(teleUser) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// GenerateUserCreds generates key to be used by client +func GenerateUserCreds(process *service.TeleportProcess, username string) (*UserCreds, error) { + priv, pub, err := testauthority.New().GenerateKeyPair("") + if err != nil { + return nil, trace.Wrap(err) + } + a := process.GetAuthServer() + sshCert, x509Cert, err := a.GenerateUserCerts(pub, username, time.Hour, teleport.CertificateFormatStandard) + if err != nil { + return nil, trace.Wrap(err) + } + clusterName, err := a.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + ca, err := a.GetCertAuthority(services.CertAuthID{ + Type: services.HostCA, + DomainName: clusterName.GetClusterName(), + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + return &UserCreds{ + HostCA: ca, + Key: client.Key{ + Priv: priv, + Pub: pub, + Cert: sshCert, + TLSCert: x509Cert, + }, + }, nil +} + +// GenerateConfig generates instance config +func (i *TeleInstance) GenerateConfig(trustedSecrets []*InstanceSecrets, tconf *service.Config) (*service.Config, error) { var err error dataDir, err := ioutil.TempDir("", "cluster-"+i.Secrets.SiteName) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if tconf == nil { tconf = service.MakeDefaultConfig() @@ -337,7 +426,7 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic ClusterName: i.Secrets.SiteName, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } tconf.Auth.StaticTokens, err = services.NewStaticTokens(services.StaticTokensSpecV2{ StaticTokens: []services.ProvisionToken{ @@ -348,7 +437,7 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic }, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } tconf.Auth.Authorities = append(tconf.Auth.Authorities, i.Secrets.GetCAs()...) tconf.Identities = append(tconf.Identities, i.Secrets.GetIdentity()) @@ -375,7 +464,20 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic } tconf.Keygen = testauthority.New() + i.Config = tconf + return tconf, nil +} +// CreateEx creates a new instance of Teleport which trusts a list of other clusters (other +// instances) +// +// Unlike Create() it allows for greater customization because it accepts +// a full Teleport config structure +func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *service.Config) error { + tconf, err := i.GenerateConfig(trustedSecrets, tconf) + if err != nil { + return trace.Wrap(err) + } i.Config = tconf i.Process, err = service.NewTeleport(tconf) if err != nil { @@ -423,7 +525,6 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic teleUser.AddRole(role.GetName()) } } - err = auth.UpsertUser(teleUser) if err != nil { return trace.Wrap(err) @@ -719,9 +820,22 @@ type ClientConfig struct { ForwardAgent bool } -// NewClient returns a fully configured and pre-authenticated client +// NewClientWithCreds creates client with credentials +func (i *TeleInstance) NewClientWithCreds(cfg ClientConfig, creds UserCreds) (tc *client.TeleportClient, err error) { + clt, err := i.NewUnauthenticatedClient(cfg) + if err != nil { + return nil, trace.Wrap(err) + } + err = SetupUserCreds(clt, i.Config.Proxy.SSHAddr.Addr, creds) + if err != nil { + return nil, trace.Wrap(err) + } + return clt, nil +} + +// NewUnauthenticatedClient returns a fully configured and pre-authenticated client // (pre-authenticated with server CAs and signed session key) -func (i *TeleInstance) NewClient(cfg ClientConfig) (tc *client.TeleportClient, err error) { +func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.TeleportClient, err error) { keyDir, err := ioutil.TempDir(i.Config.DataDir, "tsh") if err != nil { return nil, err @@ -764,11 +878,18 @@ func (i *TeleInstance) NewClient(cfg ClientConfig) (tc *client.TeleportClient, e } cconf.SetProxy(proxyHost, proxyWebPort, proxySSHPort) - tc, err = client.NewClient(cconf) + return client.NewClient(cconf) +} + +// NewClient returns a fully configured and pre-authenticated client +// (pre-authenticated with server CAs and signed session key) +func (i *TeleInstance) NewClient(cfg ClientConfig) (*client.TeleportClient, error) { + tc, err := i.NewUnauthenticatedClient(cfg) if err != nil { - return nil, err + return nil, trace.Wrap(err) } - // confnigures the client authenticate using the keys from 'secrets': + + // configures the client authenticate using the keys from 'secrets': user, ok := i.Secrets.Users[cfg.Login] if !ok { return nil, trace.BadParameter("unknown login %q", cfg.Login) @@ -832,7 +953,7 @@ func startAndWait(process *service.TeleportProcess, expectedEvents []string) ([] // register to listen for all ready events on the broadcast channel broadcastCh := make(chan service.Event) for _, eventName := range expectedEvents { - process.WaitForEvent(eventName, broadcastCh, make(chan struct{})) + process.WaitForEvent(context.TODO(), eventName, broadcastCh) } // start the process diff --git a/integration/integration_test.go b/integration/integration_test.go index 7ae5ee479f816..0cbf6497cfbe2 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -112,9 +112,9 @@ func (s *IntSuite) SetUpSuite(c *check.C) { } } -// newTeleport helper returns a running Teleport instance pre-configured +// newTeleport helper returns a created but not started Teleport instance pre-configured // with the current user os.user.Current(). -func (s *IntSuite) newTeleport(c *check.C, logins []string, enableSSH bool) *TeleInstance { +func (s *IntSuite) newUnstartedTeleport(c *check.C, logins []string, enableSSH bool) *TeleInstance { t := NewInstance(InstanceConfig{ClusterName: Site, HostID: HostID, NodeName: Host, Ports: s.getPorts(5), Priv: s.priv, Pub: s.pub}) // use passed logins, but use suite's default login if nothing was passed if logins == nil || len(logins) == 0 { @@ -126,6 +126,13 @@ func (s *IntSuite) newTeleport(c *check.C, logins []string, enableSSH bool) *Tel if err := t.Create(nil, enableSSH, nil); err != nil { c.Fatalf("Unexpected response from Create: %v", err) } + return t +} + +// newTeleport helper returns a running Teleport instance pre-configured +// with the current user os.user.Current(). +func (s *IntSuite) newTeleport(c *check.C, logins []string, enableSSH bool) *TeleInstance { + t := s.newUnstartedTeleport(c, logins, enableSSH) if err := t.Start(); err != nil { c.Fatalf("Unexpected response from Start: %v", err) } @@ -148,10 +155,10 @@ func (s *IntSuite) newTeleportWithConfig(c *check.C, logins []string, instanceSe // create a new teleport instance with passed in configuration if err := t.CreateEx(instanceSecrets, teleportConfig); err != nil { - c.Fatalf("Unexpected response from CreateEx: %v", err) + c.Fatalf("Unexpected response from CreateEx: %v", trace.DebugReport(err)) } if err := t.Start(); err != nil { - c.Fatalf("Unexpected response from Start: %v", err) + c.Fatalf("Unexpected response from Start: %v", trace.DebugReport(err)) } return t @@ -2004,6 +2011,520 @@ func (s *IntSuite) TestPAM(c *check.C) { } } +// TestRotateSuccess tests full cycle cert authority rotation +func (s *IntSuite) TestRotateSuccess(c *check.C) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tconf := rotationConfig(true) + t := NewInstance(InstanceConfig{ClusterName: Site, HostID: HostID, NodeName: Host, Ports: s.getPorts(5), Priv: s.priv, Pub: s.pub}) + logins := []string{s.me.Username} + for _, login := range logins { + t.AddUser(login, []string{login}) + } + config, err := t.GenerateConfig(nil, tconf) + c.Assert(err, check.IsNil) + + serviceC := make(chan *service.TeleportProcess, 20) + + runCtx, runCancel := context.WithCancel(context.TODO()) + go func() { + defer runCancel() + service.Run(ctx, *config, func(cfg *service.Config) (service.Process, error) { + svc, err := service.NewTeleport(cfg) + if err == nil { + serviceC <- svc + } + return svc, err + }) + }() + + l := log.WithFields(log.Fields{trace.Component: teleport.Component("test", "rotate")}) + + svc, err := waitForReload(serviceC, nil) + c.Assert(err, check.IsNil) + + // Setup user in the cluster + err = SetupUser(svc, s.me.Username, nil) + c.Assert(err, check.IsNil) + + // capture credentials before reload started to simulate old client + initialCreds, err := GenerateUserCreds(svc, s.me.Username) + c.Assert(err, check.IsNil) + + l.Infof("Service started. Setting rotation state to %v", services.RotationPhaseUpdateClients) + + // start rotation + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reload + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + cfg := ClientConfig{ + Login: s.me.Username, + Host: "127.0.0.1", + Port: t.GetPortSSHInt(), + } + clt, err := t.NewClientWithCreds(cfg, *initialCreds) + c.Assert(err, check.IsNil) + + // client works as is before servers have been rotated + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Setting rotation state to %v", services.RotationPhaseUpdateServers) + + // move to the next phase + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseUpdateServers, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reloaded + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + // new credentials will work from this phase to others + newCreds, err := GenerateUserCreds(svc, s.me.Username) + c.Assert(err, check.IsNil) + + clt, err = t.NewClientWithCreds(cfg, *newCreds) + c.Assert(err, check.IsNil) + + // new client works + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Setting rotation state to %v.", services.RotationPhaseStandby) + + // complete rotation + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseStandby, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reloaded + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + // new client still works + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Rotation has completed. Shuttting down service.") + + // shut down the service + cancel() + // close the service without waiting for the connections to drain + svc.Close() + + select { + case <-runCtx.Done(): + case <-time.After(20 * time.Second): + c.Fatalf("failed to shut down the server") + } +} + +// TestRotateRollback tests cert authority rollback +func (s *IntSuite) TestRotateRollback(c *check.C) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tconf := rotationConfig(true) + t := NewInstance(InstanceConfig{ClusterName: Site, HostID: HostID, NodeName: Host, Ports: s.getPorts(5), Priv: s.priv, Pub: s.pub}) + logins := []string{s.me.Username} + for _, login := range logins { + t.AddUser(login, []string{login}) + } + config, err := t.GenerateConfig(nil, tconf) + c.Assert(err, check.IsNil) + + serviceC := make(chan *service.TeleportProcess, 20) + + runCtx, runCancel := context.WithCancel(context.TODO()) + go func() { + defer runCancel() + service.Run(ctx, *config, func(cfg *service.Config) (service.Process, error) { + svc, err := service.NewTeleport(cfg) + if err == nil { + serviceC <- svc + } + return svc, err + }) + }() + + l := log.WithFields(log.Fields{trace.Component: teleport.Component("test", "rotate")}) + + svc, err := waitForReload(serviceC, nil) + c.Assert(err, check.IsNil) + + // Setup user in the cluster + err = SetupUser(svc, s.me.Username, nil) + c.Assert(err, check.IsNil) + + // capture credentials before reload started to simulate old client + initialCreds, err := GenerateUserCreds(svc, s.me.Username) + c.Assert(err, check.IsNil) + + l.Infof("Service started. Setting rotation state to %v", services.RotationPhaseUpdateClients) + + // start rotation + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reload + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + cfg := ClientConfig{ + Login: s.me.Username, + Host: "127.0.0.1", + Port: t.GetPortSSHInt(), + } + clt, err := t.NewClientWithCreds(cfg, *initialCreds) + c.Assert(err, check.IsNil) + + // client works as is before servers have been rotated + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Setting rotation state to %v", services.RotationPhaseUpdateServers) + + // move to the next phase + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseUpdateServers, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reloaded + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Setting rotation state to %v.", services.RotationPhaseRollback) + + // complete rotation + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseRollback, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reloaded + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + // old client works + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Rotation has completed. Shuttting down service.") + + // shut down the service + cancel() + // close the service without waiting for the connections to drain + svc.Close() + + select { + case <-runCtx.Done(): + case <-time.After(20 * time.Second): + c.Fatalf("failed to shut down the server") + } +} + +// TestRotateTrustedClusters tests CA rotation support for trusted clusters +func (s *IntSuite) TestRotateTrustedClusters(c *check.C) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clusterMain := "rotate-main" + clusterAux := "rotate-aux" + + tconf := rotationConfig(false) + main := NewInstance(InstanceConfig{ClusterName: clusterMain, HostID: HostID, NodeName: Host, Ports: s.getPorts(5), Priv: s.priv, Pub: s.pub}) + aux := NewInstance(InstanceConfig{ClusterName: clusterAux, HostID: HostID, NodeName: Host, Ports: s.getPorts(5), Priv: s.priv, Pub: s.pub}) + + logins := []string{s.me.Username} + for _, login := range logins { + main.AddUser(login, []string{login}) + } + config, err := main.GenerateConfig(nil, tconf) + c.Assert(err, check.IsNil) + + serviceC := make(chan *service.TeleportProcess, 20) + runCtx, runCancel := context.WithCancel(context.TODO()) + go func() { + defer runCancel() + service.Run(ctx, *config, func(cfg *service.Config) (service.Process, error) { + svc, err := service.NewTeleport(cfg) + if err == nil { + serviceC <- svc + } + return svc, err + }) + }() + + l := log.WithFields(log.Fields{trace.Component: teleport.Component("test", "rotate")}) + + svc, err := waitForReload(serviceC, nil) + c.Assert(err, check.IsNil) + + // main cluster has a local user and belongs to role "main-devs" + mainDevs := "main-devs" + role, err := services.NewRole(mainDevs, services.RoleSpecV3{ + Allow: services.RoleConditions{ + Logins: []string{s.me.Username}, + }, + }) + c.Assert(err, check.IsNil) + + err = SetupUser(svc, s.me.Username, []services.Role{role}) + c.Assert(err, check.IsNil) + + // create auxillary cluster and setup trust + c.Assert(aux.CreateEx(nil, rotationConfig(false)), check.IsNil) + + // auxiliary cluster has a role aux-devs + // connect aux cluster to main cluster + // using trusted clusters, so remote user will be allowed to assume + // role specified by mapping remote role "devs" to local role "local-devs" + auxDevs := "aux-devs" + role, err = services.NewRole(auxDevs, services.RoleSpecV3{ + Allow: services.RoleConditions{ + Logins: []string{s.me.Username}, + }, + }) + c.Assert(err, check.IsNil) + err = aux.Process.GetAuthServer().UpsertRole(role, backend.Forever) + c.Assert(err, check.IsNil) + trustedClusterToken := "trusted-clsuter-token" + err = svc.GetAuthServer().UpsertToken(trustedClusterToken, []teleport.Role{teleport.RoleTrustedCluster}, backend.Forever) + c.Assert(err, check.IsNil) + trustedCluster := main.Secrets.AsTrustedCluster(trustedClusterToken, services.RoleMap{ + {Remote: mainDevs, Local: []string{auxDevs}}, + }) + c.Assert(aux.Start(), check.IsNil) + + // try and upsert a trusted cluster + lib.SetInsecureDevMode(true) + defer lib.SetInsecureDevMode(false) + var upsertSuccess bool + for i := 0; i < 10; i++ { + log.Debugf("Will create trusted cluster %v, attempt %v", trustedCluster, i) + _, err = aux.Process.GetAuthServer().UpsertTrustedCluster(trustedCluster) + if err != nil { + if trace.IsConnectionProblem(err) { + log.Debugf("retrying on connection problem: %v", err) + continue + } + c.Fatalf("got non connection problem %v", err) + } + upsertSuccess = true + break + } + // make sure we upsert a trusted cluster + c.Assert(upsertSuccess, check.Equals, true) + + // capture credentials before has reload started to simulate old client + initialCreds, err := GenerateUserCreds(svc, s.me.Username) + c.Assert(err, check.IsNil) + + // credentials should work + cfg := ClientConfig{ + Login: s.me.Username, + Host: "127.0.0.1", + Cluster: clusterAux, + Port: aux.GetPortSSHInt(), + } + clt, err := main.NewClientWithCreds(cfg, *initialCreds) + c.Assert(err, check.IsNil) + + err = runAndMatch(clt, 6, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service started. Setting rotation state to %v", services.RotationPhaseUpdateClients) + + // start rotation + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reload + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + // waitForPhase waits until aux cluster detects the rotation + waitForPhase := func(phase string) error { + var lastPhase string + for i := 0; i < 10; i++ { + ca, err := aux.Process.GetAuthServer().GetCertAuthority(services.CertAuthID{ + Type: services.HostCA, + DomainName: clusterMain, + }, false) + c.Assert(err, check.IsNil) + if ca.GetRotation().Phase == phase { + return nil + } + lastPhase = phase + time.Sleep(tconf.PollingPeriod / 2) + } + return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) + } + + err = waitForPhase(services.RotationPhaseUpdateClients) + c.Assert(err, check.IsNil) + + // old client should work as is + err = runAndMatch(clt, 6, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Setting rotation state to %v", services.RotationPhaseUpdateServers) + + // move to the next phase + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseUpdateServers, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reloaded + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + err = waitForPhase(services.RotationPhaseUpdateServers) + c.Assert(err, check.IsNil) + + // new credentials will work from this phase to others + newCreds, err := GenerateUserCreds(svc, s.me.Username) + c.Assert(err, check.IsNil) + + clt, err = main.NewClientWithCreds(cfg, *newCreds) + c.Assert(err, check.IsNil) + + // new client works + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Setting rotation state to %v.", services.RotationPhaseStandby) + + // complete rotation + err = svc.GetAuthServer().RotateCertAuthority(auth.RotateRequest{ + TargetPhase: services.RotationPhaseStandby, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // wait until service reloaded + svc, err = waitForReload(serviceC, svc) + c.Assert(err, check.IsNil) + + err = waitForPhase(services.RotationPhaseStandby) + c.Assert(err, check.IsNil) + + // new client still works + err = runAndMatch(clt, 3, []string{"echo", "hello world"}, ".*hello world.*") + c.Assert(err, check.IsNil) + + l.Infof("Service reloaded. Rotation has completed. Shuttting down service.") + + // shut down the service + cancel() + // close the service without waiting for the connections to drain + svc.Close() + + select { + case <-runCtx.Done(): + case <-time.After(20 * time.Second): + c.Fatalf("failed to shut down the server") + } +} + +// rotationConfig sets up default config used for CA rotation tests +func rotationConfig(disableWebService bool) *service.Config { + tconf := service.MakeDefaultConfig() + tconf.SSH.Enabled = true + tconf.Proxy.DisableWebService = disableWebService + tconf.Proxy.DisableWebInterface = true + tconf.PollingPeriod = 500 * time.Millisecond + tconf.ClientTimeout = time.Second + tconf.ShutdownTimeout = 2 * tconf.ClientTimeout + return tconf +} + +// waitForReload waits for multiple events to happen: +// +// 1. new service to be created and started +// 2. old service, if present to shut down +// +// this helper function allows to serialize tests for reloads. +func waitForReload(serviceC chan *service.TeleportProcess, old *service.TeleportProcess) (*service.TeleportProcess, error) { + var svc *service.TeleportProcess + select { + case svc = <-serviceC: + case <-time.After(60 * time.Second): + return nil, trace.BadParameter("timeout waiting for service to start") + } + + eventC := make(chan service.Event, 1) + svc.WaitForEvent(context.TODO(), service.TeleportReadyEvent, eventC) + select { + case <-eventC: + + case <-time.After(20 * time.Second): + return nil, trace.BadParameter("timeout waiting for service to broadcast ready status") + } + + // if old service is present, wait for it to complete shut down procedure + if old != nil { + ctx, cancel := context.WithCancel(context.TODO()) + go func() { + defer cancel() + old.Supervisor.Wait() + }() + select { + case <-ctx.Done(): + case <-time.After(60 * time.Second): + return nil, trace.BadParameter("timeout waiting for old service to stop") + } + } + return svc, nil + +} + +// runAndMatch runs command and makes sure it matches the pattern +func runAndMatch(tc *client.TeleportClient, attempts int, command []string, pattern string) error { + output := &bytes.Buffer{} + tc.Stdout = output + var err error + for i := 0; i < attempts; i++ { + err = tc.SSH(context.TODO(), command, false) + if err != nil { + continue + } + out := output.String() + out = string(replaceNewlines(out)) + matched, _ := regexp.MatchString(pattern, out) + if matched { + return nil + } + err = trace.CompareFailed("output %q did not match pattern %q", out, pattern) + time.Sleep(250 * time.Millisecond) + } + return err +} + // runCommand is a shortcut for running SSH command, it creates a client // connected to proxy of the passed in instance, runs the command, and returns // the result. If multiple attempts are requested, a 250 millisecond delay is diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 83429bd38e2fd..121d8144df3ec 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -64,16 +64,14 @@ func NewAPIServer(config *APIConfig) http.Handler { // Operations on certificate authorities srv.GET("/:version/domain", srv.withAuth(srv.getDomainName)) + srv.POST("/:version/authorities/:type", srv.withAuth(srv.upsertCertAuthority)) + srv.POST("/:version/authorities/:type/rotate", srv.withAuth(srv.rotateCertAuthority)) + srv.POST("/:version/authorities/:type/rotate/external", srv.withAuth(srv.rotateExternalCertAuthority)) srv.DELETE("/:version/authorities/:type/:domain", srv.withAuth(srv.deleteCertAuthority)) srv.GET("/:version/authorities/:type/:domain", srv.withAuth(srv.getCertAuthority)) srv.GET("/:version/authorities/:type", srv.withAuth(srv.getCertAuthorities)) - // DELETE IN: 2.6.0 - // Certificate exchange for cluster upgrades used to upgrade from 2.4.0 - // to 2.5.0 clusters. - srv.POST("/:version/exchangecerts", srv.withAuth(srv.exchangeCerts)) - // Generating certificates for user and host authorities srv.POST("/:version/ca/host/certs", srv.withAuth(srv.generateHostCert)) srv.POST("/:version/ca/user/certs", srv.withAuth(srv.generateUserCert)) @@ -647,14 +645,6 @@ func (s *APIServer) authenticateSSHUser(auth ClientI, w http.ResponseWriter, r * return auth.AuthenticateSSHUser(req) } -func (s *APIServer) exchangeCerts(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - var req ExchangeCertsRequest - if err := httplib.ReadJSON(r, &req); err != nil { - return nil, trace.Wrap(err) - } - return auth.ExchangeCerts(req) -} - func (s *APIServer) changePassword(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { var req services.ChangePasswordReq if err := httplib.ReadJSON(r, &req); err != nil { @@ -931,6 +921,17 @@ func (s *APIServer) generateServerKeys(auth ClientI, w http.ResponseWriter, r *h return keys, nil } +func (s *APIServer) rotateCertAuthority(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req RotateRequest + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + if err := auth.RotateCertAuthority(req); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + type upsertCertAuthorityRawReq struct { CA json.RawMessage `json:"ca"` TTL time.Duration `json:"ttl"` @@ -954,6 +955,25 @@ func (s *APIServer) upsertCertAuthority(auth ClientI, w http.ResponseWriter, r * return message("ok"), nil } +type rotateExternalCertAuthorityRawReq struct { + CA json.RawMessage `json:"ca"` +} + +func (s *APIServer) rotateExternalCertAuthority(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req rotateExternalCertAuthorityRawReq + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + ca, err := services.GetCertAuthorityMarshaler().UnmarshalCertAuthority(req.CA) + if err != nil { + return nil, trace.Wrap(err) + } + if err := auth.RotateExternalCertAuthority(ca); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + func (s *APIServer) getCertAuthorities(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { loadKeys, _, err := httplib.ParseBool(r.URL.Query(), "load_keys") if err != nil { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 4a04cc93ad737..56f3bc47cfecd 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -24,8 +24,10 @@ limitations under the License. package auth import ( + "context" "crypto/x509" "fmt" + "math/rand" "net/url" "sync" "time" @@ -78,6 +80,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro if cfg.AuditLog == nil { cfg.AuditLog = events.NewDiscardAuditLog() } + closeCtx, cancelFunc := context.WithCancel(context.TODO()) as := AuthServer{ clusterName: cfg.ClusterName, bk: cfg.Backend, @@ -93,6 +96,8 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro oidcClients: make(map[string]*oidcClient), samlProviders: make(map[string]*samlProvider), githubClients: make(map[string]*githubClient), + cancelFunc: cancelFunc, + closeCtx: closeCtx, } for _, o := range opts { o(&as) @@ -100,6 +105,12 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro if as.clock == nil { as.clock = clockwork.NewRealClock() } + if !cfg.SkipPeriodicOperations { + log.Infof("Auth server is running periodic operations.") + go as.runPeriodicOperations() + } else { + log.Infof("Auth server is skipping periodic operations.") + } return &as, nil } @@ -119,6 +130,9 @@ type AuthServer struct { clock clockwork.Clock bk backend.Backend + closeCtx context.Context + cancelFunc context.CancelFunc + sshca.Authority // AuthServiceName is a human-readable name of this CA. If several Auth services are running @@ -135,17 +149,58 @@ type AuthServer struct { events.IAuditLog clusterName services.ClusterName + + // privateKey is used in tests to use pre-generated private keys + privateKey []byte +} + +// runPeriodicOperations runs some periodic bookkeeping operations +// performed by auth server +func (a *AuthServer) runPeriodicOperations() { + // run periodic functions with a semi-random period + // to avoid contention on the database in case if there are multiple + // auth servers running - so they don't compete trying + // to update the same resources. + r := rand.New(rand.NewSource(a.GetClock().Now().UnixNano())) + period := defaults.HighResPollingPeriod + time.Duration(r.Intn(int(defaults.HighResPollingPeriod/time.Second)))*time.Second + log.Debugf("Ticking with period: %v.", period) + ticker := time.NewTicker(period) + defer ticker.Stop() + for { + select { + case <-a.closeCtx.Done(): + return + case <-ticker.C: + err := a.autoRotateCertAuthorities() + if err != nil { + if trace.IsCompareFailed(err) { + log.Debugf("Cert authority has been updated concurrently: %v.", err) + } else { + log.Errorf("Failed to perform cert rotation check: %v.", err) + } + } + } + } } func (a *AuthServer) Close() error { + a.cancelFunc() if a.bk != nil { return trace.Wrap(a.bk.Close()) } return nil } +func (a *AuthServer) GetClock() clockwork.Clock { + a.lock.Lock() + defer a.lock.Unlock() + return a.clock +} + // SetClock sets clock, used in tests func (a *AuthServer) SetClock(clock clockwork.Clock) { + a.lock.Lock() + defer a.lock.Unlock() a.clock = clock } @@ -303,7 +358,11 @@ func (s *AuthServer) generateUserCert(req certRequest) (*certs, error) { if err != nil { return nil, trace.Wrap(err) } - hostCA, err := s.Trust.GetCertAuthority(services.CertAuthID{ + // CHANGE IN (2.7.0) Use user CA and not host CA here, + // currently host CA is used for backwards compatibility, + // because pre 2.6.0 remote clusters did not have TLS CA + // in user certificate authorities. + userCA, err := s.Trust.GetCertAuthority(services.CertAuthID{ Type: services.HostCA, DomainName: clusterName, }, true) @@ -311,7 +370,7 @@ func (s *AuthServer) generateUserCert(req certRequest) (*certs, error) { return nil, trace.Wrap(err) } // generate TLS certificate - tlsAuthority, err := hostCA.TLSCA() + tlsAuthority, err := userCA.TLSCA() if err != nil { return nil, trace.Wrap(err) } @@ -588,10 +647,18 @@ func (s *AuthServer) GenerateToken(req GenerateTokenRequest) (string, error) { // ClientCertPool returns trusted x509 cerificate authority pool func (s *AuthServer) ClientCertPool() (*x509.CertPool, error) { pool := x509.NewCertPool() - authorities, err := s.GetCertAuthorities(services.HostCA, false) + var authorities []services.CertAuthority + hostCAs, err := s.GetCertAuthorities(services.HostCA, false) if err != nil { return nil, trace.Wrap(err) } + userCAs, err := s.GetCertAuthorities(services.UserCA, false) + if err != nil { + return nil, trace.Wrap(err) + } + authorities = append(authorities, hostCAs...) + authorities = append(authorities, userCAs...) + for _, auth := range authorities { for _, keyPair := range auth.GetTLSKeyPairs() { cert, err := tlsca.ParseCertificatePEM(keyPair.Cert) diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index c23fd1578ff55..9a4bfb2494eac 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -19,6 +19,8 @@ package auth import ( "encoding/json" "fmt" + "io/ioutil" + "path/filepath" "testing" "time" @@ -44,8 +46,9 @@ import ( func TestAPI(t *testing.T) { TestingT(t) } type AuthSuite struct { - bk backend.Backend - a *AuthServer + bk backend.Backend + a *AuthServer + dataDir string } var _ = Suite(&AuthSuite{}) @@ -57,7 +60,8 @@ func (s *AuthSuite) SetUpSuite(c *C) { func (s *AuthSuite) SetUpTest(c *C) { var err error - s.bk, err = boltbk.New(backend.Params{"path": c.MkDir()}) + s.dataDir = c.MkDir() + s.bk, err = boltbk.New(backend.Params{"path": s.dataDir}) c.Assert(err, IsNil) clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ @@ -65,9 +69,10 @@ func (s *AuthSuite) SetUpTest(c *C) { }) c.Assert(err, IsNil) authConfig := &InitConfig{ - ClusterName: clusterName, - Backend: s.bk, - Authority: authority.New(), + ClusterName: clusterName, + Backend: s.bk, + Authority: authority.New(), + SkipPeriodicOperations: true, } s.a, err = NewAuthServer(authConfig) c.Assert(err, IsNil) @@ -144,7 +149,7 @@ func (s *AuthSuite) TestUserLock(c *C) { c.Assert(ws, NotNil) fakeClock := clockwork.NewFakeClock() - s.a.clock = fakeClock + s.a.SetClock(fakeClock) for i := 0; i <= defaults.MaxLoginAttempts; i++ { _, err = s.a.SignIn(user, []byte("wrong pass")) @@ -246,7 +251,7 @@ func (s *AuthSuite) TestTokensCRUD(c *C) { c.Assert(err, IsNil) // try to use after TTL: - s.a.clock = clockwork.NewFakeClockAt(time.Now().UTC().Add(time.Hour + 1)) + s.a.SetClock(clockwork.NewFakeClockAt(time.Now().UTC().Add(time.Hour + 1))) _, err = s.a.RegisterUsingToken(RegisterUsingTokenRequest{ Token: multiUseToken, HostID: "late.bird", @@ -547,9 +552,10 @@ func (s *AuthSuite) TestUpdateConfig(c *C) { c.Assert(err, IsNil) // use same backend but start a new auth server with different config. authConfig := &InitConfig{ - ClusterName: clusterName, - Backend: s.bk, - Authority: authority.New(), + ClusterName: clusterName, + Backend: s.bk, + Authority: authority.New(), + SkipPeriodicOperations: true, } authServer, err := NewAuthServer(authConfig) c.Assert(err, IsNil) @@ -590,129 +596,74 @@ func (s *AuthSuite) TestUpdateConfig(c *C) { }}) } -// TestMigrateRemote cluster creates remote cluster resource -// after the migration -func (s *AuthSuite) TestMigrateRemoteCluster(c *C) { - clusterName := "remote.example.com" - - hostCA := suite.NewTestCA(services.HostCA, clusterName) - hostCA.SetName(clusterName) - c.Assert(s.a.UpsertCertAuthority(hostCA), IsNil) - - err := migrateRemoteClusters(s.a) - c.Assert(err, IsNil) - - remoteCluster, err := s.a.GetRemoteCluster(clusterName) - c.Assert(err, IsNil) - c.Assert(remoteCluster.GetName(), Equals, clusterName) -} - -// TestMigrateEnabledTrustedCluster tests migrations of enabled trusted cluster -func (s *AuthSuite) TestMigrateEnabledTrustedCluster(c *C) { - clusterName := "example.com" - resourceName := "trustedcluster1" - - tunnel := services.NewReverseTunnel(resourceName, []string{"addr:5000"}) - err := s.a.UpsertReverseTunnel(tunnel) - c.Assert(err, IsNil) - - hostCA := suite.NewTestCA(services.HostCA, clusterName) - hostCA.SetName(resourceName) - c.Assert(s.a.UpsertCertAuthority(hostCA), IsNil) +// TestMigrateIdentity tests migration of the identity +func (s *AuthSuite) TestMigrateIdentity(c *C) { + c.Assert(s.a.UpsertCertAuthority( + suite.NewTestCA(services.UserCA, "me.localhost")), IsNil) - userCA := suite.NewTestCA(services.UserCA, clusterName) - userCA.SetName(resourceName) - c.Assert(s.a.UpsertCertAuthority(userCA), IsNil) + c.Assert(s.a.UpsertCertAuthority( + suite.NewTestCA(services.HostCA, "me.localhost")), IsNil) - tc, err := services.NewTrustedCluster(resourceName, services.TrustedClusterSpecV2{ - Enabled: true, - Token: "shmoken", - ProxyAddress: "addr:5000", - RoleMap: services.RoleMap{ - {Local: []string{"local"}, Remote: "remote"}, - }, + role := teleport.RoleAdmin + id := IdentityID{ + HostUUID: "test", + NodeName: "test", + Role: role, + } + packedKeys, err := s.a.GenerateServerKeys(GenerateServerKeysRequest{ + HostID: id.HostUUID, + NodeName: id.NodeName, + Roles: teleport.Roles{id.Role}, }) c.Assert(err, IsNil) - _, err = s.a.Presence.UpsertTrustedCluster(tc) - c.Assert(err, IsNil) - - err = migrateTrustedClusters(s.a) - c.Assert(err, IsNil) - - _, err = s.a.GetTrustedCluster(resourceName) - fixtures.ExpectNotFound(c, err) - - _, err = s.a.GetTrustedCluster(clusterName) + err = writeKeys(s.dataDir, id, packedKeys.Key, packedKeys.Cert, packedKeys.TLSCert, packedKeys.TLSCACerts[0]) c.Assert(err, IsNil) - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: clusterName}, false) + oldid, err := readIdentityCompat(s.dataDir, id) c.Assert(err, IsNil) - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: resourceName}, false) - fixtures.ExpectNotFound(c, err) - - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: clusterName}, false) + // migrate identities to the new format + err = migrateIdentities(s.dataDir) c.Assert(err, IsNil) - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: resourceName}, false) - fixtures.ExpectNotFound(c, err) - - _, err = s.a.GetReverseTunnel(resourceName) + // identity has been migrated, old identity has been removed + _, err = readIdentityCompat(s.dataDir, id) fixtures.ExpectNotFound(c, err) - _, err = s.a.GetReverseTunnel(clusterName) + newid, err := ReadLocalIdentity(filepath.Join(s.dataDir, teleport.ComponentProcess), id) + newid.ID.NodeName = id.NodeName c.Assert(err, IsNil) -} - -// TestMigrateDisabledTrustedCluster tests migrations of disabled trusted cluster -func (s *AuthSuite) TestMigrateDisabledTrustedCluster(c *C) { - clusterName := "example.com" - resourceName := "trustedcluster1" - - hostCA := suite.NewTestCA(services.HostCA, clusterName) - hostCA.SetName(resourceName) - c.Assert(s.a.UpsertCertAuthority(hostCA), IsNil) - userCA := suite.NewTestCA(services.UserCA, clusterName) - userCA.SetName(resourceName) - c.Assert(s.a.UpsertCertAuthority(userCA), IsNil) + fixtures.DeepCompare(c, newid, oldid) - err := s.a.DeactivateCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: resourceName}) + // migrate identities to the new format does nothing + // if migration has already happened + err = migrateIdentities(s.dataDir) c.Assert(err, IsNil) - err = s.a.DeactivateCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: resourceName}) + newid, err = ReadLocalIdentity(filepath.Join(s.dataDir, teleport.ComponentProcess), id) + newid.ID.NodeName = id.NodeName c.Assert(err, IsNil) - tc, err := services.NewTrustedCluster(resourceName, services.TrustedClusterSpecV2{ - Enabled: false, - Token: "shmoken", - ProxyAddress: "addr", - RoleMap: services.RoleMap{ - {Local: []string{"local"}, Remote: "remote"}, - }, - }) - c.Assert(err, IsNil) - _, err = s.a.Presence.UpsertTrustedCluster(tc) - c.Assert(err, IsNil) - - err = migrateTrustedClusters(s.a) - c.Assert(err, IsNil) - - _, err = s.a.GetTrustedCluster(resourceName) - fixtures.ExpectNotFound(c, err) - - _, err = s.a.GetTrustedCluster(clusterName) - c.Assert(err, IsNil) - - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: clusterName}, false) - fixtures.ExpectNotFound(c, err) - - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: resourceName}, false) - fixtures.ExpectNotFound(c, err) + fixtures.DeepCompare(c, newid, oldid) +} - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: clusterName}, false) - fixtures.ExpectNotFound(c, err) +// writeKeys saves the key/cert pair for a given domain onto disk. This usually means the +// domain trusts us (signed our public key) +func writeKeys(dataDir string, id IdentityID, key []byte, sshCert []byte, tlsCert []byte, tlsCACert []byte) error { + path := keysPath(dataDir, id) - _, err = s.a.GetCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: resourceName}, false) - fixtures.ExpectNotFound(c, err) + if err := ioutil.WriteFile(path.key, key, teleport.FileMaskOwnerOnly); err != nil { + return trace.Wrap(err) + } + if err := ioutil.WriteFile(path.sshCert, sshCert, teleport.FileMaskOwnerOnly); err != nil { + return trace.Wrap(err) + } + if err := ioutil.WriteFile(path.tlsCert, tlsCert, teleport.FileMaskOwnerOnly); err != nil { + return trace.Wrap(err) + } + if err := ioutil.WriteFile(path.tlsCACert, tlsCACert, teleport.FileMaskOwnerOnly); err != nil { + return trace.Wrap(err) + } + return nil } diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 95ca76dacbe09..e2f884920f7fe 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -41,6 +41,10 @@ type AuthWithRoles struct { alog events.IAuditLog } +func (a *AuthWithRoles) actionWithContext(ctx *services.Context, namespace string, resource string, action string) error { + return a.checker.CheckAccessToRule(ctx, namespace, resource, action) +} + func (a *AuthWithRoles) action(namespace string, resource string, action string) error { return a.checker.CheckAccessToRule(&services.Context{User: a.user}, namespace, resource, action) } @@ -91,16 +95,6 @@ func (a *AuthWithRoles) AuthenticateSSHUser(req AuthenticateSSHRequest) (*SSHLog return a.authServer.AuthenticateSSHUser(req) } -// ExchangeCerts exchanges TLS certificates for established host certificate authorities -func (a *AuthWithRoles) ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) { - // exchange request has it's own authentication, however this limits the requests - // types to proxies to make it harder to break - if !a.checker.HasRole(string(teleport.RoleProxy)) { - return nil, trace.AccessDenied("this request can be only executed by proxy") - } - return a.authServer.ExchangeCerts(req) -} - func (a *AuthWithRoles) GetSessions(namespace string) ([]session.Session, error) { if err := a.action(namespace, services.KindSSHSession, services.VerbList); err != nil { return nil, trace.Wrap(err) @@ -134,16 +128,61 @@ func (a *AuthWithRoles) CreateCertAuthority(ca services.CertAuthority) error { return trace.BadParameter("not implemented") } -func (a *AuthWithRoles) UpsertCertAuthority(ca services.CertAuthority) error { +// RotateCertAuthority starts or restarts certificate authority rotation process. +func (a *AuthWithRoles) RotateCertAuthority(req RotateRequest) error { + if err := req.CheckAndSetDefaults(a.authServer.clock); err != nil { + return trace.Wrap(err) + } if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbCreate); err != nil { return trace.Wrap(err) } if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbUpdate); err != nil { return trace.Wrap(err) } + return a.authServer.RotateCertAuthority(req) +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is called by a remote trusted cluster and is used to update +// only public keys and certificates of the certificate authority. +func (a *AuthWithRoles) RotateExternalCertAuthority(ca services.CertAuthority) error { + if ca == nil { + return trace.BadParameter("missing certificate authority") + } + ctx := &services.Context{User: a.user, Resource: ca} + if err := a.actionWithContext(ctx, defaults.Namespace, services.KindCertAuthority, services.VerbRotate); err != nil { + return trace.Wrap(err) + } + return a.authServer.RotateExternalCertAuthority(ca) +} + +// UpsertCertAuthority updates existing cert authority or updates the existing one. +func (a *AuthWithRoles) UpsertCertAuthority(ca services.CertAuthority) error { + if ca == nil { + return trace.BadParameter("missing certificate authority") + } + ctx := &services.Context{User: a.user, Resource: ca} + if err := a.actionWithContext(ctx, defaults.Namespace, services.KindCertAuthority, services.VerbCreate); err != nil { + return trace.Wrap(err) + } + if err := a.actionWithContext(ctx, defaults.Namespace, services.KindCertAuthority, services.VerbUpdate); err != nil { + return trace.Wrap(err) + } return a.authServer.UpsertCertAuthority(ca) } +// CompareAndSwapCertAuthority updates existing cert authority if the existing cert authority +// value matches the value stored in the backend. +func (a *AuthWithRoles) CompareAndSwapCertAuthority(new, existing services.CertAuthority) error { + if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbCreate); err != nil { + return trace.Wrap(err) + } + if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbUpdate); err != nil { + return trace.Wrap(err) + } + return a.authServer.CompareAndSwapCertAuthority(new, existing) +} + func (a *AuthWithRoles) GetCertAuthorities(caType services.CertAuthType, loadKeys bool) ([]services.CertAuthority, error) { if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbList); err != nil { return nil, trace.Wrap(err) @@ -156,7 +195,6 @@ func (a *AuthWithRoles) GetCertAuthorities(caType services.CertAuthType, loadKey return nil, trace.Wrap(err) } } - return a.authServer.GetCertAuthorities(caType, loadKeys) } @@ -172,13 +210,6 @@ func (a *AuthWithRoles) GetCertAuthority(id services.CertAuthID, loadKeys bool) return a.authServer.GetCertAuthority(id, loadKeys) } -func (a *AuthWithRoles) GetAnyCertAuthority(id services.CertAuthID) (services.CertAuthority, error) { - if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbReadNoSecrets); err != nil { - return nil, trace.Wrap(err) - } - return a.authServer.GetAnyCertAuthority(id) -} - func (a *AuthWithRoles) GetDomainName() (string, error) { // anyone can read it, no harm in that return a.authServer.GetDomainName() diff --git a/lib/auth/clt.go b/lib/auth/clt.go index c394a3b919387..63a491421f765 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -97,6 +97,20 @@ func NewAddrDialer(addrs []utils.NetAddr) DialContext { } } +// ClientTimeout sets idle and dial timeouts of the HTTP transport +// used by the client. +func ClientTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + transport, ok := (c.HTTPClient().Transport).(*http.Transport) + if !ok { + return nil + } + transport.IdleConnTimeout = timeout + transport.ResponseHeaderTimeout = timeout + return nil + } +} + // NewTLSClientWithDialer returns new TLS client that uses mutual TLS authenticate // and dials the remote server using dialer func NewTLSClientWithDialer(dialContext DialContext, cfg *tls.Config, params ...roundtrip.ClientParam) (*Client, error) { @@ -127,11 +141,11 @@ func NewTLSClientWithDialer(dialContext DialContext, cfg *tls.Config, params ... } } - params = append(params, roundtrip.HTTPClient(&http.Client{ - Transport: transport, - })) - - roundtripClient, err := roundtrip.NewClient("https://"+teleport.APIDomain, CurrentVersion, params...) + clientParams := append( + []roundtrip.ClientParam{roundtrip.HTTPClient(&http.Client{Transport: transport})}, + params..., + ) + roundtripClient, err := roundtrip.NewClient("https://"+teleport.APIDomain, CurrentVersion, clientParams...) if err != nil { return nil, trace.Wrap(err) } @@ -303,6 +317,32 @@ func (c *Client) CreateCertAuthority(ca services.CertAuthority) error { return trace.BadParameter("not implemented") } +// RotateCertAuthority starts or restarts certificate authority rotation process. +func (c *Client) RotateCertAuthority(req RotateRequest) error { + caType := "all" + if req.Type != "" { + caType = string(req.Type) + } + _, err := c.PostJSON(c.Endpoint("authorities", caType, "rotate"), req) + return trace.Wrap(err) +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is used to update only public keys and certificates of the +// the certificate authorities of trusted clusters. +func (c *Client) RotateExternalCertAuthority(ca services.CertAuthority) error { + if err := ca.Check(); err != nil { + return trace.Wrap(err) + } + data, err := services.GetCertAuthorityMarshaler().MarshalCertAuthority(ca) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(c.Endpoint("authorities", string(ca.GetType()), "rotate", "external"), + &rotateExternalCertAuthorityRawReq{CA: data}) + return trace.Wrap(err) +} + // UpsertCertAuthority updates or inserts new cert authority func (c *Client) UpsertCertAuthority(ca services.CertAuthority) error { if err := ca.Check(); err != nil { @@ -317,6 +357,12 @@ func (c *Client) UpsertCertAuthority(ca services.CertAuthority) error { return trace.Wrap(err) } +// CompareAndSwapCertAuthority updates existing cert authority if the existing cert authority +// value matches the value stored in the backend. +func (c *Client) CompareAndSwapCertAuthority(new, existing services.CertAuthority) error { + return trace.BadParameter("this function is not supported on the client") +} + // GetCertAuthorities returns a list of certificate authorities func (c *Client) GetCertAuthorities(caType services.CertAuthType, loadKeys bool) ([]services.CertAuthority, error) { if err := caType.Check(); err != nil { @@ -358,11 +404,6 @@ func (c *Client) GetCertAuthority(id services.CertAuthID, loadSigningKeys bool) return services.GetCertAuthorityMarshaler().UnmarshalCertAuthority(out.Bytes()) } -// GetAnyCertAuthority returns certificate authority by given id whether it's activated or not -func (c *Client) GetAnyCertAuthority(id services.CertAuthID) (services.CertAuthority, error) { - return nil, trace.BadParameter("not implemented") -} - // DeleteCertAuthority deletes cert authority by ID func (c *Client) DeleteCertAuthority(id services.CertAuthID) error { if err := id.Check(); err != nil { @@ -925,23 +966,6 @@ func (c *Client) CreateWebSession(user string) (services.WebSession, error) { return services.GetWebSessionMarshaler().UnmarshalWebSession(out.Bytes()) } -// DELETE IN: 2.6.0 -// ExchangeCerts exchanges TLS certificates for established host certificate authorities -func (c *Client) ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) { - out, err := c.PostJSON( - c.Endpoint("exchangecerts"), - req, - ) - if err != nil { - return nil, trace.Wrap(err) - } - var re ExchangeCertsResponse - if err := json.Unmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return &re, nil -} - // AuthenticateWebUser authenticates web user, creates and returns web session // in case if authentication is successfull func (c *Client) AuthenticateWebUser(req AuthenticateUserRequest) (services.WebSession, error) { @@ -2266,6 +2290,14 @@ type ClientI interface { session.Service services.ClusterConfiguration + // RotateCertAuthority starts or restarts certificate authority rotation process. + RotateCertAuthority(req RotateRequest) error + + // RotateExternalCertAuthority rotates external certificate authority, + // this method is used to update only public keys and certificates of the + // the certificate authorities of trusted clusters. + RotateExternalCertAuthority(ca services.CertAuthority) error + // ValidateTrustedCluster validates trusted cluster token with // main cluster, in case if validation is successfull, main cluster // adds remote cluster @@ -2281,8 +2313,4 @@ type ClientI interface { // AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH // short lived certificates as a result AuthenticateSSHUser(req AuthenticateSSHRequest) (*SSHLoginResponse, error) - - // DELETE IN: 2.6.0 - // ExchangeCerts exchanges TLS certificates between host certificate authorities of trusted clusters - ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) } diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 87ea47eac6438..8dd50d1283009 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -123,12 +123,13 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { } srv.AuthServer, err = NewAuthServer(&InitConfig{ - ClusterName: clusterName, - Backend: srv.Backend, - Authority: authority.New(), - Access: access, - Identity: identity, - AuditLog: srv.AuditLog, + ClusterName: clusterName, + Backend: srv.Backend, + Authority: authority.New(), + Access: access, + Identity: identity, + AuditLog: srv.AuditLog, + SkipPeriodicOperations: true, }) if err != nil { return nil, trace.Wrap(err) @@ -249,7 +250,7 @@ func (a *TestAuthServer) NewCertificate(identity TestIdentity) (*tls.Certificate // Clock returns clock used by auth server func (a *TestAuthServer) Clock() clockwork.Clock { - return a.AuthServer.clock + return a.AuthServer.GetClock() } // Trust adds other server host certificate authority as trusted @@ -490,6 +491,17 @@ func (t *TestTLSServer) ClientTLSConfig(identity TestIdentity) (*tls.Config, err return tlsConfig, nil } +// CloneClient uses the same credentials as the passed client +// but forces the client to be recreated +func (t *TestTLSServer) CloneClient(clt *Client) *Client { + addr := []utils.NetAddr{{Addr: t.Addr().String(), AddrNetwork: t.Addr().Network()}} + newClient, err := NewTLSClient(addr, clt.TLSConfig()) + if err != nil { + panic(err) + } + return newClient +} + // NewClient returns new client to test server authenticated with identity func (t *TestTLSServer) NewClient(identity TestIdentity) (*Client, error) { tlsConfig, err := t.ClientTLSConfig(identity) @@ -549,7 +561,7 @@ func NewServerIdentity(clt *AuthServer, hostID string, role teleport.Role) (*Ide if err != nil { return nil, trace.Wrap(err) } - return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } // clt limits required interface to the necessary methods diff --git a/lib/auth/init.go b/lib/auth/init.go index d16738d4fb2c6..fc9908a7e8587 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -22,7 +22,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" - "io/ioutil" "os" "path/filepath" "strings" @@ -115,33 +114,37 @@ type InitConfig struct { // factor (off, otp, u2f) passed in from a configuration file. AuthPreference services.AuthPreference - // AuditLog is used for emitting events to audit log + // AuditLog is used for emitting events to audit log. AuditLog events.IAuditLog // ClusterConfig holds cluster level configuration. ClusterConfig services.ClusterConfig + + // SkipPeriodicOperations turns off periodic operations + // used in tests that don't need periodc operations. + SkipPeriodicOperations bool } // Init instantiates and configures an instance of AuthServer -func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, error) { +func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, error) { if cfg.DataDir == "" { - return nil, nil, trace.BadParameter("DataDir: data dir can not be empty") + return nil, trace.BadParameter("DataDir: data dir can not be empty") } if cfg.HostUUID == "" { - return nil, nil, trace.BadParameter("HostUUID: host UUID can not be empty") + return nil, trace.BadParameter("HostUUID: host UUID can not be empty") } domainName := cfg.ClusterName.GetClusterName() err := cfg.Backend.AcquireLock(domainName, 30*time.Second) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } defer cfg.Backend.ReleaseLock(domainName) // check that user CA and host CA are present and set the certs if needed asrv, err := NewAuthServer(&cfg, opts...) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } // INTERNAL: Authorities (plus Roles) and ReverseTunnels don't follow the @@ -149,7 +152,7 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err // singletons). However, we need to keep them around while Telekube uses them. for _, role := range cfg.Roles { if err := asrv.UpsertRole(role, backend.Forever); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created role: %v.", role) } @@ -157,17 +160,22 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err ca := cfg.Authorities[i] ca, err = services.GetCertAuthorityMarshaler().GenerateCertAuthority(ca) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - - if err := asrv.Trust.UpsertCertAuthority(ca); err != nil { - return nil, nil, trace.Wrap(err) + // Don't re-create CA if it already exists, otherwise + // the existing cluster configuration will be corrupted; + // this part of code is only used in tests. + if err := asrv.Trust.CreateCertAuthority(ca); err != nil { + if !trace.IsAlreadyExists(err) { + return nil, trace.Wrap(err) + } + } else { + log.Infof("Created trusted certificate authority: %q, type: %q.", ca.GetName(), ca.GetType()) } - log.Infof("Created trusted certificate authority: %q, type: %q.", ca.GetName(), ca.GetType()) } for _, tunnel := range cfg.ReverseTunnels { if err := asrv.UpsertReverseTunnel(tunnel); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created reverse tunnel: %v.", tunnel) } @@ -175,7 +183,7 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err // set cluster level config on the backend and then force a sync of the cache. clusterConfig, err := asrv.GetClusterConfig() if err != nil && !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } // init a unique cluster ID, it must be set once only during the first // start so if it's already there, reuse it @@ -186,42 +194,42 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err } err = asrv.SetClusterConfig(cfg.ClusterConfig) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } // cluster name can only be set once. if it has already been set and we are // trying to update it to something else, hard fail. err = asrv.SetClusterName(cfg.ClusterName) if err != nil && !trace.IsAlreadyExists(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if trace.IsAlreadyExists(err) { cn, err := asrv.GetClusterName() if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if cn.GetClusterName() != cfg.ClusterName.GetClusterName() { - return nil, nil, trace.BadParameter("cannot rename cluster %q to %q: clusters cannot be renamed", cn.GetClusterName(), cfg.ClusterName.GetClusterName()) + return nil, trace.BadParameter("cannot rename cluster %q to %q: clusters cannot be renamed", cn.GetClusterName(), cfg.ClusterName.GetClusterName()) } } log.Debugf("Cluster configuration: %v.", cfg.ClusterName) err = asrv.SetStaticTokens(cfg.StaticTokens) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Updating cluster configuration: %v.", cfg.StaticTokens) err = asrv.SetAuthPreference(cfg.AuthPreference) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Updating cluster configuration: %v.", cfg.AuthPreference) // always create the default namespace err = asrv.UpsertNamespace(services.NewNamespace(defaults.Namespace)) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created namespace: %q.", defaults.Namespace) @@ -229,22 +237,31 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err defaultRole := services.NewAdminRole() err = asrv.CreateRole(defaultRole, backend.Forever) if err != nil && !trace.IsAlreadyExists(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if !trace.IsAlreadyExists(err) { log.Infof("Created default admin role: %q.", defaultRole.GetName()) } // generate a user certificate authority if it doesn't exist - if _, err := asrv.GetCertAuthority(services.CertAuthID{DomainName: cfg.ClusterName.GetClusterName(), Type: services.UserCA}, false); err != nil { + userCA, err := asrv.GetCertAuthority(services.CertAuthID{DomainName: cfg.ClusterName.GetClusterName(), Type: services.UserCA}, true) + if err != nil { if !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("First start: generating user certificate authority.") priv, pub, err := asrv.GenerateKeyPair("") if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) + } + + keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: cfg.ClusterName.GetClusterName(), + Organization: []string{cfg.ClusterName.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return nil, trace.Wrap(err) } userCA := &services.CertAuthorityV2{ @@ -259,11 +276,25 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err Type: services.UserCA, SigningKeys: [][]byte{priv}, CheckingKeys: [][]byte{pub}, + TLSKeyPairs: []services.TLSKeyPair{{Cert: certPEM, Key: keyPEM}}, }, } if err := asrv.Trust.UpsertCertAuthority(userCA); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) + } + } else if len(userCA.GetTLSKeyPairs()) == 0 { + log.Infof("Migrate: generating TLS CA for existing user CA.") + keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: cfg.ClusterName.GetClusterName(), + Organization: []string{cfg.ClusterName.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return nil, trace.Wrap(err) + } + userCA.SetTLSKeyPairs([]services.TLSKeyPair{{Cert: certPEM, Key: keyPEM}}) + if err := asrv.Trust.UpsertCertAuthority(userCA); err != nil { + return nil, trace.Wrap(err) } } @@ -271,13 +302,13 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err hostCA, err := asrv.GetCertAuthority(services.CertAuthID{DomainName: cfg.ClusterName.GetClusterName(), Type: services.HostCA}, true) if err != nil { if !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("First start: generating host certificate authority.") priv, pub, err := asrv.GenerateKeyPair("") if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ @@ -285,10 +316,9 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err Organization: []string{cfg.ClusterName.GetClusterName()}, }, nil, defaults.CATTL) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - - hostCA := &services.CertAuthorityV2{ + hostCA = &services.CertAuthorityV2{ Kind: services.KindCertAuthority, Version: services.V2, Metadata: services.Metadata{ @@ -304,28 +334,28 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err }, } if err := asrv.Trust.UpsertCertAuthority(hostCA); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } } else if len(hostCA.GetTLSKeyPairs()) == 0 { log.Infof("Migrate: generating TLS CA for existing host CA.") privateKey, err := ssh.ParseRawPrivateKey(hostCA.GetSigningKeys()[0]) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } privateKeyRSA, ok := privateKey.(*rsa.PrivateKey) if !ok { - return nil, nil, trace.BadParameter("expected RSA private key, got %T", privateKey) + return nil, trace.BadParameter("expected RSA private key, got %T", privateKey) } keyPEM, certPEM, err := tlsca.GenerateSelfSignedCAWithPrivateKey(privateKeyRSA, pkix.Name{ CommonName: cfg.ClusterName.GetClusterName(), Organization: []string{cfg.ClusterName.GetClusterName()}, }, nil, defaults.CATTL) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } hostCA.SetTLSKeyPairs([]services.TLSKeyPair{{Cert: certPEM, Key: keyPEM}}) if err := asrv.Trust.UpsertCertAuthority(hostCA); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -336,24 +366,13 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err log.Warn(warningMessage) } - // read host keys from disk or create them if they don't exist - iid := IdentityID{ - HostUUID: cfg.HostUUID, - NodeName: cfg.NodeName, - Role: teleport.RoleAdmin, - } - identity, err := initKeys(asrv, cfg.DataDir, iid) - if err != nil { - return nil, nil, trace.Wrap(err) - } - // migrate any legacy resources to new format err = migrateLegacyResources(cfg, asrv) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - return asrv, identity, nil + return asrv, nil } func migrateLegacyResources(cfg InitConfig, asrv *AuthServer) error { @@ -361,223 +380,77 @@ func migrateLegacyResources(cfg InitConfig, asrv *AuthServer) error { if err != nil { return trace.Wrap(err) } - - err = migrateRoles(asrv) - if err != nil { - return trace.Wrap(err) - } - err = migrateRemoteClusters(asrv) if err != nil { return trace.Wrap(err) } - - err = migrateTrustedClusters(asrv) + err = migrateIdentities(cfg.DataDir) if err != nil { return trace.Wrap(err) } - return nil } -func migrateUsers(asrv *AuthServer) error { - users, err := asrv.GetUsers() +func migrateIdentities(dataDir string) error { + storage, err := NewProcessStorage(filepath.Join(dataDir, teleport.ComponentProcess)) if err != nil { return trace.Wrap(err) } - - for i := range users { - user := users[i] - raw, ok := (user.GetRawObject()).(services.UserV1) - if !ok { - continue - } - log.Infof("Migrating legacy user: %v.", user.GetName()) - - // create role for user and upsert to backend - role := services.RoleForUser(user) - role.SetLogins(services.Allow, raw.AllowedLogins) - err = asrv.UpsertRole(role, backend.Forever) - if err != nil { - return trace.Wrap(err) - } - - // upsert new user to backend - user.AddRole(role.GetName()) - if err := asrv.UpsertUser(user); err != nil { + for _, role := range []teleport.Role{teleport.RoleAdmin, teleport.RoleProxy, teleport.RoleNode} { + if err := migrateIdentity(role, dataDir, storage); err != nil { return trace.Wrap(err) } } - return nil } -// DELETE IN: 2.6.0 -// All users will be migrated to the new roles in Teleport 2.5.0, which means -// this entire function can be removed in Teleport 2.6.0. -func migrateRoles(asrv *AuthServer) error { - roles, err := asrv.GetRoles() +func migrateIdentity(role teleport.Role, dataDir string, storage *ProcessStorage) error { + identity, err := readIdentityCompat(dataDir, IdentityID{Role: role}) if err != nil { - return trace.Wrap(err) - } - - // loop over all roles and make sure any v3 roles have the default value for - // certificate format - for i, _ := range roles { - role := roles[i] - - roleOptions := role.GetOptions() - - _, err = roleOptions.GetString(services.CertificateFormat) - if err != nil { - roleOptions.Set(services.CertificateFormat, teleport.CertificateFormatStandard) - role.SetOptions(roleOptions) - } - - err = asrv.UpsertRole(role, backend.Forever) - if err != nil { + if !trace.IsNotFound(err) { return trace.Wrap(err) } - log.Infof("Migrating role: %v to include default for the cert_format option.", role.GetName()) + return nil } - - return nil -} - -// DELETE IN: 2.6.0 -// This migration adds remote cluster resource migrating from 2.5.0 -// where the presence of remote cluster was identified only by presence -// of host certificate authority with cluster name not equal local cluster name -func migrateRemoteClusters(asrv *AuthServer) error { - clusterName, err := asrv.GetClusterName() + err = storage.WriteIdentity(IdentityCurrent, *identity) if err != nil { return trace.Wrap(err) } - certAuthorities, err := asrv.GetCertAuthorities(services.HostCA, false) + err = removeIdentityCompat(dataDir, IdentityID{Role: role}) if err != nil { - return trace.Wrap(err) - } - - // loop over all roles and make sure any v3 roles have permit port - // forward and forward agent allowed - for _, certAuthority := range certAuthorities { - if certAuthority.GetName() == clusterName.GetClusterName() { - log.Debugf("Migrations: skipping local cluster cert authority %q.", certAuthority.GetName()) - continue - } - // remote cluster already exists - _, err = asrv.GetRemoteCluster(certAuthority.GetName()) - if err == nil { - log.Debugf("Migrations: remote cluster already exists for cert authority %q.", certAuthority.GetName()) - continue - } - if !trace.IsNotFound(err) { - return trace.Wrap(err) - } - // the cert authority is associated with trusted cluster - _, err = asrv.GetTrustedCluster(certAuthority.GetName()) - if err == nil { - log.Debugf("Migrations: trusted cluster resource exists for cert authority %q.", certAuthority.GetName()) - continue - } if !trace.IsNotFound(err) { return trace.Wrap(err) } - remoteCluster, err := services.NewRemoteCluster(certAuthority.GetName()) - if err != nil { - return trace.Wrap(err) - } - err = asrv.CreateRemoteCluster(remoteCluster) - if err != nil { - if !trace.IsAlreadyExists(err) { - return trace.Wrap(err) - } - } - log.Infof("Migrations: added remote cluster resource for cert authority %q.", certAuthority.GetName()) } - + log.Infof("Identity %v has been migrated to new on-disk format.", role) return nil } -// DELETE IN: 2.6.0 -// migrateTrustedClusters renames the trusted cluster resource names -// and certificate authorities names to equal to actual remote cluster name -func migrateTrustedClusters(asrv *AuthServer) error { - trustedClusters, err := asrv.GetTrustedClusters() +func migrateUsers(asrv *AuthServer) error { + users, err := asrv.GetUsers() if err != nil { return trace.Wrap(err) } - // loop over all roles and make sure any v3 roles have permit port - // forward and forward agent allowed - for i := range trustedClusters { - trustedCluster := trustedClusters[i] - - hostCA, err := asrv.GetAnyCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - - userCA, err := asrv.GetAnyCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - - if hostCA.GetClusterName() == trustedCluster.GetName() { - log.Debugf("Migrations: skipping trusted cluster %q with name that already matches main cluster.", trustedCluster.GetName()) + for i := range users { + user := users[i] + raw, ok := (user.GetRawObject()).(services.UserV1) + if !ok { continue } + log.Infof("Migrating legacy user: %v.", user.GetName()) - var reverseTunnel services.ReverseTunnel - if trustedCluster.GetEnabled() { - reverseTunnel, err = asrv.GetReverseTunnel(trustedCluster.GetName()) - if err != nil { - return trace.Wrap(err) - } - } - - log.Debugf("Migrations: renaming trusted cluster %q to %q.", trustedCluster.GetName(), hostCA.GetClusterName()) - - oldName := trustedCluster.GetName() - - trustedCluster.SetName(hostCA.GetClusterName()) - _, err = asrv.Presence.UpsertTrustedCluster(trustedCluster) - if err != nil { - return trace.Wrap(err) - } - - hostCA.SetName(hostCA.GetClusterName()) - err = asrv.UpsertCertAuthority(hostCA) - if err != nil { - return trace.Wrap(err) - } - - userCA.SetName(hostCA.GetClusterName()) - err = asrv.UpsertCertAuthority(userCA) + // create role for user and upsert to backend + role := services.RoleForUser(user) + role.SetLogins(services.Allow, raw.AllowedLogins) + err = asrv.UpsertRole(role, backend.Forever) if err != nil { return trace.Wrap(err) } - if reverseTunnel != nil { - reverseTunnel.SetName(hostCA.GetClusterName()) - reverseTunnel.SetClusterName(hostCA.GetClusterName()) - if err := asrv.UpsertReverseTunnel(reverseTunnel); err != nil { - return trace.Wrap(err) - } - } - - if !trustedCluster.GetEnabled() { - log.Debugf("Migrations: trusted cluster %q is deactivated, deactivating updated authorities for %q", oldName, trustedCluster.GetName()) - err = asrv.DeactivateCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - err = asrv.DeactivateCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - } - if err := asrv.DeleteTrustedCluster(oldName); err != nil { + // upsert new user to backend + user.AddRole(role.GetName()) + if err := asrv.UpsertUser(user); err != nil { return trace.Wrap(err) } } @@ -604,71 +477,18 @@ func isFirstStart(authServer *AuthServer, cfg InitConfig) (bool, error) { return false, nil } -// initKeys initializes a nodes host certificate. If the certificate does not exist, a request -// is made to the certificate authority to generate a host certificate and it's written to disk. -// If a certificate exists on disk, it is read in and returned. -func initKeys(a *AuthServer, dataDir string, id IdentityID) (*Identity, error) { - path := keysPath(dataDir, id) - - keyExists, err := pathExists(path.key) - if err != nil { - return nil, trace.Wrap(err) - } - - sshCertExists, err := pathExists(path.sshCert) - if err != nil { - return nil, trace.Wrap(err) - } - - tlsCertExists, err := pathExists(path.tlsCert) - if err != nil { - return nil, trace.Wrap(err) - } - - if !keyExists || !sshCertExists || !tlsCertExists { - packedKeys, err := a.GenerateServerKeys(GenerateServerKeysRequest{ - HostID: id.HostUUID, - NodeName: id.NodeName, - Roles: teleport.Roles{id.Role}, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - log.Debugf("Writing keys to disk for %v.", id) - if len(packedKeys.TLSCACerts) != 1 { - return nil, trace.BadParameter("expecting one CA cert, got %v instead", len(packedKeys.TLSCACerts)) - } - err = writeKeys(dataDir, id, packedKeys.Key, packedKeys.Cert, packedKeys.TLSCert, packedKeys.TLSCACerts[0]) - if err != nil { - return nil, trace.Wrap(err) - } - } - i, err := ReadIdentity(dataDir, id) +// GenerateIdentity generates identity for the auth server +func GenerateIdentity(a *AuthServer, id IdentityID, additionalPrincipals []string) (*Identity, error) { + keys, err := a.GenerateServerKeys(GenerateServerKeysRequest{ + HostID: id.HostUUID, + NodeName: id.NodeName, + Roles: teleport.Roles{id.Role}, + AdditionalPrincipals: additionalPrincipals, + }) if err != nil { return nil, trace.Wrap(err) } - return i, nil -} - -// writeKeys saves the key/cert pair for a given domain onto disk. This usually means the -// domain trusts us (signed our public key) -func writeKeys(dataDir string, id IdentityID, key []byte, sshCert []byte, tlsCert []byte, tlsCACert []byte) error { - path := keysPath(dataDir, id) - - if err := ioutil.WriteFile(path.key, key, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - if err := ioutil.WriteFile(path.sshCert, sshCert, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - if err := ioutil.WriteFile(path.tlsCert, tlsCert, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - if err := ioutil.WriteFile(path.tlsCACert, tlsCACert, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - return nil + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } // Identity is collection of certificates and signers that represent server identity @@ -681,9 +501,9 @@ type Identity struct { CertBytes []byte // TLSCertBytes is a PEM encoded TLS x509 client certificate TLSCertBytes []byte - // TLSCACertBytes is a PEM encoded TLS x509 certificate of certificate authority + // TLSCACertBytes is a list of PEM encoded TLS x509 certificate of certificate authority // associated with auth server services - TLSCACertBytes []byte + TLSCACertsBytes [][]byte // KeySigner is an SSH host certificate signer KeySigner ssh.Signer // Cert is a parsed SSH certificate @@ -692,9 +512,29 @@ type Identity struct { ClusterName string } +// String returns user-friendly representation of the identity. +func (i *Identity) String() string { + var out []string + cert, err := tlsca.ParseCertificatePEM(i.TLSCertBytes) + if err != nil { + out = append(out, err.Error()) + } else { + out = append(out, fmt.Sprintf("cert(%v issued by %v:%v)", cert.Subject.CommonName, cert.Issuer.CommonName, cert.Issuer.SerialNumber)) + } + for j := range i.TLSCACertsBytes { + cert, err := tlsca.ParseCertificatePEM(i.TLSCACertsBytes[j]) + if err != nil { + out = append(out, err.Error()) + } else { + out = append(out, fmt.Sprintf("trust root(%v:%v)", cert.Subject.CommonName, cert.Subject.SerialNumber)) + } + } + return fmt.Sprintf("Identity(%v, %v)", i.ID.Role, strings.Join(out, ",")) +} + // HasTSLConfig returns true if this identity has TLS certificate and private key func (i *Identity) HasTLSConfig() bool { - return len(i.TLSCACertBytes) != 0 && len(i.TLSCertBytes) != 0 && len(i.TLSCACertBytes) != 0 + return len(i.TLSCACertsBytes) != 0 && len(i.TLSCertBytes) != 0 } // HasPrincipals returns whether identity has principals @@ -719,13 +559,14 @@ func (i *Identity) TLSConfig() (*tls.Config, error) { if err != nil { return nil, trace.BadParameter("failed to parse private key: %v", err) } - certPool := x509.NewCertPool() - parsedCert, err := tlsca.ParseCertificatePEM(i.TLSCACertBytes) - if err != nil { - return nil, trace.Wrap(err, "failed to parse CA certificate") + for j := range i.TLSCACertsBytes { + parsedCert, err := tlsca.ParseCertificatePEM(i.TLSCACertsBytes[j]) + if err != nil { + return nil, trace.Wrap(err, "failed to parse CA certificate") + } + certPool.AddCert(parsedCert) } - certPool.AddCert(parsedCert) tlsConfig.Certificates = []tls.Certificate{tlsCert} tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool @@ -743,7 +584,7 @@ type IdentityID struct { func (id *IdentityID) HostID() (string, error) { parts := strings.Split(id.HostUUID, ".") if len(parts) < 2 { - return "", trace.BadParameter("expected 2 parts in %v", id.HostUUID) + return "", trace.BadParameter("expected 2 parts in %q", id.HostUUID) } return parts[0], nil } @@ -759,25 +600,25 @@ func (id *IdentityID) String() string { } // ReadIdentityFromKeyPair reads TLS identity from key pair -func ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes, tlsCACertBytes []byte) (*Identity, error) { +func ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes []byte, tlsCACertsBytes [][]byte) (*Identity, error) { identity, err := ReadSSHIdentityFromKeyPair(keyBytes, sshCertBytes) if err != nil { return nil, trace.Wrap(err) } if len(tlsCertBytes) != 0 { // just to verify that identity parses properly for future use - _, err := ReadTLSIdentityFromKeyPair(keyBytes, tlsCertBytes, tlsCACertBytes) + _, err := ReadTLSIdentityFromKeyPair(keyBytes, tlsCertBytes, tlsCACertsBytes) if err != nil { return nil, trace.Wrap(err) } identity.TLSCertBytes = tlsCertBytes - identity.TLSCACertBytes = tlsCACertBytes + identity.TLSCACertsBytes = tlsCACertsBytes } return identity, nil } // ReadTLSIdentityFromKeyPair reads TLS identity from key pair -func ReadTLSIdentityFromKeyPair(keyBytes, certBytes []byte, caCertBytes []byte) (*Identity, error) { +func ReadTLSIdentityFromKeyPair(keyBytes, certBytes []byte, caCertsBytes [][]byte) (*Identity, error) { if len(keyBytes) == 0 { return nil, trace.BadParameter("missing private key") } @@ -799,19 +640,18 @@ func ReadTLSIdentityFromKeyPair(keyBytes, certBytes []byte, caCertBytes []byte) if len(cert.Issuer.Organization) == 0 { return nil, trace.BadParameter("missing CA organization") } + clusterName := cert.Issuer.Organization[0] if clusterName == "" { return nil, trace.BadParameter("misssing cluster name") } - identity := &Identity{ - ID: IdentityID{HostUUID: id.Username, Role: teleport.Role(id.Groups[0])}, - ClusterName: clusterName, - KeyBytes: keyBytes, - TLSCertBytes: certBytes, - TLSCACertBytes: caCertBytes, + ID: IdentityID{HostUUID: id.Username, Role: teleport.Role(id.Groups[0])}, + ClusterName: clusterName, + KeyBytes: keyBytes, + TLSCertBytes: certBytes, + TLSCACertsBytes: caCertsBytes, } - _, err = identity.TLSConfig() if err != nil { return nil, trace.Wrap(err) @@ -893,11 +733,95 @@ func ReadSSHIdentityFromKeyPair(keyBytes, certBytes []byte) (*Identity, error) { }, nil } -// ReadIdentity reads, parses and returns the given pub/pri key + cert from the +// ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the // key storage (dataDir). -func ReadIdentity(dataDir string, id IdentityID) (i *Identity, err error) { +func ReadLocalIdentity(dataDir string, id IdentityID) (*Identity, error) { + storage, err := NewProcessStorage(dataDir) + if err != nil { + return nil, trace.Wrap(err) + } + defer storage.Close() + return storage.ReadIdentity(IdentityCurrent, id.Role) +} + +// DELETE IN(2.7.0) +// removeIdentityCompat removes identity from disk +func removeIdentityCompat(dataDir string, id IdentityID) error { + path := keysPath(dataDir, id) + for _, filePath := range []string{path.key, path.sshCert, path.tlsCert, path.tlsCACert} { + err := trace.ConvertSystemError(os.Remove(filePath)) + if err != nil { + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + } + } + return nil +} + +// DELETE IN: 2.6.0 +// NOTE: Sadly, our integration tests depend on this logic +// because they create remote cluster resource. Our integration +// tests should be migrated to use trusted clusters instead of manually +// creating tunnels. +// This migration adds remote cluster resource migrating from 2.5.0 +// where the presence of remote cluster was identified only by presence +// of host certificate authority with cluster name not equal local cluster name +func migrateRemoteClusters(asrv *AuthServer) error { + clusterName, err := asrv.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + certAuthorities, err := asrv.GetCertAuthorities(services.HostCA, false) + if err != nil { + return trace.Wrap(err) + } + // loop over all roles and make sure any v3 roles have permit port + // forward and forward agent allowed + for _, certAuthority := range certAuthorities { + if certAuthority.GetName() == clusterName.GetClusterName() { + log.Debugf("Migrations: skipping local cluster cert authority %q.", certAuthority.GetName()) + continue + } + // remote cluster already exists + _, err = asrv.GetRemoteCluster(certAuthority.GetName()) + if err == nil { + log.Debugf("Migrations: remote cluster already exists for cert authority %q.", certAuthority.GetName()) + continue + } + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + // the cert authority is associated with trusted cluster + _, err = asrv.GetTrustedCluster(certAuthority.GetName()) + if err == nil { + log.Debugf("Migrations: trusted cluster resource exists for cert authority %q.", certAuthority.GetName()) + continue + } + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + remoteCluster, err := services.NewRemoteCluster(certAuthority.GetName()) + if err != nil { + return trace.Wrap(err) + } + err = asrv.CreateRemoteCluster(remoteCluster) + if err != nil { + if !trace.IsAlreadyExists(err) { + return trace.Wrap(err) + } + } + log.Infof("Migrations: added remote cluster resource for cert authority %q.", certAuthority.GetName()) + } + + return nil +} + +// DELETE IN(2.7.0) +// readIdentityCompat reads, parses and returns the given pub/pri key + cert from the +// key storage (dataDir). Used for data migrations +func readIdentityCompat(dataDir string, id IdentityID) (i *Identity, err error) { path := keysPath(dataDir, id) - log.Debugf("Reading keys from disk: %v.", path) keyBytes, err := utils.ReadPath(path.key) if err != nil { @@ -924,7 +848,7 @@ func ReadIdentity(dataDir string, id IdentityID) (i *Identity, err error) { } } - identity, err := ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes, tlsCACertBytes) + identity, err := ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes, [][]byte{tlsCACertBytes}) if err != nil { return nil, trace.Wrap(err) } @@ -935,34 +859,7 @@ func ReadIdentity(dataDir string, id IdentityID) (i *Identity, err error) { return identity, nil } -// WriteIdentity writes identity keypair to disk -func WriteIdentity(dataDir string, identity *Identity) error { - return trace.Wrap( - writeKeys(dataDir, identity.ID, identity.KeyBytes, identity.CertBytes, identity.TLSCertBytes, identity.TLSCACertBytes)) -} - -// HaveHostKeys checks that host keys are in place -func HaveHostKeys(dataDir string, id IdentityID) (bool, error) { - path := keysPath(dataDir, id) - - exists, err := pathExists(path.key) - if !exists || err != nil { - return exists, err - } - - exists, err = pathExists(path.sshCert) - if !exists || err != nil { - return exists, err - } - - exists, err = pathExists(path.tlsCert) - if !exists || err != nil { - return exists, err - } - - return true, nil -} - +// DELETE IN(2.7.0) type paths struct { dataDir string key string @@ -971,6 +868,7 @@ type paths struct { tlsCACert string } +// DELETE IN(2.7.0) // keysPath returns two full file paths: to the host.key and host.cert func keysPath(dataDir string, id IdentityID) paths { return paths{ @@ -980,14 +878,3 @@ func keysPath(dataDir string, id IdentityID) paths { tlsCACert: filepath.Join(dataDir, fmt.Sprintf("%s.tlscacert", strings.ToLower(string(id.Role)))), } } - -func pathExists(path string) (bool, error) { - _, err := os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - return false, nil - } - return false, err - } - return true, nil -} diff --git a/lib/auth/init_test.go b/lib/auth/init_test.go index 3f7d5cb291af7..808037985b4df 100644 --- a/lib/auth/init_test.go +++ b/lib/auth/init_test.go @@ -28,7 +28,6 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/boltbk" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -199,7 +198,7 @@ func (s *AuthInitSuite) TestAuthPreference(c *C) { StaticTokens: staticTokens, AuthPreference: ap, } - as, _, err := Init(ac) + as, err := Init(ac) c.Assert(err, IsNil) cap, err := as.GetAuthPreference() @@ -221,7 +220,7 @@ func (s *AuthInitSuite) TestClusterID(c *C) { }) c.Assert(err, IsNil) - authServer, _, err := Init(InitConfig{ + authServer, err := Init(InitConfig{ DataDir: c.MkDir(), HostUUID: "00000000-0000-0000-0000-000000000000", NodeName: "foo", @@ -238,7 +237,7 @@ func (s *AuthInitSuite) TestClusterID(c *C) { c.Assert(clusterID, Not(Equals), "") // do it again and make sure cluster ID hasn't changed - authServer, _, err = Init(InitConfig{ + authServer, err = Init(InitConfig{ DataDir: c.MkDir(), HostUUID: "00000000-0000-0000-0000-000000000000", NodeName: "foo", @@ -253,54 +252,3 @@ func (s *AuthInitSuite) TestClusterID(c *C) { c.Assert(err, IsNil) c.Assert(cc.GetClusterID(), Equals, clusterID) } - -// DELETE IN: 2.6.0 -// Migration of cert_format will be done in Teleport 2.5.0, so this test can -// be removed in Teleport 2.6.0. -func (s *AuthInitSuite) TestOptions(c *C) { - bk, err := boltbk.New(backend.Params{"path": c.MkDir()}) - c.Assert(err, IsNil) - - clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ - ClusterName: "me.localhost", - }) - c.Assert(err, IsNil) - - authServer, _, err := Init(InitConfig{ - DataDir: c.MkDir(), - HostUUID: "00000000-0000-0000-0000-000000000000", - NodeName: "foo", - Backend: bk, - Authority: testauthority.New(), - ClusterConfig: services.DefaultClusterConfig(), - ClusterName: clusterName, - }) - c.Assert(err, IsNil) - - // upsert role with no values for certificate format - role := services.NewAdminRole() - role.SetOptions(services.RoleOptions{ - services.MaxSessionTTL: services.NewDuration(defaults.MaxCertDuration), - }) - err = authServer.UpsertRole(role, backend.Forever) - c.Assert(err, IsNil) - - // do it again and make sure the options have been populated - authServer, _, err = Init(InitConfig{ - DataDir: c.MkDir(), - HostUUID: "00000000-0000-0000-0000-000000000000", - NodeName: "foo", - Backend: bk, - Authority: testauthority.New(), - ClusterConfig: services.DefaultClusterConfig(), - ClusterName: clusterName, - }) - c.Assert(err, IsNil) - - role, err = authServer.GetRole(teleport.AdminRoleName) - c.Assert(err, IsNil) - - certificateFormat, err := role.GetOptions().GetString(services.CertificateFormat) - c.Assert(err, IsNil) - c.Assert(certificateFormat, Equals, teleport.CertificateFormatStandard) -} diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 7f9f3d3373e4e..519978c7c4de0 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -17,8 +17,6 @@ limitations under the License. package auth import ( - "bytes" - "crypto/rsa" "time" "golang.org/x/crypto/ssh" @@ -26,7 +24,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -307,142 +304,3 @@ func (s *AuthServer) AuthenticateSSHUser(req AuthenticateSSHRequest) (*SSHLoginR HostSigners: AuthoritiesToTrustedCerts(hostCertAuthorities), }, nil } - -// DELETE IN: 2.6.0 -// This method is used only for upgrades from 2.4.0 to 2.5.0 -// ExchangeCertsRequest is a request to exchange TLS certificates -// for clusters that already trust each other -type ExchangeCertsRequest struct { - // PublicKey is public key of the trusted certificate authority - PublicKey []byte `json:"public_key"` - // TLSCert is TLS certificate associated with the public key - TLSCert []byte `json:"tls_cert"` -} - -// CheckAndSetDefaults checks and sets default values -func (req *ExchangeCertsRequest) CheckAndSetDefaults() error { - if len(req.PublicKey) == 0 { - return trace.BadParameter("missing parameter 'public_key'") - } - if len(req.TLSCert) == 0 { - return trace.BadParameter("missing parameter 'tls_cert'") - } - return nil -} - -// DELETE IN: 2.6.0 -// ExchangeCertsResponse is a resposne to exchange certificates request -type ExchangeCertsResponse struct { - // TLSCert is a PEM encoded certificate of a local certificate authority - TLSCert []byte `json:"tls_cert"` -} - -// DELETE IN: 2.6.0 -// This method is used to ugprade from 2.4.0 to 2.5.0 -// ExchangeCerts is a method to exchange TLS certificates between certificate authorities -// of the trusted clusters. A remote auth server that wishes to exchange TLS certs with a local auth server -// sends a request that consists of a public key already trusted by the local server and -// TLS certificate for the public key. The local server ensures that the TLS certificate -// was issued to the public key that is already trusted preventing random certificates -// to be injected by the remote server. This is a minor security enforcement. -func (s *AuthServer) ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) { - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - remoteCA, err := s.findCertAuthorityByPublicKey(req.PublicKey) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := CheckPublicKeysEqual(req.PublicKey, req.TLSCert); err != nil { - return nil, trace.Wrap(err) - } - - // make sure that cluster name in TLS cert is not the same as cluster name - cert, err := tlsca.ParseCertificatePEM(req.TLSCert) - if err != nil { - return nil, trace.Wrap(err) - } - remoteClusterName, err := tlsca.ClusterName(cert.Subject) - if err != nil { - return nil, trace.Wrap(err) - } - clusterName, err := s.GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } - - if remoteClusterName == clusterName.GetName() { - return nil, trace.BadParameter("remote cluster name can not be the same as local cluster name") - } - - remoteCA.SetTLSKeyPairs([]services.TLSKeyPair{ - { - Cert: req.TLSCert, - }, - }) - - err = s.UpsertCertAuthority(remoteCA) - if err != nil { - return nil, trace.Wrap(err) - } - - thisHostCA, err := s.GetCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: clusterName.GetClusterName()}, false) - if err != nil { - return nil, trace.Wrap(err) - } - - return &ExchangeCertsResponse{ - TLSCert: thisHostCA.GetTLSKeyPairs()[0].Cert, - }, nil - -} - -func (s *AuthServer) findCertAuthorityByPublicKey(publicKey []byte) (services.CertAuthority, error) { - authorities, err := s.GetCertAuthorities(services.HostCA, false) - if err != nil { - return nil, trace.Wrap(err) - } - for _, ca := range authorities { - for _, key := range ca.GetCheckingKeys() { - if bytes.Equal(key, publicKey) { - return ca, nil - } - } - } - return nil, trace.NotFound("certificate authority with public key is not found") -} - -// CheckPublicKeysEqual compares RSA based SSH certificate with the -// TLS certificate, returns nil if both certificates are using the same public -// key and refer to the same cluster name, error otherwise -func CheckPublicKeysEqual(sshKeyBytes []byte, certBytes []byte) error { - cert, err := tlsca.ParseCertificatePEM(certBytes) - if err != nil { - return trace.Wrap(err) - } - certPublicKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return trace.BadParameter("expected RSA public key, got %T", cert.PublicKey) - } - publicKey, _, _, _, err := ssh.ParseAuthorizedKey(sshKeyBytes) - if err != nil { - return trace.Wrap(err) - } - cryptoPubKey, ok := publicKey.(ssh.CryptoPublicKey) - if !ok { - return trace.BadParameter("unexpected key type: %T", publicKey) - } - rsaPublicKey, ok := cryptoPubKey.CryptoPublicKey().(*rsa.PublicKey) - if !ok { - return trace.BadParameter("unexpected key type: %T", publicKey) - } - if certPublicKey.E != rsaPublicKey.E { - return trace.CompareFailed("different public keys") - } - if certPublicKey.N.Cmp(rsaPublicKey.N) != 0 { - return trace.CompareFailed("different public keys") - } - return nil -} diff --git a/lib/auth/methods_test.go b/lib/auth/methods_test.go deleted file mode 100644 index 63b7b57fb8e42..0000000000000 --- a/lib/auth/methods_test.go +++ /dev/null @@ -1,52 +0,0 @@ -/* -Copyright 2017 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package auth - -import ( - "github.com/gravitational/teleport/lib/fixtures" - "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/services/suite" - "github.com/gravitational/teleport/lib/utils" - - check "gopkg.in/check.v1" -) - -type MethodsSuite struct{} - -var _ = check.Suite(&MethodsSuite{}) - -func (s *MethodsSuite) SetUpSuite(c *check.C) { - utils.InitLoggerForTests() -} - -// TestCheckPublicKeys tests checking of SSH and TLS certificates -func (s *MethodsSuite) TestCheckPublicKeys(c *check.C) { - - // same public keys match - ca1 := suite.NewTestCA(services.HostCA, "localhost") - err := CheckPublicKeysEqual(ca1.GetCheckingKeys()[0], ca1.GetTLSKeyPairs()[0].Cert) - c.Assert(err, check.IsNil) - - ca2 := suite.NewTestCA(services.HostCA, "other") - err = CheckPublicKeysEqual(ca1.GetCheckingKeys()[0], ca2.GetTLSKeyPairs()[0].Cert) - c.Assert(err, check.IsNil) - - // different public keys don't match - ca3 := suite.NewTestCA(services.HostCA, "localhost", fixtures.PEMBytes["rsa2"]) - err = CheckPublicKeysEqual(ca1.GetCheckingKeys()[0], ca3.GetTLSKeyPairs()[0].Cert) - fixtures.ExpectCompareFailed(c, err) -} diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 564f0576ccf00..e343b2cf9fc3a 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -39,6 +39,8 @@ type TLSServerConfig struct { LimiterConfig limiter.LimiterConfig // AccessPoint is caching access point AccessPoint AccessPoint + // Component is used for debugging purposes + Component string } // CheckAndSetDefaults checks and sets default values @@ -89,6 +91,7 @@ func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) { limiter.WrapHandle(authMiddleware) // force client auth if given cfg.TLS.ClientAuth = tls.VerifyClientCertIfGiven + server := &TLSServer{ TLSServerConfig: cfg, Server: &http.Server{ @@ -146,6 +149,10 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { // https://github.com/kubernetes/kubernetes/pull/34524/files#diff-2b283dde198c92424df5355f39544aa4R59 return nil, trace.AccessDenied("access denied: intermediaries are not supported") } + localClusterName, err := a.AccessPoint.GetDomainName() + if err != nil { + return nil, trace.Wrap(err) + } // with no client authentication in place, middleware // assumes not-privileged Nop role. // it theoretically possible to use bearer token auth even @@ -156,6 +163,7 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { GetClusterConfig: a.AccessPoint.GetClusterConfig, Role: teleport.RoleNop, Username: string(teleport.RoleNop), + ClusterName: localClusterName, }, nil } clientCert := peers[0] @@ -164,10 +172,7 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { log.Warning("Failed to parse client certificate %v.", err) return nil, trace.AccessDenied("access denied: invalid client certificate") } - localClusterName, err := a.AccessPoint.GetDomainName() - if err != nil { - return nil, trace.Wrap(err) - } + identity, err := tlsca.FromSubject(clientCert.Subject) if err != nil { return nil, trace.Wrap(err) @@ -209,6 +214,7 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { GetClusterConfig: a.AccessPoint.GetClusterConfig, Role: *systemRole, Username: identity.Username, + ClusterName: localClusterName, }, nil } // otherwise assume that is a local role, no need to pass the roles diff --git a/lib/auth/password_test.go b/lib/auth/password_test.go index 4e48124079f56..fb19db7a49638 100644 --- a/lib/auth/password_test.go +++ b/lib/auth/password_test.go @@ -63,9 +63,10 @@ func (s *PasswordSuite) SetUpTest(c *C) { }) c.Assert(err, IsNil) authConfig := &InitConfig{ - ClusterName: clusterName, - Backend: s.bk, - Authority: authority.New(), + ClusterName: clusterName, + Backend: s.bk, + Authority: authority.New(), + SkipPeriodicOperations: true, } s.a, err = NewAuthServer(authConfig) c.Assert(err, IsNil) @@ -119,7 +120,7 @@ func (s *PasswordSuite) TestChangePassword(c *C) { c.Assert(err, IsNil) fakeClock := clockwork.NewFakeClock() - s.a.clock = fakeClock + s.a.SetClock(fakeClock) req.NewPassword = []byte("abce456") err = s.a.ChangePassword(req) @@ -144,9 +145,9 @@ func (s *PasswordSuite) TestChangePasswordWithOTP(c *C) { c.Assert(err, IsNil) fakeClock := clockwork.NewFakeClock() - s.a.clock = fakeClock + s.a.SetClock(fakeClock) - validToken, err := totp.GenerateCode(otpSecret, s.a.clock.Now()) + validToken, err := totp.GenerateCode(otpSecret, s.a.GetClock().Now()) c.Assert(err, IsNil) // change password @@ -160,7 +161,7 @@ func (s *PasswordSuite) TestChangePasswordWithOTP(c *C) { // advance time and make sure we can login again fakeClock.Advance(defaults.AccountLockInterval + time.Second) - validToken, _ = totp.GenerateCode(otpSecret, s.a.clock.Now()) + validToken, _ = totp.GenerateCode(otpSecret, s.a.GetClock().Now()) req.OldPassword = req.NewPassword req.NewPassword = []byte("abc5555") req.SecondFactorToken = validToken diff --git a/lib/auth/permissions.go b/lib/auth/permissions.go index 86e076eaacbcc..dd17abcaa6e10 100644 --- a/lib/auth/permissions.go +++ b/lib/auth/permissions.go @@ -24,11 +24,12 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/trace" + "github.com/vulcand/predicate/builder" ) // NewAdminContext returns new admin auth context func NewAdminContext() (*AuthContext, error) { - authContext, err := contextForBuiltinRole(nil, teleport.RoleAdmin, fmt.Sprintf("%v", teleport.RoleAdmin)) + authContext, err := contextForBuiltinRole("", nil, teleport.RoleAdmin, fmt.Sprintf("%v", teleport.RoleAdmin)) if err != nil { return nil, trace.Wrap(err) } @@ -36,8 +37,8 @@ func NewAdminContext() (*AuthContext, error) { } // NewRoleAuthorizer authorizes everyone as predefined role, used in tests -func NewRoleAuthorizer(clusterConfig services.ClusterConfig, r teleport.Role) (Authorizer, error) { - authContext, err := contextForBuiltinRole(clusterConfig, r, fmt.Sprintf("%v", r)) +func NewRoleAuthorizer(clusterName string, clusterConfig services.ClusterConfig, r teleport.Role) (Authorizer, error) { + authContext, err := contextForBuiltinRole(clusterName, clusterConfig, r, fmt.Sprintf("%v", r)) if err != nil { return nil, trace.Wrap(err) } @@ -159,15 +160,13 @@ func (a *authorizer) authorizeBuiltinRole(r BuiltinRole) (*AuthContext, error) { if err != nil { return nil, trace.Wrap(err) } - return contextForBuiltinRole(config, r.Role, r.Username) + return contextForBuiltinRole(r.ClusterName, config, r.Role, r.Username) } func (a *authorizer) authorizeRemoteBuiltinRole(r RemoteBuiltinRole) (*AuthContext, error) { if r.Role != teleport.RoleProxy { return nil, trace.AccessDenied("access denied for remote %v connecting to cluster", r.Role) } - // TODO(klizhentas): allow remote proxy to update the cluster's certificate authorities - // during certificates renewal roles, err := services.FromSpec( string(teleport.RoleRemoteProxy), services.RoleSpecV3{ @@ -183,6 +182,17 @@ func (a *authorizer) authorizeRemoteBuiltinRole(r RemoteBuiltinRole) (*AuthConte services.NewRule(services.KindReverseTunnel, services.RO()), services.NewRule(services.KindTunnelConnection, services.RO()), services.NewRule(services.KindClusterConfig, services.RO()), + // this rule allows remote proxy to update the cluster's certificate authorities + // during certificates renewal + { + Resources: []string{services.KindCertAuthority}, + // It is important that remote proxy can only rotate + // existing certificate authority, and not create or update new ones + Verbs: []string{services.VerbRead, services.VerbRotate}, + // allow administrative access to the certificate authority names + // matching the cluster name only + Where: builder.Equals(services.ResourceNameExpr, builder.String(r.ClusterName)).String(), + }, }, }, }) @@ -201,7 +211,7 @@ func (a *authorizer) authorizeRemoteBuiltinRole(r RemoteBuiltinRole) (*AuthConte } // GetCheckerForBuiltinRole returns checkers for embedded builtin role -func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role teleport.Role) (services.AccessChecker, error) { +func GetCheckerForBuiltinRole(clusterName string, clusterConfig services.ClusterConfig, role teleport.Role) (services.AccessChecker, error) { switch role { case teleport.RoleAuth: return services.FromSpec( @@ -272,6 +282,23 @@ func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role telepor services.NewRule(services.KindTunnelConnection, services.RW()), services.NewRule(services.KindHostCert, services.RW()), services.NewRule(services.KindRemoteCluster, services.RO()), + // this rule allows local proxy to update the remote cluster's host certificate authorities + // during certificates renewal + { + Resources: []string{services.KindCertAuthority}, + Verbs: []string{services.VerbCreate, services.VerbRead, services.VerbUpdate}, + // allow administrative access to the host certificate authorities + // matching any cluster name except local + Where: builder.And( + builder.Equals(services.CertAuthorityTypeExpr, builder.String(string(services.HostCA))), + builder.Not( + builder.Equals( + services.ResourceNameExpr, + builder.String(clusterName), + ), + ), + ).String(), + }, }, }, }) @@ -305,6 +332,23 @@ func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role telepor services.NewRule(services.KindStaticTokens, services.RO()), services.NewRule(services.KindTunnelConnection, services.RW()), services.NewRule(services.KindRemoteCluster, services.RO()), + // this rule allows local proxy to update the remote cluster's host certificate authorities + // during certificates renewal + { + Resources: []string{services.KindCertAuthority}, + Verbs: []string{services.VerbCreate, services.VerbRead, services.VerbUpdate}, + // allow administrative access to the certificate authority names + // matching any cluster name except local + Where: builder.And( + builder.Equals(services.CertAuthorityTypeExpr, builder.String(string(services.HostCA))), + builder.Not( + builder.Equals( + services.ResourceNameExpr, + builder.String(clusterName), + ), + ), + ).String(), + }, }, }, }) @@ -367,8 +411,8 @@ func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role telepor return nil, trace.NotFound("%v is not reconginzed", role.String()) } -func contextForBuiltinRole(clusterConfig services.ClusterConfig, r teleport.Role, username string) (*AuthContext, error) { - checker, err := GetCheckerForBuiltinRole(clusterConfig, r) +func contextForBuiltinRole(clusterName string, clusterConfig services.ClusterConfig, r teleport.Role, username string) (*AuthContext, error) { + checker, err := GetCheckerForBuiltinRole(clusterName, clusterConfig, r) if err != nil { return nil, trace.Wrap(err) } @@ -417,6 +461,9 @@ type BuiltinRole struct { // Username is for authentication tracking purposes Username string + + // ClusterName is the name of the local cluster + ClusterName string } // RemoteBuiltinRole is the role of the remote (service connecting via trusted cluster link) @@ -428,17 +475,17 @@ type RemoteBuiltinRole struct { // Username is for authentication tracking purposes Username string - // ClusterName is the name of the remote cluster + // ClusterName is the name of the remote cluster. ClusterName string } -// RemoteUser defines encoded remote user +// RemoteUser defines encoded remote user. type RemoteUser struct { // Username is a name of the remote user Username string `json:"username"` - // ClusterName is a name of the remote cluster - // of the user + // ClusterName is the name of the remote cluster + // of the user. ClusterName string `json:"cluster_name"` // RemoteRoles is optional list of remote roles diff --git a/lib/auth/register.go b/lib/auth/register.go index 811f81e6def54..4172a152b7316 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -33,7 +33,7 @@ import ( // LocalRegister is used to generate host keys when a node or proxy is running within the same process // as the auth server. This method does not need to use provisioning tokens. -func LocalRegister(dataDir string, id IdentityID, authServer *AuthServer, additionalPrincipals []string) error { +func LocalRegister(id IdentityID, authServer *AuthServer, additionalPrincipals []string) (*Identity, error) { keys, err := authServer.GenerateServerKeys(GenerateServerKeysRequest{ HostID: id.HostUUID, NodeName: id.NodeName, @@ -41,27 +41,26 @@ func LocalRegister(dataDir string, id IdentityID, authServer *AuthServer, additi AdditionalPrincipals: additionalPrincipals, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return writeKeys(dataDir, id, keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } // Register is used to generate host keys when a node or proxy are running on different hosts // than the auth server. This method requires provisioning tokens to prove a valid auth server // was used to issue the joining request. -func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, additionalPrincipals []string) error { +func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, additionalPrincipals []string) (*Identity, error) { tok, err := readToken(token) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } tlsConfig := utils.TLSConfig() certPath := filepath.Join(dataDir, defaults.CACertFile) certBytes, err := utils.ReadPath(certPath) if err != nil { - // DELETE IN: 2.6.0 // Only support secure cluster joins in the next releases if !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } message := fmt.Sprintf(`Your configuration is insecure! Registering without TLS certificate authority, to fix this warning add ca.cert to %v, you can get ca.cert using 'tctl auth export --type=tls > ca.cert'`, dataDir) log.Warning(message) @@ -69,7 +68,7 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, add } else { cert, err := tlsca.ParseCertificatePEM(certBytes) if err != nil { - return trace.Wrap(err, "failed to parse certificate at %v", certPath) + return nil, trace.Wrap(err, "failed to parse certificate at %v", certPath) } log.Infof("Joining remote cluster %v.", cert.Subject.CommonName) certPool := x509.NewCertPool() @@ -78,7 +77,7 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, add } client, err := NewTLSClient(servers, tlsConfig) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer client.Close() @@ -91,18 +90,17 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, add AdditionalPrincipals: additionalPrincipals, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return writeKeys(dataDir, id, keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } -// ReRegister renews the certificates and private keys based on the existing -// identity ID -func ReRegister(dataDir string, clt ClientI, id IdentityID, additionalPrincipals []string) error { +// ReRegister renews the certificates and private keys based on the client's existing identity. +func ReRegister(clt ClientI, id IdentityID, additionalPrincipals []string) (*Identity, error) { hostID, err := id.HostID() if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } keys, err := clt.GenerateServerKeys(GenerateServerKeysRequest{ HostID: hostID, @@ -111,9 +109,9 @@ func ReRegister(dataDir string, clt ClientI, id IdentityID, additionalPrincipals AdditionalPrincipals: additionalPrincipals, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return writeKeys(dataDir, id, keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } func readToken(token string) (string, error) { diff --git a/lib/auth/rotate.go b/lib/auth/rotate.go new file mode 100644 index 0000000000000..e3d21b1c55011 --- /dev/null +++ b/lib/auth/rotate.go @@ -0,0 +1,575 @@ +/* +Copyright 2018 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "crypto/rsa" + "crypto/x509/pkix" + "time" + + "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/pborman/uuid" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +// RotateRequest is a request to start rotation of the certificate authority. +type RotateRequest struct { + // Type is a certificate authority type, if omitted, both user and host CA + // will be rotated. + Type services.CertAuthType `json:"type"` + // GracePeriod is used to generate cert rotation schedule that defines + // times at which different rotation phases will be applied by the auth server + // in auto mode. It is not used in manual rotation mode. + // If omitted, default value is set, if 0 is supplied, it is interpreted as + // forcing rotation of all certificate authorities with no grace period, + // all existing users and hosts will have to re-login and re-added + // into the cluster. + GracePeriod *time.Duration `json:"grace_period,omitempty"` + // TargetPhase sets desired rotation phase to move to, if not set + // will be set automatically, it is a required argument + // for manual rotation. + TargetPhase string `json:"target_phase,omitempty"` + // Mode sets manual or auto rotation mode. + Mode string `json:"mode"` + // Schedule is an optional rotation schedule, + // autogenerated based on GracePeriod parameter if not set. + Schedule *services.RotationSchedule `json:"schedule"` +} + +// Types returns cert authority types requested to be rotated. +func (r *RotateRequest) Types() []services.CertAuthType { + switch r.Type { + case "": + return []services.CertAuthType{services.HostCA, services.UserCA} + case services.HostCA: + return []services.CertAuthType{services.HostCA} + case services.UserCA: + return []services.CertAuthType{services.UserCA} + } + return nil +} + +// CheckAndSetDefaults checks and sets default values. +func (r *RotateRequest) CheckAndSetDefaults(clock clockwork.Clock) error { + if r.TargetPhase == "" { + // if phase if not set, imply that the first meaningful phase + // is set as a target phase + r.TargetPhase = services.RotationPhaseUpdateClients + } + // if mode is not set, default to manual (as it's safer) + if r.Mode == "" { + r.Mode = services.RotationModeManual + } + switch r.Type { + case "", services.HostCA, services.UserCA: + default: + return trace.BadParameter("unsupported certificate authority type: %q", r.Type) + } + if r.GracePeriod == nil { + period := defaults.RotationGracePeriod + r.GracePeriod = &period + } + if r.Schedule == nil { + var err error + r.Schedule, err = services.GenerateSchedule(clock, *r.GracePeriod) + if err != nil { + return trace.Wrap(err) + } + } else { + if err := r.Schedule.CheckAndSetDefaults(clock); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +// rotationReq is an internal rotation requrest +type rotationReq struct { + // clock implements test or real wall clock + clock clockwork.Clock + // ca is a certificate authority to rotate + ca services.CertAuthority + // targetPhase is a target rotation phase to set + targetPhase string + // mode is a rotation mode + mode string + // gracePeriod is a rotation grace period + gracePeriod time.Duration + // schedule is a schedule to set + schedule services.RotationSchedule + // privateKey is passed by tests to supply private key for cert authorities + // instead of generating them on each iteration + privateKey []byte +} + +// RotateCertAuthority starts or restarts certificate authority rotation process. +// +// Rotation procedure is based on the state machine approach. +// +// Here are the supported rotation states: +// +// * Standby - the cluster is in standby mode and ready to take action. +// * In-progress - cluster CA rotation is in progress. +// +// In-progress state is split into multiple phases and the cluster +// can traverse between phases using supported transitions. +// +// Here are the supported phases: +// +// * Standby - no action is taken. +// +// * Update Clients - new CA is issued, all internal system clients +// have to reconnect and receive the new credentials, but all servers +// TLS, SSH and Proxies will still use old credentials. +// Certs from old CA and new CA are trusted within the system. +// This phase is necessary because old clients should receive new credentials +// from the auth servers. If this phase did not exist, old clients could not +// trust servers serving new credentials, because old clients did not receive +// new information yet. It is possible to transition from this phase to phase +// "Update servers" or "Rollback". +// +// * Update Servers - triggers all internal system components to reload and use +// new credentials both in the internal clients and servers, however +// old CA issued credentials are still trusted. This is done to make it possible +// for old components to be trusted within the system, to make rollback possible. +// It is possible to transition from this phase to "Rollback" or "Standby". +// When transitioning to "Standby" phase, the rotation is considered completed, +// old CA is removed from the system and components reload again, +// but this time they don't trust old CA any more. +// +// * Rollback phase is used to revert any changes. When going to rollback phase +// the newly issued CA is no longer used, but set up as trusted, +// so components can reload and receive credentials issued by "old" CA back. +// This phase is useful when administrator makes a mistake, or there are some +// offline components that will loose the connection in case if rotation +// completes. It is only possible to transition from this phase to "Standby". +// When transitioning to "Standby" phase from "Rollback" phase, all components +// reload again, but the "new" CA is discarded and is no longer trusted, +// cluster goes back to the original state. +// +// Rotation modes +// +// There are two rotation modes supported - manual or automatic. +// +// * Manual mode allows administrators to transition between +// phases explicitly setting a phase on every request. +// +// * Automatic mode performs automatic transition between phases +// on a given schedule. Schedule is a time table +// that specifies exact date when the next phase should take place. If automatic +// transition between any phase fails, the rotation switches back to the manual +// mode and stops execution phases on the schedule. If schedule is not specified, +// it will be auto generated based on the "grace period" duration parameter, +// and time between all phases will be evenly split over the grace period duration. +// +// It is possible to switch from automatic to manual by setting the phase +// to the rollback phase. +// +func (a *AuthServer) RotateCertAuthority(req RotateRequest) error { + if err := req.CheckAndSetDefaults(a.clock); err != nil { + return trace.Wrap(err) + } + clusterName := a.clusterName.GetClusterName() + + caTypes := req.Types() + for _, caType := range caTypes { + existing, err := a.GetCertAuthority(services.CertAuthID{ + Type: caType, + DomainName: clusterName, + }, true) + if err != nil { + return trace.Wrap(err) + } + rotated, err := processRotationRequest(rotationReq{ + ca: existing, + clock: a.clock, + targetPhase: req.TargetPhase, + schedule: *req.Schedule, + gracePeriod: *req.GracePeriod, + mode: req.Mode, + privateKey: a.privateKey, + }) + if err != nil { + return trace.Wrap(err) + } + if err := a.CompareAndSwapCertAuthority(rotated, existing); err != nil { + return trace.Wrap(err) + } + rotation := rotated.GetRotation() + switch rotation.State { + case services.RotationStateInProgress: + log.WithFields(logrus.Fields{"type": caType}).Infof("Updated rotation state, set current phase to: %q.", rotation.Phase) + case services.RotationStateStandby: + log.WithFields(logrus.Fields{"type": caType}).Infof("Updated and completed rotation.") + } + } + return nil +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is called by remote trusted cluster and is used to update +// only public keys and certificates of the certificate authority. +func (a *AuthServer) RotateExternalCertAuthority(ca services.CertAuthority) error { + if ca == nil { + return trace.BadParameter("missing certificate authority") + } + // this is just an extra precaution against local admins, + // because this is additionally enforced by RBAC as well + if ca.GetClusterName() == a.clusterName.GetClusterName() { + return trace.BadParameter("can not rotate local certificate authority") + } + + existing, err := a.GetCertAuthority(services.CertAuthID{ + Type: ca.GetType(), + DomainName: ca.GetClusterName(), + }, false) + if err != nil { + return trace.Wrap(err) + } + + updated := existing.Clone() + updated.SetCheckingKeys(ca.GetCheckingKeys()) + updated.SetTLSKeyPairs(ca.GetTLSKeyPairs()) + updated.SetRotation(ca.GetRotation()) + + // use compare and swap to protect from concurrent updates + // by trusted cluster API + if err := a.CompareAndSwapCertAuthority(updated, existing); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// autoRotateCertAuthorities automatically rotates cert authorities, +// does nothing if no rotation parameters were set up +// or it is too early to rotate per schedule +func (a *AuthServer) autoRotateCertAuthorities() error { + clusterName := a.clusterName.GetClusterName() + for _, caType := range []services.CertAuthType{services.HostCA, services.UserCA} { + ca, err := a.GetCertAuthority(services.CertAuthID{ + Type: caType, + DomainName: clusterName, + }, true) + if err != nil { + return trace.Wrap(err) + } + if err := a.autoRotate(ca); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +func (a *AuthServer) autoRotate(ca services.CertAuthority) error { + rotation := ca.GetRotation() + // rotation mode is not automatic, nothing to do + if rotation.Mode != services.RotationModeAuto { + return nil + } + // rotation is not in progress, there is nothing to do + if rotation.State != services.RotationStateInProgress { + return nil + } + logger := log.WithFields(logrus.Fields{"type": ca.GetType()}) + var req *rotationReq + switch rotation.Phase { + case services.RotationPhaseUpdateClients: + if rotation.Schedule.UpdateServers.After(a.clock.Now()) { + return nil + } + req = &rotationReq{ + clock: a.clock, + ca: ca, + targetPhase: services.RotationPhaseUpdateServers, + mode: services.RotationModeAuto, + gracePeriod: rotation.GracePeriod.Duration, + schedule: rotation.Schedule, + } + case services.RotationPhaseUpdateServers: + if rotation.Schedule.Standby.After(a.clock.Now()) { + return nil + } + req = &rotationReq{ + clock: a.clock, + ca: ca, + targetPhase: services.RotationPhaseStandby, + mode: services.RotationModeAuto, + gracePeriod: rotation.GracePeriod.Duration, + schedule: rotation.Schedule, + } + default: + return trace.BadParameter("phase is not supported: %q", rotation.Phase) + } + logger.Infof("Setting rotation phase %q", req.targetPhase) + rotated, err := processRotationRequest(*req) + if err != nil { + return trace.Wrap(err) + } + if err := a.CompareAndSwapCertAuthority(rotated, ca); err != nil { + return trace.Wrap(err) + } + logger.Infof("Cert authority rotation request is completed") + return nil +} + +// processRotationRequest processes rotation request based on the target and +// current phase and state. +func processRotationRequest(req rotationReq) (services.CertAuthority, error) { + rotation := req.ca.GetRotation() + ca := req.ca.Clone() + + switch req.targetPhase { + // This is the first stage of the rotation - new certificate authorities + // are being generated, clients will start using new credentials + // and servers will use the existing credentials, but will trust clients + // with both old and new credentials. + case services.RotationPhaseUpdateClients: + switch rotation.State { + case services.RotationStateStandby, "": + default: + return nil, trace.BadParameter("can not initate rotation while another is in progress") + } + if err := startNewRotation(req, ca); err != nil { + return nil, trace.Wrap(err) + } + return ca, nil + // Update server phase uses the new credentials both for servers + // and clients, but still trusts clients with old credentials. + case services.RotationPhaseUpdateServers: + if rotation.Phase != services.RotationPhaseUpdateClients { + return nil, trace.BadParameter( + "can only switch to phase %v from %v, current phase is %v", + services.RotationPhaseUpdateServers, + services.RotationPhaseUpdateClients, + rotation.Phase) + } + // Signal nodes to restart and start serving new signatures + // by updating the phase. + rotation.Phase = req.targetPhase + rotation.Mode = req.mode + ca.SetRotation(rotation) + return ca, nil + // Rollback moves back both clients and servers to use the old credentials + // but will trust new credentials. + case services.RotationPhaseRollback: + switch rotation.Phase { + case services.RotationPhaseUpdateClients, services.RotationPhaseUpdateServers: + if err := startRollingBackRotation(ca); err != nil { + return nil, trace.Wrap(err) + } + return ca, nil + default: + return nil, trace.BadParameter("can not transition to phase %q from %q phase.", req.targetPhase, rotation.Phase) + } + // Transition to the standby phase moves rotation process + // to standby, servers will only trust one certificate authority. + case services.RotationPhaseStandby: + switch rotation.Phase { + case services.RotationPhaseUpdateServers, services.RotationPhaseRollback: + if err := completeRotation(req.clock, ca); err != nil { + return nil, trace.Wrap(err) + } + return ca, nil + default: + return nil, trace.BadParameter( + "can only switch to phase %v from %v, current phase is %v", + services.RotationPhaseUpdateServers, + services.RotationPhaseUpdateClients, + rotation.Phase) + } + default: + return nil, trace.BadParameter("unsupported phase: %q", req.targetPhase) + } +} + +// startNewRotation starts new rotation and updates the certificate +// authority with new CA keys. +func startNewRotation(req rotationReq, ca services.CertAuthority) error { + clock := req.clock + gracePeriod := req.gracePeriod + + rotation := ca.GetRotation() + id := uuid.New() + + rotation.Mode = req.mode + rotation.Schedule = req.schedule + + var sshPrivPEM, sshPubPEM []byte + var keyPEM, certPEM []byte + + // generate keys and certificates: + if len(req.privateKey) != 0 { + log.Infof("Generating CA, using pregenerated test private key.") + rsaKey, err := ssh.ParseRawPrivateKey(req.privateKey) + if err != nil { + return trace.Wrap(err) + } + + signer, err := ssh.NewSignerFromKey(rsaKey) + if err != nil { + return trace.Wrap(err) + } + + sshPubPEM = ssh.MarshalAuthorizedKey(signer.PublicKey()) + sshPrivPEM = req.privateKey + + keyPEM, certPEM, err = tlsca.GenerateSelfSignedCAWithPrivateKey(rsaKey.(*rsa.PrivateKey), pkix.Name{ + CommonName: ca.GetClusterName(), + Organization: []string{ca.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return trace.Wrap(err) + } + } else { + var err error + sshPrivPEM, sshPubPEM, err = native.GenerateKeyPair("") + if err != nil { + return trace.Wrap(err) + } + + keyPEM, certPEM, err = tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: ca.GetClusterName(), + Organization: []string{ca.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return trace.Wrap(err) + } + } + + tlsKeyPair := &services.TLSKeyPair{ + Cert: certPEM, + Key: keyPEM, + } + + // rotate the certificate authority: + rotation.Started = clock.Now().UTC() + rotation.GracePeriod = services.NewDuration(gracePeriod) + rotation.CurrentID = id + + signingKeys := ca.GetSigningKeys() + checkingKeys := ca.GetCheckingKeys() + keyPairs := ca.GetTLSKeyPairs() + + // Drop old certificate authority without keeping it as trusted. + if gracePeriod == 0 { + signingKeys = [][]byte{sshPrivPEM} + checkingKeys = [][]byte{sshPubPEM} + keyPairs = []services.TLSKeyPair{*tlsKeyPair} + // In case of forced rotation, rotation has been started and completed + // in the same step moving it to standby state. + rotation.State = services.RotationStateStandby + rotation.Phase = services.RotationPhaseStandby + } else { + // Rotation sets the first key to be the new key + // and keep only public keys/certs for the new CA. + signingKeys = [][]byte{sshPrivPEM, signingKeys[0]} + checkingKeys = [][]byte{sshPubPEM, checkingKeys[0]} + oldKeyPair := keyPairs[0] + keyPairs = []services.TLSKeyPair{*tlsKeyPair, oldKeyPair} + rotation.State = services.RotationStateInProgress + rotation.Phase = services.RotationPhaseUpdateClients + } + + ca.SetSigningKeys(signingKeys) + ca.SetCheckingKeys(checkingKeys) + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} + +// startRollingBackRotation starts roll back to the original state. +func startRollingBackRotation(ca services.CertAuthority) error { + rotation := ca.GetRotation() + + // Rollback always sets rotation to manual mode. + rotation.Mode = services.RotationModeManual + + signingKeys := ca.GetSigningKeys() + checkingKeys := ca.GetCheckingKeys() + keyPairs := ca.GetTLSKeyPairs() + + // Rotation sets the first key to be the new key + // and keep only public keys/certs for the new CA. + signingKeys = [][]byte{signingKeys[1]} + checkingKeys = [][]byte{checkingKeys[1]} + + // Keep the new certificate as trusted + // as during the rollback phase, both types of clients may be present in the cluster. + keyPairs = []services.TLSKeyPair{keyPairs[1], services.TLSKeyPair{Cert: keyPairs[0].Cert}} + rotation.State = services.RotationStateInProgress + rotation.Phase = services.RotationPhaseRollback + + ca.SetSigningKeys(signingKeys) + ca.SetCheckingKeys(checkingKeys) + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} + +// completeRollingBackRotation completes rollback of the rotation and sets it to the standby state +func completeRollingBackRotation(clock clockwork.Clock, ca services.CertAuthority) error { + rotation := ca.GetRotation() + + // clean up the state + rotation.Started = time.Time{} + rotation.State = services.RotationStateStandby + rotation.Phase = services.RotationPhaseStandby + rotation.Mode = "" + rotation.Schedule = services.RotationSchedule{} + + keyPairs := ca.GetTLSKeyPairs() + // only keep the original certificate authority as trusted + // and remove everything else. + keyPairs = []services.TLSKeyPair{keyPairs[0]} + + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} + +// completeRotation completes the certificate authority rotation. +func completeRotation(clock clockwork.Clock, ca services.CertAuthority) error { + rotation := ca.GetRotation() + signingKeys := ca.GetSigningKeys() + checkingKeys := ca.GetCheckingKeys() + keyPairs := ca.GetTLSKeyPairs() + + signingKeys = signingKeys[:1] + checkingKeys = checkingKeys[:1] + keyPairs = keyPairs[:1] + + rotation.Started = time.Time{} + rotation.State = services.RotationStateStandby + rotation.Phase = services.RotationPhaseStandby + rotation.LastRotated = clock.Now() + rotation.Mode = "" + rotation.Schedule = services.RotationSchedule{} + + ca.SetSigningKeys(signingKeys) + ca.SetCheckingKeys(checkingKeys) + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} diff --git a/lib/auth/state.go b/lib/auth/state.go new file mode 100644 index 0000000000000..66fec360dfb85 --- /dev/null +++ b/lib/auth/state.go @@ -0,0 +1,246 @@ +package auth + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/dir" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/gravitational/trace" +) + +// ProcessStorage is a backend for local process state, +// it helps to manage rotation for certificate authorities +// and keeps local process credentials - x509 and SSH certs and keys. +type ProcessStorage struct { + b backend.Backend +} + +// NewProcessStorage returns a new instance of the process storage. +func NewProcessStorage(path string) (*ProcessStorage, error) { + if path == "" { + return nil, trace.BadParameter("missing parameter path") + } + backend, err := dir.New(backend.Params{"path": path}) + if err != nil { + return nil, trace.Wrap(err) + } + return &ProcessStorage{b: backend}, nil +} + +// Close closes all resources used by process storage backend. +func (p *ProcessStorage) Close() error { + return p.b.Close() +} + +const ( + // IdentityNameCurrent is a name for the identity credentials that are + // currently used by the process. + IdentityCurrent = "current" + // IdentityReplacement is a name for the identity crdentials that are + // replacing current identity credentials during CA rotation. + IdentityReplacement = "replacement" +) + +// stateName is an internal resource object name +const stateName = "state" + +// GetState reads rotation state from disk. +func (p *ProcessStorage) GetState(role teleport.Role) (*StateV2, error) { + data, err := p.b.GetVal([]string{"states", strings.ToLower(role.String())}, stateName) + if err != nil { + return nil, trace.Wrap(err) + } + var res StateV2 + if err := utils.UnmarshalWithSchema(GetStateSchema(), &res, data); err != nil { + return nil, trace.BadParameter(err.Error()) + } + return &res, nil +} + +// CreateState creates process state if it does not exist yet. +func (p *ProcessStorage) CreateState(role teleport.Role, state StateV2) error { + if err := state.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + data, err := json.Marshal(state) + if err != nil { + return trace.Wrap(err) + } + err = p.b.CreateVal([]string{"states", strings.ToLower(role.String())}, stateName, data, backend.Forever) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// WriteState writes local cluster state to the backend. +func (p *ProcessStorage) WriteState(role teleport.Role, state StateV2) error { + if err := state.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + data, err := json.Marshal(state) + if err != nil { + return trace.Wrap(err) + } + err = p.b.UpsertVal([]string{"states", strings.ToLower(role.String())}, stateName, data, backend.Forever) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// ReadIdentity reads identity using identity name and role. +func (p *ProcessStorage) ReadIdentity(name string, role teleport.Role) (*Identity, error) { + if name == "" { + return nil, trace.BadParameter("missing parameter name") + } + data, err := p.b.GetVal([]string{"ids", strings.ToLower(role.String())}, name) + if err != nil { + return nil, trace.Wrap(err) + } + var res IdentityV2 + if err := utils.UnmarshalWithSchema(GetIdentitySchema(), &res, data); err != nil { + return nil, trace.BadParameter(err.Error()) + } + return ReadIdentityFromKeyPair(res.Spec.Key, res.Spec.SSHCert, res.Spec.TLSCert, res.Spec.TLSCACerts) +} + +// WriteIdentity writes identity to the backend. +func (p *ProcessStorage) WriteIdentity(name string, id Identity) error { + res := IdentityV2{ + ResourceHeader: services.ResourceHeader{ + Kind: services.KindIdentity, + Version: services.V2, + Metadata: services.Metadata{ + Name: name, + }, + }, + Spec: IdentitySpecV2{ + Key: id.KeyBytes, + SSHCert: id.CertBytes, + TLSCert: id.TLSCertBytes, + TLSCACerts: id.TLSCACertsBytes, + }, + } + if err := res.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + data, err := json.Marshal(res) + if err != nil { + return trace.Wrap(err) + } + return p.b.UpsertVal([]string{"ids", strings.ToLower(id.ID.Role.String())}, name, data, backend.Forever) +} + +// StateV2 is a local process state. +type StateV2 struct { + // ResourceHeader is a common resource header. + services.ResourceHeader + // Spec is a process spec. + Spec StateSpecV2 `json:"spec"` +} + +// CheckAndSetDefaults checks and sets defaults values. +func (s *StateV2) CheckAndSetDefaults() error { + s.Kind = services.KindState + s.Version = services.V2 + // for state resource name does not matter + if s.Metadata.Name == "" { + s.Metadata.Name = stateName + } + if err := s.Metadata.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + return nil +} + +// StateSpecV2 is a state spec. +type StateSpecV2 struct { + // Rotation holds local process rotation state. + Rotation services.Rotation `json:"rotation"` +} + +// IdentityV2 specifies local host identity. +type IdentityV2 struct { + // ResourceHeader is a common resource header. + services.ResourceHeader + // Spec is the identity spec. + Spec IdentitySpecV2 `json:"spec"` +} + +// CheckAndSetDefaults checks and sets defaults values. +func (s *IdentityV2) CheckAndSetDefaults() error { + s.Kind = services.KindIdentity + s.Version = services.V2 + if err := s.Metadata.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if len(s.Spec.Key) == 0 { + return trace.BadParameter("missing parameter Key") + } + if len(s.Spec.SSHCert) == 0 { + return trace.BadParameter("missing parameter SSHCert") + } + if len(s.Spec.TLSCert) == 0 { + return trace.BadParameter("missing parameter TLSCert") + } + if len(s.Spec.TLSCACerts) == 0 { + return trace.BadParameter("missing parameter TLSCACerts") + } + return nil +} + +// IdentitySpecV2 specifies credentials used by local process. +type IdentitySpecV2 struct { + // Key is a PEM encoded private key. + Key []byte `json:"key,omitempty"` + // SSHCert is a PEM encoded SSH host cert. + SSHCert []byte `json:"ssh_cert,omitempty"` + // TLSCert is a PEM encoded x509 client certificate. + TLSCert []byte `json:"tls_cert,omitempty"` + // TLSCACert is a list of PEM encoded x509 certificate of the + // certificate authority of the cluster. + TLSCACerts [][]byte `json:"tls_ca_certs,omitempty"` +} + +// IdentitySpecV2Schema is a schema for identity spec. +const IdentitySpecV2Schema = `{ + "type": "object", + "additionalProperties": false, + "required": ["key", "ssh_cert", "tls_cert", "tls_ca_certs"], + "properties": { + "key": {"type": "string"}, + "ssh_cert": {"type": "string"}, + "tls_cert": {"type": "string"}, + "tls_ca_certs": { + "type": "array", + "items": {"type": "string"} + } + } +}` + +// GetIdentitySchema returns JSON Schema for cert authorities. +func GetIdentitySchema() string { + return fmt.Sprintf(services.V2SchemaTemplate, services.MetadataSchema, IdentitySpecV2Schema, services.DefaultDefinitions) +} + +// StateSpecV2Schema is a schema for local server state. +const StateSpecV2Schema = `{ + "type": "object", + "additionalProperties": false, + "required": ["rotation"], + "properties": { + "rotation": %v + } +}` + +// GetStateSchema returns JSON Schema for cert authorities. +func GetStateSchema() string { + return fmt.Sprintf(services.V2SchemaTemplate, services.MetadataSchema, fmt.Sprintf(StateSpecV2Schema, services.RotationSchema), services.DefaultDefinitions) +} diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 2cca155e88f28..a1d1613605e6e 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" + "github.com/jonboulle/clockwork" "github.com/pquerna/otp/totp" "gopkg.in/check.v1" ) @@ -101,6 +102,414 @@ func (s *TLSSuite) TestRemoteBuiltinRole(c *check.C) { fixtures.ExpectAccessDenied(c, err) } +// TestRemoteRotation tests remote builtin role +// that attempts certificate authority rotation +func (s *TLSSuite) TestRemoteRotation(c *check.C) { + remoteServer, err := NewTestAuthServer(TestAuthServerConfig{ + Dir: c.MkDir(), + ClusterName: "remote", + }) + c.Assert(err, check.IsNil) + + certPool, err := s.server.CertPool() + c.Assert(err, check.IsNil) + + // after trust is established, things are good + err = s.server.AuthServer.Trust(remoteServer, nil) + + remoteProxy, err := remoteServer.NewRemoteClient( + TestBuiltin(teleport.RoleProxy), s.server.Addr(), certPool) + c.Assert(err, check.IsNil) + + remoteAuth, err := remoteServer.NewRemoteClient( + TestBuiltin(teleport.RoleAuth), s.server.Addr(), certPool) + c.Assert(err, check.IsNil) + + // remote cluster starts rotation + gracePeriod := time.Hour + remoteServer.AuthServer.privateKey = fixtures.PEMBytes["rsa2"] + err = remoteServer.AuthServer.RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + remoteCA, err := remoteServer.AuthServer.GetCertAuthority(services.CertAuthID{ + DomainName: remoteServer.ClusterName, + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + + // remote proxy should be rejected when trying to rotate ca + // that is not associated with the remote cluster + clone := remoteCA.Clone() + clone.SetName(s.server.ClusterName()) + err = remoteProxy.RotateExternalCertAuthority(clone) + fixtures.ExpectAccessDenied(c, err) + + // remote proxy can't upsert the certificate authority, + // only to rotate it (in remote rotation only certain fields are updated) + err = remoteProxy.UpsertCertAuthority(remoteCA) + fixtures.ExpectAccessDenied(c, err) + + // remote auth server will get rejected + err = remoteAuth.RotateExternalCertAuthority(remoteCA) + fixtures.ExpectAccessDenied(c, err) + + // remote proxy should be able to perform remote cert authority + // rotation + err = remoteProxy.RotateExternalCertAuthority(remoteCA) + c.Assert(err, check.IsNil) + + // newRemoteProxy should be trusted by the auth server + newRemoteProxy, err := remoteServer.NewRemoteClient( + TestBuiltin(teleport.RoleProxy), s.server.Addr(), certPool) + c.Assert(err, check.IsNil) + + _, err = newRemoteProxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // old proxy client is still trusted + _, err = s.server.CloneClient(remoteProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) +} + +// TestLocalProxyPermissions tests new local proxy permissions +// as it's now allowed to update host cert authorities of remote clusters +func (s *TLSSuite) TestLocalProxyPermissions(c *check.C) { + remoteServer, err := NewTestAuthServer(TestAuthServerConfig{ + Dir: c.MkDir(), + ClusterName: "remote", + }) + c.Assert(err, check.IsNil) + + // after trust is established, things are good + err = s.server.AuthServer.Trust(remoteServer, nil) + c.Assert(err, check.IsNil) + + ca, err := s.server.Auth().GetCertAuthority(services.CertAuthID{ + DomainName: s.server.ClusterName(), + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + + proxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + // local proxy can't update local cert authorities + err = proxy.UpsertCertAuthority(ca) + fixtures.ExpectAccessDenied(c, err) + + // local proxy is allowed to update host CA of remote cert authorities + remoteCA, err := s.server.Auth().GetCertAuthority(services.CertAuthID{ + DomainName: remoteServer.ClusterName, + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + + err = proxy.UpsertCertAuthority(remoteCA) + c.Assert(err, check.IsNil) +} + +// TestAutoRotation tests local automatic rotation +func (s *TLSSuite) TestAutoRotation(c *check.C) { + clock := clockwork.NewFakeClockAt(time.Now()) + s.server.Auth().SetClock(clock) + + // create proxy client + proxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + // client works before rotation is initiated + _, err = proxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // starts rotation + s.server.Auth().privateKey = fixtures.PEMBytes["rsa2"] + gracePeriod := time.Hour + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + Mode: services.RotationModeAuto, + }) + c.Assert(err, check.IsNil) + + // old clients should work + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // new clients work as well + newProxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + _, err = newProxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // advance rotation by clock + clock.Advance(gracePeriod/2 + time.Minute) + err = s.server.Auth().autoRotateCertAuthorities() + c.Assert(err, check.IsNil) + + ca, err := s.server.Auth().GetCertAuthority(services.CertAuthID{ + DomainName: s.server.ClusterName(), + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + c.Assert(ca.GetRotation().Phase, check.Equals, services.RotationPhaseUpdateServers) + + // old clients should work + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // new clients work as well + _, err = s.server.CloneClient(newProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // complete rotation - advance rotation by clock + clock.Advance(gracePeriod/2 + time.Minute) + err = s.server.Auth().autoRotateCertAuthorities() + ca, err = s.server.Auth().GetCertAuthority(services.CertAuthID{ + DomainName: s.server.ClusterName(), + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + c.Assert(ca.GetRotation().Phase, check.Equals, services.RotationPhaseStandby) + c.Assert(err, check.IsNil) + + // old clients should no longer work + // new client has to be created here to force re-create the new + // connection instead of re-using the one from pool + // this is not going to be a problem in real teleport + // as it reloads the full server after reload + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.ErrorMatches, ".*bad certificate.*") + + // new clients work + _, err = s.server.CloneClient(newProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) +} + +// TestAutoFallback tests local automatic rotation fallback, +// when user intervenes with rollback and rotation gets switched +// to manual mode +func (s *TLSSuite) TestAutoFallback(c *check.C) { + clock := clockwork.NewFakeClockAt(time.Now()) + s.server.Auth().SetClock(clock) + + // create proxy client just for test purposes + proxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + // client works before rotation is initiated + _, err = proxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // starts rotation + s.server.Auth().privateKey = fixtures.PEMBytes["rsa2"] + gracePeriod := time.Hour + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + Mode: services.RotationModeAuto, + }) + c.Assert(err, check.IsNil) + + ca, err := s.server.Auth().GetCertAuthority(services.CertAuthID{ + DomainName: s.server.ClusterName(), + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + c.Assert(ca.GetRotation().Phase, check.Equals, services.RotationPhaseUpdateClients) + c.Assert(ca.GetRotation().Mode, check.Equals, services.RotationModeAuto) + + // rollback rotation + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseRollback, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + ca, err = s.server.Auth().GetCertAuthority(services.CertAuthID{ + DomainName: s.server.ClusterName(), + Type: services.HostCA, + }, false) + c.Assert(err, check.IsNil) + c.Assert(ca.GetRotation().Phase, check.Equals, services.RotationPhaseRollback) + c.Assert(ca.GetRotation().Mode, check.Equals, services.RotationModeManual) +} + +// TestManualRotation tests local manual rotation +// that performs full-cycle certificate authority rotation +func (s *TLSSuite) TestManualRotation(c *check.C) { + // create proxy client just for test purposes + proxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + // client works before rotation is initiated + _, err = proxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // can't jump to mid-phase + gracePeriod := time.Hour + s.server.Auth().privateKey = fixtures.PEMBytes["rsa2"] + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateServers, + Mode: services.RotationModeManual, + }) + fixtures.ExpectBadParameter(c, err) + + // starts rotation + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // old clients should work + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // new clients work as well + newProxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + _, err = newProxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // can't jump to standy + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseStandby, + Mode: services.RotationModeManual, + }) + fixtures.ExpectBadParameter(c, err) + + // advance rotation: + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateServers, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // old clients should work + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // new clients work as well + _, err = s.server.CloneClient(newProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // complete rotation + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseStandby, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // old clients should no longer work + // new client has to be created here to force re-create the new + // connection instead of re-using the one from pool + // this is not going to be a problem in real teleport + // as it reloads the full server after reload + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.ErrorMatches, ".*bad certificate.*") + + // new clients work + _, err = s.server.CloneClient(newProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) +} + +// TestRollback tests local manual rotation rollback +func (s *TLSSuite) TestRollback(c *check.C) { + // create proxy client just for test purposes + proxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + // client works before rotation is initiated + _, err = proxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // starts rotation + gracePeriod := time.Hour + s.server.Auth().privateKey = fixtures.PEMBytes["rsa2"] + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // new clients work + newProxy, err := s.server.NewClient(TestBuiltin(teleport.RoleProxy)) + c.Assert(err, check.IsNil) + + _, err = newProxy.GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // advance rotation: + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateServers, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // rollback rotation + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseRollback, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // new clients work, server still accepts the creds + // because new clients should re-register and receive new certs + _, err = s.server.CloneClient(newProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) + + // can't jump to other phases + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseUpdateClients, + Mode: services.RotationModeManual, + }) + fixtures.ExpectBadParameter(c, err) + + // complete rollback + err = s.server.Auth().RotateCertAuthority(RotateRequest{ + Type: services.HostCA, + GracePeriod: &gracePeriod, + TargetPhase: services.RotationPhaseStandby, + Mode: services.RotationModeManual, + }) + c.Assert(err, check.IsNil) + + // clients with new creds will no longer work + _, err = s.server.CloneClient(newProxy).GetNodes(defaults.Namespace) + c.Assert(err, check.ErrorMatches, ".*bad certificate.*") + + // clients with old creds will still work + _, err = s.server.CloneClient(proxy).GetNodes(defaults.Namespace) + c.Assert(err, check.IsNil) +} + // TestRemoteUser tests scenario when remote user connects to the local // auth server and some edge cases. func (s *TLSSuite) TestRemoteUser(c *check.C) { diff --git a/lib/backend/backend.go b/lib/backend/backend.go index 619bfaa03061f..ee6b00ddddc7a 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -47,6 +47,11 @@ type Backend interface { UpsertVal(bucket []string, key string, val []byte, ttl time.Duration) error // GetVal return a value for a given key in the bucket GetVal(path []string, key string) ([]byte, error) + // CompareAndSwapVal compares and swaps values in atomic operation, + // succeeds if prevVal matches the value stored in the database, + // requires prevVal as a non-empty value. Returns trace.CompareFailed + // in case if value did not match. + CompareAndSwapVal(bucket []string, key string, val []byte, prevVal []byte, ttl time.Duration) error // DeleteKey deletes a key in a bucket DeleteKey(bucket []string, key string) error // DeleteBucket deletes the bucket by a given path diff --git a/lib/backend/boltbk/boltbk.go b/lib/backend/boltbk/boltbk.go index 044e0f14f7b26..bd07cc2eeb1d3 100644 --- a/lib/backend/boltbk/boltbk.go +++ b/lib/backend/boltbk/boltbk.go @@ -17,6 +17,7 @@ limitations under the License. package boltbk import ( + "bytes" "encoding/json" "path/filepath" "sort" @@ -141,6 +142,51 @@ func (b *BoltBackend) CreateVal(bucket []string, key string, val []byte, ttl tim return trace.Wrap(err) } +// CompareAndSwapVal compares and swap values in atomic operation, +// succeeds if prevData matches the value stored in the databases, +// requires prevData as a non-empty value. Returns trace.CompareFailed +// in case if value did not match +func (b *BoltBackend) CompareAndSwapVal(bucket []string, key string, newData []byte, prevData []byte, ttl time.Duration) error { + if len(prevData) == 0 { + return trace.BadParameter("missing prevData parameter, to atomically create item, use CreateVal method") + } + v := &kv{ + Created: b.clock.Now().UTC(), + Value: newData, + TTL: ttl, + } + newEncodedData, err := json.Marshal(v) + if err != nil { + return trace.Wrap(err) + } + err = b.db.Update(func(tx *bolt.Tx) error { + bkt, err := GetBucket(tx, bucket) + if err != nil { + if trace.IsNotFound(err) { + return trace.CompareFailed("key %q is not found", key) + } + return trace.Wrap(err) + } + currentData := bkt.Get([]byte(key)) + if currentData == nil { + _, err := GetBucket(tx, append(bucket, key)) + if err == nil { + return trace.BadParameter("key %q is a bucket", key) + } + return trace.CompareFailed("%v %v is not found", bucket, key) + } + var currentVal kv + if err := json.Unmarshal(currentData, ¤tVal); err != nil { + return trace.Wrap(err) + } + if bytes.Compare(prevData, currentVal.Value) != 0 { + return trace.CompareFailed("%q is not matching expected value", key) + } + return boltErr(bkt.Put([]byte(key), newEncodedData)) + }) + return trace.Wrap(err) +} + func (b *BoltBackend) upsertVal(path []string, key string, val []byte, ttl time.Duration) error { v := &kv{ Created: b.clock.Now().UTC(), diff --git a/lib/backend/boltbk/boltbk_test.go b/lib/backend/boltbk/boltbk_test.go index 355729c521e4d..e938af0ec5d93 100644 --- a/lib/backend/boltbk/boltbk_test.go +++ b/lib/backend/boltbk/boltbk_test.go @@ -60,6 +60,10 @@ func (s *BoltSuite) TestBasicCRUD(c *C) { s.suite.BasicCRUD(c) } +func (s *BoltSuite) TestCompareAndSwap(c *C) { + s.suite.CompareAndSwap(c) +} + func (s *BoltSuite) TestExpiration(c *C) { s.suite.Expiration(c) } diff --git a/lib/backend/dir/impl.go b/lib/backend/dir/impl.go index 845738643385d..c5c997c1bd5a9 100644 --- a/lib/backend/dir/impl.go +++ b/lib/backend/dir/impl.go @@ -17,6 +17,7 @@ limitations under the License. package dir import ( + "bytes" "io" "io/ioutil" "os" @@ -174,6 +175,57 @@ func (bk *Backend) CreateVal(bucket []string, key string, val []byte, ttl time.D return trace.Wrap(bk.applyTTL(dirPath, key, ttl)) } +// CompareAndSwapVal compares and swap values in atomic operation +func (bk *Backend) CompareAndSwapVal(bucket []string, key string, val []byte, prevVal []byte, ttl time.Duration) error { + if len(prevVal) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + // do not allow keys that start with a dot + if key[0] == reservedPrefix { + return trace.BadParameter("invalid key: '%s'. Key names cannot start with '.'", key) + } + // create the directory: + dirPath := path.Join(bk.RootDir, path.Join(bucket...)) + err := os.MkdirAll(dirPath, defaultDirMode) + if err != nil { + return trace.ConvertSystemError(err) + } + // create the file (AKA "key"): + filename := path.Join(dirPath, key) + f, err := os.OpenFile(filename, os.O_RDWR|os.O_EXCL, defaultFileMode) + if err != nil { + err = trace.ConvertSystemError(err) + if trace.IsNotFound(err) { + return trace.CompareFailed("%v/%v did not match expected value", dirPath, key) + } + return trace.Wrap(err) + } + defer f.Close() + if err := utils.FSWriteLock(f); err != nil { + return trace.Wrap(err) + } + defer utils.FSUnlock(f) + // before writing, make sure the values are equal + oldVal, err := ioutil.ReadAll(f) + if err != nil { + return trace.ConvertSystemError(err) + } + if bytes.Compare(oldVal, prevVal) != 0 { + return trace.CompareFailed("%v/%v did not match expected value", dirPath, key) + } + if _, err := f.Seek(0, 0); err != nil { + return trace.ConvertSystemError(err) + } + if err := f.Truncate(0); err != nil { + return trace.ConvertSystemError(err) + } + n, err := f.Write(val) + if err == nil && n < len(val) { + return trace.Wrap(io.ErrShortWrite) + } + return trace.Wrap(bk.applyTTL(dirPath, key, ttl)) +} + // UpsertVal updates or inserts value with a given TTL into a bucket // ForeverTTL for no TTL func (bk *Backend) UpsertVal(bucket []string, key string, val []byte, ttl time.Duration) error { @@ -280,11 +332,7 @@ func (bk *Backend) DeleteBucket(parent []string, bucket string) error { func removeFiles(dir string) error { d, err := os.Open(dir) if err != nil { - err = trace.ConvertSystemError(err) - if !trace.IsNotFound(err) { - return err - } - return nil + return trace.ConvertSystemError(err) } defer d.Close() names, err := d.Readdirnames(-1) @@ -308,6 +356,10 @@ func removeFiles(dir string) error { if err != nil { return err } + } else if fi.IsDir() { + if err := removeFiles(path); err != nil { + return err + } } } return nil diff --git a/lib/backend/dir/impl_test.go b/lib/backend/dir/impl_test.go index 5248a7cb81a7b..823b16bd5643f 100644 --- a/lib/backend/dir/impl_test.go +++ b/lib/backend/dir/impl_test.go @@ -88,6 +88,14 @@ func (s *Suite) TestConcurrentOperations(c *check.C) { c.Assert(err, check.IsNil) }(i) + go func(cnt int) { + err := s.bk.CompareAndSwapVal(bucket, "key", []byte(value2), []byte(value1), time.Hour) + resultsC <- struct{}{} + if err != nil && !trace.IsCompareFailed(err) { + c.Assert(err, check.IsNil) + } + }(i) + go func(cnt int) { err := s.bk.CreateVal(bucket, "key", []byte(value2), time.Hour) resultsC <- struct{}{} @@ -113,12 +121,14 @@ func (s *Suite) TestConcurrentOperations(c *check.C) { go func(cnt int) { err := s.bk.DeleteBucket([]string{"concurrent"}, "bucket") + if err != nil && !trace.IsNotFound(err) { + c.Assert(err, check.IsNil) + } resultsC <- struct{}{} - c.Assert(err, check.IsNil) }(i) } timeoutC := time.After(3 * time.Second) - for i := 0; i < attempts*4; i++ { + for i := 0; i < attempts*5; i++ { select { case <-resultsC: case <-timeoutC: @@ -127,6 +137,10 @@ func (s *Suite) TestConcurrentOperations(c *check.C) { } } +func (s *Suite) TestCompareAndSwap(c *check.C) { + s.suite.CompareAndSwap(c) +} + func (s *Suite) TestCreateAndRead(c *check.C) { bucket := []string{"one", "two"} diff --git a/lib/backend/dynamo/dynamodbbk.go b/lib/backend/dynamo/dynamodbbk.go index 0b89e95a3dad2..0a679c06c73ce 100644 --- a/lib/backend/dynamo/dynamodbbk.go +++ b/lib/backend/dynamo/dynamodbbk.go @@ -471,6 +471,51 @@ func (b *DynamoDBBackend) UpsertVal(path []string, key string, val []byte, ttl t return b.createKey(fullPath, val, ttl, true) } +// CompareAndSwapVal compares and swap values in atomic operation +func (b *DynamoDBBackend) CompareAndSwapVal(path []string, key string, val []byte, prevVal []byte, ttl time.Duration) error { + if len(prevVal) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + fullPath := b.fullPath(append(path, key)...) + r := record{ + HashKey: hashKey, + FullPath: fullPath, + Value: val, + TTL: ttl, + Timestamp: time.Now().UTC().Unix(), + } + if ttl != backend.Forever { + r.Expires = aws.Int64(b.clock.Now().UTC().Add(ttl).Unix()) + } + av, err := dynamodbattribute.MarshalMap(r) + if err != nil { + return trace.Wrap(err) + } + input := dynamodb.PutItemInput{ + Item: av, + TableName: aws.String(b.Tablename), + } + input.SetConditionExpression("#v = :prev") + input.SetExpressionAttributeNames(map[string]*string{ + "#v": aws.String("Value"), + }) + input.SetExpressionAttributeValues(map[string]*dynamodb.AttributeValue{ + ":prev": &dynamodb.AttributeValue{ + B: prevVal, + }, + }) + _, err = b.svc.PutItem(&input) + err = convertError(err) + if err != nil { + // in this case let's use more specific compare failed error + if trace.IsAlreadyExists(err) { + return trace.CompareFailed(err.Error()) + } + return trace.Wrap(err) + } + return nil +} + const delayBetweenLockAttempts = 100 * time.Millisecond // AcquireLock for a token diff --git a/lib/backend/dynamo/dynamodbbk_test.go b/lib/backend/dynamo/dynamodbbk_test.go index 04153c60ebbe4..11fc2af7eef94 100644 --- a/lib/backend/dynamo/dynamodbbk_test.go +++ b/lib/backend/dynamo/dynamodbbk_test.go @@ -66,6 +66,10 @@ func (s *DynamoDBSuite) TestBasicCRUD(c *C) { s.suite.BasicCRUD(c) } +func (s *DynamoDBSuite) TestCompareAndSwap(c *C) { + s.suite.CompareAndSwap(c) +} + func (s *DynamoDBSuite) TestBatchCRUD(c *C) { s.suite.BatchCRUD(c) } diff --git a/lib/backend/etcdbk/etcd.go b/lib/backend/etcdbk/etcd.go index b2ed02802a507..e41bac2345430 100644 --- a/lib/backend/etcdbk/etcd.go +++ b/lib/backend/etcdbk/etcd.go @@ -166,6 +166,26 @@ func (b *bk) CreateVal(path []string, key string, val []byte, ttl time.Duration) return trace.Wrap(convertErr(err)) } +// CompareAndSwapVal compares and swap values in atomic operation, +// succeeds if prevVal matches the value stored in the databases, +// requires prevVal as a non-empty value. Returns trace.CompareFailed +// in case if value did not match +func (b *bk) CompareAndSwapVal(path []string, key string, val []byte, prevVal []byte, ttl time.Duration) error { + if len(prevVal) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + encodedPrev := base64.StdEncoding.EncodeToString(prevVal) + _, err := b.api.Set( + context.Background(), + b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val), + &client.SetOptions{PrevValue: encodedPrev, PrevExist: client.PrevExist, TTL: ttl}) + err = convertErr(err) + if trace.IsNotFound(err) { + return trace.CompareFailed(err.Error()) + } + return trace.Wrap(err) +} + // maxOptimisticAttempts is the number of attempts optimistic locking const maxOptimisticAttempts = 5 diff --git a/lib/backend/etcdbk/etcd_test.go b/lib/backend/etcdbk/etcd_test.go index 7bb7855a7ae1c..f959ebc4c2e27 100644 --- a/lib/backend/etcdbk/etcd_test.go +++ b/lib/backend/etcdbk/etcd_test.go @@ -103,6 +103,10 @@ func (s *EtcdSuite) TestBasicCRUD(c *C) { s.suite.BasicCRUD(c) } +func (s *EtcdSuite) TestCompareAndSwap(c *C) { + s.suite.CompareAndSwap(c) +} + func (s *EtcdSuite) TestExpiration(c *C) { s.suite.Expiration(c) } diff --git a/lib/backend/test/suite.go b/lib/backend/test/suite.go index 57ff274abe88c..56f5c5b4aad06 100644 --- a/lib/backend/test/suite.go +++ b/lib/backend/test/suite.go @@ -24,6 +24,7 @@ import ( "time" "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/trace" . "gopkg.in/check.v1" @@ -115,6 +116,35 @@ func (s *BackendSuite) BasicCRUD(c *C) { c.Assert(trace.IsNotFound(err), Equals, true, Commentf("%#v", err)) } +// CompareAndSwap tests compare and swap functionality +func (s *BackendSuite) CompareAndSwap(c *C) { + bucket := []string{"test", "cas"} + + // compare and swap on non existing operation will fail + err := s.B.CompareAndSwapVal(bucket, "one", []byte("1"), []byte("2"), backend.Forever) + fixtures.ExpectCompareFailed(c, err) + + err = s.B.CreateVal(bucket, "one", []byte("1"), backend.Forever) + c.Assert(err, IsNil) + + // success CAS! + err = s.B.CompareAndSwapVal(bucket, "one", []byte("2"), []byte("1"), backend.Forever) + c.Assert(err, IsNil) + + val, err := s.B.GetVal(bucket, "one") + c.Assert(err, IsNil) + c.Assert(string(val), Equals, "2") + + // value has been updated - not '1' any more + err = s.B.CompareAndSwapVal(bucket, "one", []byte("3"), []byte("1"), backend.Forever) + fixtures.ExpectCompareFailed(c, err) + + // existing value has not been changed by the failed CAS operation + val, err = s.B.GetVal(bucket, "one") + c.Assert(err, IsNil) + c.Assert(string(val), Equals, "2") +} + // BatchCRUD tests batch CRUD operations if supported by the backend func (s *BackendSuite) BatchCRUD(c *C) { getter, ok := s.B.(backend.ItemsGetter) diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 9b4d0759d3994..29acd28f4ac8b 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -764,7 +764,7 @@ func (k *KeyPair) Identity() (*auth.Identity, error) { } else { certBytes = []byte(k.Cert) } - return auth.ReadIdentityFromKeyPair(keyBytes, certBytes, []byte(k.TLSCert), []byte(k.TLSCACert)) + return auth.ReadIdentityFromKeyPair(keyBytes, certBytes, []byte(k.TLSCert), [][]byte{[]byte(k.TLSCACert)}) } // Authority is a host or user certificate authority that diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 6d299aa20789b..12ca7176ebbfb 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -242,6 +242,9 @@ var ( // ReportingPeriod is a period for reports in logs ReportingPeriod = 5 * time.Minute + + // HighResPollingPeriod is a default high resolution polling period + HighResPollingPeriod = 10 * time.Second ) // Default connection limits, they can be applied separately on any of the Teleport @@ -270,6 +273,10 @@ const ( // CertDuration is a default certificate duration // 12 is default as it' longer than average working day (I hope so) CertDuration = 12 * time.Hour + // RotationGracePeriod is a default rotation period for graceful + // certificate rotations, by default to set to maximum allowed user + // cert duration + RotationGracePeriod = MaxCertDuration ) // list of roles teleport service can run as: diff --git a/lib/fixtures/fixtures.go b/lib/fixtures/fixtures.go index 73cbaa9709d92..d1233b356cf1c 100644 --- a/lib/fixtures/fixtures.go +++ b/lib/fixtures/fixtures.go @@ -11,27 +11,27 @@ import ( // ExpectNotFound expects not found error func ExpectNotFound(c *check.C, err error) { - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected NotFound, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected NotFound, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectBadParameter expects bad parameter error func ExpectBadParameter(c *check.C, err error) { - c.Assert(trace.IsBadParameter(err), check.Equals, true, check.Commentf("expected BadParameter, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsBadParameter(err), check.Equals, true, check.Commentf("expected BadParameter, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectCompareFailed expects compare failed error func ExpectCompareFailed(c *check.C, err error) { - c.Assert(trace.IsCompareFailed(err), check.Equals, true, check.Commentf("expected CompareFailed, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsCompareFailed(err), check.Equals, true, check.Commentf("expected CompareFailed, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectAccessDenied expects error to be access denied func ExpectAccessDenied(c *check.C, err error) { - c.Assert(trace.IsAccessDenied(err), check.Equals, true, check.Commentf("expected AccessDenied, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsAccessDenied(err), check.Equals, true, check.Commentf("expected AccessDenied, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectAlreadyExists expects already exists error func ExpectAlreadyExists(c *check.C, err error) { - c.Assert(trace.IsAlreadyExists(err), check.Equals, true, check.Commentf("expected AlreadyExists, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsAlreadyExists(err), check.Equals, true, check.Commentf("expected AlreadyExists, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // DeepCompare uses gocheck DeepEquals but provides nice diff if things are not equal diff --git a/lib/httplib/httplib.go b/lib/httplib/httplib.go index ee7d76ee33f84..5ba19cb15e666 100644 --- a/lib/httplib/httplib.go +++ b/lib/httplib/httplib.go @@ -104,9 +104,9 @@ func ReadJSON(r *http.Request, val interface{}) error { func ConvertResponse(re *roundtrip.Response, err error) (*roundtrip.Response, error) { if err != nil { if uerr, ok := err.(*url.Error); ok && uerr != nil && uerr.Err != nil { - return nil, trace.Wrap(uerr.Err) + return nil, trace.ConnectionProblem(uerr.Err, uerr.Error()) } - return nil, trace.Wrap(err) + return nil, trace.ConvertSystemError(err) } return re, trace.ReadError(re.Code(), re.Bytes()) } diff --git a/lib/multiplexer/multiplexer.go b/lib/multiplexer/multiplexer.go index 6b380b9b179cc..8b902417df998 100644 --- a/lib/multiplexer/multiplexer.go +++ b/lib/multiplexer/multiplexer.go @@ -32,6 +32,7 @@ import ( "sync" "time" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/trace" @@ -57,6 +58,8 @@ type Config struct { DisableSSH bool // DisableTLS disables TLS socket DisableTLS bool + // ID is an identifier used for debugging purposes + ID string } // CheckAndSetDefaults verifies configuration and sets defaults @@ -86,7 +89,7 @@ func New(cfg Config) (*Mux, error) { waitContext, waitCancel := context.WithCancel(context.TODO()) return &Mux{ Entry: log.WithFields(log.Fields{ - trace.Component: "mux", + trace.Component: teleport.Component("mx", cfg.ID), }), Config: cfg, context: ctx, diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index fb8c63d9ece8a..9ac7e9b05d530 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -76,6 +76,12 @@ type remoteSite struct { // remoteAccessPoint provides access to a cached subset of the Auth Server API of // the remote cluster this site belongs to. remoteAccessPoint auth.AccessPoint + + // remoteCA is the last remote certificate authority recorded by the client. + // It is used to detect CA rotation status changes. If the rotation + // state has been changed, the tunnel will reconnect to re-create the client + // with new settings. + remoteCA services.CertAuthority } func (s *remoteSite) getRemoteClient() (auth.ClientI, bool, error) { @@ -287,6 +293,24 @@ func (s *remoteSite) GetLastConnected() time.Time { return connInfo.GetLastHeartbeat() } +func (s *remoteSite) compareAndSwapCertAuthority(ca services.CertAuthority) error { + s.Lock() + defer s.Unlock() + + if s.remoteCA == nil { + s.remoteCA = ca + return nil + } + + rotation := s.remoteCA.GetRotation() + if rotation.Matches(ca.GetRotation()) { + s.remoteCA = ca + return nil + } + s.remoteCA = ca + return trace.CompareFailed("remote certificate authority rotation has been updated") +} + func (s *remoteSite) periodicSendDiscoveryRequests() { ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) defer ticker.Stop() @@ -307,68 +331,86 @@ func (s *remoteSite) periodicSendDiscoveryRequests() { } } -// DELETE IN: 2.6.0 -// attemptCertExchange tries to exchange TLS certificates with remote -// clusters that have upgraded to 2.5.0 -func (s *remoteSite) attemptCertExchange() error { - // this logic is explicitly using the local non cached - // client as it has to have write access to the auth server - localCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ +// updateCertAuthorities updates local and remote cert authorities +func (s *remoteSite) updateCertAuthorities() error { + // update main cluster cert authorities on the remote side + // remote side makes sure that only relevant fields + // are updated + hostCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ Type: services.HostCA, DomainName: s.srv.ClusterName, }, false) if err != nil { return trace.Wrap(err) } - re, err := s.remoteClient.ExchangeCerts(auth.ExchangeCertsRequest{ - PublicKey: localCA.GetCheckingKeys()[0], - TLSCert: localCA.GetTLSKeyPairs()[0].Cert, - }) + err = s.remoteClient.RotateExternalCertAuthority(hostCA) + if err != nil { + return trace.Wrap(err) + } + + userCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ + Type: services.UserCA, + DomainName: s.srv.ClusterName, + }, false) + if err != nil { + return trace.Wrap(err) + } + err = s.remoteClient.RotateExternalCertAuthority(userCA) if err != nil { return trace.Wrap(err) } - remoteCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ + + // update remote cluster's host cert authoritiy on a local cluster + // local proxy is authorized to perform this operation only for + // host authorities of remote clusters. + remoteCA, err := s.remoteClient.GetCertAuthority(services.CertAuthID{ Type: services.HostCA, DomainName: s.domainName, }, false) if err != nil { return trace.Wrap(err) } - _, err = s.localClient.ExchangeCerts(auth.ExchangeCertsRequest{ - PublicKey: remoteCA.GetCheckingKeys()[0], - TLSCert: re.TLSCert, - }) - return trace.Wrap(err) -} -// DELETE IN: 2.6.0 -// This logic is only relevant for upgrades from 2.5.0 to 2.6.0 -func (s *remoteSite) periodicAttemptCertExchange() { - ticker := time.NewTicker(defaults.NetworkBackoffDuration) - defer ticker.Stop() - if err := s.attemptCertExchange(); err != nil { - s.Warningf("Attempt at cert exchange failed: %v.", err) - } else { - s.Debugf("Certificate exchange has completed, going to force reconnect.") - s.srv.RemoveSite(s.domainName) - s.Close() - return + if remoteCA.GetClusterName() != s.domainName { + return trace.BadParameter( + "remote cluster sent different cluster name %v instead of expected one %v", + remoteCA.GetClusterName(), s.domainName) } + err = s.localClient.UpsertCertAuthority(remoteCA) + if err != nil { + return trace.Wrap(err) + } + + return s.compareAndSwapCertAuthority(remoteCA) +} +func (s *remoteSite) periodicUpdateCertAuthorities() { + s.Debugf("Ticking with period %v", s.srv.PollingPeriod) + ticker := time.NewTicker(s.srv.PollingPeriod) + defer ticker.Stop() for { select { case <-s.ctx.Done(): s.Debugf("Context is closing.") return case <-ticker.C: - err := s.attemptCertExchange() + err := s.updateCertAuthorities() if err != nil { - s.Warningf("Could not perform certificate exchange: %v.", trace.DebugReport(err)) + switch { + case trace.IsNotFound(err): + s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName) + case trace.IsCompareFailed(err): + s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.") + s.srv.RemoveSite(s.domainName) + s.Close() + return + case trace.IsConnectionProblem(err): + s.Debugf("Remote cluster %v is offline.", s.domainName) + default: + s.Warningf("Could not perform cert authorities updated: %v.", trace.DebugReport(err)) + } } else { - s.Debugf("Certificate exchange has completed, going to force reconnect.") - s.srv.RemoveSite(s.domainName) - s.Close() - return + s.Debugf("Certificate authorities updated.") } } } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index df29b50441709..1a45d6eb84f5f 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -147,6 +147,9 @@ type Config struct { MACAlgorithms []string // DataDir is a local server data directory DataDir string + // PollingPeriod specifies polling period for internal sync + // goroutines, used to speed up sync-ups in tests. + PollingPeriod time.Duration } // CheckAndSetDefaults checks parameters and sets default values @@ -169,6 +172,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.Context == nil { cfg.Context = context.TODO() } + if cfg.PollingPeriod == 0 { + cfg.PollingPeriod = defaults.HighResPollingPeriod + } if cfg.Limiter == nil { var err error cfg.Limiter, err = limiter.NewLimiter(limiter.LimiterConfig{}) @@ -827,7 +833,7 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { remoteSite.localClient = srv.localAuthClient remoteSite.localAccessPoint = srv.localAccessPoint - clt, isLegacyRemoteCluster, err := remoteSite.getRemoteClient() + clt, _, err := remoteSite.getRemoteClient() if err != nil { return nil, trace.Wrap(err) } @@ -852,11 +858,7 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { remoteSite.certificateCache = certificateCache go remoteSite.periodicSendDiscoveryRequests() - - // if remote cluster is legacy, attempt periodic certificate exchanges - if isLegacyRemoteCluster { - go remoteSite.periodicAttemptCertExchange() - } + go remoteSite.periodicUpdateCertAuthorities() return remoteSite, nil } diff --git a/lib/service/cfg.go b/lib/service/cfg.go index 8844e9d421551..db24a9641eb8c 100644 --- a/lib/service/cfg.go +++ b/lib/service/cfg.go @@ -140,6 +140,21 @@ type Config struct { // UploadEventsC is a channel for upload events // used in tests UploadEventsC chan *events.UploadEvent `json:"-"` + + // FileDescriptors is an optional list of file descriptors for the process + // to inherit and use for listeners, used for in-process updates. + FileDescriptors []FileDescriptor + + // PollingPeriod is set to override default internal polling periods + // of sync agents, used to speed up integration tests. + PollingPeriod time.Duration + + // ClientTimeout is set to override default client timeouts + // used by internal clients, used to speed up integration tests. + ClientTimeout time.Duration + + // ShutdownTimeout is set to override default shutdown timeout. + ShutdownTimeout time.Duration } // ApplyToken assigns a given token to all internal services but only if token diff --git a/lib/service/cfg_test.go b/lib/service/cfg_test.go index d8e5b11038f9f..f61891af22d46 100644 --- a/lib/service/cfg_test.go +++ b/lib/service/cfg_test.go @@ -22,56 +22,56 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" - . "gopkg.in/check.v1" + "gopkg.in/check.v1" ) -func TestConfig(t *testing.T) { TestingT(t) } +func TestConfig(t *testing.T) { check.TestingT(t) } type ConfigSuite struct { } -var _ = Suite(&ConfigSuite{}) +var _ = check.Suite(&ConfigSuite{}) -func (s *ConfigSuite) SetUpSuite(c *C) { +func (s *ConfigSuite) SetUpSuite(c *check.C) { utils.InitLoggerForTests() } -func (s *ConfigSuite) TestDefaultConfig(c *C) { +func (s *ConfigSuite) TestDefaultConfig(c *check.C) { config := MakeDefaultConfig() - c.Assert(config, NotNil) + c.Assert(config, check.NotNil) // all 3 services should be enabled by default - c.Assert(config.Auth.Enabled, Equals, true) - c.Assert(config.SSH.Enabled, Equals, true) - c.Assert(config.Proxy.Enabled, Equals, true) + c.Assert(config.Auth.Enabled, check.Equals, true) + c.Assert(config.SSH.Enabled, check.Equals, true) + c.Assert(config.Proxy.Enabled, check.Equals, true) localAuthAddr := utils.NetAddr{AddrNetwork: "tcp", Addr: "0.0.0.0:3025"} localProxyAddr := utils.NetAddr{AddrNetwork: "tcp", Addr: "0.0.0.0:3023"} localSSHAddr := utils.NetAddr{AddrNetwork: "tcp", Addr: "0.0.0.0:3022"} // data dir, hostname and auth server - c.Assert(config.DataDir, Equals, defaults.DataDir) + c.Assert(config.DataDir, check.Equals, defaults.DataDir) if len(config.Hostname) < 2 { c.Error("default hostname wasn't properly set") } // auth section auth := config.Auth - c.Assert(auth.SSHAddr, DeepEquals, localAuthAddr) - c.Assert(auth.Limiter.MaxConnections, Equals, int64(defaults.LimiterMaxConnections)) - c.Assert(auth.Limiter.MaxNumberOfUsers, Equals, defaults.LimiterMaxConcurrentUsers) - c.Assert(auth.StorageConfig.Type, Equals, "bolt") - c.Assert(auth.StorageConfig.Params["path"], Equals, config.DataDir) + c.Assert(auth.SSHAddr, check.DeepEquals, localAuthAddr) + c.Assert(auth.Limiter.MaxConnections, check.Equals, int64(defaults.LimiterMaxConnections)) + c.Assert(auth.Limiter.MaxNumberOfUsers, check.Equals, defaults.LimiterMaxConcurrentUsers) + c.Assert(auth.StorageConfig.Type, check.Equals, "bolt") + c.Assert(auth.StorageConfig.Params["path"], check.Equals, config.DataDir) // SSH section ssh := config.SSH - c.Assert(ssh.Addr, DeepEquals, localSSHAddr) - c.Assert(ssh.Limiter.MaxConnections, Equals, int64(defaults.LimiterMaxConnections)) - c.Assert(ssh.Limiter.MaxNumberOfUsers, Equals, defaults.LimiterMaxConcurrentUsers) + c.Assert(ssh.Addr, check.DeepEquals, localSSHAddr) + c.Assert(ssh.Limiter.MaxConnections, check.Equals, int64(defaults.LimiterMaxConnections)) + c.Assert(ssh.Limiter.MaxNumberOfUsers, check.Equals, defaults.LimiterMaxConcurrentUsers) // proxy section proxy := config.Proxy - c.Assert(proxy.SSHAddr, DeepEquals, localProxyAddr) - c.Assert(proxy.Limiter.MaxConnections, Equals, int64(defaults.LimiterMaxConnections)) - c.Assert(proxy.Limiter.MaxNumberOfUsers, Equals, defaults.LimiterMaxConcurrentUsers) + c.Assert(proxy.SSHAddr, check.DeepEquals, localProxyAddr) + c.Assert(proxy.Limiter.MaxConnections, check.Equals, int64(defaults.LimiterMaxConnections)) + c.Assert(proxy.Limiter.MaxNumberOfUsers, check.Equals, defaults.LimiterMaxConcurrentUsers) } diff --git a/lib/service/connect.go b/lib/service/connect.go new file mode 100644 index 0000000000000..642bf82c36326 --- /dev/null +++ b/lib/service/connect.go @@ -0,0 +1,434 @@ +package service + +import ( + "time" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" +) + +// connectToAuthService attempts to login into the auth servers specified in the +// configuration and receive credentials. +func (process *TeleportProcess) connectToAuthService(role teleport.Role) (*Connector, error) { + connector, err := process.connect(role) + if err != nil { + return nil, trace.Wrap(err) + } + process.Debugf("Connected client: %v", connector.ClientIdentity) + process.Debugf("Connected server: %v", connector.ServerIdentity) + process.addConnector(connector) + return connector, nil +} + +func (process *TeleportProcess) connect(role teleport.Role) (*Connector, error) { + state, err := process.storage.GetState(role) + if err != nil { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + // no state recorded - this is the first connect + // process will try to connect with the security token. + return process.firstTimeConnect(role) + } + process.Debugf("Connected state: %v.", state.Spec.Rotation.String()) + + identity, err := process.GetIdentity(role) + if err != nil { + return nil, trace.Wrap(err) + } + + rotation := state.Spec.Rotation + + switch rotation.State { + // rotation is on standby, so just use whatever is current + case "", services.RotationStateStandby: + // The roles of admin and auth are treaded in a special way, as in this case + // the process does not need TLS clients and can use local auth directly. + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + }, nil + } + log.Infof("Connecting to the cluster %v with TLS client certificate.", identity.ClusterName) + client, err := process.newClient(process.Config.AuthServers, identity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{Client: client, ClientIdentity: identity, ServerIdentity: identity}, nil + case services.RotationStateInProgress: + switch rotation.Phase { + case services.RotationPhaseUpdateClients: + // Clients should use updated credentials, + // while servers should use old credentials to answer auth requests. + newIdentity, err := process.storage.ReadIdentity(auth.IdentityReplacement, role) + if err != nil { + return nil, trace.Wrap(err) + } + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: newIdentity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + }, nil + } + client, err := process.newClient(process.Config.AuthServers, newIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{ + Client: client, + ClientIdentity: newIdentity, + ServerIdentity: identity, + }, nil + case services.RotationPhaseUpdateServers: + // Servers and clients are using new identity credentials, but the + // identity is still set up to trust the old certificate authority certificates. + newIdentity, err := process.storage.ReadIdentity(auth.IdentityReplacement, role) + if err != nil { + return nil, trace.Wrap(err) + } + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: newIdentity, + ServerIdentity: newIdentity, + AuthServer: process.getLocalAuth(), + }, nil + } + client, err := process.newClient(process.Config.AuthServers, newIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{ + Client: client, + ClientIdentity: newIdentity, + ServerIdentity: newIdentity, + }, nil + case services.RotationPhaseRollback: + // In rollback phase, clients and servers should switch back + // to the old certificate authority-issued credentials, + // but the new certificate authority should be trusted + // because not all clients can update at the same time. + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + }, nil + } + client, err := process.newClient(process.Config.AuthServers, identity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{ + Client: client, + ClientIdentity: identity, + ServerIdentity: identity, + }, nil + default: + return nil, trace.BadParameter("unsupported rotation phase: %q", rotation.Phase) + } + default: + return nil, trace.BadParameter("unsupported rotation state: %q", rotation.State) + } +} + +func (process *TeleportProcess) firstTimeConnect(role teleport.Role) (*Connector, error) { + id := auth.IdentityID{ + Role: role, + HostUUID: process.Config.HostUUID, + NodeName: process.Config.Hostname, + } + additionalPrincipals, err := process.getAdditionalPrincipals(role) + if err != nil { + return nil, trace.Wrap(err) + } + var identity *auth.Identity + if process.getLocalAuth() != nil { + // Auth service is on the same host, no need to go though the invitation + // procedure. + process.Debugf("This server has local Auth server started, using it to add role to the cluster.") + identity, err = auth.LocalRegister(id, process.getLocalAuth(), additionalPrincipals) + } else { + // Auth server is remote, so we need a provisioning token. + if process.Config.Token == "" { + return nil, trace.BadParameter("%v must join a cluster and needs a provisioning token", role) + } + process.Infof("Joining the cluster with a token %v.", process.Config.Token) + identity, err = auth.Register(process.Config.DataDir, process.Config.Token, id, process.Config.AuthServers, additionalPrincipals) + } + if err != nil { + return nil, trace.Wrap(err) + } + + log.Infof("%v has successfully registered with the cluster.", role) + var connector *Connector + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + connector = &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + } + } else { + client, err := process.newClient(process.Config.AuthServers, identity) + if err != nil { + return nil, trace.Wrap(err) + } + connector = &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + Client: client, + } + } + + // Sync local rotation state to match the remote rotation state. + ca, err := connector.GetCertAuthority(services.CertAuthID{ + DomainName: connector.ClientIdentity.ClusterName, + Type: services.HostCA, + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + + err = process.storage.WriteIdentity(auth.IdentityCurrent, *identity) + if err != nil { + process.Warningf("Failed to write %v identity: %v.", role, err) + } + + err = process.storage.WriteState(role, auth.StateV2{ + Spec: auth.StateSpecV2{ + Rotation: ca.GetRotation(), + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + process.Infof("The process has successfully wrote credentials and state of %v to disk.", role) + return connector, nil +} + +// periodicSyncRotationState checks rotation state periodically and +// takes action if necessary +func (process *TeleportProcess) periodicSyncRotationState() error { + // start rotation only after teleport process has started + eventC := make(chan Event, 1) + process.WaitForEvent(process.ExitContext(), TeleportReadyEvent, eventC) + select { + case <-eventC: + process.Infof("The new service has started successfully. Starting syncing rotation status.") + case <-process.ExitContext().Done(): + process.Infof("Periodic rotation sync has exited.") + return nil + } + + t := time.NewTicker(process.Config.PollingPeriod) + defer t.Stop() + for { + select { + case <-t.C: + needsReload, err := process.syncRotationState() + if err != nil { + if trace.IsConnectionProblem(err) { + process.Warningf("Connection problem: sync rotation state: %v.", err) + } else { + process.Warningf("Failed to sync rotation state: %v.", err) + } + } else if needsReload { + process.Debugf("Sync rotation state detected cert authority reload. Triggering reload process.") + process.BroadcastEvent(Event{Name: TeleportReloadEvent}) + return nil + } + case <-process.ExitContext().Done(): + process.Infof("Periodic rotation sync has exited because the process is shutting down.") + return nil + } + } +} + +// syncRotationState compares cluster rotation state with the state of +// internal services and performs the rotation if necessary. +func (process *TeleportProcess) syncRotationState() (bool, error) { + connectors := process.getConnectors() + if len(connectors) == 0 { + return false, trace.BadParameter("no connectors found") + } + // it is important to use the same view of the certificate authority + // for all internal services at the same time, so that the same + // procedure will be applied at the same time for multiple service process + // and no internal services is left behind. + conn := connectors[0] + ca, err := conn.GetCertAuthority(services.CertAuthID{ + DomainName: conn.ClientIdentity.ClusterName, + Type: services.HostCA, + }, false) + if err != nil { + return false, trace.Wrap(err) + } + var needsReload bool + for _, conn := range connectors { + reload, err := process.syncServiceRotationState(ca, conn) + if err != nil { + return false, trace.Wrap(err) + } + if reload { + needsReload = true + } + } + return needsReload, nil +} + +// syncServiceRotationState syncs up rotation state for internal services (Auth, Proxy, Node) and +// if necessary, updates credentials. Returns true if the service will need to reload. +func (process *TeleportProcess) syncServiceRotationState(ca services.CertAuthority, conn *Connector) (bool, error) { + state, err := process.storage.GetState(conn.ClientIdentity.ID.Role) + if err != nil { + return false, trace.Wrap(err) + } + ret, err := process.rotate(conn, *state, ca.GetRotation()) + return ret, err +} + +// rotate is called to check if rotation should be triggered. +func (process *TeleportProcess) rotate(conn *Connector, localState auth.StateV2, remote services.Rotation) (bool, error) { + id := conn.ClientIdentity.ID + local := localState.Spec.Rotation + if local.Matches(remote) { + // nothing to do, local state and rotation state are in sync + return false, nil + } + + additionalPrincipals, err := process.getAdditionalPrincipals(id.Role) + if err != nil { + return false, trace.Wrap(err) + } + + storage := process.storage + + const outOfSync = "%v and cluster rotation state (%v) is out of sync with local (%v). Clear local state and re-register this %v." + + writeStateAndIdentity := func(name string, identity *auth.Identity) error { + err = storage.WriteIdentity(name, *identity) + if err != nil { + return trace.Wrap(err) + } + localState.Spec.Rotation = remote + err = storage.WriteState(id.Role, localState) + if err != nil { + return trace.Wrap(err) + } + return nil + } + + switch remote.State { + case "", services.RotationStateStandby: + switch local.State { + // There is nothing to do, it could happen + // that the old node came up and missed the whole rotation + // rollback cycle. + case "", services.RotationStateStandby: + if len(additionalPrincipals) != 0 && !conn.ServerIdentity.HasPrincipals(additionalPrincipals) { + process.Infof("%v has updated principals to %q, going to request new principals and update") + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = storage.WriteIdentity(auth.IdentityCurrent, *identity) + if err != nil { + return false, trace.Wrap(err) + } + return true, nil + } + return false, nil + case services.RotationStateInProgress: + // Rollback phase has been completed, all services + // will receive new identities. + if local.Phase != services.RotationPhaseRollback && local.CurrentID != remote.CurrentID { + return false, trace.CompareFailed(outOfSync, id.Role, remote, local, id.Role) + } + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityCurrent, identity) + if err != nil { + return false, trace.Wrap(err) + } + return true, nil + default: + return false, trace.BadParameter("unsupported state: %q", localState) + } + case services.RotationStateInProgress: + switch remote.Phase { + case services.RotationPhaseStandby, "": + // There is nothing to do. + return false, nil + case services.RotationPhaseUpdateClients: + // Only allow transition in case if local rotation state is standby + // so this server is in the "clean" state. + if local.State != services.RotationStateStandby && local.State != "" { + return false, trace.CompareFailed(outOfSync, id.Role, remote, local, id.Role) + } + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityReplacement, identity) + if err != nil { + return false, trace.Wrap(err) + } + // Require reload of teleport process to update client and servers. + return true, nil + case services.RotationPhaseUpdateServers: + // Allow transition to this phase only if the previous + // phase was "Update clients". + if local.Phase != services.RotationPhaseUpdateClients && local.CurrentID != remote.CurrentID { + return false, trace.CompareFailed(outOfSync, id.Role, remote, local, id.Role) + } + // Write the replacement identity as a current identity and reload the server. + replacement, err := storage.ReadIdentity(auth.IdentityReplacement, id.Role) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityCurrent, replacement) + if err != nil { + return false, trace.Wrap(err) + } + // Require reload of teleport process to update servers. + return true, nil + case services.RotationPhaseRollback: + // Allow transition to this phase from any other local phase + // because it will be widely used to recover cluster state to + // the previously valid state, client will re-register to receive + // credentials signed by the "old" CA. + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityCurrent, identity) + if err != nil { + return false, trace.Wrap(err) + } + // Require reload of teleport process to update servers. + return true, nil + default: + return false, trace.BadParameter("unsupported phase: %q", remote.Phase) + } + default: + return false, trace.BadParameter("unsupported state: %q", remote.State) + } +} + +func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity *auth.Identity) (*auth.Client, error) { + tlsConfig, err := identity.TLSConfig() + if err != nil { + return nil, trace.Wrap(err) + } + if process.Config.ClientTimeout != 0 { + return auth.NewTLSClient(authServers, tlsConfig, auth.ClientTimeout(process.Config.ClientTimeout)) + } + return auth.NewTLSClient(authServers, tlsConfig) +} diff --git a/lib/service/service.go b/lib/service/service.go index de2cbda6c19de..5bbaa884e8c44 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -32,6 +32,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "golang.org/x/crypto/ssh" @@ -113,6 +114,15 @@ const ( // TeleportExitEvent is generated when the Teleport process begins closing // all listening sockets and exiting. TeleportExitEvent = "TeleportExit" + + // TeleportReloadEvent is generated to trigger in-process teleport + // service reload - all servers and clients will be re-created + // in a graceful way. + TeleportReloadEvent = "TeleportReload" + + // TeleportReadyEvent is generated to signal that all teleport + // internal components have started successfully. + TeleportReadyEvent = "TeleportReady" ) // RoleConfig is a configuration for a server role (either proxy or node) @@ -128,8 +138,38 @@ type RoleConfig struct { // Connector has all resources process needs to connect // to other parts of the cluster: client and identity type Connector struct { - Identity *auth.Identity - Client *auth.Client + // ClientIdentity is the identity to be used in internal cluster + // clients to the auth service. + ClientIdentity *auth.Identity + // ServerIdentity is the identity to be used in servers - serving SSH + // and x509 certificates to clients. + ServerIdentity *auth.Identity + // Client is authenticated client with credentials from ClientIdenity. + Client *auth.Client + // AuthServer is auth server, used in connectors created for auth + // service components. + AuthServer *auth.AuthServer +} + +// ReRegister receives new identity credentials for proxy, node and auth. +// In case if auth servers, the role is 'TeleportAdmin' and instead of using +// TLS client this method uses the local auth server. +func (c *Connector) ReRegister(additionalPrincipals []string) (*auth.Identity, error) { + if c.ClientIdentity.ID.Role == teleport.RoleAdmin || c.ClientIdentity.ID.Role == teleport.RoleAuth { + return auth.GenerateIdentity(c.AuthServer, c.ClientIdentity.ID, additionalPrincipals) + } + return auth.ReRegister(c.Client, c.ClientIdentity.ID, additionalPrincipals) +} + +// GetCertAuthority returns cert authority by ID. +// In case if auth servers, the role is 'TeleportAdmin' and instead of using +// TLS client this method uses the local auth server. +func (c *Connector) GetCertAuthority(id services.CertAuthID, loadPrivateKeys bool) (services.CertAuthority, error) { + if c.ClientIdentity.ID.Role == teleport.RoleAdmin || c.ClientIdentity.ID.Role == teleport.RoleAuth { + return c.AuthServer.GetCertAuthority(id, loadPrivateKeys) + } else { + return c.Client.GetCertAuthority(id, loadPrivateKeys) + } } // TeleportProcess structure holds the state of the Teleport daemon, controlling @@ -149,6 +189,10 @@ type TeleportProcess struct { // identities of this process (credentials to auth sever, basically) Identities map[teleport.Role]*auth.Identity + + // connectors is a list of connected clients and their identities + connectors map[teleport.Role]*Connector + // registeredListeners keeps track of all listeners created by the process // used to pass listeners to child processes during live reload registeredListeners []RegisteredListener @@ -160,6 +204,25 @@ type TeleportProcess struct { // during restart used to collect their status in case if the // child process crashed. forkedPIDs []int + + // storage is a server local storage + storage *auth.ProcessStorage + + // id is a process id - used to identify different processes + // during in-process reloads. + id string + + // Entry is a process-local log entry. + *logrus.Entry +} + +// processIndex is an internal process index +// to help differentiate between two different teleport processes +// during in-process reload. +var processID int32 = 0 + +func nextProcessID() int32 { + return atomic.AddInt32(&processID, 1) } // GetAuthServer returns the process' auth server @@ -196,22 +259,25 @@ func (process *TeleportProcess) findStaticIdentity(id auth.IdentityID) (*auth.Id return nil, trace.NotFound("identity %v not found", &id) } -// readIdentity reads identity from disk and resets the local state -func (process *TeleportProcess) readIdentity(role teleport.Role) (*auth.Identity, error) { +// getConnectors returns a copy of the identities registered for auth server +func (process *TeleportProcess) getConnectors() []*Connector { process.Lock() defer process.Unlock() - id := auth.IdentityID{ - Role: role, - HostUUID: process.Config.HostUUID, - NodeName: process.Config.Hostname, - } - identity, err := auth.ReadIdentity(process.Config.DataDir, id) - if err != nil { - return nil, trace.Wrap(err) + out := make([]*Connector, 0, len(process.connectors)) + for role := range process.connectors { + out = append(out, process.connectors[role]) } - process.Identities[role] = identity - return identity, nil + return out +} + +// addConnector adds connector to registered connectors list, +// it will overwrite the connector for the same role +func (process *TeleportProcess) addConnector(connector *Connector) { + process.Lock() + defer process.Unlock() + + process.connectors[connector.ClientIdentity.ID.Role] = connector } // GetIdentity returns the process identity (credentials to the auth server) for a given @@ -227,67 +293,180 @@ func (process *TeleportProcess) GetIdentity(role teleport.Role) (i *auth.Identit if found { return i, nil } - + i, err = process.storage.ReadIdentity(auth.IdentityCurrent, role) id := auth.IdentityID{ Role: role, HostUUID: process.Config.HostUUID, NodeName: process.Config.Hostname, } - i, err = auth.ReadIdentity(process.Config.DataDir, id) if err != nil { - if trace.IsNotFound(err) { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + if role == teleport.RoleAdmin { + // for admin identity use local auth server + // because admin identity is requested by auth server + // itself + principals, err := process.getAdditionalPrincipals(role) + if err != nil { + return nil, trace.Wrap(err) + } + i, err = auth.GenerateIdentity(process.localAuth, id, principals) + } else { // try to locate static identity provided in the file i, err = process.findStaticIdentity(id) if err != nil { return nil, trace.Wrap(err) } - log.Infof("Found static identity %v in the config file, writing to disk.", &id) - if err = auth.WriteIdentity(process.Config.DataDir, i); err != nil { + process.Infof("Found static identity %v in the config file, writing to disk.", &id) + if err = process.storage.WriteIdentity(auth.IdentityCurrent, *i); err != nil { return nil, trace.Wrap(err) } - } else { - return nil, trace.Wrap(err) } } process.Identities[role] = i return i, nil } -// connectToAuthService attempts to login into the auth servers specified in the -// configuration. Returns 'true' if successful -func (process *TeleportProcess) connectToAuthService(role teleport.Role, additionalPrincipals []string) (*Connector, error) { - identity, err := process.GetIdentity(role) - if err != nil { - return nil, trace.Wrap(err) +// Process is a interface for processes +type Process interface { + // Closer closes all resources used by the process + io.Closer + // Start starts the process in a non-blocking way + Start() error + // WaitForSignals waits for and handles system process signals. + WaitForSignals(context.Context) error + // ExportFileDescriptors exports service listeners + // file descriptors used by the process. + ExportFileDescriptors() ([]FileDescriptor, error) + // Shutdown starts graceful shutdown of the process, + // blocks until all resources are freed and go-routines are + // shut down. + Shutdown(context.Context) + // WaitForEvent waits for event to occur, sends event to the channel, + // this is a non-blocking function. + WaitForEvent(ctx context.Context, name string, eventC chan Event) + // WaitWithContext waits for the service to stop. This is a blocking + // function. + WaitWithContext(ctx context.Context) +} + +// NewProcess is a function that creates new teleport from config +type NewProcess func(cfg *Config) (Process, error) + +func newTeleportProcess(cfg *Config) (Process, error) { + return NewTeleport(cfg) +} + +// Run starts teleport processes, waits for signals +// and handles internal process reloads. +func Run(ctx context.Context, cfg Config, newTeleport NewProcess) error { + if newTeleport == nil { + newTeleport = newTeleportProcess } - tlsConfig, err := identity.TLSConfig() + copyCfg := cfg + srv, err := newTeleport(©Cfg) if err != nil { + return trace.Wrap(err, "initialization failed") + } + if srv == nil { + return trace.BadParameter("process has returned nil server") + } + if err := srv.Start(); err != nil { + return trace.Wrap(err, "startup failed") + } + // Wait and reload until called exit. + for { + srv, err = waitAndReload(ctx, cfg, srv, newTeleport) + if err != nil { + // This error means that was a clean shutdown + // an no reload is necessary. + if err == ErrTeleportExited { + return nil + } + return trace.Wrap(err) + } + } +} + +func waitAndReload(ctx context.Context, cfg Config, srv Process, newTeleport NewProcess) (Process, error) { + err := srv.WaitForSignals(ctx) + if err == nil { + return nil, ErrTeleportExited + } + if err != ErrTeleportReloading { return nil, trace.Wrap(err) } - log.Infof("Connecting to the cluster %v with TLS client certificate.", identity.ClusterName) - client, err := auth.NewTLSClient(process.Config.AuthServers, tlsConfig) + log.Infof("Started in-process service reload.") + fileDescriptors, err := srv.ExportFileDescriptors() if err != nil { + warnOnErr(srv.Close()) return nil, trace.Wrap(err) } - if len(additionalPrincipals) != 0 && !identity.HasPrincipals(additionalPrincipals) { - log.Infof("Identity %v needs principals %v, going to re-register.", identity.ID, additionalPrincipals) - if err := auth.ReRegister(process.Config.DataDir, client, identity.ID, additionalPrincipals); err != nil { - return nil, trace.Wrap(err) - } - if identity, err = process.readIdentity(role); err != nil { - return nil, trace.Wrap(err) - } - tlsConfig, err = identity.TLSConfig() - if err != nil { - return nil, trace.Wrap(err) + newCfg := cfg + newCfg.FileDescriptors = fileDescriptors + newSrv, err := newTeleport(&newCfg) + if err != nil { + warnOnErr(srv.Close()) + return nil, trace.Wrap(err, "failed to create a new service") + } + log.Infof("Created new process.") + if err := newSrv.Start(); err != nil { + warnOnErr(srv.Close()) + return nil, trace.Wrap(err, "failed to start a new service") + } + // Wait for the new server to report that it has started + // before shutting down the old one. + startTimeoutCtx, startCancel := context.WithTimeout(ctx, signalPipeTimeout) + defer startCancel() + eventC := make(chan Event, 1) + newSrv.WaitForEvent(startTimeoutCtx, TeleportReadyEvent, eventC) + select { + case <-eventC: + log.Infof("New service has started successfully.") + case <-startTimeoutCtx.Done(): + warnOnErr(newSrv.Close()) + warnOnErr(srv.Close()) + return nil, trace.BadParameter("the new service has failed to start") + } + shutdownTimeout := cfg.ShutdownTimeout + if shutdownTimeout == 0 { + // The default shutdown timeout is very generous to avoid disrupting + // longer running connections. + shutdownTimeout = defaults.DefaultIdleConnectionDuration + } + log.Infof("Shutting down the old service with timeout %v.", shutdownTimeout) + // After the new process has started, initiate the graceful shutdown of the old process + // new process could have generated connections to the new process's server + // so not all connections can be kept forever. + timeoutCtx, cancel := context.WithTimeout(ctx, shutdownTimeout) + defer cancel() + srv.Shutdown(timeoutCtx) + if timeoutCtx.Err() == context.DeadlineExceeded { + // The new serivce can start initiating connections to the old service + // keeping it from shutting down gracefully, or some external + // connections can keep hanging the old auth service and prevent + // the services from shutting down, so abort the graceful way + // after some time to keep going. + log.Infof("Some connections to the old service were aborted after timeout of %v.", shutdownTimeout) + // Make sure that all parts of the service have exited, this function + // can not allow execution to continue if the shutdown is not complete, + // otherwise subsequent Run executions will hold system resources in case + // if old versions of the service are not exiting completely. + timeoutCtx, cancel := context.WithTimeout(ctx, shutdownTimeout) + defer cancel() + srv.WaitWithContext(timeoutCtx) + if timeoutCtx.Err() == context.DeadlineExceeded { + return nil, trace.BadParameter("the old service has failed to exit.") } + } else { + log.Infof("The old service was successfully shut down gracefully.") } - // success ? we're logged in! - return &Connector{Client: client, Identity: identity}, nil + return newSrv, nil } // NewTeleport takes the daemon configuration, instantiates all required services -// and starts them under a supervisor, returning the supervisor object +// and starts them under a supervisor, returning the supervisor object. func NewTeleport(cfg *Config) (*TeleportProcess, error) { // before we do anything reset the SIGINT handler back to the default system.ResetInterruptSignalHandler() @@ -305,9 +484,11 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } } - importedDescriptors, err := importFileDescriptors() - if err != nil { - return nil, trace.Wrap(err) + if len(cfg.FileDescriptors) == 0 { + cfg.FileDescriptors, err = importFileDescriptors() + if err != nil { + return nil, trace.Wrap(err) + } } // if there's no host uuid initialized yet, try to read one from the @@ -345,14 +526,27 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } } + storage, err := auth.NewProcessStorage(filepath.Join(cfg.DataDir, teleport.ComponentProcess)) + if err != nil { + return nil, trace.Wrap(err) + } + + processID := fmt.Sprintf("%v", nextProcessID()) process := &TeleportProcess{ Clock: clockwork.NewRealClock(), - Supervisor: NewSupervisor(), + Supervisor: NewSupervisor(processID), Config: cfg, Identities: make(map[teleport.Role]*auth.Identity), - importedDescriptors: importedDescriptors, + connectors: make(map[teleport.Role]*Connector), + importedDescriptors: cfg.FileDescriptors, + storage: storage, + id: processID, } + process.Entry = logrus.WithFields(logrus.Fields{ + trace.Component: teleport.Component(teleport.ComponentProcess, process.id), + }) + serviceStarted := false if !cfg.DiagnosticAddr.IsEmpty() { @@ -369,6 +563,22 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { cfg.Keygen = native.New() } + // Produce global TeleportReadyEvent + // when all components have started + eventMapping := EventMapping{ + Out: TeleportReadyEvent, + } + if cfg.Auth.Enabled { + eventMapping.In = append(eventMapping.In, AuthTLSReady) + } + if cfg.SSH.Enabled { + eventMapping.In = append(eventMapping.In, NodeSSHReady) + } + if cfg.Proxy.Enabled { + eventMapping.In = append(eventMapping.In, ProxySSHReady) + } + process.RegisterEventMapping(eventMapping) + if cfg.Auth.Enabled { if err := process.initAuthService(); err != nil { return nil, trace.Wrap(err) @@ -388,6 +598,7 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } if cfg.Proxy.Enabled { + eventMapping.In = append(eventMapping.In, ProxySSHReady) if err := process.initProxy(); err != nil { return nil, err } @@ -396,6 +607,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { warnOnErr(process.closeImportedDescriptors(teleport.ComponentProxy)) } + process.RegisterFunc("common.rotate", process.periodicSyncRotationState) + if !serviceStarted { return nil, trace.BadParameter("all services failed to start") } @@ -410,14 +623,34 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { defer f.Close() } + // notify parent process that this process has started + go process.notifyParent() + + return process, nil +} + +// notifyParent notifies parent process that this process has started +// by writing to in-memory pipe used by communication channel. +func (process *TeleportProcess) notifyParent() { + ctx, cancel := context.WithTimeout(process.ExitContext(), signalPipeTimeout) + defer cancel() + + eventC := make(chan Event, 1) + process.WaitForEvent(ctx, TeleportReadyEvent, eventC) + select { + case <-eventC: + process.Infof("New service has started successfully.") + case <-ctx.Done(): + process.Errorf("Timeout waiting for process to start: %v", ctx.Err()) + return + } + if err := process.writeToSignalPipe(fmt.Sprintf("Process %v has started.", os.Getpid())); err != nil { - log.Warningf("Failed to write to signal pipe: %v", err) + process.Warningf("Failed to write to signal pipe: %v", err) // despite the failure, it's ok to proceed, // it could mean that the parent process has crashed and the pipe // is no longer valid. } - - return process, nil } func (process *TeleportProcess) setLocalAuth(a *auth.AuthServer) { @@ -527,7 +760,7 @@ func (process *TeleportProcess) initAuthService() error { warningMessage := "Warning: Teleport audit and session recording have been " + "turned off. This is dangerous, you will not be able to view audit events " + "or save and playback recorded sessions." - log.Warn(warningMessage) + process.Warn(warningMessage) } else { // check if session recording has been disabled. note, we will continue // logging audit events, we just won't record sessions. @@ -537,7 +770,7 @@ func (process *TeleportProcess) initAuthService() error { warningMessage := "Warning: Teleport session recording have been turned off. " + "This is dangerous, you will not be able to save and playback sessions." - log.Warn(warningMessage) + process.Warn(warningMessage) } auditConfig := cfg.Auth.ClusterConfig.GetAuditConfig() @@ -573,7 +806,7 @@ func (process *TeleportProcess) initAuthService() error { } // first, create the AuthServer - authServer, identity, err := auth.Init(auth.InitConfig{ + authServer, err := auth.Init(auth.InitConfig{ Backend: b, Authority: cfg.Keygen, ClusterConfiguration: cfg.ClusterConfiguration, @@ -602,6 +835,11 @@ func (process *TeleportProcess) initAuthService() error { process.setLocalAuth(authServer) + connector, err := process.connectToAuthService(teleport.RoleAdmin) + if err != nil { + return trace.Wrap(err) + } + // second, create the API Server: it's actually a collection of API servers, // each serving requests for a "role" which is assigned to every connected // client based on their certificate (user, server, admin, etc) @@ -632,11 +870,11 @@ func (process *TeleportProcess) initAuthService() error { } log := logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentAuth, + trace.Component: teleport.Component(teleport.ComponentAuth, process.id), }) // Register TLS endpoint of the auth service - tlsConfig, err := identity.TLSConfig() + tlsConfig, err := connector.ServerIdentity.TLSConfig() if err != nil { return trace.Wrap(err) } @@ -645,11 +883,11 @@ func (process *TeleportProcess) initAuthService() error { APIConfig: *apiConf, LimiterConfig: cfg.Auth.Limiter, AccessPoint: adminAccessPoint, + Component: teleport.Component(teleport.ComponentAuth, process.id), }) if err != nil { return trace.Wrap(err) } - // auth server listens on SSH and TLS, reusing the same socket listener, err := process.importOrCreateListener(teleport.ComponentAuth, cfg.Auth.SSHAddr.Addr) if err != nil { @@ -664,31 +902,29 @@ func (process *TeleportProcess) initAuthService() error { mux, err := multiplexer.New(multiplexer.Config{ EnableProxyProtocol: cfg.Auth.EnableProxyProtocol, Listener: listener, + ID: teleport.Component(process.id), }) if err != nil { listener.Close() return trace.Wrap(err) } go mux.Serve() - process.RegisterFunc("auth.tls", func() error { utils.Consolef(cfg.Console, teleport.ComponentAuth, "Auth service is starting on %v.", cfg.Auth.SSHAddr.Addr) // since tlsServer.Serve is a blocking call, we emit this even right before // the service has started process.BroadcastEvent(Event{Name: AuthTLSReady, Payload: nil}) - err := tlsServer.Serve(mux.TLS()) if err != nil && err != http.ErrServerClosed { log.Warningf("TLS server exited with error: %v.", err) } return nil }) - process.RegisterFunc("auth.heartbeat.broadcast", func() error { // Heart beat auth server presence, this is not the best place for this // logic, consolidate it into auth package later - connector, err := process.connectToAuthService(teleport.RoleAdmin, nil) + connector, err := process.connectToAuthService(teleport.RoleAdmin) if err != nil { return trace.Wrap(err) } @@ -699,9 +935,6 @@ func (process *TeleportProcess) initAuthService() error { }) return nil }) - - closeContext, signalClose := context.WithCancel(context.TODO()) - process.RegisterFunc("auth.heartbeat", func() error { srv := services.ServerV2{ Kind: services.KindAuthServer, @@ -740,13 +973,21 @@ func (process *TeleportProcess) initAuthService() error { defer ticker.Stop() announce: for { + state, err := process.storage.GetState(teleport.RoleAdmin) + if err != nil { + if !trace.IsNotFound(err) { + log.Warningf("Failed to get rotation state: %v.", err) + } + } else { + srv.Spec.Rotation = state.Spec.Rotation + } srv.SetTTL(process, defaults.ServerHeartbeatTTL) - err := authServer.UpsertAuthServer(&srv) + err = authServer.UpsertAuthServer(&srv) if err != nil { log.Warningf("Failed to announce presence: %v.", err) } select { - case <-closeContext.Done(): + case <-process.ExitContext().Done(): break announce case <-ticker.C: } @@ -754,17 +995,18 @@ func (process *TeleportProcess) initAuthService() error { log.Infof("Heartbeat to other auth servers exited.") return nil }) - // execute this when process is asked to exit: process.onExit("auth.shutdown", func(payload interface{}) { - // as a last resort, at least close listeners (e.g. panic) - if listener != nil { - defer listener.Close() - } + // The listeners have to be closed here, because if shutdown + // was called before the start of the http server, + // the http server would have not started tracking the listeners + // and http.Shutdown will do nothing. if mux != nil { - defer mux.Close() + warnOnErr(mux.Close()) + } + if listener != nil { + warnOnErr(listener.Close()) } - signalClose() if payload == nil { log.Info("Shutting down immediately.") warnOnErr(tlsServer.Close()) @@ -793,7 +1035,7 @@ func payloadContext(payload interface{}) context.Context { func (process *TeleportProcess) onExit(serviceName string, callback func(interface{})) { process.RegisterFunc(serviceName, func() error { eventC := make(chan Event) - process.WaitForEvent(TeleportExitEvent, eventC, make(chan struct{})) + process.WaitForEvent(context.TODO(), TeleportExitEvent, eventC) select { case event := <-eventC: callback(event.Payload) @@ -825,22 +1067,36 @@ func (process *TeleportProcess) newLocalCache(clt auth.ClientI, cacheName []stri }) } +func (process *TeleportProcess) getRotation(role teleport.Role) (*services.Rotation, error) { + state, err := process.storage.GetState(role) + if err != nil { + return nil, trace.Wrap(err) + } + return &state.Spec.Rotation, nil +} + // initSSH initializes the "node" role, i.e. a simple SSH server connected to the auth server. func (process *TeleportProcess) initSSH() error { - process.RegisterWithAuthServer( - process.Config.Token, teleport.RoleNode, SSHIdentityEvent, nil) + process.registerWithAuthServer(teleport.RoleNode, SSHIdentityEvent) eventsC := make(chan Event) - process.WaitForEvent(SSHIdentityEvent, eventsC, make(chan struct{})) + process.WaitForEvent(process.ExitContext(), SSHIdentityEvent, eventsC) var s *regular.Server log := logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentNode, + trace.Component: teleport.Component(teleport.ComponentNode, process.id), }) process.RegisterFunc("ssh.node", func() error { - event := <-eventsC - log.Infof("Received event %q.", event.Name) + var event Event + select { + case event = <-eventsC: + log.Debugf("Received event %q.", event.Name) + case <-process.ExitContext().Done(): + log.Debugf("Process is exiting.") + return nil + } + conn, ok := (event.Payload).(*Connector) if !ok { return trace.BadParameter("unsupported connector type: %T", event.Payload) @@ -878,7 +1134,7 @@ func (process *TeleportProcess) initSSH() error { s, err = regular.New(cfg.SSH.Addr, cfg.Hostname, - []ssh.Signer{conn.Identity.KeySigner}, + []ssh.Signer{conn.ServerIdentity.KeySigner}, authClient, cfg.DataDir, cfg.AdvertiseIP, @@ -894,6 +1150,7 @@ func (process *TeleportProcess) initSSH() error { regular.SetKEXAlgorithms(cfg.KEXAlgorithms), regular.SetMACAlgorithms(cfg.MACAlgorithms), regular.SetPAMConfig(cfg.SSH.PAM), + regular.SetRotationGetter(process.getRotation), ) if err != nil { return trace.Wrap(err) @@ -938,55 +1195,35 @@ func (process *TeleportProcess) initSSH() error { return nil } -// RegisterWithAuthServer uses one time provisioning token obtained earlier +// registerWithAuthServer uses one time provisioning token obtained earlier // from the server to get a pair of SSH keys signed by Auth server host // certificate authority -func (process *TeleportProcess) RegisterWithAuthServer(token string, role teleport.Role, eventName string, additionalPrincipals []string) { - cfg := process.Config - identityID := auth.IdentityID{Role: role, HostUUID: cfg.HostUUID, NodeName: cfg.Hostname} - - // this means the server has not been initialized yet, we are starting - // the registering client that attempts to connect to the auth server - // and provision the keys +func (process *TeleportProcess) registerWithAuthServer(role teleport.Role, eventName string) { var authClient *auth.Client process.RegisterFunc(fmt.Sprintf("register.%v", strings.ToLower(role.String())), func() error { retryTime := defaults.ServerHeartbeatTTL / 3 for { - connector, err := process.connectToAuthService(role, additionalPrincipals) + connector, err := process.connectToAuthService(role) if err == nil { process.BroadcastEvent(Event{Name: eventName, Payload: connector}) authClient = connector.Client return nil } + // in between attempts, check if teleport is shutting down + select { + case <-process.ExitContext().Done(): + process.Infof("%v stopping connection attempts, teleport is shutting down.", role) + return ErrTeleportExited + default: + } if trace.IsConnectionProblem(err) { - log.Infof("%v failed attempt connecting to auth server: %v", role, err) + process.Infof("%v failed attempt connecting to auth server: %v.", role, err) time.Sleep(retryTime) continue } if !trace.IsNotFound(err) { return trace.Wrap(err) } - // we haven't connected yet, so we expect the token to exist - if process.getLocalAuth() != nil { - // Auth service is on the same host, no need to go though the invitation - // procedure - log.Debugf("This server has local Auth server started, using it to add role to the cluster.") - err = auth.LocalRegister(cfg.DataDir, identityID, process.getLocalAuth(), additionalPrincipals) - } else { - // Auth server is remote, so we need a provisioning token - if token == "" { - return trace.BadParameter("%v must join a cluster and needs a provisioning token", role) - } - log.Infof("Joining the cluster with a token %v.", token) - err = auth.Register(cfg.DataDir, token, identityID, cfg.AuthServers, additionalPrincipals) - } - if err != nil { - log.Errorf("Failed to join the cluster: %v.", err) - time.Sleep(retryTime) - } else { - log.Infof("%v has successfully registered with the cluster.", role) - continue - } } }) @@ -999,7 +1236,7 @@ func (process *TeleportProcess) RegisterWithAuthServer(token string, role telepo func (process *TeleportProcess) initUploaderService(accessPoint auth.AccessPoint, auditLog events.IAuditLog) error { log := logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentAuditLog, + trace.Component: teleport.Component(teleport.ComponentAuditLog, process.id), }) // create folder for uploads uid, gid, err := adminCreds() @@ -1084,7 +1321,7 @@ func (process *TeleportProcess) initDiagnosticService() error { } log := logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentDiagnostic, + trace.Component: teleport.Component(teleport.ComponentDiagnostic, process.id), }) log.Infof("Starting diagnostic service on %v.", process.Config.DiagnosticAddr.Addr) @@ -1111,6 +1348,23 @@ func (process *TeleportProcess) initDiagnosticService() error { return nil } +// getAdditionalPrincipals returns a list of additional principals to add +// to role's service certificate. +func (process *TeleportProcess) getAdditionalPrincipals(role teleport.Role) ([]string, error) { + var principals []string + if process.Config.Hostname != "" { + principals = append(principals, process.Config.Hostname) + } + if process.Config.Proxy.PublicAddr.Addr != "" { + host, err := utils.Host(process.Config.Proxy.PublicAddr.Addr) + if err != nil { + return nil, trace.Wrap(err) + } + principals = append(principals, host) + } + return principals, nil +} + // initProxy gets called if teleport runs with 'proxy' role enabled. // this means it will do two things: // 1. serve a web UI @@ -1124,23 +1378,20 @@ func (process *TeleportProcess) initProxy() error { return trace.Wrap(err) } } - - var additionalPrincipals []string - if process.Config.Proxy.PublicAddr.Addr != "" { - host, err := utils.Host(process.Config.Proxy.PublicAddr.Addr) - if err != nil { - return trace.Wrap(err) - } - additionalPrincipals = []string{host} - } - - process.RegisterWithAuthServer(process.Config.Token, teleport.RoleProxy, ProxyIdentityEvent, additionalPrincipals) + process.registerWithAuthServer(teleport.RoleProxy, ProxyIdentityEvent) process.RegisterFunc("proxy.init", func() error { eventsC := make(chan Event) - process.WaitForEvent(ProxyIdentityEvent, eventsC, make(chan struct{})) + process.WaitForEvent(process.ExitContext(), ProxyIdentityEvent, eventsC) + + var event Event + select { + case event = <-eventsC: + process.Debugf("Received event %q.", event.Name) + case <-process.ExitContext().Done(): + process.Debugf("Process is exiting.") + return nil + } - event := <-eventsC - log.Debugf("Received event %q.", event.Name) conn, ok := (event.Payload).(*Connector) if !ok { return trace.BadParameter("unsupported connector type: %T", event.Payload) @@ -1177,15 +1428,15 @@ func (l *proxyListeners) Close() { // setupProxyListeners sets up web proxy listeners based on the configuration func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { cfg := process.Config - log.Debugf("Setup Proxy: Web Proxy Address: %v, Reverse Tunnel Proxy Address: %v", cfg.Proxy.WebAddr.Addr, cfg.Proxy.ReverseTunnelListenAddr.Addr) + process.Debugf("Setup Proxy: Web Proxy Address: %v, Reverse Tunnel Proxy Address: %v", cfg.Proxy.WebAddr.Addr, cfg.Proxy.ReverseTunnelListenAddr.Addr) var err error var listeners proxyListeners switch { case cfg.Proxy.DisableWebService && cfg.Proxy.DisableReverseTunnel: - log.Debugf("Setup Proxy: Reverse tunnel proxy and web proxy are disabled.") + process.Debugf("Setup Proxy: Reverse tunnel proxy and web proxy are disabled.") return &listeners, nil case cfg.Proxy.ReverseTunnelListenAddr.Equals(cfg.Proxy.WebAddr) && !cfg.Proxy.DisableTLS: - log.Debugf("Setup Proxy: Reverse tunnel proxy and web proxy listen on the same port, multiplexing is on.") + process.Debugf("Setup Proxy: Reverse tunnel proxy and web proxy listen on the same port, multiplexing is on.") listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel", "web"), cfg.Proxy.WebAddr.Addr) if err != nil { return nil, trace.Wrap(err) @@ -1195,6 +1446,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { Listener: listener, DisableTLS: cfg.Proxy.DisableWebService, DisableSSH: cfg.Proxy.DisableReverseTunnel, + ID: teleport.Component(teleport.ComponentProxy, "tunnel", "web", process.id), }) if err != nil { listener.Close() @@ -1205,7 +1457,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { go listeners.mux.Serve() return &listeners, nil case cfg.Proxy.EnableProxyProtocol && !cfg.Proxy.DisableWebService && !cfg.Proxy.DisableTLS: - log.Debugf("Setup Proxy: Proxy protocol is enabled for web service, multiplexing is on.") + process.Debugf("Setup Proxy: Proxy protocol is enabled for web service, multiplexing is on.") listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "web"), cfg.Proxy.WebAddr.Addr) if err != nil { return nil, trace.Wrap(err) @@ -1215,6 +1467,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { Listener: listener, DisableTLS: false, DisableSSH: true, + ID: teleport.Component(teleport.ComponentProxy, "web", process.id), }) if err != nil { listener.Close() @@ -1230,7 +1483,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { go listeners.mux.Serve() return &listeners, nil default: - log.Debugf("Proxy reverse tunnel are listening on the separate ports") + process.Debugf("Proxy reverse tunnel are listening on the separate ports.") if !cfg.Proxy.DisableReverseTunnel { listeners.reverseTunnel, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel"), cfg.Proxy.ReverseTunnelListenAddr.Addr) if err != nil { @@ -1271,7 +1524,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } - tlsConfig, err := conn.Identity.TLSConfig() + clientTLSConfig, err := conn.ClientIdentity.TLSConfig() if err != nil { return trace.Wrap(err) } @@ -1283,18 +1536,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // Register reverse tunnel agents pool agentPool, err := reversetunnel.NewAgentPool(reversetunnel.AgentPoolConfig{ - HostUUID: conn.Identity.ID.HostUUID, + HostUUID: conn.ServerIdentity.ID.HostUUID, Client: conn.Client, AccessPoint: accessPoint, - HostSigners: []ssh.Signer{conn.Identity.KeySigner}, - Cluster: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], + HostSigners: []ssh.Signer{conn.ServerIdentity.KeySigner}, + Cluster: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority], }) if err != nil { return trace.Wrap(err) } log := logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentReverseTunnelServer, + trace.Component: teleport.Component(teleport.ComponentReverseTunnelServer, process.id), }) // register SSH reverse tunnel server that accepts connections @@ -1304,17 +1557,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { tsrv, err = reversetunnel.NewServer( reversetunnel.Config{ ID: process.Config.HostUUID, - ClusterName: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], - ClientTLS: tlsConfig, + ClusterName: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority], + ClientTLS: clientTLSConfig, Listener: listeners.reverseTunnel, - HostSigners: []ssh.Signer{conn.Identity.KeySigner}, + HostSigners: []ssh.Signer{conn.ServerIdentity.KeySigner}, LocalAuthClient: conn.Client, LocalAccessPoint: accessPoint, NewCachingAccessPoint: process.newLocalCache, Limiter: reverseTunnelLimiter, DirectClusters: []reversetunnel.DirectCluster{ { - Name: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], + Name: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority], Client: conn.Client, }, }, @@ -1323,6 +1576,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { KEXAlgorithms: cfg.KEXAlgorithms, MACAlgorithms: cfg.MACAlgorithms, DataDir: process.Config.DataDir, + PollingPeriod: process.Config.PollingPeriod, }) if err != nil { return trace.Wrap(err) @@ -1392,7 +1646,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } sshProxy, err := regular.New(cfg.Proxy.SSHAddr, cfg.Hostname, - []ssh.Signer{conn.Identity.KeySigner}, + []ssh.Signer{conn.ServerIdentity.KeySigner}, accessPoint, cfg.DataDir, nil, @@ -1405,6 +1659,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { regular.SetKEXAlgorithms(cfg.KEXAlgorithms), regular.SetMACAlgorithms(cfg.MACAlgorithms), regular.SetNamespace(defaults.Namespace), + regular.SetRotationGetter(process.getRotation), ) if err != nil { return trace.Wrap(err) @@ -1421,7 +1676,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { process.RegisterFunc("proxy.reversetunnel.agent", func() error { log := logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentReverseTunnelAgent, + trace.Component: teleport.Component(teleport.ComponentReverseTunnelAgent, process.id), }) log.Infof("Starting reverse tunnel agent pool.") if err := agentPool.Start(); err != nil { @@ -1466,7 +1721,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { func warnOnErr(err error) { if err != nil { - log.Errorf("Error while performing operation: %v", err) + // don't warn on double close, happens sometimes when closing + // calling accept on a closed listener + if strings.Contains(err.Error(), "use of closed network connection") { + return + } + log.Warningf("Got error while cleaning up: %v.", err) } } @@ -1496,6 +1756,19 @@ func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error return bk, nil } +// WaitWithContext waits until all internal services stop. +func (process *TeleportProcess) WaitWithContext(ctx context.Context) { + local, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + process.Supervisor.Wait() + }() + select { + case <-local.Done(): + return + } +} + // StartShutdown launches non-blocking graceful shutdown process that signals // completion, returns context that will be closed once the shutdown is done func (process *TeleportProcess) StartShutdown(ctx context.Context) context.Context { @@ -1504,11 +1777,11 @@ func (process *TeleportProcess) StartShutdown(ctx context.Context) context.Conte go func() { defer cancel() process.Supervisor.Wait() - log.Debugf("All supervisor functions are completed.") + process.Debugf("All supervisor functions are completed.") localAuth := process.getLocalAuth() if localAuth != nil { if err := process.localAuth.Close(); err != nil { - log.Warningf("Failed closing auth server: %v", trace.DebugReport(err)) + process.Warningf("Failed closing auth server: %v.", err) } } }() @@ -1522,7 +1795,7 @@ func (process *TeleportProcess) Shutdown(ctx context.Context) { // wait until parent context closes select { case <-localCtx.Done(): - log.Debugf("Process completed.") + process.Debugf("Process completed.") } } @@ -1532,12 +1805,17 @@ func (process *TeleportProcess) Close() error { process.Config.Keygen.Close() + var errors []error localAuth := process.getLocalAuth() if localAuth != nil { - return trace.Wrap(process.localAuth.Close()) + errors = append(errors, process.localAuth.Close()) } - return nil + if process.storage != nil { + errors = append(errors, process.storage.Close()) + } + + return trace.NewAggregate(errors...) } func validateConfig(cfg *Config) error { @@ -1572,6 +1850,10 @@ func validateConfig(cfg *Config) error { } } + if cfg.PollingPeriod == 0 { + cfg.PollingPeriod = defaults.HighResPollingPeriod + } + cfg.SSH.Namespace = services.ProcessNamespace(cfg.SSH.Namespace) return nil diff --git a/lib/service/signals.go b/lib/service/signals.go index bbcc23590ef8e..c8f6ed2fc6b5b 100644 --- a/lib/service/signals.go +++ b/lib/service/signals.go @@ -75,10 +75,10 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { case syscall.SIGQUIT: go process.printShutdownStatus(doneContext) process.Shutdown(ctx) - log.Infof("All services stopped, exiting.") + process.Infof("All services stopped, exiting.") return nil case syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT: - log.Infof("Got signal %q, exiting immediately.", signal) + process.Infof("Got signal %q, exiting immediately.", signal) process.Close() return nil case syscall.SIGUSR1: @@ -91,30 +91,30 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { // That was not quite enough. With pipelines diagnostics could come from any of several programs running simultaneously. // Diagnostics needed to identify themselves. // - Doug McIllroy, "A Research UNIX Reader: Annotated Excerpts from the Programmer’s Manual, 1971-1986" - log.Infof("Got signal %q, logging diagostic info to stderr.", signal) + process.Infof("Got signal %q, logging diagostic info to stderr.", signal) writeDebugInfo(os.Stderr) case syscall.SIGUSR2: if !process.backendSupportsForks() { - log.Warningf("Process is using backend that does not support multiple processes, switch to another backend to use USR2.") + process.Warningf("Process is using backend that does not support multiple processes, switch to another backend to use USR2.") continue } log.Infof("Got signal %q, forking a new process.", signal) if err := process.forkChild(); err != nil { - log.Warningf("Failed to fork: %v", err) + process.Warningf("Failed to fork: %v", err) } else { - log.Infof("Successfully started new process.") + process.Infof("Successfully started new process.") } case syscall.SIGHUP: if !process.backendSupportsForks() { - log.Warningf("Process is using backend that does not support multiple processes, switch to another backend to use HUP.") + process.Warningf("Process is using backend that does not support multiple processes, switch to another backend to use HUP.") continue } - log.Infof("Got signal %q, performing graceful restart.", signal) + process.Infof("Got signal %q, performing graceful restart.", signal) if err := process.forkChild(); err != nil { - log.Warningf("Failed to fork: %v", err) + process.Warningf("Failed to fork: %v", err) continue } - log.Infof("Successfully started new process, shutting down gracefully.") + process.Infof("Successfully started new process, shutting down gracefully.") go process.printShutdownStatus(doneContext) process.Shutdown(ctx) log.Infof("All services stopped, exiting.") @@ -122,24 +122,37 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { case syscall.SIGCHLD: process.collectStatuses() default: - log.Infof("Ignoring %q.", signal) + process.Infof("Ignoring %q.", signal) } + case <-process.ReloadContext().Done(): + process.Infof("Exiting signal handler: process has started internal reload.") + return ErrTeleportReloading + case <-process.ExitContext().Done(): + process.Infof("Someone else has closed context, exiting.") + return nil case <-ctx.Done(): process.Close() process.Wait() - log.Info("Got request to shutdown, context is closing") + process.Info("Got request to shutdown, context is closing") return nil } } } +// ErrTeleportReloading is returned when signal waiter exits +// because the teleport process has initiaded shutdown +var ErrTeleportReloading = &trace.CompareFailedError{Message: "teleport process is reloading"} + +// ErrTeleportExited means that teleport has exited +var ErrTeleportExited = &trace.CompareFailedError{Message: "teleport process has shutdown"} + func (process *TeleportProcess) writeToSignalPipe(message string) error { signalPipe, err := process.importSignalPipe() if err != nil { if !trace.IsNotFound(err) { return trace.Wrap(err) } - log.Debugf("No signal pipe to import, must be first Teleport process.") + process.Debugf("No signal pipe to import, must be first Teleport process.") return nil } defer signalPipe.Close() @@ -151,7 +164,7 @@ func (process *TeleportProcess) writeToSignalPipe(message string) error { go func() { _, err := signalPipe.Write([]byte(message)) if err != nil { - log.Debugf("Failed to write to pipe: %v", trace.DebugReport(err)) + process.Debugf("Failed to write to pipe: %v", trace.DebugReport(err)) return } cancel() @@ -161,7 +174,7 @@ func (process *TeleportProcess) writeToSignalPipe(message string) error { case <-time.After(signalPipeTimeout): return trace.BadParameter("Failed to write to parent process pipe") case <-messageSignalled.Done(): - log.Infof("Signalled success to parent process") + process.Infof("Signalled success to parent process") } return nil } @@ -176,7 +189,7 @@ func (process *TeleportProcess) closeImportedDescriptors(prefix string) error { for i := range process.importedDescriptors { d := process.importedDescriptors[i] if strings.HasPrefix(d.Type, prefix) { - log.Infof("Closing imported but unused descriptor %v %v.", d.Type, d.Address) + process.Infof("Closing imported but unused descriptor %v %v.", d.Type, d.Address) errors = append(errors, d.File.Close()) } } @@ -188,13 +201,13 @@ func (process *TeleportProcess) closeImportedDescriptors(prefix string) error { func (process *TeleportProcess) importOrCreateListener(listenerType, address string) (net.Listener, error) { l, err := process.importListener(listenerType, address) if err == nil { - log.Infof("Using file descriptor %v %v passed by the parent process.", listenerType, address) + process.Infof("Using file descriptor %v %v passed by the parent process.", listenerType, address) return l, nil } if !trace.IsNotFound(err) { return nil, trace.Wrap(err) } - log.Infof("Service %v is creating new listener on %v.", listenerType, address) + process.Infof("Service %v is creating new listener on %v.", listenerType, address) return process.createListener(listenerType, address) } @@ -248,8 +261,8 @@ func (process *TeleportProcess) createListener(listenerType, address string) (ne return listener, nil } -// exportFileDescriptors exports file descriptors to be passed to child process -func (process *TeleportProcess) exportFileDescriptors() ([]FileDescriptor, error) { +// ExportFileDescriptors exports file descriptors to be passed to child process +func (process *TeleportProcess) ExportFileDescriptors() ([]FileDescriptor, error) { var out []FileDescriptor process.Lock() defer process.Unlock() @@ -377,7 +390,7 @@ const ( // signalPipeTimeout is a time parent process is expecting // the child process to initialize and write back, // or child process is blocked on write to the pipe - signalPipeTimeout = 5 * time.Second + signalPipeTimeout = 2 * time.Minute ) func (process *TeleportProcess) forkChild() error { @@ -402,7 +415,7 @@ func (process *TeleportProcess) forkChild() error { log.Info("Forking child.") - listenerFiles, err := process.exportFileDescriptors() + listenerFiles, err := process.ExportFileDescriptors() if err != nil { return trace.Wrap(err) } @@ -482,12 +495,12 @@ func (process *TeleportProcess) collectStatuses() { var wait syscall.WaitStatus rpid, err := syscall.Wait4(pid, &wait, syscall.WNOHANG, nil) if err != nil { - log.Errorf("Wait call failed: %v.", err) + process.Errorf("Wait call failed: %v.", err) continue } if rpid == pid { process.popForkedPID(pid) - log.Warningf("Forked teleport process %v has exited with status: %v.", pid, wait.ExitStatus()) + process.Warningf("Forked teleport process %v has exited with status: %v.", pid, wait.ExitStatus()) } } } diff --git a/lib/service/supervisor.go b/lib/service/supervisor.go index fe469759d43c6..7bd6cf43a7336 100644 --- a/lib/service/supervisor.go +++ b/lib/service/supervisor.go @@ -18,10 +18,10 @@ package service import ( "context" + "fmt" "sync" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -57,15 +57,58 @@ type Supervisor interface { Services() []string // BroadcastEvent generates event and broadcasts it to all - // interested parties + // subscribed parties. BroadcastEvent(Event) // WaitForEvent waits for event to be broadcasted, if the event - // was already broadcasted, payloadC will receive current event immediately - // CLose 'cancelC' channel to force WaitForEvent to return prematurely - WaitForEvent(name string, eventC chan Event, cancelC chan struct{}) + // was already broadcasted, eventC will receive current event immediately. + WaitForEvent(ctx context.Context, name string, eventC chan Event) + + // RegisterEventMapping registers event mapping - + // when the sequence in the event mapping triggers, the + // outbound event will be generated. + RegisterEventMapping(EventMapping) + + // ExitContext returns context that will be closed when + // TeleportExitEvent is broadcasted. + ExitContext() context.Context + + // ReloadContext returns context that will be closed when + // TeleportReloadEvent is broadcasted. + ReloadContext() context.Context } +// EventMapping maps a sequence of incoming +// events and if triggered, generates an out event. +type EventMapping struct { + // In is the incoming event sequence. + In []string + // Out is the outbound event to generate. + Out string +} + +// String returns user-friendly representation of the mapping. +func (e EventMapping) String() string { + return fmt.Sprintf("EventMapping(in=%v, out=%v)", e.In, e.Out) +} + +func (e EventMapping) matches(currentEvent string, m map[string]Event) bool { + // existing events that have been fired should match + for _, in := range e.In { + if _, ok := m[in]; !ok { + return false + } + } + // current event that is firing should match one of the expected events + for _, in := range e.In { + if currentEvent == in { + return true + } + } + return false +} + +// LocalSupervisor is a Teleport's implementation of the Supervisor interface. type LocalSupervisor struct { state int sync.Mutex @@ -75,14 +118,30 @@ type LocalSupervisor struct { events map[string]Event eventsC chan Event eventWaiters map[string][]*waiter + closeContext context.Context signalClose context.CancelFunc + + // exitContext is closed when someone emits Exit event + exitContext context.Context + signalExit context.CancelFunc + + reloadContext context.Context + signalReload context.CancelFunc + + eventMappings []EventMapping + id string } // NewSupervisor returns new instance of initialized supervisor -func NewSupervisor() Supervisor { +func NewSupervisor(id string) Supervisor { closeContext, cancel := context.WithCancel(context.TODO()) + + exitContext, signalExit := context.WithCancel(context.TODO()) + reloadContext, signalReload := context.WithCancel(context.TODO()) + srv := &LocalSupervisor{ + id: id, services: []Service{}, wg: &sync.WaitGroup{}, events: map[string]Event{}, @@ -90,6 +149,12 @@ func NewSupervisor() Supervisor { eventWaiters: make(map[string][]*waiter), closeContext: closeContext, signalClose: cancel, + + exitContext: exitContext, + signalExit: signalExit, + + reloadContext: reloadContext, + signalReload: signalReload, } go srv.fanOut() return srv @@ -132,7 +197,7 @@ func (s *LocalSupervisor) RegisterFunc(name string, fn ServiceFunc) { // RemoveService removes service from supervisor tracking list func (s *LocalSupervisor) RemoveService(srv Service) error { - l := logrus.WithFields(logrus.Fields{"service": srv.Name(), trace.Component: teleport.ComponentProcess}) + l := logrus.WithFields(logrus.Fields{"service": srv.Name(), trace.Component: teleport.Component(teleport.ComponentProcess, s.id)}) s.Lock() defer s.Unlock() for i, el := range s.services { @@ -151,10 +216,15 @@ func (s *LocalSupervisor) serve(srv Service) { go func() { defer s.wg.Done() defer s.RemoveService(srv) - log.WithFields(logrus.Fields{"service": srv.Name()}).Debugf("Service has started.") + l := log.WithFields(logrus.Fields{"service": srv.Name(), trace.Component: teleport.Component(teleport.ComponentProcess, s.id)}) + l.Debugf("Service has started.") err := srv.Serve() if err != nil { - utils.FatalError(err) + if err == ErrTeleportExited { + l.Infof("Teleport process has shut down.") + } else { + l.Warningf("Teleport process has exited with error: %v", err) + } } }() } @@ -201,22 +271,75 @@ func (s *LocalSupervisor) Run() error { return s.Wait() } +// ExitContext returns context that will be closed when +// TeleportExitEvent is broadcasted. +func (s *LocalSupervisor) ExitContext() context.Context { + return s.exitContext +} + +// ReloadContext returns context that will be closed when +// TeleportReloadEvent is broadcasted. +func (s *LocalSupervisor) ReloadContext() context.Context { + return s.reloadContext +} + +// BroadcastEvent generates event and broadcasts it to all +// subscribed parties. func (s *LocalSupervisor) BroadcastEvent(event Event) { s.Lock() defer s.Unlock() + + switch event.Name { + case TeleportExitEvent: + s.signalExit() + case TeleportReloadEvent: + s.signalReload() + } + s.events[event.Name] = event - log.WithFields(logrus.Fields{"event": event.String()}).Debugf("Broadcasting event.") + log.WithFields(logrus.Fields{"event": event.String(), trace.Component: teleport.Component(teleport.ComponentProcess, s.id)}).Debugf("Broadcasting event.") go func() { - s.eventsC <- event + select { + case s.eventsC <- event: + case <-s.closeContext.Done(): + return + } }() + + for _, m := range s.eventMappings { + if m.matches(event.Name, s.events) { + mappedEvent := Event{Name: m.Out} + s.events[mappedEvent.Name] = mappedEvent + go func(e Event) { + select { + case s.eventsC <- e: + case <-s.closeContext.Done(): + return + } + }(mappedEvent) + log.WithFields(logrus.Fields{"in": event.String(), "out": m.String(), trace.Component: teleport.Component(teleport.ComponentProcess, s.id)}).Debugf("Broadcasting mapped event.") + } + } +} + +// RegisterEventMapping registers event mapping - +// when the sequence in the event mapping triggers, the +// outbound event will be generated. +func (s *LocalSupervisor) RegisterEventMapping(m EventMapping) { + s.Lock() + defer s.Unlock() + + s.eventMappings = append(s.eventMappings, m) } -func (s *LocalSupervisor) WaitForEvent(name string, eventC chan Event, cancelC chan struct{}) { +// WaitForEvent waits for event to be broadcasted, if the event +// was already broadcasted, eventC will receive current event immediately. +func (s *LocalSupervisor) WaitForEvent(ctx context.Context, name string, eventC chan Event) { s.Lock() defer s.Unlock() - waiter := &waiter{eventC: eventC, cancelC: cancelC} + waiter := &waiter{eventC: eventC, context: ctx} event, ok := s.events[name] if ok { go s.notifyWaiter(waiter, event) @@ -240,7 +363,7 @@ func (s *LocalSupervisor) getWaiters(name string) []*waiter { func (s *LocalSupervisor) notifyWaiter(w *waiter, event Event) { select { case w.eventC <- event: - case <-w.cancelC: + case <-w.context.Done(): } } @@ -260,7 +383,7 @@ func (s *LocalSupervisor) fanOut() { type waiter struct { eventC chan Event - cancelC chan struct{} + context context.Context } // Service is a running teleport service function diff --git a/lib/services/authority.go b/lib/services/authority.go index 9defe4fc40a0b..313490415d39c 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -169,6 +169,8 @@ type CertAuthority interface { CheckAndSetDefaults() error // SetSigningKeys sets signing keys SetSigningKeys([][]byte) error + // SetCheckingKeys sets signing keys + SetCheckingKeys([][]byte) error // AddRole adds a role to ca role list AddRole(name string) // Checkers returns public keys that can be used to check cert authorities @@ -187,6 +189,12 @@ type CertAuthority interface { SetTLSKeyPairs(keyPairs []TLSKeyPair) // GetTLSKeyPairs returns first PEM encoded TLS cert GetTLSKeyPairs() []TLSKeyPair + // GetRotation returns rotation state. + GetRotation() Rotation + // SetRotation sets rotation state. + SetRotation(Rotation) + // Clone returns a copy of the cert authority object. + Clone() CertAuthority } // CertPoolFromCertAuthorities returns certificate pools from TLS certificates @@ -296,6 +304,32 @@ type CertAuthorityV2 struct { rawObject interface{} } +// Clone returns a copy of the cert authority object. +func (c *CertAuthorityV2) Clone() CertAuthority { + out := *c + out.rawObject = nil + out.Spec.CheckingKeys = utils.CopyByteSlices(c.Spec.CheckingKeys) + out.Spec.SigningKeys = utils.CopyByteSlices(c.Spec.SigningKeys) + for i, kp := range c.Spec.TLSKeyPairs { + out.Spec.TLSKeyPairs[i] = TLSKeyPair{ + Key: utils.CopyByteSlice(kp.Key), + Cert: utils.CopyByteSlice(kp.Cert), + } + } + out.Spec.Roles = utils.CopyStrings(c.Spec.Roles) + return &out +} + +// GetRotation returns rotation state. +func (c *CertAuthorityV2) GetRotation() Rotation { + return c.Spec.Rotation +} + +// SetRotation sets rotation state. +func (c *CertAuthorityV2) SetRotation(r Rotation) { + c.Spec.Rotation = r +} + // TLSCA returns TLS certificate authority func (c *CertAuthorityV2) TLSCA() (*tlsca.CertAuthority, error) { if len(c.Spec.TLSKeyPairs) == 0 { @@ -375,6 +409,12 @@ func (ca *CertAuthorityV2) SetSigningKeys(keys [][]byte) error { return nil } +// SetCheckingKeys sets SSH public keys +func (ca *CertAuthorityV2) SetCheckingKeys(keys [][]byte) error { + ca.Spec.CheckingKeys = keys + return nil +} + // GetID returns certificate authority ID - // combined type and name func (ca *CertAuthorityV2) GetID() CertAuthID { @@ -397,7 +437,7 @@ func (ca *CertAuthorityV2) GetType() CertAuthType { } // GetClusterName returns cluster name this cert authority -// is associated with +// is associated with. func (ca *CertAuthorityV2) GetClusterName() string { return ca.Spec.ClusterName } @@ -520,11 +560,198 @@ func (ca *CertAuthorityV2) CheckAndSetDefaults() error { return nil } +const ( + // RotationStateStandby is initial status of the rotation - + // nothing is being rotated. + RotationStateStandby = "standby" + // RotationStateInProgress - that rotation is in progress. + RotationStateInProgress = "in_progress" + // RotationPhaseStandby is the initial phase of the rotation + // it means no operations have started. + RotationPhaseStandby = "standby" + // RotationPhaseUpdateClients is a phase of the rotation + // when client credentials will have to be updated and reloaded + // but servers will use and respond with old credentials + // because clients have no idea about new credentials at first. + RotationPhaseUpdateClients = "update_clients" + // RotationPhaseUpdateServers is a phase of the rotation + // when servers will have to reload and should start serving + // TLS and SSH certificates signed by new CA. + RotationPhaseUpdateServers = "update_servers" + // RotationPhaseRollback means that rotation is rolling + // back to the old certificate authority. + RotationPhaseRollback = "rollback" + // RotationModeManual is a manual rotation mode when all phases + // are set by the operator. + RotationModeManual = "manual" + // RotationModeAuto is set to go through all phases by the schedule. + RotationModeAuto = "auto" +) + +// RotatePhases lists all supported rotation phases +var RotatePhases = []string{ + RotationPhaseStandby, + RotationPhaseUpdateClients, + RotationPhaseUpdateServers, + RotationPhaseRollback, +} + +// Rotation is a status of the rotation of the certificate authority +type Rotation struct { + // State could be one of "init" or "in_progress". + State string `json:"state,omitempty"` + // Phase is the current rotation phase. + Phase string `json:"phase,omitempty"` + // Mode sets manual or automatic rotation mode. + Mode string `json:"mode,omitempty"` + // CurrentID is the ID of the rotation operation + // to differentiate between rotation attempts. + CurrentID string `json:"current_id"` + // Started is set to the time when rotation has been started + // in case if the state of the rotation is "in_progress". + Started time.Time `json:"started,omitempty"` + // GracePeriod is a period during which old and new CA + // are valid for checking purposes, but only new CA is issuing certificates. + GracePeriod Duration `json:"grace_period,omitempty"` + // LastRotated specifies the last time of the completed rotation. + LastRotated time.Time `json:"last_rotated,omitempty"` + // Schedule is a rotation schedule - used in + // automatic mode to switch beetween phases. + Schedule RotationSchedule `json:"schedule,omitempty"` +} + +// Matches returns true if this state rotation matches +// external rotation state, phase and rotation ID should match, +// notice that matches does not behave like Equals because it does not require +// all fields to be the same. +func (s *Rotation) Matches(rotation Rotation) bool { + return s.CurrentID == rotation.CurrentID && s.State == rotation.State && s.Phase == rotation.Phase +} + +// LastRotatedDescription returns human friendly description. +func (r *Rotation) LastRotatedDescription() string { + if r.LastRotated.IsZero() { + return "never updated" + } + return fmt.Sprintf("last rotated %v", r.LastRotated.Format(teleport.HumanDateFormatSeconds)) +} + +// PhaseDescription returns human friendly description of a current rotation phase. +func (r *Rotation) PhaseDescription() string { + switch r.Phase { + case RotationPhaseStandby, "": + return "on standby" + case RotationPhaseUpdateClients: + return "rotating clients" + case RotationPhaseUpdateServers: + return "rotating servers" + case RotationPhaseRollback: + return "rolling back" + default: + return fmt.Sprintf("unknown phase: %q", r.Phase) + } +} + +// String returns user friendly information about certificate authority. +func (r *Rotation) String() string { + switch r.State { + case "", RotationStateStandby: + if r.LastRotated.IsZero() { + return "never updated" + } + return fmt.Sprintf("rotated %v", r.LastRotated.Format(teleport.HumanDateFormatSeconds)) + case RotationStateInProgress: + return fmt.Sprintf("%v (mode: %v, started: %v, ending: %v)", + r.PhaseDescription(), + r.Mode, + r.Started.Format(teleport.HumanDateFormatSeconds), + r.Started.Add(r.GracePeriod.Duration).Format(teleport.HumanDateFormatSeconds), + ) + default: + return "unknown" + } +} + +// CheckAndSetDefaults checks and sets default rotation parameters. +func (r *Rotation) CheckAndSetDefaults(clock clockwork.Clock) error { + switch r.Phase { + case "", RotationPhaseRollback, RotationPhaseUpdateClients, RotationPhaseUpdateServers: + default: + return trace.BadParameter("unsupported phase: %q", r.Phase) + } + switch r.Mode { + case "", RotationModeAuto, RotationModeManual: + default: + return trace.BadParameter("unsupported mode: %q", r.Mode) + } + switch r.State { + case "": + r.State = RotationStateStandby + case RotationStateStandby: + case RotationStateInProgress: + if r.CurrentID == "" { + return trace.BadParameter("set 'current_id' parameter for in progress rotation") + } + if r.Started.IsZero() { + return trace.BadParameter("set 'started' parameter for in progress rotation") + } + default: + return trace.BadParameter( + "unsupported rotation 'state': %q, supported states are: %q, %q", + r.State, RotationStateStandby, RotationStateInProgress) + } + return nil +} + +// GenerateSchedule generates schedule based on the time period, using +// even time periods between rotation phases. +func GenerateSchedule(clock clockwork.Clock, gracePeriod time.Duration) (*RotationSchedule, error) { + if gracePeriod <= 0 { + return nil, trace.BadParameter("bad grace period %q, provide value >= 0", gracePeriod) + } + return &RotationSchedule{ + UpdateServers: clock.Now().UTC().Add(gracePeriod / 2).UTC(), + Standby: clock.Now().UTC().Add(gracePeriod).UTC(), + }, nil +} + +// RotationSchedule is a rotation schedule setting time switches +// for different phases. +type RotationSchedule struct { + // UpdateServers specifies time to switch to the "Update servers" phase. + UpdateServers time.Time `json:"update_servers,omitempty"` + // Standby specifies time to switch to the "Standby" phase. + Standby time.Time `json:"standby,omitempty"` +} + +// CheckAndSetDefaults checks and sets default values of the rotation schedule. +func (s *RotationSchedule) CheckAndSetDefaults(clock clockwork.Clock) error { + if s.UpdateServers.IsZero() { + return trace.BadParameter("phase %q has no time switch scheduled", RotationPhaseUpdateServers) + } + if s.Standby.IsZero() { + return trace.BadParameter("phase %q has no time switch scheduled", RotationPhaseStandby) + } + if s.Standby.Before(s.UpdateServers) { + return trace.BadParameter("phase %q can not be scheduled before %q", RotationPhaseStandby, RotationPhaseUpdateServers) + } + if s.UpdateServers.Before(clock.Now()) { + return trace.BadParameter("phase %q can not be scheduled in the past", RotationPhaseUpdateServers) + } + if s.Standby.Before(clock.Now()) { + return trace.BadParameter("phase %q can not be scheduled in the past", RotationPhaseStandby) + } + return nil +} + // CertAuthoritySpecV2 is a host or user certificate authority that // can check and if it has private key stored as well, sign it too type CertAuthoritySpecV2 struct { // Type is either user or host certificate authority Type CertAuthType `json:"type"` + // DELETE IN(2.7.0) this field is deprecated, + // as resource name matches cluster name after migrations. + // and this property is enforced by the auth server code. // ClusterName identifies cluster name this authority serves, // for host authorities that means base hostname of all servers, // for user authorities that means organization name @@ -540,6 +767,8 @@ type CertAuthoritySpecV2 struct { RoleMap RoleMap `json:"role_map,omitempty"` // TLS is a list of TLS key pairs TLSKeyPairs []TLSKeyPair `json:"tls_key_pairs,omitempty"` + // Rotation is a status of the certificate authority rotation + Rotation Rotation `json:"rotation,omitempty"` } // CertAuthoritySpecV2Schema is JSON schema for cert authority V2 @@ -579,10 +808,33 @@ const CertAuthoritySpecV2Schema = `{ } } }, + "rotation": %v, "role_map": %v } }` +// RotationSchema is a JSON validation schema of the CA rotation state object. +const RotationSchema = `{ + "type": "object", + "additionalProperties": false, + "properties": { + "state": {"type": "string"}, + "phase": {"type": "string"}, + "mode": {"type": "string"}, + "current_id": {"type": "string"}, + "started": {"type": "string"}, + "grace_period": {"type": "string"}, + "last_rotated": {"type": "string"}, + "schedule": { + "type": "object", + "properties": { + "update_servers": {"type": "string"}, + "standby": {"type": "string"} + } + } + } +}` + // CertAuthorityV1 is a host or user certificate authority that // can check and if it has private key stored as well, sign it too type CertAuthorityV1 struct { @@ -677,7 +929,7 @@ type CertAuthorityMarshaler interface { // GetCertAuthoritySchema returns JSON Schema for cert authorities func GetCertAuthoritySchema() string { - return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, fmt.Sprintf(CertAuthoritySpecV2Schema, RoleMapSchema), DefaultDefinitions) + return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, fmt.Sprintf(CertAuthoritySpecV2Schema, RotationSchema, RoleMapSchema), DefaultDefinitions) } type TeleportCertAuthorityMarshaler struct{} diff --git a/lib/services/local/configuration_test.go b/lib/services/local/configuration_test.go index c142681a06dce..5ae2f050a34f6 100644 --- a/lib/services/local/configuration_test.go +++ b/lib/services/local/configuration_test.go @@ -22,7 +22,7 @@ import ( "os" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/boltbk" + "github.com/gravitational/teleport/lib/backend/dir" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -50,7 +50,7 @@ func (s *ClusterConfigurationSuite) SetUpTest(c *check.C) { s.tempDir, err = ioutil.TempDir("", "preference-test-") c.Assert(err, check.IsNil) - s.bk, err = boltbk.New(backend.Params{"path": s.tempDir}) + s.bk, err = dir.New(backend.Params{"path": s.tempDir}) c.Assert(err, check.IsNil) } diff --git a/lib/services/local/presence_test.go b/lib/services/local/presence_test.go index e02c0492946a4..5bb45fd083125 100644 --- a/lib/services/local/presence_test.go +++ b/lib/services/local/presence_test.go @@ -22,7 +22,7 @@ import ( "os" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/boltbk" + "github.com/gravitational/teleport/lib/backend/dir" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -52,7 +52,7 @@ func (s *PresenceSuite) SetUpTest(c *check.C) { s.tempDir, err = ioutil.TempDir("", "trusted-clusters-") c.Assert(err, check.IsNil) - s.bk, err = boltbk.New(backend.Params{"path": s.tempDir}) + s.bk, err = dir.New(backend.Params{"path": s.tempDir}) c.Assert(err, check.IsNil) } diff --git a/lib/services/local/trust.go b/lib/services/local/trust.go index 9e7ddfcebae3b..d4b1bf71e767f 100644 --- a/lib/services/local/trust.go +++ b/lib/services/local/trust.go @@ -66,6 +66,32 @@ func (s *CA) UpsertCertAuthority(ca services.CertAuthority) error { return nil } +// CompareAndSwapCertAuthority updates the cert authority value +// if the existing value matches existing parameter, returns nil if succeeds, +// trace.CompareFailed otherwise. +func (s *CA) CompareAndSwapCertAuthority(new, existing services.CertAuthority) error { + if err := new.Check(); err != nil { + return trace.Wrap(err) + } + newData, err := services.GetCertAuthorityMarshaler().MarshalCertAuthority(new) + if err != nil { + return trace.Wrap(err) + } + existingData, err := services.GetCertAuthorityMarshaler().MarshalCertAuthority(existing) + if err != nil { + return trace.Wrap(err) + } + ttl := backend.TTL(s.Clock(), new.Expiry()) + err = s.CompareAndSwapVal([]string{"authorities", string(new.GetType())}, new.GetName(), newData, existingData, ttl) + if err != nil { + if trace.IsCompareFailed(err) { + return trace.CompareFailed("cluster %v settings have been updated, try again", new.GetName()) + } + return trace.Wrap(err) + } + return nil +} + // DeleteCertAuthority deletes particular certificate authority func (s *CA) DeleteCertAuthority(id services.CertAuthID) error { if err := id.Check(); err != nil { @@ -172,26 +198,6 @@ func setSigningKeys(ca services.CertAuthority, loadSigningKeys bool) { ca.SetTLSKeyPairs(keyPairs) } -// DELETE IN: 2.6.0 -// GetAnyCertAuthority returns activated or deactivated certificate authority -// by given id whether it is activated or not. This method is used in migrations. -func (s *CA) GetAnyCertAuthority(id services.CertAuthID) (services.CertAuthority, error) { - if err := id.Check(); err != nil { - return nil, trace.Wrap(err) - } - data, err := s.GetVal([]string{"authorities", string(id.Type)}, id.DomainName) - if err != nil { - if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - data, err = s.GetVal([]string{"authorities", "deactivated", string(id.Type)}, id.DomainName) - if err != nil { - return nil, trace.Wrap(err) - } - } - return services.GetCertAuthorityMarshaler().UnmarshalCertAuthority(data) -} - // GetCertAuthorities returns a list of authorities of a given type // loadSigningKeys controls whether signing keys should be loaded or not func (s *CA) GetCertAuthorities(caType services.CertAuthType, loadSigningKeys bool) ([]services.CertAuthority, error) { diff --git a/lib/services/local/users.go b/lib/services/local/users.go index b16960bfc32b3..3d3798efdf783 100644 --- a/lib/services/local/users.go +++ b/lib/services/local/users.go @@ -178,7 +178,7 @@ func (s *IdentityService) DeleteUser(user string) error { err := s.DeleteBucket([]string{"web", "users"}, user) if err != nil { if trace.IsNotFound(err) { - return trace.NotFound(fmt.Sprintf("user '%v' is not found", user)) + return trace.NotFound("user %q is not found", user) } } return trace.Wrap(err) diff --git a/lib/services/parser.go b/lib/services/parser.go index c1673a5dc976a..3cbed8a1c81d5 100644 --- a/lib/services/parser.go +++ b/lib/services/parser.go @@ -28,6 +28,7 @@ import ( "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "github.com/vulcand/predicate" + "github.com/vulcand/predicate/builder" ) // RuleContext specifies context passed to the @@ -38,18 +39,47 @@ type RuleContext interface { GetIdentifier(fields []string) (interface{}, error) // String returns human friendly representation of a context String() string + // GetResource returns resource if specified in the context, + // if unpecified, returns error. + GetResource() (Resource, error) } -// NewWhereParser returns standard parser for `where` section in access rules +var ( + // ResourceNameExpr is the identifer that specifies resource name. + ResourceNameExpr = builder.Identifier("resource.metadata.name") + // CertAuthorityTypeExpr is a function call that returns + // cert authority type. + CertAuthorityTypeExpr = builder.Identifier(`system.catype()`) +) + +// NewWhereParser returns standard parser for `where` section in access rules. func NewWhereParser(ctx RuleContext) (predicate.Parser, error) { return predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: predicate.And, OR: predicate.Or, + NOT: predicate.Not, }, Functions: map[string]interface{}{ "equals": predicate.Equals, "contains": predicate.Contains, + // system.catype is a function that returns cert authority type, + // it returns empty values for unrecognized values to + // pass static rule checks. + "system.catype": func() (interface{}, error) { + resource, err := ctx.GetResource() + if err != nil { + if trace.IsNotFound(err) { + return "", nil + } + return nil, trace.Wrap(err) + } + ca, ok := resource.(CertAuthority) + if !ok { + return "", nil + } + return string(ca.GetType()), nil + }, }, GetIdentifier: ctx.GetIdentifier, GetProperty: predicate.GetStringMapValue, @@ -124,6 +154,15 @@ const ( ResourceIdentifier = "resource" ) +// GetResource returns resource specified in the context, +// returns error if not specified. +func (ctx *Context) GetResource() (Resource, error) { + if ctx.Resource == nil { + return nil, trace.NotFound("resource is not set in the context") + } + return ctx.Resource, nil +} + // GetIdentifier returns identifier defined in a context func (ctx *Context) GetIdentifier(fields []string) (interface{}, error) { switch fields[0] { diff --git a/lib/services/resource.go b/lib/services/resource.go index be87f3979f0b8..357ec7bc802a2 100644 --- a/lib/services/resource.go +++ b/lib/services/resource.go @@ -153,6 +153,12 @@ const ( // to proxy KindRemoteCluster = "remote_cluster" + // KindIdenity is local on disk identity resource + KindIdentity = "identity" + + // KindState is local on disk process state + KindState = "state" + // V3 is the third version of resources. V3 = "v3" @@ -182,6 +188,10 @@ const ( // VerbDelete is used to remove an object. VerbDelete = "delete" + + // VerbRotate is used to rotate certificate authorities + // used only internally + VerbRotate = "rotate" ) func collectOptions(opts []MarshalOption) (*MarshalConfig, error) { diff --git a/lib/services/role.go b/lib/services/role.go index 9883636616fae..60788c245a85b 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -1637,7 +1637,7 @@ func (set RoleSet) CheckAccessToRule(ctx RuleContext, namespace string, resource } log.Infof("[RBAC] %s access to %s [namespace %s] denied for %v: no allow rule matched", verb, resource, namespace, set) - return trace.AccessDenied("access denied to perform action '%s' on %s", verb, resource) + return trace.AccessDenied("access denied to perform action %q on %q", verb, resource) } // ProcessNamespace sets default namespace in case if namespace is empty diff --git a/lib/services/server.go b/lib/services/server.go index c83f376ee8735..4aef105113d48 100644 --- a/lib/services/server.go +++ b/lib/services/server.go @@ -249,6 +249,8 @@ type ServerSpecV2 struct { Hostname string `json:"hostname"` // CmdLabels is server dynamic labels CmdLabels map[string]CommandLabelV2 `json:"cmd_labels,omitempty"` + // Rotation specifies server rotatoin status + Rotation Rotation `json:"rotation,omitempty"` } // ServerSpecV2Schema is JSON schema for server @@ -279,7 +281,8 @@ const ServerSpecV2Schema = `{ } } } - } + }, + "rotation": %v } }` @@ -423,7 +426,7 @@ func (c *CommandLabels) SetEnv(v string) error { // GetServerSchema returns role schema with optionally injected // schema for extensions func GetServerSchema() string { - return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, ServerSpecV2Schema, DefaultDefinitions) + return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, fmt.Sprintf(ServerSpecV2Schema, RotationSchema), DefaultDefinitions) } // UnmarshalServerResource unmarshals role from JSON or YAML, diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 0404a5e95221a..095bde88e23b9 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -174,11 +174,11 @@ func (s *ServicesTestSuite) UsersCRUD(c *C) { userSlicesEqual(c, u, []services.User{newUser("user2", nil)}) err = s.WebS.DeleteUser("user1") - c.Assert(trace.IsNotFound(err), Equals, true, Commentf("unexpected %T %#v", err, err)) + fixtures.ExpectNotFound(c, err) // bad username err = s.WebS.UpsertUser(newUser("", nil)) - c.Assert(trace.IsBadParameter(err), Equals, true, Commentf("expected bad parameter error, got %T", err)) + fixtures.ExpectBadParameter(c, err) } func (s *ServicesTestSuite) LoginAttempts(c *C) { @@ -226,6 +226,26 @@ func (s *ServicesTestSuite) CertAuthCRUD(c *C) { err = s.CAS.DeleteCertAuthority(*ca.ID()) c.Assert(err, IsNil) + + // test compare and swap + ca = NewTestCA(services.UserCA, "example.com") + c.Assert(s.CAS.CreateCertAuthority(ca), IsNil) + + clock := clockwork.NewFakeClock() + newCA := *ca + rotation := services.Rotation{ + State: services.RotationStateInProgress, + CurrentID: "id1", + GracePeriod: services.NewDuration(time.Hour), + Started: clock.Now(), + } + newCA.SetRotation(rotation) + + err = s.CAS.CompareAndSwapCertAuthority(&newCA, ca) + + out, err = s.CAS.GetCertAuthority(ca.GetID(), true) + c.Assert(err, IsNil) + fixtures.DeepCompare(c, &newCA, out) } func newServer(kind, name, addr, namespace string) *services.ServerV2 { diff --git a/lib/services/trust.go b/lib/services/trust.go index 7145e995c6160..8d379cc93a0b8 100644 --- a/lib/services/trust.go +++ b/lib/services/trust.go @@ -39,6 +39,11 @@ type Trust interface { // UpsertCertAuthority updates or inserts a new certificate authority UpsertCertAuthority(ca CertAuthority) error + // CompareAndSwapCertAuthority updates the cert authority value + // if existing value matches existing parameter, + // returns nil if succeeds, trace.CompareFailed otherwise + CompareAndSwapCertAuthority(new, existing CertAuthority) error + // DeleteCertAuthority deletes particular certificate authority DeleteCertAuthority(id CertAuthID) error @@ -49,12 +54,6 @@ type Trust interface { // controls if signing keys are loaded GetCertAuthority(id CertAuthID, loadSigningKeys bool) (CertAuthority, error) - // DELETE IN: 2.6.0 - // GetAnyCertAuthority returns activated or deactivated certificate authority - // by given id whether it is activated or not. Signing keys are never loaded. - // This method is used in migrations. - GetAnyCertAuthority(id CertAuthID) (ca CertAuthority, error error) - // GetCertAuthorities returns a list of authorities of a given type // loadSigningKeys controls whether signing keys should be loaded or not GetCertAuthorities(caType CertAuthType, loadSigningKeys bool) ([]CertAuthority, error) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index fd4fa6654bbb8..97323a9d5f596 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -66,6 +66,7 @@ type Server struct { srv *sshutils.Server hostSigner ssh.Signer shell string + getRotation RotationGetter authService auth.AccessPoint reg *srv.SessionRegistry sessionServer rsession.Service @@ -219,6 +220,17 @@ func (s *Server) Wait() { s.srv.Wait(context.TODO()) } +// RotationGetter returns rotation state +type RotationGetter func(role teleport.Role) (*services.Rotation, error) + +// SetRotationGetter sets rotation state getter +func SetRotationGetter(getter RotationGetter) ServerOption { + return func(s *Server) error { + s.getRotation = getter + return nil + } +} + // SetShell sets default shell that will be executed for interactive // sessions func SetShell(shell string) ServerOption { @@ -239,7 +251,10 @@ func SetSessionServer(sessionServer rsession.Service) ServerOption { // SetProxyMode starts this server in SSH proxying mode func SetProxyMode(tsrv reversetunnel.Server) ServerOption { return func(s *Server) error { - s.proxyMode = (tsrv != nil) + // always set proxy mode to true, + // because in some tests reverse tunnel is disabled, + // but proxy is still used without it. + s.proxyMode = true s.proxyTun = tsrv return nil } @@ -466,7 +481,14 @@ func (s *Server) AdvertiseAddr() string { return net.JoinHostPort(s.getAdvertiseIP().String(), port) } -func (s *Server) getInfo() services.Server { +func (s *Server) getRole() teleport.Role { + if s.proxyMode { + return teleport.RoleProxy + } + return teleport.RoleNode +} + +func (s *Server) getInfo() *services.ServerV2 { return &services.ServerV2{ Kind: services.KindNode, Version: services.V2, @@ -486,6 +508,16 @@ func (s *Server) getInfo() services.Server { // registerServer attempts to register server in the cluster func (s *Server) registerServer() error { server := s.getInfo() + if s.getRotation != nil { + rotation, err := s.getRotation(s.getRole()) + if err != nil { + if !trace.IsNotFound(err) { + log.Warningf("Failed to get rotation state: %v", err) + } + } else { + server.Spec.Rotation = *rotation + } + } server.SetTTL(s.clock, defaults.ServerHeartbeatTTL) if !s.proxyMode { return trace.Wrap(s.authService.UpsertNode(server)) diff --git a/lib/state/cachingaccesspoint.go b/lib/state/cachingaccesspoint.go index f1767c26575a7..74cb5f39ea925 100644 --- a/lib/state/cachingaccesspoint.go +++ b/lib/state/cachingaccesspoint.go @@ -19,6 +19,7 @@ package state import ( "fmt" + "io" "strings" "sync" "time" @@ -819,7 +820,15 @@ func (cs *CachingAuthClient) try(f func() error) error { return trace.ConnectionProblem(fmt.Errorf("backoff"), "backing off due to recent errors") } accessPointRequests.Inc() - err := trace.ConvertSystemError(f()) + err := f() + if err != nil { + // EOF in this context means connection problem + if trace.Unwrap(err) == io.EOF { + err = trace.ConnectionProblem(trace.Unwrap(err), "EOF") + } else { + err = trace.ConvertSystemError(err) + } + } accessPointLatencies.Observe(mdiff(start)) if trace.IsConnectionProblem(err) { cs.setLastErrorTime(time.Now()) diff --git a/lib/utils/copy.go b/lib/utils/copy.go index 7901e4dd10880..00378a66c31c2 100644 --- a/lib/utils/copy.go +++ b/lib/utils/copy.go @@ -16,6 +16,28 @@ limitations under the License. package utils +// CopyByteSlice returns a copy of the byte slice. +func CopyByteSlice(in []byte) []byte { + if in == nil { + return nil + } + out := make([]byte, len(in)) + copy(out, in) + return out +} + +// CopyByteSlices returns a copy of the byte slices. +func CopyByteSlices(in [][]byte) [][]byte { + if in == nil { + return nil + } + out := make([][]byte, len(in)) + for i := range in { + out[i] = CopyByteSlice(in[i]) + } + return out +} + // CopyStrings makes a deep copy of the passed in string slice and returns // the copy. func CopyStrings(in []string) []string { diff --git a/tool/tctl/common/auth_command.go b/tool/tctl/common/auth_command.go index 66507ae5e1d57..d5fc142744a36 100644 --- a/tool/tctl/common/auth_command.go +++ b/tool/tctl/common/auth_command.go @@ -17,9 +17,9 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" "github.com/gravitational/kingpin" + "github.com/gravitational/trace" ) // AuthCommand implements `tctl auth` group of commands @@ -38,9 +38,15 @@ type AuthCommand struct { compatVersion string compatibility string + rotateGracePeriod time.Duration + rotateType string + rotateManualMode bool + rotateTargetPhase string + authGenerate *kingpin.CmdClause authExport *kingpin.CmdClause authSign *kingpin.CmdClause + authRotate *kingpin.CmdClause } // Initialize allows TokenCommand to plug itself into the CLI parser @@ -66,6 +72,12 @@ func (a *AuthCommand) Initialize(app *kingpin.Application, config *service.Confi a.authSign.Flag("format", "identity format: 'file' (default) or 'dir'").Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&a.outputFormat)) a.authSign.Flag("ttl", "TTL (time to live) for the generated certificate").Default(fmt.Sprintf("%v", defaults.CertDuration)).DurationVar(&a.genTTL) a.authSign.Flag("compat", "OpenSSH compatibility flag").StringVar(&a.compatibility) + + a.authRotate = auth.Command("rotate", "Rotate certificate authorities in the cluster") + a.authRotate.Flag("grace-period", "Grace period keeps previous certificate authorities signatures valid, if set to 0 will force users to relogin and nodes to re-register.").Default(fmt.Sprintf("%v", defaults.RotationGracePeriod)).DurationVar(&a.rotateGracePeriod) + a.authRotate.Flag("manual", "Activate manual rotation , set rotation phases manually").BoolVar(&a.rotateManualMode) + a.authRotate.Flag("type", "Certificate authority to rotate, rotates both host and user CA by default").StringVar(&a.rotateType) + a.authRotate.Flag("phase", fmt.Sprintf("Target rotation phase to set, used in manual rotation, one of: %v", strings.Join(services.RotatePhases, ", "))).StringVar(&a.rotateTargetPhase) } // TryRun takes the CLI command as an argument (like "auth gen") and executes it @@ -78,7 +90,8 @@ func (a *AuthCommand) TryRun(cmd string, client auth.ClientI) (match bool, err e err = a.ExportAuthorities(client) case a.authSign.FullCommand(): err = a.GenerateAndSignKeys(client) - + case a.authRotate.FullCommand(): + err = a.RotateCertAuthority(client) default: return false, nil } @@ -236,6 +249,30 @@ func (a *AuthCommand) GenerateAndSignKeys(clusterApi auth.ClientI) error { } } +// RotateCertAuthority starts or restarts certificate authority rotation process +func (a *AuthCommand) RotateCertAuthority(client auth.ClientI) error { + req := auth.RotateRequest{ + Type: services.CertAuthType(a.rotateType), + GracePeriod: &a.rotateGracePeriod, + TargetPhase: a.rotateTargetPhase, + } + if a.rotateManualMode { + req.Mode = services.RotationModeManual + } else { + req.Mode = services.RotationModeAuto + } + if err := client.RotateCertAuthority(req); err != nil { + return err + } + if a.rotateTargetPhase != "" { + fmt.Printf("Updated rotation phase to %q. To check status use 'tctl status'\n", a.rotateTargetPhase) + } else { + fmt.Printf("Initiated certificate authority rotation. To check status use 'tctl status'\n") + } + + return nil +} + func (a *AuthCommand) generateHostKeys(clusterApi auth.ClientI) error { // only format=openssh is supported if a.outputFormat != client.IdentityFormatOpenSSH { diff --git a/tool/tctl/common/status_command.go b/tool/tctl/common/status_command.go new file mode 100644 index 0000000000000..a026efefb3ae5 --- /dev/null +++ b/tool/tctl/common/status_command.go @@ -0,0 +1,118 @@ +/* +Copyright 2018 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" + "strings" + + "github.com/gravitational/kingpin" + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/trace" +) + +// StatusCommand implements `tctl token` group of commands. +type StatusCommand struct { + config *service.Config + + // CLI clauses (subcommands) + status *kingpin.CmdClause +} + +// Initialize allows StatusCommand to plug itself into the CLI parser. +func (c *StatusCommand) Initialize(app *kingpin.Application, config *service.Config) { + c.config = config + c.status = app.Command("status", "Report cluster status") +} + +// TryRun takes the CLI command as an argument (like "nodes ls") and executes it. +func (c *StatusCommand) TryRun(cmd string, client auth.ClientI) (match bool, err error) { + switch cmd { + case c.status.FullCommand(): + err = c.Status(client) + default: + return false, nil + } + return true, trace.Wrap(err) +} + +// Status is called to execute "status" CLI command. +func (c *StatusCommand) Status(client auth.ClientI) error { + clusterNameResource, err := client.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + clusterName := clusterNameResource.GetClusterName() + + hostCAs, err := client.GetCertAuthorities(services.HostCA, false) + if err != nil { + return trace.Wrap(err) + } + + userCAs, err := client.GetCertAuthorities(services.UserCA, false) + if err != nil { + return trace.Wrap(err) + } + + authorities := append(userCAs, hostCAs...) + view := func() string { + table := asciitable.MakeHeadlessTable(2) + table.AddRow([]string{"Cluster", clusterName}) + for _, ca := range authorities { + if ca.GetClusterName() != clusterName { + continue + } + info := fmt.Sprintf("%v CA ", strings.Title(string(ca.GetType()))) + rotation := ca.GetRotation() + if c.config.Debug { + table.AddRow([]string{info, + fmt.Sprintf("%v, update_servers: %v, complete: %v", + rotation.String(), + rotation.Schedule.UpdateServers.Format(teleport.HumanDateFormatSeconds), + rotation.Schedule.Standby.Format(teleport.HumanDateFormatSeconds), + )}) + } else { + table.AddRow([]string{info, rotation.String()}) + } + + } + return table.AsBuffer().String() + } + fmt.Printf(view()) + + // in debug mode, output mode of remote certificate authorities + if c.config.Debug { + view := func() string { + table := asciitable.MakeHeadlessTable(2) + for _, ca := range authorities { + if ca.GetClusterName() == clusterName { + continue + } + info := fmt.Sprintf("Remote %v CA %q", strings.Title(string(ca.GetType())), ca.GetClusterName()) + rotation := ca.GetRotation() + table.AddRow([]string{info, rotation.String()}) + } + return "Remote clusters\n\n" + table.AsBuffer().String() + } + fmt.Printf(view()) + } + return nil +} diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 16acb75d8b6f5..8e1de1d1b2a8a 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -19,6 +19,7 @@ package common import ( "fmt" "os" + "path/filepath" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" @@ -86,7 +87,7 @@ func Run(commands []CLICommand) { "Base64 encoded configuration string").Hidden().Envar(defaults.ConfigEnvar).StringVar(&ccf.ConfigString) // "version" command is always available: - ver := app.Command("version", "Print the version.") + ver := app.Command("version", "Print cluster version") app.HelpFlag.Short('h') // parse CLI commands+flags: @@ -133,7 +134,7 @@ func connectToAuthService(cfg *service.Config) (client auth.ClientI, err error) } } // read the host SSH keys and use them to open an SSH connection to the auth service - i, err := auth.ReadIdentity(cfg.DataDir, auth.IdentityID{Role: teleport.RoleAdmin, HostUUID: cfg.HostUUID}) + i, err := auth.ReadLocalIdentity(filepath.Join(cfg.DataDir, teleport.ComponentProcess), auth.IdentityID{Role: teleport.RoleAdmin, HostUUID: cfg.HostUUID}) if err != nil { // the "admin" identity is not present? this means the tctl is running NOT on the auth server. if trace.IsNotFound(err) { @@ -182,8 +183,9 @@ func applyConfig(ccf *GlobalCLIFlags, cfg *service.Config) error { } // --debug flag if ccf.Debug { + cfg.Debug = ccf.Debug utils.InitLogger(utils.LoggingForCLI, logrus.DebugLevel) - logrus.Debugf("DEBUG loggign enabled") + logrus.Debugf("DEBUG logging enabled") } // read a host UUID for this node diff --git a/tool/tctl/main.go b/tool/tctl/main.go index 45fcc59c4dc8e..5b5974bf1cef3 100644 --- a/tool/tctl/main.go +++ b/tool/tctl/main.go @@ -27,6 +27,7 @@ func main() { &common.TokenCommand{}, &common.AuthCommand{}, &common.ResourceCommand{}, + &common.StatusCommand{}, } common.Run(commands) } diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 4593b96f8502a..0c168dfa5e91d 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -165,16 +165,7 @@ func Run(options Options) (executedCommand string, conf *service.Config) { // OnStart is the handler for "start" CLI command func OnStart(config *service.Config) error { - srv, err := service.NewTeleport(config) - if err != nil { - return trace.Wrap(err, "Initialization failed") - } - - if err := srv.Start(); err != nil { - return trace.Wrap(err, "Startup Failed") - } - - return srv.WaitForSignals(context.TODO()) + return service.Run(context.TODO(), *config, nil) } // onStatus is the handler for "status" CLI command diff --git a/vendor/github.com/cenkalti/backoff/.gitignore b/vendor/github.com/cenkalti/backoff/.gitignore deleted file mode 100644 index 00268614f0456..0000000000000 --- a/vendor/github.com/cenkalti/backoff/.gitignore +++ /dev/null @@ -1,22 +0,0 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe diff --git a/vendor/github.com/cenkalti/backoff/.travis.yml b/vendor/github.com/cenkalti/backoff/.travis.yml deleted file mode 100644 index 1040404bfbc06..0000000000000 --- a/vendor/github.com/cenkalti/backoff/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go -go: - - 1.3.3 - - tip -before_install: - - go get github.com/mattn/goveralls - - go get golang.org/x/tools/cmd/cover -script: - - $HOME/gopath/bin/goveralls -service=travis-ci diff --git a/vendor/github.com/cenkalti/backoff/LICENSE b/vendor/github.com/cenkalti/backoff/LICENSE deleted file mode 100644 index 89b8179965581..0000000000000 --- a/vendor/github.com/cenkalti/backoff/LICENSE +++ /dev/null @@ -1,20 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2014 Cenk Altı - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/cenkalti/backoff/README.md b/vendor/github.com/cenkalti/backoff/README.md deleted file mode 100644 index 13b347fb95179..0000000000000 --- a/vendor/github.com/cenkalti/backoff/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls] - -This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client]. - -[Exponential backoff][exponential backoff wiki] -is an algorithm that uses feedback to multiplicatively decrease the rate of some process, -in order to gradually find an acceptable rate. -The retries exponentially increase and stop increasing when a certain threshold is met. - -## Usage - -See https://godoc.org/github.com/cenkalti/backoff#pkg-examples - -## Contributing - -* I would like to keep this library as small as possible. -* Please don't send a PR without opening an issue and discussing it first. -* If proposed change is not a common use case, I will probably not accept it. - -[godoc]: https://godoc.org/github.com/cenkalti/backoff -[godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png -[travis]: https://travis-ci.org/cenkalti/backoff -[travis image]: https://travis-ci.org/cenkalti/backoff.png?branch=master -[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master -[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master - -[google-http-java-client]: https://github.com/google/google-http-java-client -[exponential backoff wiki]: http://en.wikipedia.org/wiki/Exponential_backoff - -[advanced example]: https://godoc.org/github.com/cenkalti/backoff#example_ diff --git a/vendor/github.com/cenkalti/backoff/backoff.go b/vendor/github.com/cenkalti/backoff/backoff.go deleted file mode 100644 index 2102c5f2de96b..0000000000000 --- a/vendor/github.com/cenkalti/backoff/backoff.go +++ /dev/null @@ -1,66 +0,0 @@ -// Package backoff implements backoff algorithms for retrying operations. -// -// Use Retry function for retrying operations that may fail. -// If Retry does not meet your needs, -// copy/paste the function into your project and modify as you wish. -// -// There is also Ticker type similar to time.Ticker. -// You can use it if you need to work with channels. -// -// See Examples section below for usage examples. -package backoff - -import "time" - -// BackOff is a backoff policy for retrying an operation. -type BackOff interface { - // NextBackOff returns the duration to wait before retrying the operation, - // or backoff.Stop to indicate that no more retries should be made. - // - // Example usage: - // - // duration := backoff.NextBackOff(); - // if (duration == backoff.Stop) { - // // Do not retry operation. - // } else { - // // Sleep for duration and retry operation. - // } - // - NextBackOff() time.Duration - - // Reset to initial state. - Reset() -} - -// Stop indicates that no more retries should be made for use in NextBackOff(). -const Stop time.Duration = -1 - -// ZeroBackOff is a fixed backoff policy whose backoff time is always zero, -// meaning that the operation is retried immediately without waiting, indefinitely. -type ZeroBackOff struct{} - -func (b *ZeroBackOff) Reset() {} - -func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 } - -// StopBackOff is a fixed backoff policy that always returns backoff.Stop for -// NextBackOff(), meaning that the operation should never be retried. -type StopBackOff struct{} - -func (b *StopBackOff) Reset() {} - -func (b *StopBackOff) NextBackOff() time.Duration { return Stop } - -// ConstantBackOff is a backoff policy that always returns the same backoff delay. -// This is in contrast to an exponential backoff policy, -// which returns a delay that grows longer as you call NextBackOff() over and over again. -type ConstantBackOff struct { - Interval time.Duration -} - -func (b *ConstantBackOff) Reset() {} -func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval } - -func NewConstantBackOff(d time.Duration) *ConstantBackOff { - return &ConstantBackOff{Interval: d} -} diff --git a/vendor/github.com/cenkalti/backoff/backoff_test.go b/vendor/github.com/cenkalti/backoff/backoff_test.go deleted file mode 100644 index 91f27c4f19010..0000000000000 --- a/vendor/github.com/cenkalti/backoff/backoff_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package backoff - -import ( - "testing" - "time" -) - -func TestNextBackOffMillis(t *testing.T) { - subtestNextBackOff(t, 0, new(ZeroBackOff)) - subtestNextBackOff(t, Stop, new(StopBackOff)) -} - -func subtestNextBackOff(t *testing.T, expectedValue time.Duration, backOffPolicy BackOff) { - for i := 0; i < 10; i++ { - next := backOffPolicy.NextBackOff() - if next != expectedValue { - t.Errorf("got: %d expected: %d", next, expectedValue) - } - } -} - -func TestConstantBackOff(t *testing.T) { - backoff := NewConstantBackOff(time.Second) - if backoff.NextBackOff() != time.Second { - t.Error("invalid interval") - } -} diff --git a/vendor/github.com/cenkalti/backoff/context.go b/vendor/github.com/cenkalti/backoff/context.go deleted file mode 100644 index 5d157092544fd..0000000000000 --- a/vendor/github.com/cenkalti/backoff/context.go +++ /dev/null @@ -1,60 +0,0 @@ -package backoff - -import ( - "time" - - "golang.org/x/net/context" -) - -// BackOffContext is a backoff policy that stops retrying after the context -// is canceled. -type BackOffContext interface { - BackOff - Context() context.Context -} - -type backOffContext struct { - BackOff - ctx context.Context -} - -// WithContext returns a BackOffContext with context ctx -// -// ctx must not be nil -func WithContext(b BackOff, ctx context.Context) BackOffContext { - if ctx == nil { - panic("nil context") - } - - if b, ok := b.(*backOffContext); ok { - return &backOffContext{ - BackOff: b.BackOff, - ctx: ctx, - } - } - - return &backOffContext{ - BackOff: b, - ctx: ctx, - } -} - -func ensureContext(b BackOff) BackOffContext { - if cb, ok := b.(BackOffContext); ok { - return cb - } - return WithContext(b, context.Background()) -} - -func (b *backOffContext) Context() context.Context { - return b.ctx -} - -func (b *backOffContext) NextBackOff() time.Duration { - select { - case <-b.Context().Done(): - return Stop - default: - return b.BackOff.NextBackOff() - } -} diff --git a/vendor/github.com/cenkalti/backoff/context_test.go b/vendor/github.com/cenkalti/backoff/context_test.go deleted file mode 100644 index 993fa6149db77..0000000000000 --- a/vendor/github.com/cenkalti/backoff/context_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package backoff - -import ( - "testing" - "time" - - "golang.org/x/net/context" -) - -func TestContext(t *testing.T) { - b := NewConstantBackOff(time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cb := WithContext(b, ctx) - - if cb.Context() != ctx { - t.Error("invalid context") - } - - cancel() - - if cb.NextBackOff() != Stop { - t.Error("invalid next back off") - } -} diff --git a/vendor/github.com/cenkalti/backoff/example_test.go b/vendor/github.com/cenkalti/backoff/example_test.go deleted file mode 100644 index d97a8db8e34c3..0000000000000 --- a/vendor/github.com/cenkalti/backoff/example_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package backoff - -import ( - "log" - - "golang.org/x/net/context" -) - -func ExampleRetry() { - // An operation that may fail. - operation := func() error { - return nil // or an error - } - - err := Retry(operation, NewExponentialBackOff()) - if err != nil { - // Handle error. - return - } - - // Operation is successful. -} - -func ExampleRetryContext() { - // A context - ctx := context.Background() - - // An operation that may fail. - operation := func() error { - return nil // or an error - } - - b := WithContext(NewExponentialBackOff(), ctx) - - err := Retry(operation, b) - if err != nil { - // Handle error. - return - } - - // Operation is successful. -} - -func ExampleTicker() { - // An operation that may fail. - operation := func() error { - return nil // or an error - } - - ticker := NewTicker(NewExponentialBackOff()) - - var err error - - // Ticks will continue to arrive when the previous operation is still running, - // so operations that take a while to fail could run in quick succession. - for _ = range ticker.C { - if err = operation(); err != nil { - log.Println(err, "will retry...") - continue - } - - ticker.Stop() - break - } - - if err != nil { - // Operation has failed. - return - } - - // Operation is successful. - return -} diff --git a/vendor/github.com/cenkalti/backoff/exponential.go b/vendor/github.com/cenkalti/backoff/exponential.go deleted file mode 100644 index 9a6addf075016..0000000000000 --- a/vendor/github.com/cenkalti/backoff/exponential.go +++ /dev/null @@ -1,156 +0,0 @@ -package backoff - -import ( - "math/rand" - "time" -) - -/* -ExponentialBackOff is a backoff implementation that increases the backoff -period for each retry attempt using a randomization function that grows exponentially. - -NextBackOff() is calculated using the following formula: - - randomized interval = - RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor]) - -In other words NextBackOff() will range between the randomization factor -percentage below and above the retry interval. - -For example, given the following parameters: - - RetryInterval = 2 - RandomizationFactor = 0.5 - Multiplier = 2 - -the actual backoff period used in the next retry attempt will range between 1 and 3 seconds, -multiplied by the exponential, that is, between 2 and 6 seconds. - -Note: MaxInterval caps the RetryInterval and not the randomized interval. - -If the time elapsed since an ExponentialBackOff instance is created goes past the -MaxElapsedTime, then the method NextBackOff() starts returning backoff.Stop. - -The elapsed time can be reset by calling Reset(). - -Example: Given the following default arguments, for 10 tries the sequence will be, -and assuming we go over the MaxElapsedTime on the 10th try: - - Request # RetryInterval (seconds) Randomized Interval (seconds) - - 1 0.5 [0.25, 0.75] - 2 0.75 [0.375, 1.125] - 3 1.125 [0.562, 1.687] - 4 1.687 [0.8435, 2.53] - 5 2.53 [1.265, 3.795] - 6 3.795 [1.897, 5.692] - 7 5.692 [2.846, 8.538] - 8 8.538 [4.269, 12.807] - 9 12.807 [6.403, 19.210] - 10 19.210 backoff.Stop - -Note: Implementation is not thread-safe. -*/ -type ExponentialBackOff struct { - InitialInterval time.Duration - RandomizationFactor float64 - Multiplier float64 - MaxInterval time.Duration - // After MaxElapsedTime the ExponentialBackOff stops. - // It never stops if MaxElapsedTime == 0. - MaxElapsedTime time.Duration - Clock Clock - - currentInterval time.Duration - startTime time.Time - random *rand.Rand -} - -// Clock is an interface that returns current time for BackOff. -type Clock interface { - Now() time.Time -} - -// Default values for ExponentialBackOff. -const ( - DefaultInitialInterval = 500 * time.Millisecond - DefaultRandomizationFactor = 0.5 - DefaultMultiplier = 1.5 - DefaultMaxInterval = 60 * time.Second - DefaultMaxElapsedTime = 15 * time.Minute -) - -// NewExponentialBackOff creates an instance of ExponentialBackOff using default values. -func NewExponentialBackOff() *ExponentialBackOff { - b := &ExponentialBackOff{ - InitialInterval: DefaultInitialInterval, - RandomizationFactor: DefaultRandomizationFactor, - Multiplier: DefaultMultiplier, - MaxInterval: DefaultMaxInterval, - MaxElapsedTime: DefaultMaxElapsedTime, - Clock: SystemClock, - random: rand.New(rand.NewSource(time.Now().UnixNano())), - } - b.Reset() - return b -} - -type systemClock struct{} - -func (t systemClock) Now() time.Time { - return time.Now() -} - -// SystemClock implements Clock interface that uses time.Now(). -var SystemClock = systemClock{} - -// Reset the interval back to the initial retry interval and restarts the timer. -func (b *ExponentialBackOff) Reset() { - b.currentInterval = b.InitialInterval - b.startTime = b.Clock.Now() -} - -// NextBackOff calculates the next backoff interval using the formula: -// Randomized interval = RetryInterval +/- (RandomizationFactor * RetryInterval) -func (b *ExponentialBackOff) NextBackOff() time.Duration { - // Make sure we have not gone over the maximum elapsed time. - if b.MaxElapsedTime != 0 && b.GetElapsedTime() > b.MaxElapsedTime { - return Stop - } - defer b.incrementCurrentInterval() - if b.random == nil { - b.random = rand.New(rand.NewSource(time.Now().UnixNano())) - } - return getRandomValueFromInterval(b.RandomizationFactor, b.random.Float64(), b.currentInterval) -} - -// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance -// is created and is reset when Reset() is called. -// -// The elapsed time is computed using time.Now().UnixNano(). -func (b *ExponentialBackOff) GetElapsedTime() time.Duration { - return b.Clock.Now().Sub(b.startTime) -} - -// Increments the current interval by multiplying it with the multiplier. -func (b *ExponentialBackOff) incrementCurrentInterval() { - // Check for overflow, if overflow is detected set the current interval to the max interval. - if float64(b.currentInterval) >= float64(b.MaxInterval)/b.Multiplier { - b.currentInterval = b.MaxInterval - } else { - b.currentInterval = time.Duration(float64(b.currentInterval) * b.Multiplier) - } -} - -// Returns a random value from the following interval: -// [randomizationFactor * currentInterval, randomizationFactor * currentInterval]. -func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration { - var delta = randomizationFactor * float64(currentInterval) - var minInterval = float64(currentInterval) - delta - var maxInterval = float64(currentInterval) + delta - - // Get a random value from the range [minInterval, maxInterval]. - // The formula used below has a +1 because if the minInterval is 1 and the maxInterval is 3 then - // we want a 33% chance for selecting either 1, 2 or 3. - return time.Duration(minInterval + (random * (maxInterval - minInterval + 1))) -} diff --git a/vendor/github.com/cenkalti/backoff/exponential_test.go b/vendor/github.com/cenkalti/backoff/exponential_test.go deleted file mode 100644 index 11b95e4f61d43..0000000000000 --- a/vendor/github.com/cenkalti/backoff/exponential_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package backoff - -import ( - "math" - "testing" - "time" -) - -func TestBackOff(t *testing.T) { - var ( - testInitialInterval = 500 * time.Millisecond - testRandomizationFactor = 0.1 - testMultiplier = 2.0 - testMaxInterval = 5 * time.Second - testMaxElapsedTime = 15 * time.Minute - ) - - exp := NewExponentialBackOff() - exp.InitialInterval = testInitialInterval - exp.RandomizationFactor = testRandomizationFactor - exp.Multiplier = testMultiplier - exp.MaxInterval = testMaxInterval - exp.MaxElapsedTime = testMaxElapsedTime - exp.Reset() - - var expectedResults = []time.Duration{500, 1000, 2000, 4000, 5000, 5000, 5000, 5000, 5000, 5000} - for i, d := range expectedResults { - expectedResults[i] = d * time.Millisecond - } - - for _, expected := range expectedResults { - assertEquals(t, expected, exp.currentInterval) - // Assert that the next backoff falls in the expected range. - var minInterval = expected - time.Duration(testRandomizationFactor*float64(expected)) - var maxInterval = expected + time.Duration(testRandomizationFactor*float64(expected)) - var actualInterval = exp.NextBackOff() - if !(minInterval <= actualInterval && actualInterval <= maxInterval) { - t.Error("error") - } - } -} - -func TestGetRandomizedInterval(t *testing.T) { - // 33% chance of being 1. - assertEquals(t, 1, getRandomValueFromInterval(0.5, 0, 2)) - assertEquals(t, 1, getRandomValueFromInterval(0.5, 0.33, 2)) - // 33% chance of being 2. - assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.34, 2)) - assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.66, 2)) - // 33% chance of being 3. - assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.67, 2)) - assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.99, 2)) -} - -type TestClock struct { - i time.Duration - start time.Time -} - -func (c *TestClock) Now() time.Time { - t := c.start.Add(c.i) - c.i += time.Second - return t -} - -func TestGetElapsedTime(t *testing.T) { - var exp = NewExponentialBackOff() - exp.Clock = &TestClock{} - exp.Reset() - - var elapsedTime = exp.GetElapsedTime() - if elapsedTime != time.Second { - t.Errorf("elapsedTime=%d", elapsedTime) - } -} - -func TestMaxElapsedTime(t *testing.T) { - var exp = NewExponentialBackOff() - exp.Clock = &TestClock{start: time.Time{}.Add(10000 * time.Second)} - // Change the currentElapsedTime to be 0 ensuring that the elapsed time will be greater - // than the max elapsed time. - exp.startTime = time.Time{} - assertEquals(t, Stop, exp.NextBackOff()) -} - -func TestBackOffOverflow(t *testing.T) { - var ( - testInitialInterval time.Duration = math.MaxInt64 / 2 - testMaxInterval time.Duration = math.MaxInt64 - testMultiplier = 2.1 - ) - - exp := NewExponentialBackOff() - exp.InitialInterval = testInitialInterval - exp.Multiplier = testMultiplier - exp.MaxInterval = testMaxInterval - exp.Reset() - - exp.NextBackOff() - // Assert that when an overflow is possible the current varerval time.Duration is set to the max varerval time.Duration . - assertEquals(t, testMaxInterval, exp.currentInterval) -} - -func assertEquals(t *testing.T, expected, value time.Duration) { - if expected != value { - t.Errorf("got: %d, expected: %d", value, expected) - } -} diff --git a/vendor/github.com/cenkalti/backoff/retry.go b/vendor/github.com/cenkalti/backoff/retry.go deleted file mode 100644 index 5dbd825b5c8b5..0000000000000 --- a/vendor/github.com/cenkalti/backoff/retry.go +++ /dev/null @@ -1,78 +0,0 @@ -package backoff - -import "time" - -// An Operation is executing by Retry() or RetryNotify(). -// The operation will be retried using a backoff policy if it returns an error. -type Operation func() error - -// Notify is a notify-on-error function. It receives an operation error and -// backoff delay if the operation failed (with an error). -// -// NOTE that if the backoff policy stated to stop retrying, -// the notify function isn't called. -type Notify func(error, time.Duration) - -// Retry the operation o until it does not return error or BackOff stops. -// o is guaranteed to be run at least once. -// It is the caller's responsibility to reset b after Retry returns. -// -// If o returns a *PermanentError, the operation is not retried, and the -// wrapped error is returned. -// -// Retry sleeps the goroutine for the duration returned by BackOff after a -// failed operation returns. -func Retry(o Operation, b BackOff) error { return RetryNotify(o, b, nil) } - -// RetryNotify calls notify function with the error and wait duration -// for each failed attempt before sleep. -func RetryNotify(operation Operation, b BackOff, notify Notify) error { - var err error - var next time.Duration - - cb := ensureContext(b) - - b.Reset() - for { - if err = operation(); err == nil { - return nil - } - - if permanent, ok := err.(*PermanentError); ok { - return permanent.Err - } - - if next = b.NextBackOff(); next == Stop { - return err - } - - if notify != nil { - notify(err, next) - } - - t := time.NewTimer(next) - - select { - case <-cb.Context().Done(): - t.Stop() - return err - case <-t.C: - } - } -} - -// PermanentError signals that the operation should not be retried. -type PermanentError struct { - Err error -} - -func (e *PermanentError) Error() string { - return e.Err.Error() -} - -// Permanent wraps the given err in a *PermanentError. -func Permanent(err error) *PermanentError { - return &PermanentError{ - Err: err, - } -} diff --git a/vendor/github.com/cenkalti/backoff/retry_test.go b/vendor/github.com/cenkalti/backoff/retry_test.go deleted file mode 100644 index 5af288897df6a..0000000000000 --- a/vendor/github.com/cenkalti/backoff/retry_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package backoff - -import ( - "errors" - "fmt" - "log" - "testing" - "time" - - "golang.org/x/net/context" -) - -func TestRetry(t *testing.T) { - const successOn = 3 - var i = 0 - - // This function is successful on "successOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - if i == successOn { - log.Println("OK") - return nil - } - - log.Println("error") - return errors.New("error") - } - - err := Retry(f, NewExponentialBackOff()) - if err != nil { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != successOn { - t.Errorf("invalid number of retries: %d", i) - } -} - -func TestRetryContext(t *testing.T) { - var cancelOn = 3 - var i = 0 - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // This function cancels context on "cancelOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - // cancelling the context in the operation function is not a typical - // use-case, however it allows to get predictable test results. - if i == cancelOn { - cancel() - } - - log.Println("error") - return fmt.Errorf("error (%d)", i) - } - - err := Retry(f, WithContext(NewConstantBackOff(time.Millisecond), ctx)) - if err == nil { - t.Errorf("error is unexpectedly nil") - } - if err.Error() != "error (3)" { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != cancelOn { - t.Errorf("invalid number of retries: %d", i) - } -} - -func TestRetryPermenent(t *testing.T) { - const permanentOn = 3 - var i = 0 - - // This function fails permanently after permanentOn tries - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - if i == permanentOn { - log.Println("permanent error") - return Permanent(errors.New("permanent error")) - } - - log.Println("error") - return errors.New("error") - } - - err := Retry(f, NewExponentialBackOff()) - if err == nil || err.Error() != "permanent error" { - t.Errorf("unexpected error: %s", err) - } - if i != permanentOn { - t.Errorf("invalid number of retries: %d", i) - } -} diff --git a/vendor/github.com/cenkalti/backoff/ticker.go b/vendor/github.com/cenkalti/backoff/ticker.go deleted file mode 100644 index 49a99718d7425..0000000000000 --- a/vendor/github.com/cenkalti/backoff/ticker.go +++ /dev/null @@ -1,81 +0,0 @@ -package backoff - -import ( - "runtime" - "sync" - "time" -) - -// Ticker holds a channel that delivers `ticks' of a clock at times reported by a BackOff. -// -// Ticks will continue to arrive when the previous operation is still running, -// so operations that take a while to fail could run in quick succession. -type Ticker struct { - C <-chan time.Time - c chan time.Time - b BackOffContext - stop chan struct{} - stopOnce sync.Once -} - -// NewTicker returns a new Ticker containing a channel that will send the time at times -// specified by the BackOff argument. Ticker is guaranteed to tick at least once. -// The channel is closed when Stop method is called or BackOff stops. -func NewTicker(b BackOff) *Ticker { - c := make(chan time.Time) - t := &Ticker{ - C: c, - c: c, - b: ensureContext(b), - stop: make(chan struct{}), - } - go t.run() - runtime.SetFinalizer(t, (*Ticker).Stop) - return t -} - -// Stop turns off a ticker. After Stop, no more ticks will be sent. -func (t *Ticker) Stop() { - t.stopOnce.Do(func() { close(t.stop) }) -} - -func (t *Ticker) run() { - c := t.c - defer close(c) - t.b.Reset() - - // Ticker is guaranteed to tick at least once. - afterC := t.send(time.Now()) - - for { - if afterC == nil { - return - } - - select { - case tick := <-afterC: - afterC = t.send(tick) - case <-t.stop: - t.c = nil // Prevent future ticks from being sent to the channel. - return - case <-t.b.Context().Done(): - return - } - } -} - -func (t *Ticker) send(tick time.Time) <-chan time.Time { - select { - case t.c <- tick: - case <-t.stop: - return nil - } - - next := t.b.NextBackOff() - if next == Stop { - t.Stop() - return nil - } - - return time.After(next) -} diff --git a/vendor/github.com/cenkalti/backoff/ticker_test.go b/vendor/github.com/cenkalti/backoff/ticker_test.go deleted file mode 100644 index 085828cca5276..0000000000000 --- a/vendor/github.com/cenkalti/backoff/ticker_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package backoff - -import ( - "errors" - "fmt" - "log" - "testing" - "time" - - "golang.org/x/net/context" -) - -func TestTicker(t *testing.T) { - const successOn = 3 - var i = 0 - - // This function is successful on "successOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - if i == successOn { - log.Println("OK") - return nil - } - - log.Println("error") - return errors.New("error") - } - - b := NewExponentialBackOff() - ticker := NewTicker(b) - - var err error - for _ = range ticker.C { - if err = f(); err != nil { - t.Log(err) - continue - } - - break - } - if err != nil { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != successOn { - t.Errorf("invalid number of retries: %d", i) - } -} - -func TestTickerContext(t *testing.T) { - const cancelOn = 3 - var i = 0 - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // This function cancels context on "cancelOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - // cancelling the context in the operation function is not a typical - // use-case, however it allows to get predictable test results. - if i == cancelOn { - cancel() - } - - log.Println("error") - return fmt.Errorf("error (%d)", i) - } - - b := WithContext(NewConstantBackOff(time.Millisecond), ctx) - ticker := NewTicker(b) - - var err error - for _ = range ticker.C { - if err = f(); err != nil { - t.Log(err) - continue - } - - break - } - if err == nil { - t.Errorf("error is unexpectedly nil") - } - if err.Error() != "error (3)" { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != cancelOn { - t.Errorf("invalid number of retries: %d", i) - } -} diff --git a/vendor/github.com/cenkalti/backoff/tries.go b/vendor/github.com/cenkalti/backoff/tries.go deleted file mode 100644 index d2da7308b6aa9..0000000000000 --- a/vendor/github.com/cenkalti/backoff/tries.go +++ /dev/null @@ -1,35 +0,0 @@ -package backoff - -import "time" - -/* -WithMaxTries creates a wrapper around another BackOff, which will -return Stop if NextBackOff() has been called too many times since -the last time Reset() was called - -Note: Implementation is not thread-safe. -*/ -func WithMaxTries(b BackOff, max uint64) BackOff { - return &backOffTries{delegate: b, maxTries: max} -} - -type backOffTries struct { - delegate BackOff - maxTries uint64 - numTries uint64 -} - -func (b *backOffTries) NextBackOff() time.Duration { - if b.maxTries > 0 { - if b.maxTries <= b.numTries { - return Stop - } - b.numTries++ - } - return b.delegate.NextBackOff() -} - -func (b *backOffTries) Reset() { - b.numTries = 0 - b.delegate.Reset() -} diff --git a/vendor/github.com/cenkalti/backoff/tries_test.go b/vendor/github.com/cenkalti/backoff/tries_test.go deleted file mode 100644 index bd1021143ed45..0000000000000 --- a/vendor/github.com/cenkalti/backoff/tries_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package backoff - -import ( - "math/rand" - "testing" - "time" -) - -func TestMaxTriesHappy(t *testing.T) { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - max := 17 + r.Intn(13) - bo := WithMaxTries(&ZeroBackOff{}, uint64(max)) - - // Load up the tries count, but reset should clear the record - for ix := 0; ix < max/2; ix++ { - bo.NextBackOff() - } - bo.Reset() - - // Now fill the tries count all the way up - for ix := 0; ix < max; ix++ { - d := bo.NextBackOff() - if d == Stop { - t.Errorf("returned Stop on try %d", ix) - } - } - - // We have now called the BackOff max number of times, we expect - // the next result to be Stop, even if we try it multiple times - for ix := 0; ix < 7; ix++ { - d := bo.NextBackOff() - if d != Stop { - t.Error("invalid next back off") - } - } - - // Reset makes it all work again - bo.Reset() - d := bo.NextBackOff() - if d == Stop { - t.Error("returned Stop after reset") - } - -} - -func TestMaxTriesZero(t *testing.T) { - // It might not make sense, but its okay to send a zero - bo := WithMaxTries(&ZeroBackOff{}, uint64(0)) - for ix := 0; ix < 11; ix++ { - d := bo.NextBackOff() - if d == Stop { - t.Errorf("returned Stop on try %d", ix) - } - } -} diff --git a/vendor/github.com/vulcand/predicate/builder/builder.go b/vendor/github.com/vulcand/predicate/builder/builder.go new file mode 100644 index 0000000000000..148673966e373 --- /dev/null +++ b/vendor/github.com/vulcand/predicate/builder/builder.go @@ -0,0 +1,169 @@ +/* +Copyright 2014-2018 Vulcand Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +// package builder is used to construct predicate +// expressions using builder functions. +package builder + +import ( + "fmt" + "strings" +) + +// Expr is an expression builder, +// used to create expressions in rules definitions +type Expr interface { + // String serializes expression into format parsed by rules engine + // (golang based syntax) + String() string +} + +// IdentiferExpr is identifer expression +type IdentifierExpr string + +// String serializes identifer expression into format parsed by rules engine +func (i IdentifierExpr) String() string { + return string(i) +} + +// Identifer returns identifier expression +func Identifier(v string) IdentifierExpr { + return IdentifierExpr(v) +} + +// String returns string expression +func String(v string) StringExpr { + return StringExpr(v) +} + +// StringExpr is a string expression +type StringExpr string + +func (s StringExpr) String() string { + return fmt.Sprintf("%q", string(s)) +} + +// StringsExpr is a slice of strings +type StringsExpr []string + +func (s StringsExpr) String() string { + var out []string + for _, val := range s { + out = append(out, fmt.Sprintf("%q", val)) + } + return strings.Join(out, ",") +} + +// Equals returns equals expression +func Equals(left, right Expr) EqualsExpr { + return EqualsExpr{Left: left, Right: right} +} + +// EqualsExpr constructs function expression used in rules specifications +// that checks if one value is equal to another +// e.g. equals("a", "b") where Left is "a" and right is "b" +type EqualsExpr struct { + // Left is a left argument of Equals expression + Left Expr + // Value to check + Right Expr +} + +// String returns function call expression used in rules +func (i EqualsExpr) String() string { + return fmt.Sprintf("equals(%v, %v)", i.Left, i.Right) +} + +// Not returns ! expression +func Not(expr Expr) NotExpr { + return NotExpr{Expr: expr} +} + +// NotExpr constructs function expression used in rules specifications +// that negates the result of the boolean predicate +// e.g. ! equals"a", "b") where Left is "a" and right is "b" +type NotExpr struct { + // Expr is an expression to negate + Expr Expr +} + +// String returns function call expression used in rules +func (n NotExpr) String() string { + return fmt.Sprintf("!%v", n.Expr) +} + +// Contains returns contains function call expression +func Contains(a, b Expr) ContainsExpr { + return ContainsExpr{Left: a, Right: b} +} + +// ContainsExpr constructs function expression used in rules specifications +// that checks if one value contains the other, e.g. +// contains([]string{"a"}, "b") where left is []string{"a"} and right is "b" +type ContainsExpr struct { + // Left is a left argument of Contains expression + Left Expr + // Right is a right argument of Contains expression + Right Expr +} + +// String rturns function call expression used in rules +func (i ContainsExpr) String() string { + return fmt.Sprintf("contains(%v, %v)", i.Left, i.Right) +} + +// And returns && expression +func And(left, right Expr) AndExpr { + return AndExpr{ + Left: left, + Right: right, + } +} + +// AndExpr returns && expression +type AndExpr struct { + // Left is a left argument of && operator expression + Left Expr + // Right is a right argument of && operator expression + Right Expr +} + +// String returns expression text used in rules +func (a AndExpr) String() string { + return fmt.Sprintf("%v && %v", a.Left, a.Right) +} + +// Or returns || expression +func Or(left, right Expr) OrExpr { + return OrExpr{ + Left: left, + Right: right, + } +} + +// OrExpr returns || expression +type OrExpr struct { + // Left is a left argument of || operator expression + Left Expr + // Right is a right argument of || operator expression + Right Expr +} + +// String returns expression text used in rules +func (a OrExpr) String() string { + return fmt.Sprintf("%v || %v", a.Left, a.Right) +} diff --git a/vendor/github.com/vulcand/predicate/lib.go b/vendor/github.com/vulcand/predicate/lib.go index 013e16c5068b5..c313c1bae71a4 100644 --- a/vendor/github.com/vulcand/predicate/lib.go +++ b/vendor/github.com/vulcand/predicate/lib.go @@ -119,6 +119,14 @@ func Or(a, b BoolPredicate) BoolPredicate { } } +// Not is a boolean predicate that calls a boolean predicate +// and returns negated result +func Not(a BoolPredicate) BoolPredicate { + return func() bool { + return !a() + } +} + // GetFieldByTag returns a field from the object based on the tag func GetFieldByTag(ival interface{}, tagName string, fieldNames []string) (interface{}, error) { if len(fieldNames) == 0 { diff --git a/vendor/github.com/vulcand/predicate/parse.go b/vendor/github.com/vulcand/predicate/parse.go index 03518cf164045..b80e7361a4bd7 100644 --- a/vendor/github.com/vulcand/predicate/parse.go +++ b/vendor/github.com/vulcand/predicate/parse.go @@ -59,6 +59,16 @@ func (p *predicateParser) parseNode(node ast.Node) (interface{}, error) { return callFunction(fn, arguments) case *ast.ParenExpr: return p.parseNode(n.X) + case *ast.UnaryExpr: + joinFn, err := p.getJoinFunction(n.Op) + if err != nil { + return nil, err + } + node, err := p.parseNode(n.X) + if err != nil { + return nil, err + } + return callFunction(joinFn, []interface{}{node}) } return nil, trace.BadParameter("unsupported %T", node) } @@ -122,6 +132,20 @@ func (p *predicateParser) evaluateExpr(n ast.Expr) (interface{}, error) { return nil, trace.Wrap(err) } return val, nil + case *ast.CallExpr: + name, err := getIdentifier(l.Fun) + if err != nil { + return nil, err + } + fn, err := p.getFunction(name) + if err != nil { + return nil, err + } + arguments, err := p.evaluateArguments(l.Args) + if err != nil { + return nil, err + } + return callFunction(fn, arguments) default: return nil, trace.BadParameter("%T is not supported", n) } @@ -161,6 +185,8 @@ func (p *predicateParser) joinPredicates(op token.Token, a, b interface{}) (inte func (p *predicateParser) getJoinFunction(op token.Token) (interface{}, error) { var fn interface{} switch op { + case token.NOT: + fn = p.d.Operators.NOT case token.LAND: fn = p.d.Operators.AND case token.LOR: diff --git a/vendor/github.com/vulcand/predicate/parse_test.go b/vendor/github.com/vulcand/predicate/parse_test.go index 24392d41d41eb..dd11b816b76cb 100644 --- a/vendor/github.com/vulcand/predicate/parse_test.go +++ b/vendor/github.com/vulcand/predicate/parse_test.go @@ -30,6 +30,7 @@ func (s *PredicateSuite) getParserWithOpts(c *check.C, getID GetIdentifierFn, ge NEQ: numberNEQ, LE: numberLE, GE: numberGE, + NOT: numberNOT, }, Functions: map[string]interface{}{ "DivisibleBy": divisibleBy, @@ -38,6 +39,12 @@ func (s *PredicateSuite) getParserWithOpts(c *check.C, getID GetIdentifierFn, ge "number.DivisibleBy": divisibleBy, "Equals": Equals, "Contains": Contains, + "fnreturn": func(arg interface{}) (interface{}, error) { + return arg, nil + }, + "fnerr": func(arg interface{}) (interface{}, error) { + return nil, trace.BadParameter("don't like this parameter") + }, }, GetIdentifier: getID, GetProperty: getProperty, @@ -58,6 +65,35 @@ func (s *PredicateSuite) TestSinglePredicate(c *check.C) { c.Assert(fn(3), check.Equals, false) } +func (s *PredicateSuite) TestSinglePredicateNot(c *check.C) { + p := s.getParser(c) + + pr, err := p.Parse("!DivisibleBy(2)") + c.Assert(err, check.IsNil) + c.Assert(pr, check.FitsTypeOf, divisibleBy(2)) + fn := pr.(numberPredicate) + c.Assert(fn(2), check.Equals, false) + c.Assert(fn(3), check.Equals, true) +} + +func (s *PredicateSuite) TestSinglePredicateWithFunc(c *check.C) { + p := s.getParser(c) + + pr, err := p.Parse("DivisibleBy(fnreturn(2))") + c.Assert(err, check.IsNil) + c.Assert(pr, check.FitsTypeOf, divisibleBy(2)) + fn := pr.(numberPredicate) + c.Assert(fn(2), check.Equals, true) + c.Assert(fn(3), check.Equals, false) +} + +func (s *PredicateSuite) TestSinglePredicateWithFuncErr(c *check.C) { + p := s.getParser(c) + + _, err := p.Parse("DivisibleBy(fnerr(2))") + c.Assert(err, check.NotNil) +} + func (s *PredicateSuite) TestModulePredicate(c *check.C) { p := s.getParser(c) @@ -380,7 +416,6 @@ func (s *PredicateSuite) TestUnhappyCases(c *check.C) { "Remainder(banana)", // unsupported argument "Remainder(1, 2)", // unsupported arguments count "Remainder(Len)", // unsupported argument - `Remainder(Len("Ho"))`, // unsupported argument "Bla(1)", // unknown method call "0.2 && Remainder(1)", // unsupported value `Len("Ho") && 0.2`, // unsupported value @@ -405,6 +440,12 @@ func divisibleBy(divisor int) numberPredicate { } } +func numberNOT(a numberPredicate) numberPredicate { + return func(v int) bool { + return !a(v) + } +} + func numberAND(a, b numberPredicate) numberPredicate { return func(v int) bool { return a(v) && b(v) diff --git a/vendor/github.com/vulcand/predicate/predicate.go b/vendor/github.com/vulcand/predicate/predicate.go index 256e701461851..c5aacb4881671 100644 --- a/vendor/github.com/vulcand/predicate/predicate.go +++ b/vendor/github.com/vulcand/predicate/predicate.go @@ -1,3 +1,20 @@ +/* +Copyright 2014-2018 Vulcand Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + /* Predicate package used to create interpreted mini languages with Go syntax - mostly to define various predicates for configuration, e.g. Latency() > 40 || ErrorRate() > 0.5. @@ -76,6 +93,7 @@ type Operators struct { OR interface{} AND interface{} + NOT interface{} } // Parser takes the string with expression and calls the operators and functions.