From 3bf28c50dce16319c15dee1384bd5aec05387b1b Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sun, 8 Apr 2018 14:37:33 -0700 Subject: [PATCH] Start work on rotation, updates #1860 --- Gopkg.lock | 242 ++------- Gopkg.toml | 2 +- constants.go | 9 + e | 2 +- lib/auth/apiserver.go | 46 +- lib/auth/auth.go | 57 ++- lib/auth/auth_with_roles.go | 66 ++- lib/auth/clt.go | 66 ++- lib/auth/helpers.go | 2 +- lib/auth/init.go | 455 +++++------------ lib/auth/methods.go | 142 ------ lib/auth/middleware.go | 11 +- lib/auth/permissions.go | 65 ++- lib/auth/register.go | 32 +- lib/auth/rotate.go | 475 ++++++++++++++++++ lib/auth/state.go | 237 +++++++++ lib/backend/backend.go | 5 + lib/backend/boltbk/boltbk.go | 46 ++ lib/backend/boltbk/boltbk_test.go | 4 + lib/backend/dir/impl.go | 62 ++- lib/backend/dir/impl_test.go | 18 +- lib/backend/dynamo/dynamodbbk.go | 45 ++ lib/backend/dynamo/dynamodbbk_test.go | 4 + lib/backend/etcdbk/etcd.go | 20 + lib/backend/etcdbk/etcd_test.go | 4 + lib/backend/test/suite.go | 30 ++ lib/config/fileconf.go | 2 +- lib/defaults/defaults.go | 7 + lib/fixtures/fixtures.go | 10 +- lib/httplib/httplib.go | 4 +- lib/reversetunnel/remotesite.go | 113 +++-- lib/reversetunnel/srv.go | 8 +- lib/service/cfg.go | 4 + lib/service/connect.go | 409 +++++++++++++++ lib/service/service.go | 328 ++++++++---- lib/service/signals.go | 11 +- lib/service/supervisor.go | 44 +- lib/services/authority.go | 253 +++++++++- lib/services/local/configuration_test.go | 4 +- lib/services/local/presence_test.go | 4 +- lib/services/local/trust.go | 46 +- lib/services/local/users.go | 2 +- lib/services/parser.go | 41 ++ lib/services/resource.go | 10 + lib/services/role.go | 2 +- lib/services/server.go | 7 +- lib/services/suite/suite.go | 24 +- lib/services/trust.go | 11 +- lib/srv/regular/sshserver.go | 31 +- lib/utils/copy.go | 22 + tool/tctl/common/auth_command.go | 41 +- tool/tctl/common/status_command.go | 118 +++++ tool/tctl/common/tctl.go | 8 +- tool/tctl/main.go | 1 + tool/teleport/common/teleport.go | 11 +- vendor/github.com/cenkalti/backoff/.gitignore | 22 - .../github.com/cenkalti/backoff/.travis.yml | 9 - vendor/github.com/cenkalti/backoff/LICENSE | 20 - vendor/github.com/cenkalti/backoff/README.md | 30 -- vendor/github.com/cenkalti/backoff/backoff.go | 66 --- .../cenkalti/backoff/backoff_test.go | 27 - vendor/github.com/cenkalti/backoff/context.go | 60 --- .../cenkalti/backoff/context_test.go | 26 - .../cenkalti/backoff/example_test.go | 73 --- .../cenkalti/backoff/exponential.go | 156 ------ .../cenkalti/backoff/exponential_test.go | 108 ---- vendor/github.com/cenkalti/backoff/retry.go | 78 --- .../github.com/cenkalti/backoff/retry_test.go | 99 ---- vendor/github.com/cenkalti/backoff/ticker.go | 81 --- .../cenkalti/backoff/ticker_test.go | 94 ---- vendor/github.com/cenkalti/backoff/tries.go | 35 -- .../github.com/cenkalti/backoff/tries_test.go | 55 -- .../vulcand/predicate/builder/builder.go | 169 +++++++ vendor/github.com/vulcand/predicate/lib.go | 8 + vendor/github.com/vulcand/predicate/parse.go | 26 + .../vulcand/predicate/parse_test.go | 43 +- .../github.com/vulcand/predicate/predicate.go | 18 + 77 files changed, 2991 insertions(+), 2035 deletions(-) create mode 100644 lib/auth/rotate.go create mode 100644 lib/auth/state.go create mode 100644 lib/service/connect.go create mode 100644 tool/tctl/common/status_command.go delete mode 100644 vendor/github.com/cenkalti/backoff/.gitignore delete mode 100644 vendor/github.com/cenkalti/backoff/.travis.yml delete mode 100644 vendor/github.com/cenkalti/backoff/LICENSE delete mode 100644 vendor/github.com/cenkalti/backoff/README.md delete mode 100644 vendor/github.com/cenkalti/backoff/backoff.go delete mode 100644 vendor/github.com/cenkalti/backoff/backoff_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/context.go delete mode 100644 vendor/github.com/cenkalti/backoff/context_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/example_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/exponential.go delete mode 100644 vendor/github.com/cenkalti/backoff/exponential_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/retry.go delete mode 100644 vendor/github.com/cenkalti/backoff/retry_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/ticker.go delete mode 100644 vendor/github.com/cenkalti/backoff/ticker_test.go delete mode 100644 vendor/github.com/cenkalti/backoff/tries.go delete mode 100644 vendor/github.com/cenkalti/backoff/tries_test.go create mode 100644 vendor/github.com/vulcand/predicate/builder/builder.go diff --git a/Gopkg.lock b/Gopkg.lock index fc284566a2db8..061ed07991ec8 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -4,19 +4,13 @@ [[projects]] branch = "master" name = "github.com/Azure/go-ansiterm" - packages = [ - ".", - "winterm" - ] + packages = [".","winterm"] revision = "19f72df4d05d31cbe1c56bfc8045c96babff6c7e" [[projects]] branch = "master" name = "github.com/alecthomas/template" - packages = [ - ".", - "parse" - ] + packages = [".","parse"] revision = "a0175ee3bccc567396460bf5acd36800cb10c49c" [[projects]] @@ -27,39 +21,7 @@ [[projects]] name = "github.com/aws/aws-sdk-go" - packages = [ - "aws", - "aws/awserr", - "aws/awsutil", - "aws/client", - "aws/client/metadata", - "aws/corehandlers", - "aws/credentials", - "aws/credentials/ec2rolecreds", - "aws/credentials/endpointcreds", - "aws/credentials/stscreds", - "aws/defaults", - "aws/ec2metadata", - "aws/endpoints", - "aws/request", - "aws/session", - "aws/signer/v4", - "internal/shareddefaults", - "private/protocol", - "private/protocol/json/jsonutil", - "private/protocol/jsonrpc", - "private/protocol/query", - "private/protocol/query/queryutil", - "private/protocol/rest", - "private/protocol/restxml", - "private/protocol/xml/xmlutil", - "service/dynamodb", - "service/dynamodb/dynamodbattribute", - "service/s3", - "service/s3/s3iface", - "service/s3/s3manager", - "service/sts" - ] + packages = ["aws","aws/awserr","aws/awsutil","aws/client","aws/client/metadata","aws/corehandlers","aws/credentials","aws/credentials/ec2rolecreds","aws/credentials/endpointcreds","aws/credentials/stscreds","aws/defaults","aws/ec2metadata","aws/endpoints","aws/request","aws/session","aws/signer/v4","internal/shareddefaults","private/protocol","private/protocol/json/jsonutil","private/protocol/jsonrpc","private/protocol/query","private/protocol/query/queryutil","private/protocol/rest","private/protocol/restxml","private/protocol/xml/xmlutil","service/dynamodb","service/dynamodb/dynamodbattribute","service/s3","service/s3/s3iface","service/s3/s3manager","service/sts"] revision = "a201bf33b18ad4ab54344e4bc26b87eb6ad37b8e" version = "v1.12.25" @@ -81,19 +43,9 @@ [[projects]] name = "github.com/boombuler/barcode" - packages = [ - ".", - "qr", - "utils" - ] + packages = [".","qr","utils"] revision = "fe0f26ff6d26693948ee8189aa064ee8c54141fa" -[[projects]] - name = "github.com/cenkalti/backoff" - packages = ["."] - revision = "61153c768f31ee5f130071d08fc82b85208528de" - version = "v1.1.0" - [[projects]] name = "github.com/codahale/hdrhistogram" packages = ["."] @@ -101,28 +53,14 @@ [[projects]] name = "github.com/coreos/etcd" - packages = [ - "client", - "pkg/pathutil", - "pkg/srv", - "pkg/tlsutil", - "pkg/transport", - "pkg/types", - "version" - ] + packages = ["client","pkg/pathutil","pkg/srv","pkg/tlsutil","pkg/transport","pkg/types","version"] revision = "9d43462d174c664f5edf313dec0de31e1ef4ed47" version = "v3.2.6" [[projects]] branch = "master" name = "github.com/coreos/go-oidc" - packages = [ - "http", - "jose", - "key", - "oauth2", - "oidc" - ] + packages = ["http","jose","key","oauth2","oidc"] revision = "e51edf2e47e65e5708600d4da6fca1388ee437b4" source = "github.com/gravitational/go-oidc" @@ -134,11 +72,7 @@ [[projects]] name = "github.com/coreos/pkg" - packages = [ - "health", - "httputil", - "timeutil" - ] + packages = ["health","httputil","timeutil"] revision = "1914e367e85eaf0c25d495b48e060dfe6190f8d0" [[projects]] @@ -177,31 +111,18 @@ [[projects]] name = "github.com/golang/protobuf" - packages = [ - "jsonpb", - "proto", - "protoc-gen-go/descriptor", - "ptypes/any", - "ptypes/empty" - ] + packages = ["jsonpb","proto","protoc-gen-go/descriptor","ptypes/any","ptypes/empty"] revision = "8ee79997227bf9b34611aee7946ae64735e6fd93" [[projects]] name = "github.com/google/gops" - packages = [ - "agent", - "internal", - "signal" - ] + packages = ["agent","internal","signal"] revision = "fa6968806ca68b7db113256f300d82b5206a2c3c" version = "v0.3.1" [[projects]] name = "github.com/gravitational/configure" - packages = [ - "cstrings", - "jsonschema" - ] + packages = ["cstrings","jsonschema"] revision = "1db4b84fe9dbbbaf40827aa714dcca17b368de2c" [[projects]] @@ -218,20 +139,13 @@ [[projects]] name = "github.com/gravitational/license" - packages = [ - ".", - "constants" - ] + packages = [".","constants"] revision = "102213511ace56c97ccf1eef645835e16f84d130" version = "0.0.4" [[projects]] name = "github.com/gravitational/reporting" - packages = [ - ".", - "client", - "types" - ] + packages = [".","client","types"] revision = "3c4a4e96fb5896e14fe29da7fcce14b8d93f3965" version = "0.0.4" @@ -255,12 +169,7 @@ [[projects]] name = "github.com/grpc-ecosystem/grpc-gateway" - packages = [ - "runtime", - "runtime/internal", - "third_party/googleapis/google/api", - "utilities" - ] + packages = ["runtime","runtime/internal","third_party/googleapis/google/api","utilities"] revision = "a8f25bd1ab549f8b87afd48aa9181221e9d439bb" version = "v1.1.0" @@ -299,10 +208,7 @@ [[projects]] name = "github.com/mailgun/lemma" - packages = [ - "random", - "secret" - ] + packages = ["random","secret"] revision = "e8b0cd607f5855f9a4a33f8ae5d033178f559964" version = "0.0.2" @@ -319,10 +225,7 @@ [[projects]] branch = "alexander/copy" name = "github.com/mailgun/oxy" - packages = [ - "forward", - "utils" - ] + packages = ["forward","utils"] revision = "0c3e45a1f7b20e1f818612ad4f42abfc0923f3d5" [[projects]] @@ -343,11 +246,7 @@ [[projects]] branch = "master" name = "github.com/mdp/rsc" - packages = [ - "gf256", - "qr", - "qr/coding" - ] + packages = ["gf256","qr","qr/coding"] revision = "90f07065088deccf50b28eb37c93dad3078c0f3c" [[projects]] @@ -364,11 +263,7 @@ [[projects]] name = "github.com/pquerna/otp" - packages = [ - ".", - "hotp", - "totp" - ] + packages = [".","hotp","totp"] revision = "54653902c20e47f3417541d35435cb6d6162e28a" [[projects]] @@ -385,11 +280,7 @@ [[projects]] name = "github.com/prometheus/common" - packages = [ - "expfmt", - "internal/bitbucket.org/ww/goautoneg", - "model" - ] + packages = ["expfmt","internal/bitbucket.org/ww/goautoneg","model"] revision = "50022896a67062a54a54f268fb6fe4bf90b34859" [[projects]] @@ -399,20 +290,13 @@ [[projects]] name = "github.com/russellhaering/gosaml2" - packages = [ - ".", - "types" - ] + packages = [".","types"] revision = "8908227c114abe0b63b1f0606abae72d11bf632a" [[projects]] branch = "master" name = "github.com/russellhaering/goxmldsig" - packages = [ - ".", - "etreeutils", - "types" - ] + packages = [".","etreeutils","types"] revision = "605161228693b2efadce55323c9c661a40c5fbaa" [[projects]] @@ -422,10 +306,7 @@ [[projects]] name = "github.com/sirupsen/logrus" - packages = [ - ".", - "hooks/syslog" - ] + packages = [".","hooks/syslog"] revision = "8ab1e1b91d5f1a6124287906f8b0402844d3a2b3" source = "github.com/gravitational/logrus" version = "1.0.0" @@ -443,18 +324,14 @@ [[projects]] name = "github.com/vulcand/oxy" - packages = [ - "connlimit", - "ratelimit", - "utils" - ] + packages = ["connlimit","ratelimit","utils"] revision = "5725fecc9a4f3aa6fdc3ffd29cef771241809add" [[projects]] name = "github.com/vulcand/predicate" - packages = ["."] - revision = "939c094524d124c55fa8afe0e077701db4a865e2" - version = "v1.0.0" + packages = [".","builder"] + revision = "8fbfb3ab0e94276b6b58bec378600829adc7a203" + version = "v1.1.0" [[projects]] name = "github.com/xeipuuv/gojsonpointer" @@ -475,64 +352,23 @@ [[projects]] branch = "master" name = "golang.org/x/crypto" - packages = [ - "bcrypt", - "blowfish", - "curve25519", - "ed25519", - "ed25519/internal/edwards25519", - "internal/chacha20", - "nacl/secretbox", - "poly1305", - "salsa20/salsa", - "ssh", - "ssh/agent", - "ssh/terminal" - ] + packages = ["bcrypt","blowfish","curve25519","ed25519","ed25519/internal/edwards25519","internal/chacha20","nacl/secretbox","poly1305","salsa20/salsa","ssh","ssh/agent","ssh/terminal"] revision = "b2aa35443fbc700ab74c586ae79b81c171851023" [[projects]] name = "golang.org/x/net" - packages = [ - "context", - "http2", - "http2/hpack", - "idna", - "internal/timeseries", - "lex/httplex", - "trace", - "websocket" - ] + packages = ["context","http2","http2/hpack","idna","internal/timeseries","lex/httplex","trace","websocket"] revision = "48359f4f600b3a2d5cf657458e3f940021631a56" [[projects]] branch = "master" name = "golang.org/x/sys" - packages = [ - "unix", - "windows" - ] + packages = ["unix","windows"] revision = "1d206c9fa8975fb4cf00df1dc8bf3283dc24ba0e" [[projects]] name = "golang.org/x/text" - packages = [ - "encoding", - "encoding/internal", - "encoding/internal/identifier", - "encoding/unicode", - "internal/gen", - "internal/triegen", - "internal/ucd", - "internal/utf8internal", - "runes", - "secure/bidirule", - "transform", - "unicode/bidi", - "unicode/cldr", - "unicode/norm", - "unicode/rangetable" - ] + packages = ["encoding","encoding/internal","encoding/internal/identifier","encoding/unicode","internal/gen","internal/triegen","internal/ucd","internal/utf8internal","runes","secure/bidirule","transform","unicode/bidi","unicode/cldr","unicode/norm","unicode/rangetable"] revision = "19e51611da83d6be54ddafce4a4af510cb3e9ea4" [[projects]] @@ -543,23 +379,7 @@ [[projects]] name = "google.golang.org/grpc" - packages = [ - ".", - "codes", - "connectivity", - "credentials", - "grpclb/grpc_lb_v1", - "grpclog", - "internal", - "keepalive", - "metadata", - "naming", - "peer", - "stats", - "status", - "tap", - "transport" - ] + packages = [".","codes","connectivity","credentials","grpclb/grpc_lb_v1","grpclog","internal","keepalive","metadata","naming","peer","stats","status","tap","transport"] revision = "b3ddf786825de56a4178401b7e174ee332173b66" version = "v1.5.2" @@ -581,6 +401,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "43348de969afeba8e1ab0c22221ee1a15b09c40f24aa8b7b96945a0c6bc030b3" + inputs-digest = "600efc6f221f0a0c99270cf1caa4d7bf3e607974c5e52c8b7ba3210bee185ec9" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 3f2cb24b5f292..7cda049addfac 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -36,7 +36,7 @@ ignored = ["github.com/Sirupsen/logrus"] [[constraint]] name = "github.com/vulcand/predicate" - version = "v1.0.0" + version = "v1.1.0" [[constraint]] name = "github.com/docker/docker" diff --git a/constants.go b/constants.go index d687e19bc0fd9..575687ff3c067 100644 --- a/constants.go +++ b/constants.go @@ -234,6 +234,15 @@ const ( // Syslog is a mode for syslog logging Syslog = "syslog" + + // HumanDateFormat is a human readable date formatting + HumanDateFormat = "Jan _2 15:04 UTC" + + // HumanDateFormatSeconds is a human readable date formatting with seconds + HumanDateFormatSeconds = "Jan _2 15:04:05 UTC" + + // HumanDateFormatMilli is a human readable date formatting with milliseconds + HumanDateFormatMilli = "Jan _2 15:04:05.000 UTC" ) // Component generates "component:subcomponent1:subcomponent2" strings used diff --git a/e b/e index be25b74542948..b438b4da56561 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit be25b7454294810170d7b77a13d88755ff690e53 +Subproject commit b438b4da56561bbfe2b9c1320d861685c4d097e8 diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 83429bd38e2fd..121d8144df3ec 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -64,16 +64,14 @@ func NewAPIServer(config *APIConfig) http.Handler { // Operations on certificate authorities srv.GET("/:version/domain", srv.withAuth(srv.getDomainName)) + srv.POST("/:version/authorities/:type", srv.withAuth(srv.upsertCertAuthority)) + srv.POST("/:version/authorities/:type/rotate", srv.withAuth(srv.rotateCertAuthority)) + srv.POST("/:version/authorities/:type/rotate/external", srv.withAuth(srv.rotateExternalCertAuthority)) srv.DELETE("/:version/authorities/:type/:domain", srv.withAuth(srv.deleteCertAuthority)) srv.GET("/:version/authorities/:type/:domain", srv.withAuth(srv.getCertAuthority)) srv.GET("/:version/authorities/:type", srv.withAuth(srv.getCertAuthorities)) - // DELETE IN: 2.6.0 - // Certificate exchange for cluster upgrades used to upgrade from 2.4.0 - // to 2.5.0 clusters. - srv.POST("/:version/exchangecerts", srv.withAuth(srv.exchangeCerts)) - // Generating certificates for user and host authorities srv.POST("/:version/ca/host/certs", srv.withAuth(srv.generateHostCert)) srv.POST("/:version/ca/user/certs", srv.withAuth(srv.generateUserCert)) @@ -647,14 +645,6 @@ func (s *APIServer) authenticateSSHUser(auth ClientI, w http.ResponseWriter, r * return auth.AuthenticateSSHUser(req) } -func (s *APIServer) exchangeCerts(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - var req ExchangeCertsRequest - if err := httplib.ReadJSON(r, &req); err != nil { - return nil, trace.Wrap(err) - } - return auth.ExchangeCerts(req) -} - func (s *APIServer) changePassword(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { var req services.ChangePasswordReq if err := httplib.ReadJSON(r, &req); err != nil { @@ -931,6 +921,17 @@ func (s *APIServer) generateServerKeys(auth ClientI, w http.ResponseWriter, r *h return keys, nil } +func (s *APIServer) rotateCertAuthority(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req RotateRequest + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + if err := auth.RotateCertAuthority(req); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + type upsertCertAuthorityRawReq struct { CA json.RawMessage `json:"ca"` TTL time.Duration `json:"ttl"` @@ -954,6 +955,25 @@ func (s *APIServer) upsertCertAuthority(auth ClientI, w http.ResponseWriter, r * return message("ok"), nil } +type rotateExternalCertAuthorityRawReq struct { + CA json.RawMessage `json:"ca"` +} + +func (s *APIServer) rotateExternalCertAuthority(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req rotateExternalCertAuthorityRawReq + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + ca, err := services.GetCertAuthorityMarshaler().UnmarshalCertAuthority(req.CA) + if err != nil { + return nil, trace.Wrap(err) + } + if err := auth.RotateExternalCertAuthority(ca); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + func (s *APIServer) getCertAuthorities(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { loadKeys, _, err := httplib.ParseBool(r.URL.Query(), "load_keys") if err != nil { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 4a04cc93ad737..11d2d1de2829e 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -24,8 +24,10 @@ limitations under the License. package auth import ( + "context" "crypto/x509" "fmt" + "math/rand" "net/url" "sync" "time" @@ -78,6 +80,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro if cfg.AuditLog == nil { cfg.AuditLog = events.NewDiscardAuditLog() } + closeCtx, cancelFunc := context.WithCancel(context.TODO()) as := AuthServer{ clusterName: cfg.ClusterName, bk: cfg.Backend, @@ -93,6 +96,8 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro oidcClients: make(map[string]*oidcClient), samlProviders: make(map[string]*samlProvider), githubClients: make(map[string]*githubClient), + cancelFunc: cancelFunc, + closeCtx: closeCtx, } for _, o := range opts { o(&as) @@ -100,6 +105,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro if as.clock == nil { as.clock = clockwork.NewRealClock() } + go as.runPeriodicOperations() return &as, nil } @@ -119,6 +125,9 @@ type AuthServer struct { clock clockwork.Clock bk backend.Backend + closeCtx context.Context + cancelFunc context.CancelFunc + sshca.Authority // AuthServiceName is a human-readable name of this CA. If several Auth services are running @@ -137,7 +146,37 @@ type AuthServer struct { clusterName services.ClusterName } +// runPeriodicOperations runs some periodic bookkeeping operations +// performed by auth server +func (a *AuthServer) runPeriodicOperations() { + // run periodic functions with a semi-random period + // to avoid contention on the database in case if there are multiple + // auth servers running - so they don't compete trying + // to update the same resources. + r := rand.New(rand.NewSource(a.clock.Now().UnixNano())) + period := defaults.HighResPollingPeriod + time.Duration(r.Intn(int(defaults.HighResPollingPeriod/time.Second)))*time.Second + log.Debugf("Ticking with period: %v", period) + ticker := time.NewTicker(period) + defer ticker.Stop() + for { + select { + case <-a.closeCtx.Done(): + return + case <-ticker.C: + err := a.autoRotateCertAuthorities() + if err != nil { + if trace.IsCompareFailed(err) { + log.Debugf("Cert authority has been updated concurrently: %v", err) + } else { + log.Errorf("Failed to perform cert rotation check: %v", err) + } + } + } + } +} + func (a *AuthServer) Close() error { + a.cancelFunc() if a.bk != nil { return trace.Wrap(a.bk.Close()) } @@ -303,7 +342,11 @@ func (s *AuthServer) generateUserCert(req certRequest) (*certs, error) { if err != nil { return nil, trace.Wrap(err) } - hostCA, err := s.Trust.GetCertAuthority(services.CertAuthID{ + // CHANGE IN (2.7.0) Use User CA and not host CA + // currently it is used for backwards compatibility + // as pre 2.6.0 remote clusters don't have TLS CAs stored + // for user certificate authorities. + userCA, err := s.Trust.GetCertAuthority(services.CertAuthID{ Type: services.HostCA, DomainName: clusterName, }, true) @@ -311,7 +354,7 @@ func (s *AuthServer) generateUserCert(req certRequest) (*certs, error) { return nil, trace.Wrap(err) } // generate TLS certificate - tlsAuthority, err := hostCA.TLSCA() + tlsAuthority, err := userCA.TLSCA() if err != nil { return nil, trace.Wrap(err) } @@ -588,10 +631,18 @@ func (s *AuthServer) GenerateToken(req GenerateTokenRequest) (string, error) { // ClientCertPool returns trusted x509 cerificate authority pool func (s *AuthServer) ClientCertPool() (*x509.CertPool, error) { pool := x509.NewCertPool() - authorities, err := s.GetCertAuthorities(services.HostCA, false) + var authorities []services.CertAuthority + hostCAs, err := s.GetCertAuthorities(services.HostCA, false) if err != nil { return nil, trace.Wrap(err) } + userCAs, err := s.GetCertAuthorities(services.UserCA, false) + if err != nil { + return nil, trace.Wrap(err) + } + authorities = append(authorities, hostCAs...) + authorities = append(authorities, userCAs...) + for _, auth := range authorities { for _, keyPair := range auth.GetTLSKeyPairs() { cert, err := tlsca.ParseCertificatePEM(keyPair.Cert) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 95ca76dacbe09..3862c81b73bfe 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -41,6 +41,10 @@ type AuthWithRoles struct { alog events.IAuditLog } +func (a *AuthWithRoles) actionWithContext(ctx *services.Context, namespace string, resource string, action string) error { + return a.checker.CheckAccessToRule(ctx, namespace, resource, action) +} + func (a *AuthWithRoles) action(namespace string, resource string, action string) error { return a.checker.CheckAccessToRule(&services.Context{User: a.user}, namespace, resource, action) } @@ -91,16 +95,6 @@ func (a *AuthWithRoles) AuthenticateSSHUser(req AuthenticateSSHRequest) (*SSHLog return a.authServer.AuthenticateSSHUser(req) } -// ExchangeCerts exchanges TLS certificates for established host certificate authorities -func (a *AuthWithRoles) ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) { - // exchange request has it's own authentication, however this limits the requests - // types to proxies to make it harder to break - if !a.checker.HasRole(string(teleport.RoleProxy)) { - return nil, trace.AccessDenied("this request can be only executed by proxy") - } - return a.authServer.ExchangeCerts(req) -} - func (a *AuthWithRoles) GetSessions(namespace string) ([]session.Session, error) { if err := a.action(namespace, services.KindSSHSession, services.VerbList); err != nil { return nil, trace.Wrap(err) @@ -134,16 +128,58 @@ func (a *AuthWithRoles) CreateCertAuthority(ca services.CertAuthority) error { return trace.BadParameter("not implemented") } -func (a *AuthWithRoles) UpsertCertAuthority(ca services.CertAuthority) error { +// Rotate starts or restarts certificate rotation process +func (a *AuthWithRoles) RotateCertAuthority(req RotateRequest) error { + if err := req.CheckAndSetDefaults(a.authServer.clock); err != nil { + return trace.Wrap(err) + } if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbCreate); err != nil { return trace.Wrap(err) } if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbUpdate); err != nil { return trace.Wrap(err) } + return a.authServer.RotateCertAuthority(req) +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is called by remote trusted cluster and is used to update +// only public keys and certificates of the certificate authority. +func (a *AuthWithRoles) RotateExternalCertAuthority(ca services.CertAuthority) error { + if ca == nil { + return trace.BadParameter("missing certificate authority") + } + ctx := &services.Context{User: a.user, Resource: ca} + if err := a.actionWithContext(ctx, defaults.Namespace, services.KindCertAuthority, services.VerbRotate); err != nil { + return trace.Wrap(err) + } + return a.authServer.RotateExternalCertAuthority(ca) +} + +func (a *AuthWithRoles) UpsertCertAuthority(ca services.CertAuthority) error { + if ca == nil { + return trace.BadParameter("missing certificate authority") + } + ctx := &services.Context{User: a.user, Resource: ca} + if err := a.actionWithContext(ctx, defaults.Namespace, services.KindCertAuthority, services.VerbCreate); err != nil { + return trace.Wrap(err) + } + if err := a.actionWithContext(ctx, defaults.Namespace, services.KindCertAuthority, services.VerbUpdate); err != nil { + return trace.Wrap(err) + } return a.authServer.UpsertCertAuthority(ca) } +func (a *AuthWithRoles) CompareAndSwapCertAuthority(new, existing services.CertAuthority) error { + if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbCreate); err != nil { + return trace.Wrap(err) + } + if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbUpdate); err != nil { + return trace.Wrap(err) + } + return a.authServer.CompareAndSwapCertAuthority(new, existing) +} + func (a *AuthWithRoles) GetCertAuthorities(caType services.CertAuthType, loadKeys bool) ([]services.CertAuthority, error) { if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbList); err != nil { return nil, trace.Wrap(err) @@ -156,7 +192,6 @@ func (a *AuthWithRoles) GetCertAuthorities(caType services.CertAuthType, loadKey return nil, trace.Wrap(err) } } - return a.authServer.GetCertAuthorities(caType, loadKeys) } @@ -172,13 +207,6 @@ func (a *AuthWithRoles) GetCertAuthority(id services.CertAuthID, loadKeys bool) return a.authServer.GetCertAuthority(id, loadKeys) } -func (a *AuthWithRoles) GetAnyCertAuthority(id services.CertAuthID) (services.CertAuthority, error) { - if err := a.action(defaults.Namespace, services.KindCertAuthority, services.VerbReadNoSecrets); err != nil { - return nil, trace.Wrap(err) - } - return a.authServer.GetAnyCertAuthority(id) -} - func (a *AuthWithRoles) GetDomainName() (string, error) { // anyone can read it, no harm in that return a.authServer.GetDomainName() diff --git a/lib/auth/clt.go b/lib/auth/clt.go index c394a3b919387..fd1a4a4966385 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -303,6 +303,32 @@ func (c *Client) CreateCertAuthority(ca services.CertAuthority) error { return trace.BadParameter("not implemented") } +// Rotate starts or restart certificate authority rotation request +func (c *Client) RotateCertAuthority(req RotateRequest) error { + caType := "all" + if req.Type != "" { + caType = string(req.Type) + } + _, err := c.PostJSON(c.Endpoint("authorities", caType, "rotate"), req) + return trace.Wrap(err) +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is called by remote trusted cluster and is used to update +// only public keys and certificates of the certificate authority. +func (c *Client) RotateExternalCertAuthority(ca services.CertAuthority) error { + if err := ca.Check(); err != nil { + return trace.Wrap(err) + } + data, err := services.GetCertAuthorityMarshaler().MarshalCertAuthority(ca) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(c.Endpoint("authorities", string(ca.GetType()), "rotate", "external"), + &rotateExternalCertAuthorityRawReq{CA: data}) + return trace.Wrap(err) +} + // UpsertCertAuthority updates or inserts new cert authority func (c *Client) UpsertCertAuthority(ca services.CertAuthority) error { if err := ca.Check(); err != nil { @@ -317,6 +343,12 @@ func (c *Client) UpsertCertAuthority(ca services.CertAuthority) error { return trace.Wrap(err) } +// CompareAndSwapCertAuthority updates certificate authority if existing certificate +// authority matches +func (c *Client) CompareAndSwapCertAuthority(new, existing services.CertAuthority) error { + return trace.BadParameter("this function is not supported on the client") +} + // GetCertAuthorities returns a list of certificate authorities func (c *Client) GetCertAuthorities(caType services.CertAuthType, loadKeys bool) ([]services.CertAuthority, error) { if err := caType.Check(); err != nil { @@ -358,11 +390,6 @@ func (c *Client) GetCertAuthority(id services.CertAuthID, loadSigningKeys bool) return services.GetCertAuthorityMarshaler().UnmarshalCertAuthority(out.Bytes()) } -// GetAnyCertAuthority returns certificate authority by given id whether it's activated or not -func (c *Client) GetAnyCertAuthority(id services.CertAuthID) (services.CertAuthority, error) { - return nil, trace.BadParameter("not implemented") -} - // DeleteCertAuthority deletes cert authority by ID func (c *Client) DeleteCertAuthority(id services.CertAuthID) error { if err := id.Check(); err != nil { @@ -925,23 +952,6 @@ func (c *Client) CreateWebSession(user string) (services.WebSession, error) { return services.GetWebSessionMarshaler().UnmarshalWebSession(out.Bytes()) } -// DELETE IN: 2.6.0 -// ExchangeCerts exchanges TLS certificates for established host certificate authorities -func (c *Client) ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) { - out, err := c.PostJSON( - c.Endpoint("exchangecerts"), - req, - ) - if err != nil { - return nil, trace.Wrap(err) - } - var re ExchangeCertsResponse - if err := json.Unmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return &re, nil -} - // AuthenticateWebUser authenticates web user, creates and returns web session // in case if authentication is successfull func (c *Client) AuthenticateWebUser(req AuthenticateUserRequest) (services.WebSession, error) { @@ -2266,6 +2276,14 @@ type ClientI interface { session.Service services.ClusterConfiguration + // RotateCertAuthority starts or restarts certificate authority rotation procedure + RotateCertAuthority(req RotateRequest) error + + // RotateExternalCertAuthority rotates external certificate authority, + // this method is called by remote trusted cluster and is used to update + // only public keys and certificates of the certificate authority. + RotateExternalCertAuthority(ca services.CertAuthority) error + // ValidateTrustedCluster validates trusted cluster token with // main cluster, in case if validation is successfull, main cluster // adds remote cluster @@ -2281,8 +2299,4 @@ type ClientI interface { // AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH // short lived certificates as a result AuthenticateSSHUser(req AuthenticateSSHRequest) (*SSHLoginResponse, error) - - // DELETE IN: 2.6.0 - // ExchangeCerts exchanges TLS certificates between host certificate authorities of trusted clusters - ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) } diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 87ea47eac6438..8e6833a591121 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -549,7 +549,7 @@ func NewServerIdentity(clt *AuthServer, hostID string, role teleport.Role) (*Ide if err != nil { return nil, trace.Wrap(err) } - return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } // clt limits required interface to the necessary methods diff --git a/lib/auth/init.go b/lib/auth/init.go index d16738d4fb2c6..0da69b09244be 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -22,7 +22,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" - "io/ioutil" "os" "path/filepath" "strings" @@ -123,25 +122,25 @@ type InitConfig struct { } // Init instantiates and configures an instance of AuthServer -func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, error) { +func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, error) { if cfg.DataDir == "" { - return nil, nil, trace.BadParameter("DataDir: data dir can not be empty") + return nil, trace.BadParameter("DataDir: data dir can not be empty") } if cfg.HostUUID == "" { - return nil, nil, trace.BadParameter("HostUUID: host UUID can not be empty") + return nil, trace.BadParameter("HostUUID: host UUID can not be empty") } domainName := cfg.ClusterName.GetClusterName() err := cfg.Backend.AcquireLock(domainName, 30*time.Second) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } defer cfg.Backend.ReleaseLock(domainName) // check that user CA and host CA are present and set the certs if needed asrv, err := NewAuthServer(&cfg, opts...) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } // INTERNAL: Authorities (plus Roles) and ReverseTunnels don't follow the @@ -149,7 +148,7 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err // singletons). However, we need to keep them around while Telekube uses them. for _, role := range cfg.Roles { if err := asrv.UpsertRole(role, backend.Forever); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created role: %v.", role) } @@ -157,17 +156,17 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err ca := cfg.Authorities[i] ca, err = services.GetCertAuthorityMarshaler().GenerateCertAuthority(ca) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if err := asrv.Trust.UpsertCertAuthority(ca); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created trusted certificate authority: %q, type: %q.", ca.GetName(), ca.GetType()) } for _, tunnel := range cfg.ReverseTunnels { if err := asrv.UpsertReverseTunnel(tunnel); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created reverse tunnel: %v.", tunnel) } @@ -175,7 +174,7 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err // set cluster level config on the backend and then force a sync of the cache. clusterConfig, err := asrv.GetClusterConfig() if err != nil && !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } // init a unique cluster ID, it must be set once only during the first // start so if it's already there, reuse it @@ -186,42 +185,42 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err } err = asrv.SetClusterConfig(cfg.ClusterConfig) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } // cluster name can only be set once. if it has already been set and we are // trying to update it to something else, hard fail. err = asrv.SetClusterName(cfg.ClusterName) if err != nil && !trace.IsAlreadyExists(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if trace.IsAlreadyExists(err) { cn, err := asrv.GetClusterName() if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if cn.GetClusterName() != cfg.ClusterName.GetClusterName() { - return nil, nil, trace.BadParameter("cannot rename cluster %q to %q: clusters cannot be renamed", cn.GetClusterName(), cfg.ClusterName.GetClusterName()) + return nil, trace.BadParameter("cannot rename cluster %q to %q: clusters cannot be renamed", cn.GetClusterName(), cfg.ClusterName.GetClusterName()) } } log.Debugf("Cluster configuration: %v.", cfg.ClusterName) err = asrv.SetStaticTokens(cfg.StaticTokens) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Updating cluster configuration: %v.", cfg.StaticTokens) err = asrv.SetAuthPreference(cfg.AuthPreference) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Updating cluster configuration: %v.", cfg.AuthPreference) // always create the default namespace err = asrv.UpsertNamespace(services.NewNamespace(defaults.Namespace)) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("Created namespace: %q.", defaults.Namespace) @@ -229,22 +228,31 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err defaultRole := services.NewAdminRole() err = asrv.CreateRole(defaultRole, backend.Forever) if err != nil && !trace.IsAlreadyExists(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } if !trace.IsAlreadyExists(err) { log.Infof("Created default admin role: %q.", defaultRole.GetName()) } // generate a user certificate authority if it doesn't exist - if _, err := asrv.GetCertAuthority(services.CertAuthID{DomainName: cfg.ClusterName.GetClusterName(), Type: services.UserCA}, false); err != nil { + userCA, err := asrv.GetCertAuthority(services.CertAuthID{DomainName: cfg.ClusterName.GetClusterName(), Type: services.UserCA}, false) + if err != nil { if !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("First start: generating user certificate authority.") priv, pub, err := asrv.GenerateKeyPair("") if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) + } + + keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: cfg.ClusterName.GetClusterName(), + Organization: []string{cfg.ClusterName.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return nil, trace.Wrap(err) } userCA := &services.CertAuthorityV2{ @@ -259,11 +267,25 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err Type: services.UserCA, SigningKeys: [][]byte{priv}, CheckingKeys: [][]byte{pub}, + TLSKeyPairs: []services.TLSKeyPair{{Cert: certPEM, Key: keyPEM}}, }, } if err := asrv.Trust.UpsertCertAuthority(userCA); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) + } + } else if len(userCA.GetTLSKeyPairs()) == 0 { + log.Infof("Migrate: generating TLS CA for existing host CA.") + keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: cfg.ClusterName.GetClusterName(), + Organization: []string{cfg.ClusterName.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return nil, trace.Wrap(err) + } + userCA.SetTLSKeyPairs([]services.TLSKeyPair{{Cert: certPEM, Key: keyPEM}}) + if err := asrv.Trust.UpsertCertAuthority(userCA); err != nil { + return nil, trace.Wrap(err) } } @@ -271,13 +293,13 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err hostCA, err := asrv.GetCertAuthority(services.CertAuthID{DomainName: cfg.ClusterName.GetClusterName(), Type: services.HostCA}, true) if err != nil { if !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } log.Infof("First start: generating host certificate authority.") priv, pub, err := asrv.GenerateKeyPair("") if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ @@ -285,10 +307,9 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err Organization: []string{cfg.ClusterName.GetClusterName()}, }, nil, defaults.CATTL) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - - hostCA := &services.CertAuthorityV2{ + hostCA = &services.CertAuthorityV2{ Kind: services.KindCertAuthority, Version: services.V2, Metadata: services.Metadata{ @@ -304,28 +325,28 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err }, } if err := asrv.Trust.UpsertCertAuthority(hostCA); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } } else if len(hostCA.GetTLSKeyPairs()) == 0 { log.Infof("Migrate: generating TLS CA for existing host CA.") privateKey, err := ssh.ParseRawPrivateKey(hostCA.GetSigningKeys()[0]) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } privateKeyRSA, ok := privateKey.(*rsa.PrivateKey) if !ok { - return nil, nil, trace.BadParameter("expected RSA private key, got %T", privateKey) + return nil, trace.BadParameter("expected RSA private key, got %T", privateKey) } keyPEM, certPEM, err := tlsca.GenerateSelfSignedCAWithPrivateKey(privateKeyRSA, pkix.Name{ CommonName: cfg.ClusterName.GetClusterName(), Organization: []string{cfg.ClusterName.GetClusterName()}, }, nil, defaults.CATTL) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } hostCA.SetTLSKeyPairs([]services.TLSKeyPair{{Cert: certPEM, Key: keyPEM}}) if err := asrv.Trust.UpsertCertAuthority(hostCA); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -336,24 +357,13 @@ func Init(cfg InitConfig, opts ...AuthServerOption) (*AuthServer, *Identity, err log.Warn(warningMessage) } - // read host keys from disk or create them if they don't exist - iid := IdentityID{ - HostUUID: cfg.HostUUID, - NodeName: cfg.NodeName, - Role: teleport.RoleAdmin, - } - identity, err := initKeys(asrv, cfg.DataDir, iid) - if err != nil { - return nil, nil, trace.Wrap(err) - } - // migrate any legacy resources to new format err = migrateLegacyResources(cfg, asrv) if err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - return asrv, identity, nil + return asrv, nil } func migrateLegacyResources(cfg InitConfig, asrv *AuthServer) error { @@ -362,21 +372,6 @@ func migrateLegacyResources(cfg InitConfig, asrv *AuthServer) error { return trace.Wrap(err) } - err = migrateRoles(asrv) - if err != nil { - return trace.Wrap(err) - } - - err = migrateRemoteClusters(asrv) - if err != nil { - return trace.Wrap(err) - } - - err = migrateTrustedClusters(asrv) - if err != nil { - return trace.Wrap(err) - } - return nil } @@ -412,179 +407,6 @@ func migrateUsers(asrv *AuthServer) error { return nil } -// DELETE IN: 2.6.0 -// All users will be migrated to the new roles in Teleport 2.5.0, which means -// this entire function can be removed in Teleport 2.6.0. -func migrateRoles(asrv *AuthServer) error { - roles, err := asrv.GetRoles() - if err != nil { - return trace.Wrap(err) - } - - // loop over all roles and make sure any v3 roles have the default value for - // certificate format - for i, _ := range roles { - role := roles[i] - - roleOptions := role.GetOptions() - - _, err = roleOptions.GetString(services.CertificateFormat) - if err != nil { - roleOptions.Set(services.CertificateFormat, teleport.CertificateFormatStandard) - role.SetOptions(roleOptions) - } - - err = asrv.UpsertRole(role, backend.Forever) - if err != nil { - return trace.Wrap(err) - } - log.Infof("Migrating role: %v to include default for the cert_format option.", role.GetName()) - } - - return nil -} - -// DELETE IN: 2.6.0 -// This migration adds remote cluster resource migrating from 2.5.0 -// where the presence of remote cluster was identified only by presence -// of host certificate authority with cluster name not equal local cluster name -func migrateRemoteClusters(asrv *AuthServer) error { - clusterName, err := asrv.GetClusterName() - if err != nil { - return trace.Wrap(err) - } - certAuthorities, err := asrv.GetCertAuthorities(services.HostCA, false) - if err != nil { - return trace.Wrap(err) - } - - // loop over all roles and make sure any v3 roles have permit port - // forward and forward agent allowed - for _, certAuthority := range certAuthorities { - if certAuthority.GetName() == clusterName.GetClusterName() { - log.Debugf("Migrations: skipping local cluster cert authority %q.", certAuthority.GetName()) - continue - } - // remote cluster already exists - _, err = asrv.GetRemoteCluster(certAuthority.GetName()) - if err == nil { - log.Debugf("Migrations: remote cluster already exists for cert authority %q.", certAuthority.GetName()) - continue - } - if !trace.IsNotFound(err) { - return trace.Wrap(err) - } - // the cert authority is associated with trusted cluster - _, err = asrv.GetTrustedCluster(certAuthority.GetName()) - if err == nil { - log.Debugf("Migrations: trusted cluster resource exists for cert authority %q.", certAuthority.GetName()) - continue - } - if !trace.IsNotFound(err) { - return trace.Wrap(err) - } - remoteCluster, err := services.NewRemoteCluster(certAuthority.GetName()) - if err != nil { - return trace.Wrap(err) - } - err = asrv.CreateRemoteCluster(remoteCluster) - if err != nil { - if !trace.IsAlreadyExists(err) { - return trace.Wrap(err) - } - } - log.Infof("Migrations: added remote cluster resource for cert authority %q.", certAuthority.GetName()) - } - - return nil -} - -// DELETE IN: 2.6.0 -// migrateTrustedClusters renames the trusted cluster resource names -// and certificate authorities names to equal to actual remote cluster name -func migrateTrustedClusters(asrv *AuthServer) error { - trustedClusters, err := asrv.GetTrustedClusters() - if err != nil { - return trace.Wrap(err) - } - - // loop over all roles and make sure any v3 roles have permit port - // forward and forward agent allowed - for i := range trustedClusters { - trustedCluster := trustedClusters[i] - - hostCA, err := asrv.GetAnyCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - - userCA, err := asrv.GetAnyCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - - if hostCA.GetClusterName() == trustedCluster.GetName() { - log.Debugf("Migrations: skipping trusted cluster %q with name that already matches main cluster.", trustedCluster.GetName()) - continue - } - - var reverseTunnel services.ReverseTunnel - if trustedCluster.GetEnabled() { - reverseTunnel, err = asrv.GetReverseTunnel(trustedCluster.GetName()) - if err != nil { - return trace.Wrap(err) - } - } - - log.Debugf("Migrations: renaming trusted cluster %q to %q.", trustedCluster.GetName(), hostCA.GetClusterName()) - - oldName := trustedCluster.GetName() - - trustedCluster.SetName(hostCA.GetClusterName()) - _, err = asrv.Presence.UpsertTrustedCluster(trustedCluster) - if err != nil { - return trace.Wrap(err) - } - - hostCA.SetName(hostCA.GetClusterName()) - err = asrv.UpsertCertAuthority(hostCA) - if err != nil { - return trace.Wrap(err) - } - - userCA.SetName(hostCA.GetClusterName()) - err = asrv.UpsertCertAuthority(userCA) - if err != nil { - return trace.Wrap(err) - } - - if reverseTunnel != nil { - reverseTunnel.SetName(hostCA.GetClusterName()) - reverseTunnel.SetClusterName(hostCA.GetClusterName()) - if err := asrv.UpsertReverseTunnel(reverseTunnel); err != nil { - return trace.Wrap(err) - } - } - - if !trustedCluster.GetEnabled() { - log.Debugf("Migrations: trusted cluster %q is deactivated, deactivating updated authorities for %q", oldName, trustedCluster.GetName()) - err = asrv.DeactivateCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - err = asrv.DeactivateCertAuthority(services.CertAuthID{Type: services.UserCA, DomainName: trustedCluster.GetName()}) - if err != nil { - return trace.Wrap(err) - } - } - if err := asrv.DeleteTrustedCluster(oldName); err != nil { - return trace.Wrap(err) - } - } - - return nil -} - // isFirstStart returns 'true' if the auth server is starting for the 1st time // on this server. func isFirstStart(authServer *AuthServer, cfg InitConfig) (bool, error) { @@ -604,71 +426,18 @@ func isFirstStart(authServer *AuthServer, cfg InitConfig) (bool, error) { return false, nil } -// initKeys initializes a nodes host certificate. If the certificate does not exist, a request -// is made to the certificate authority to generate a host certificate and it's written to disk. -// If a certificate exists on disk, it is read in and returned. -func initKeys(a *AuthServer, dataDir string, id IdentityID) (*Identity, error) { - path := keysPath(dataDir, id) - - keyExists, err := pathExists(path.key) - if err != nil { - return nil, trace.Wrap(err) - } - - sshCertExists, err := pathExists(path.sshCert) +// GenerateIdentity generates identity for the auth server +func GenerateIdentity(a *AuthServer, id IdentityID, additionalPrincipals []string) (*Identity, error) { + keys, err := a.GenerateServerKeys(GenerateServerKeysRequest{ + HostID: id.HostUUID, + NodeName: id.NodeName, + Roles: teleport.Roles{id.Role}, + AdditionalPrincipals: additionalPrincipals, + }) if err != nil { return nil, trace.Wrap(err) } - - tlsCertExists, err := pathExists(path.tlsCert) - if err != nil { - return nil, trace.Wrap(err) - } - - if !keyExists || !sshCertExists || !tlsCertExists { - packedKeys, err := a.GenerateServerKeys(GenerateServerKeysRequest{ - HostID: id.HostUUID, - NodeName: id.NodeName, - Roles: teleport.Roles{id.Role}, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - log.Debugf("Writing keys to disk for %v.", id) - if len(packedKeys.TLSCACerts) != 1 { - return nil, trace.BadParameter("expecting one CA cert, got %v instead", len(packedKeys.TLSCACerts)) - } - err = writeKeys(dataDir, id, packedKeys.Key, packedKeys.Cert, packedKeys.TLSCert, packedKeys.TLSCACerts[0]) - if err != nil { - return nil, trace.Wrap(err) - } - } - i, err := ReadIdentity(dataDir, id) - if err != nil { - return nil, trace.Wrap(err) - } - return i, nil -} - -// writeKeys saves the key/cert pair for a given domain onto disk. This usually means the -// domain trusts us (signed our public key) -func writeKeys(dataDir string, id IdentityID, key []byte, sshCert []byte, tlsCert []byte, tlsCACert []byte) error { - path := keysPath(dataDir, id) - - if err := ioutil.WriteFile(path.key, key, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - if err := ioutil.WriteFile(path.sshCert, sshCert, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - if err := ioutil.WriteFile(path.tlsCert, tlsCert, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - if err := ioutil.WriteFile(path.tlsCACert, tlsCACert, teleport.FileMaskOwnerOnly); err != nil { - return trace.Wrap(err) - } - return nil + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } // Identity is collection of certificates and signers that represent server identity @@ -681,9 +450,9 @@ type Identity struct { CertBytes []byte // TLSCertBytes is a PEM encoded TLS x509 client certificate TLSCertBytes []byte - // TLSCACertBytes is a PEM encoded TLS x509 certificate of certificate authority + // TLSCACertBytes is a list of PEM encoded TLS x509 certificate of certificate authority // associated with auth server services - TLSCACertBytes []byte + TLSCACertsBytes [][]byte // KeySigner is an SSH host certificate signer KeySigner ssh.Signer // Cert is a parsed SSH certificate @@ -692,9 +461,39 @@ type Identity struct { ClusterName string } +// NeedsUpdate checks if host identity +// is issued by the current certificate authority (last current issuing +// certificate authority, not the one that is being rotated) +func (i *Identity) NeedsUpdate(ca services.CertAuthority) (bool, error) { + if len(i.TLSCertBytes) == 0 { + return true, nil + } + keyPair := ca.GetTLSKeyPairs()[0] + roots := x509.NewCertPool() + + // add the provided CA certificate to the roots + ok := roots.AppendCertsFromPEM(keyPair.Cert) + if !ok { + return false, trace.BadParameter("could not parse certificate PEM") + } + + cert, err := tlsca.ParseCertificatePEM(i.TLSCertBytes) + if err != nil { + return false, trace.Wrap(err) + } + + _, err = cert.Verify(x509.VerifyOptions{Roots: roots}) + if err != nil { + log.Warningf("Failed to verify cert: %#v", err) + return true, nil + } + + return false, nil +} + // HasTSLConfig returns true if this identity has TLS certificate and private key func (i *Identity) HasTLSConfig() bool { - return len(i.TLSCACertBytes) != 0 && len(i.TLSCertBytes) != 0 && len(i.TLSCACertBytes) != 0 + return len(i.TLSCACertsBytes) != 0 && len(i.TLSCertBytes) != 0 } // HasPrincipals returns whether identity has principals @@ -719,13 +518,14 @@ func (i *Identity) TLSConfig() (*tls.Config, error) { if err != nil { return nil, trace.BadParameter("failed to parse private key: %v", err) } - certPool := x509.NewCertPool() - parsedCert, err := tlsca.ParseCertificatePEM(i.TLSCACertBytes) - if err != nil { - return nil, trace.Wrap(err, "failed to parse CA certificate") + for j := range i.TLSCACertsBytes { + parsedCert, err := tlsca.ParseCertificatePEM(i.TLSCACertsBytes[j]) + if err != nil { + return nil, trace.Wrap(err, "failed to parse CA certificate") + } + certPool.AddCert(parsedCert) } - certPool.AddCert(parsedCert) tlsConfig.Certificates = []tls.Certificate{tlsCert} tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool @@ -743,7 +543,7 @@ type IdentityID struct { func (id *IdentityID) HostID() (string, error) { parts := strings.Split(id.HostUUID, ".") if len(parts) < 2 { - return "", trace.BadParameter("expected 2 parts in %v", id.HostUUID) + return "", trace.BadParameter("expected 2 parts in %q", id.HostUUID) } return parts[0], nil } @@ -759,25 +559,25 @@ func (id *IdentityID) String() string { } // ReadIdentityFromKeyPair reads TLS identity from key pair -func ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes, tlsCACertBytes []byte) (*Identity, error) { +func ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes []byte, tlsCACertsBytes [][]byte) (*Identity, error) { identity, err := ReadSSHIdentityFromKeyPair(keyBytes, sshCertBytes) if err != nil { return nil, trace.Wrap(err) } if len(tlsCertBytes) != 0 { // just to verify that identity parses properly for future use - _, err := ReadTLSIdentityFromKeyPair(keyBytes, tlsCertBytes, tlsCACertBytes) + _, err := ReadTLSIdentityFromKeyPair(keyBytes, tlsCertBytes, tlsCACertsBytes) if err != nil { return nil, trace.Wrap(err) } identity.TLSCertBytes = tlsCertBytes - identity.TLSCACertBytes = tlsCACertBytes + identity.TLSCACertsBytes = tlsCACertsBytes } return identity, nil } // ReadTLSIdentityFromKeyPair reads TLS identity from key pair -func ReadTLSIdentityFromKeyPair(keyBytes, certBytes []byte, caCertBytes []byte) (*Identity, error) { +func ReadTLSIdentityFromKeyPair(keyBytes, certBytes []byte, caCertsBytes [][]byte) (*Identity, error) { if len(keyBytes) == 0 { return nil, trace.BadParameter("missing private key") } @@ -799,19 +599,18 @@ func ReadTLSIdentityFromKeyPair(keyBytes, certBytes []byte, caCertBytes []byte) if len(cert.Issuer.Organization) == 0 { return nil, trace.BadParameter("missing CA organization") } + clusterName := cert.Issuer.Organization[0] if clusterName == "" { return nil, trace.BadParameter("misssing cluster name") } - identity := &Identity{ - ID: IdentityID{HostUUID: id.Username, Role: teleport.Role(id.Groups[0])}, - ClusterName: clusterName, - KeyBytes: keyBytes, - TLSCertBytes: certBytes, - TLSCACertBytes: caCertBytes, + ID: IdentityID{HostUUID: id.Username, Role: teleport.Role(id.Groups[0])}, + ClusterName: clusterName, + KeyBytes: keyBytes, + TLSCertBytes: certBytes, + TLSCACertsBytes: caCertsBytes, } - _, err = identity.TLSConfig() if err != nil { return nil, trace.Wrap(err) @@ -893,9 +692,21 @@ func ReadSSHIdentityFromKeyPair(keyBytes, certBytes []byte) (*Identity, error) { }, nil } -// ReadIdentity reads, parses and returns the given pub/pri key + cert from the +// ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the // key storage (dataDir). -func ReadIdentity(dataDir string, id IdentityID) (i *Identity, err error) { +func ReadLocalIdentity(dataDir string, id IdentityID) (*Identity, error) { + storage, err := NewProcessStorage(dataDir) + if err != nil { + return nil, trace.Wrap(err) + } + defer storage.Close() + return storage.ReadIdentity(IdentityCurrent, id.Role) +} + +// DELETE IN(2.7.0) +// readIdentityCompat reads, parses and returns the given pub/pri key + cert from the +// key storage (dataDir). Used for data migrations +func readIdentityCompat(dataDir string, id IdentityID) (i *Identity, err error) { path := keysPath(dataDir, id) log.Debugf("Reading keys from disk: %v.", path) @@ -924,7 +735,7 @@ func ReadIdentity(dataDir string, id IdentityID) (i *Identity, err error) { } } - identity, err := ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes, tlsCACertBytes) + identity, err := ReadIdentityFromKeyPair(keyBytes, sshCertBytes, tlsCertBytes, [][]byte{tlsCACertBytes}) if err != nil { return nil, trace.Wrap(err) } @@ -935,12 +746,6 @@ func ReadIdentity(dataDir string, id IdentityID) (i *Identity, err error) { return identity, nil } -// WriteIdentity writes identity keypair to disk -func WriteIdentity(dataDir string, identity *Identity) error { - return trace.Wrap( - writeKeys(dataDir, identity.ID, identity.KeyBytes, identity.CertBytes, identity.TLSCertBytes, identity.TLSCACertBytes)) -} - // HaveHostKeys checks that host keys are in place func HaveHostKeys(dataDir string, id IdentityID) (bool, error) { path := keysPath(dataDir, id) diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 7f9f3d3373e4e..519978c7c4de0 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -17,8 +17,6 @@ limitations under the License. package auth import ( - "bytes" - "crypto/rsa" "time" "golang.org/x/crypto/ssh" @@ -26,7 +24,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -307,142 +304,3 @@ func (s *AuthServer) AuthenticateSSHUser(req AuthenticateSSHRequest) (*SSHLoginR HostSigners: AuthoritiesToTrustedCerts(hostCertAuthorities), }, nil } - -// DELETE IN: 2.6.0 -// This method is used only for upgrades from 2.4.0 to 2.5.0 -// ExchangeCertsRequest is a request to exchange TLS certificates -// for clusters that already trust each other -type ExchangeCertsRequest struct { - // PublicKey is public key of the trusted certificate authority - PublicKey []byte `json:"public_key"` - // TLSCert is TLS certificate associated with the public key - TLSCert []byte `json:"tls_cert"` -} - -// CheckAndSetDefaults checks and sets default values -func (req *ExchangeCertsRequest) CheckAndSetDefaults() error { - if len(req.PublicKey) == 0 { - return trace.BadParameter("missing parameter 'public_key'") - } - if len(req.TLSCert) == 0 { - return trace.BadParameter("missing parameter 'tls_cert'") - } - return nil -} - -// DELETE IN: 2.6.0 -// ExchangeCertsResponse is a resposne to exchange certificates request -type ExchangeCertsResponse struct { - // TLSCert is a PEM encoded certificate of a local certificate authority - TLSCert []byte `json:"tls_cert"` -} - -// DELETE IN: 2.6.0 -// This method is used to ugprade from 2.4.0 to 2.5.0 -// ExchangeCerts is a method to exchange TLS certificates between certificate authorities -// of the trusted clusters. A remote auth server that wishes to exchange TLS certs with a local auth server -// sends a request that consists of a public key already trusted by the local server and -// TLS certificate for the public key. The local server ensures that the TLS certificate -// was issued to the public key that is already trusted preventing random certificates -// to be injected by the remote server. This is a minor security enforcement. -func (s *AuthServer) ExchangeCerts(req ExchangeCertsRequest) (*ExchangeCertsResponse, error) { - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - remoteCA, err := s.findCertAuthorityByPublicKey(req.PublicKey) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := CheckPublicKeysEqual(req.PublicKey, req.TLSCert); err != nil { - return nil, trace.Wrap(err) - } - - // make sure that cluster name in TLS cert is not the same as cluster name - cert, err := tlsca.ParseCertificatePEM(req.TLSCert) - if err != nil { - return nil, trace.Wrap(err) - } - remoteClusterName, err := tlsca.ClusterName(cert.Subject) - if err != nil { - return nil, trace.Wrap(err) - } - clusterName, err := s.GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } - - if remoteClusterName == clusterName.GetName() { - return nil, trace.BadParameter("remote cluster name can not be the same as local cluster name") - } - - remoteCA.SetTLSKeyPairs([]services.TLSKeyPair{ - { - Cert: req.TLSCert, - }, - }) - - err = s.UpsertCertAuthority(remoteCA) - if err != nil { - return nil, trace.Wrap(err) - } - - thisHostCA, err := s.GetCertAuthority(services.CertAuthID{Type: services.HostCA, DomainName: clusterName.GetClusterName()}, false) - if err != nil { - return nil, trace.Wrap(err) - } - - return &ExchangeCertsResponse{ - TLSCert: thisHostCA.GetTLSKeyPairs()[0].Cert, - }, nil - -} - -func (s *AuthServer) findCertAuthorityByPublicKey(publicKey []byte) (services.CertAuthority, error) { - authorities, err := s.GetCertAuthorities(services.HostCA, false) - if err != nil { - return nil, trace.Wrap(err) - } - for _, ca := range authorities { - for _, key := range ca.GetCheckingKeys() { - if bytes.Equal(key, publicKey) { - return ca, nil - } - } - } - return nil, trace.NotFound("certificate authority with public key is not found") -} - -// CheckPublicKeysEqual compares RSA based SSH certificate with the -// TLS certificate, returns nil if both certificates are using the same public -// key and refer to the same cluster name, error otherwise -func CheckPublicKeysEqual(sshKeyBytes []byte, certBytes []byte) error { - cert, err := tlsca.ParseCertificatePEM(certBytes) - if err != nil { - return trace.Wrap(err) - } - certPublicKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return trace.BadParameter("expected RSA public key, got %T", cert.PublicKey) - } - publicKey, _, _, _, err := ssh.ParseAuthorizedKey(sshKeyBytes) - if err != nil { - return trace.Wrap(err) - } - cryptoPubKey, ok := publicKey.(ssh.CryptoPublicKey) - if !ok { - return trace.BadParameter("unexpected key type: %T", publicKey) - } - rsaPublicKey, ok := cryptoPubKey.CryptoPublicKey().(*rsa.PublicKey) - if !ok { - return trace.BadParameter("unexpected key type: %T", publicKey) - } - if certPublicKey.E != rsaPublicKey.E { - return trace.CompareFailed("different public keys") - } - if certPublicKey.N.Cmp(rsaPublicKey.N) != 0 { - return trace.CompareFailed("different public keys") - } - return nil -} diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 564f0576ccf00..bd82098e42967 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -146,6 +146,10 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { // https://github.com/kubernetes/kubernetes/pull/34524/files#diff-2b283dde198c92424df5355f39544aa4R59 return nil, trace.AccessDenied("access denied: intermediaries are not supported") } + localClusterName, err := a.AccessPoint.GetDomainName() + if err != nil { + return nil, trace.Wrap(err) + } // with no client authentication in place, middleware // assumes not-privileged Nop role. // it theoretically possible to use bearer token auth even @@ -156,6 +160,7 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { GetClusterConfig: a.AccessPoint.GetClusterConfig, Role: teleport.RoleNop, Username: string(teleport.RoleNop), + ClusterName: localClusterName, }, nil } clientCert := peers[0] @@ -164,10 +169,7 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { log.Warning("Failed to parse client certificate %v.", err) return nil, trace.AccessDenied("access denied: invalid client certificate") } - localClusterName, err := a.AccessPoint.GetDomainName() - if err != nil { - return nil, trace.Wrap(err) - } + identity, err := tlsca.FromSubject(clientCert.Subject) if err != nil { return nil, trace.Wrap(err) @@ -209,6 +211,7 @@ func (a *AuthMiddleware) GetUser(r *http.Request) (interface{}, error) { GetClusterConfig: a.AccessPoint.GetClusterConfig, Role: *systemRole, Username: identity.Username, + ClusterName: localClusterName, }, nil } // otherwise assume that is a local role, no need to pass the roles diff --git a/lib/auth/permissions.go b/lib/auth/permissions.go index 86e076eaacbcc..4f92f44adbeb0 100644 --- a/lib/auth/permissions.go +++ b/lib/auth/permissions.go @@ -24,11 +24,12 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/trace" + "github.com/vulcand/predicate/builder" ) // NewAdminContext returns new admin auth context func NewAdminContext() (*AuthContext, error) { - authContext, err := contextForBuiltinRole(nil, teleport.RoleAdmin, fmt.Sprintf("%v", teleport.RoleAdmin)) + authContext, err := contextForBuiltinRole("", nil, teleport.RoleAdmin, fmt.Sprintf("%v", teleport.RoleAdmin)) if err != nil { return nil, trace.Wrap(err) } @@ -36,8 +37,8 @@ func NewAdminContext() (*AuthContext, error) { } // NewRoleAuthorizer authorizes everyone as predefined role, used in tests -func NewRoleAuthorizer(clusterConfig services.ClusterConfig, r teleport.Role) (Authorizer, error) { - authContext, err := contextForBuiltinRole(clusterConfig, r, fmt.Sprintf("%v", r)) +func NewRoleAuthorizer(clusterName string, clusterConfig services.ClusterConfig, r teleport.Role) (Authorizer, error) { + authContext, err := contextForBuiltinRole(clusterName, clusterConfig, r, fmt.Sprintf("%v", r)) if err != nil { return nil, trace.Wrap(err) } @@ -159,15 +160,13 @@ func (a *authorizer) authorizeBuiltinRole(r BuiltinRole) (*AuthContext, error) { if err != nil { return nil, trace.Wrap(err) } - return contextForBuiltinRole(config, r.Role, r.Username) + return contextForBuiltinRole(r.ClusterName, config, r.Role, r.Username) } func (a *authorizer) authorizeRemoteBuiltinRole(r RemoteBuiltinRole) (*AuthContext, error) { if r.Role != teleport.RoleProxy { return nil, trace.AccessDenied("access denied for remote %v connecting to cluster", r.Role) } - // TODO(klizhentas): allow remote proxy to update the cluster's certificate authorities - // during certificates renewal roles, err := services.FromSpec( string(teleport.RoleRemoteProxy), services.RoleSpecV3{ @@ -183,6 +182,17 @@ func (a *authorizer) authorizeRemoteBuiltinRole(r RemoteBuiltinRole) (*AuthConte services.NewRule(services.KindReverseTunnel, services.RO()), services.NewRule(services.KindTunnelConnection, services.RO()), services.NewRule(services.KindClusterConfig, services.RO()), + // this rule allows remote proxy to update the cluster's certificate authorities + // during certificates renewal + { + Resources: []string{services.KindCertAuthority}, + // It is important that remote proxy can only rotate + // existing certificate authority, and not create or update new ones + Verbs: []string{services.VerbRead, services.VerbRotate}, + // allow administrative access to the certificate authority names + // matching the cluster name only + Where: builder.Equals(services.ResourceNameExpr, builder.String(r.ClusterName)).String(), + }, }, }, }) @@ -201,7 +211,7 @@ func (a *authorizer) authorizeRemoteBuiltinRole(r RemoteBuiltinRole) (*AuthConte } // GetCheckerForBuiltinRole returns checkers for embedded builtin role -func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role teleport.Role) (services.AccessChecker, error) { +func GetCheckerForBuiltinRole(clusterName string, clusterConfig services.ClusterConfig, role teleport.Role) (services.AccessChecker, error) { switch role { case teleport.RoleAuth: return services.FromSpec( @@ -272,6 +282,23 @@ func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role telepor services.NewRule(services.KindTunnelConnection, services.RW()), services.NewRule(services.KindHostCert, services.RW()), services.NewRule(services.KindRemoteCluster, services.RO()), + // this rule allows local proxy to update the remote cluster's host certificate authorities + // during certificates renewal + { + Resources: []string{services.KindCertAuthority}, + Verbs: []string{services.VerbCreate, services.VerbRead, services.VerbUpdate}, + // allow administrative access to the host certificate authorities + // matching any cluster name except local + Where: builder.And( + builder.Equals(services.CertAuthorityTypeExpr, builder.String(string(services.HostCA))), + builder.Not( + builder.Equals( + services.ResourceNameExpr, + builder.String(clusterName), + ), + ), + ).String(), + }, }, }, }) @@ -305,6 +332,23 @@ func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role telepor services.NewRule(services.KindStaticTokens, services.RO()), services.NewRule(services.KindTunnelConnection, services.RW()), services.NewRule(services.KindRemoteCluster, services.RO()), + // this rule allows local proxy to update the remote cluster's host certificate authorities + // during certificates renewal + { + Resources: []string{services.KindCertAuthority}, + Verbs: []string{services.VerbCreate, services.VerbRead, services.VerbUpdate}, + // allow administrative access to the certificate authority names + // matching any cluster name except local + Where: builder.And( + builder.Equals(services.CertAuthorityTypeExpr, builder.String(string(services.HostCA))), + builder.Not( + builder.Equals( + services.ResourceNameExpr, + builder.String(clusterName), + ), + ), + ).String(), + }, }, }, }) @@ -367,8 +411,8 @@ func GetCheckerForBuiltinRole(clusterConfig services.ClusterConfig, role telepor return nil, trace.NotFound("%v is not reconginzed", role.String()) } -func contextForBuiltinRole(clusterConfig services.ClusterConfig, r teleport.Role, username string) (*AuthContext, error) { - checker, err := GetCheckerForBuiltinRole(clusterConfig, r) +func contextForBuiltinRole(clusterName string, clusterConfig services.ClusterConfig, r teleport.Role, username string) (*AuthContext, error) { + checker, err := GetCheckerForBuiltinRole(clusterName, clusterConfig, r) if err != nil { return nil, trace.Wrap(err) } @@ -417,6 +461,9 @@ type BuiltinRole struct { // Username is for authentication tracking purposes Username string + + // ClusterName is the name of the local cluster + ClusterName string } // RemoteBuiltinRole is the role of the remote (service connecting via trusted cluster link) diff --git a/lib/auth/register.go b/lib/auth/register.go index 811f81e6def54..cca140afe8617 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -33,7 +33,7 @@ import ( // LocalRegister is used to generate host keys when a node or proxy is running within the same process // as the auth server. This method does not need to use provisioning tokens. -func LocalRegister(dataDir string, id IdentityID, authServer *AuthServer, additionalPrincipals []string) error { +func LocalRegister(id IdentityID, authServer *AuthServer, additionalPrincipals []string) (*Identity, error) { keys, err := authServer.GenerateServerKeys(GenerateServerKeysRequest{ HostID: id.HostUUID, NodeName: id.NodeName, @@ -41,27 +41,26 @@ func LocalRegister(dataDir string, id IdentityID, authServer *AuthServer, additi AdditionalPrincipals: additionalPrincipals, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return writeKeys(dataDir, id, keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } // Register is used to generate host keys when a node or proxy are running on different hosts // than the auth server. This method requires provisioning tokens to prove a valid auth server // was used to issue the joining request. -func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, additionalPrincipals []string) error { +func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, additionalPrincipals []string) (*Identity, error) { tok, err := readToken(token) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } tlsConfig := utils.TLSConfig() certPath := filepath.Join(dataDir, defaults.CACertFile) certBytes, err := utils.ReadPath(certPath) if err != nil { - // DELETE IN: 2.6.0 // Only support secure cluster joins in the next releases if !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } message := fmt.Sprintf(`Your configuration is insecure! Registering without TLS certificate authority, to fix this warning add ca.cert to %v, you can get ca.cert using 'tctl auth export --type=tls > ca.cert'`, dataDir) log.Warning(message) @@ -69,7 +68,7 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, add } else { cert, err := tlsca.ParseCertificatePEM(certBytes) if err != nil { - return trace.Wrap(err, "failed to parse certificate at %v", certPath) + return nil, trace.Wrap(err, "failed to parse certificate at %v", certPath) } log.Infof("Joining remote cluster %v.", cert.Subject.CommonName) certPool := x509.NewCertPool() @@ -78,7 +77,7 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, add } client, err := NewTLSClient(servers, tlsConfig) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer client.Close() @@ -91,18 +90,17 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr, add AdditionalPrincipals: additionalPrincipals, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return writeKeys(dataDir, id, keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } -// ReRegister renews the certificates and private keys based on the existing -// identity ID -func ReRegister(dataDir string, clt ClientI, id IdentityID, additionalPrincipals []string) error { +// ReRegister renews the certificates and private keys based on the existing identity ID +func ReRegister(clt ClientI, id IdentityID, additionalPrincipals []string) (*Identity, error) { hostID, err := id.HostID() if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } keys, err := clt.GenerateServerKeys(GenerateServerKeysRequest{ HostID: hostID, @@ -111,9 +109,9 @@ func ReRegister(dataDir string, clt ClientI, id IdentityID, additionalPrincipals AdditionalPrincipals: additionalPrincipals, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return writeKeys(dataDir, id, keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts[0]) + return ReadIdentityFromKeyPair(keys.Key, keys.Cert, keys.TLSCert, keys.TLSCACerts) } func readToken(token string) (string, error) { diff --git a/lib/auth/rotate.go b/lib/auth/rotate.go new file mode 100644 index 0000000000000..6b1f935bdd782 --- /dev/null +++ b/lib/auth/rotate.go @@ -0,0 +1,475 @@ +/* +Copyright 2018 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "crypto/x509/pkix" + "time" + + "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/pborman/uuid" + "github.com/sirupsen/logrus" +) + +// RotateRequest is a request to start rotation of the certificate authority +type RotateRequest struct { + // Type is certificate authority type, if omitted, both will be rotated + Type services.CertAuthType `json:"type"` + // GracePeriod is optional grace period, if omitted, default is set, + // if 0 is supplied, means force rotate all certificate authorities + // right away. + GracePeriod *time.Duration `json:"grace_period,omitempty"` + // TargetPhase sets desired rotation phase to move to, if not set + // will be set automatically, is a required argument + // for manual rotation. + TargetPhase string `json:"target_phase,omitempty"` + // Mode sets manual mode with manually updated phases, + // otherwise phases are set automatically + Mode string `json:"mode"` + // Schedule is an optional rotation schedule, + // autogenerated if not set + Schedule *services.RotationSchedule `json:"schedule"` +} + +// Types returns cert authority types requested to rotate +func (r *RotateRequest) Types() []services.CertAuthType { + switch r.Type { + case "": + return []services.CertAuthType{services.HostCA, services.UserCA} + case services.HostCA: + return []services.CertAuthType{services.HostCA} + case services.UserCA: + return []services.CertAuthType{services.UserCA} + } + return nil +} + +// CheckAndSetDefaults checks and sets defaults +func (r *RotateRequest) CheckAndSetDefaults(clock clockwork.Clock) error { + if r.TargetPhase == "" { + // if phase if not set, imply that the first meaningful phase + // is set as a target phase + r.TargetPhase = services.RotationPhaseUpdateClients + } + // if mode is not set, default to manual (as it's safer) + if r.Mode == "" { + r.Mode = services.RotationModeManual + } + switch r.Type { + case "", services.HostCA, services.UserCA: + default: + return trace.BadParameter("unsupported certificate authority type: %q", r.Type) + } + if r.GracePeriod == nil { + period := defaults.RotationGracePeriod + r.GracePeriod = &period + } + if r.Schedule == nil { + var err error + r.Schedule, err = services.GenerateSchedule(clock, *r.GracePeriod) + if err != nil { + return trace.Wrap(err) + } + } else { + if err := r.Schedule.CheckAndSetDefaults(clock); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +// rotationReq is an internal rotation requrest +type rotationReq struct { + // clock is set by the auth server internally + clock clockwork.Clock + // ca is a certificate authority to rotate, set by the auth server internally + ca services.CertAuthority + // targetPhase is a target rotation phase to set + targetPhase string + // mode is a rotation mode + mode string + // gracePeriod is a rotation grace period + gracePeriod time.Duration + // schedule is a schedule to set + schedule services.RotationSchedule +} + +// RotateCertAuthority starts or restarts certificate rotation process +func (a *AuthServer) RotateCertAuthority(req RotateRequest) error { + // TODO: For whatever reason rotation does not work on DynamoDB - getting error. + if err := req.CheckAndSetDefaults(a.clock); err != nil { + return trace.Wrap(err) + } + clusterName := a.clusterName.GetClusterName() + + caTypes := req.Types() + for _, caType := range caTypes { + existing, err := a.GetCertAuthority(services.CertAuthID{ + Type: caType, + DomainName: clusterName, + }, true) + if err != nil { + return trace.Wrap(err) + } + rotated, err := processRotationRequest(rotationReq{ + ca: existing, + clock: a.clock, + targetPhase: req.TargetPhase, + schedule: *req.Schedule, + gracePeriod: *req.GracePeriod, + mode: req.Mode, + }) + if err != nil { + return trace.Wrap(err) + } + if err := a.CompareAndSwapCertAuthority(rotated, existing); err != nil { + return trace.Wrap(err) + } + rotation := rotated.GetRotation() + switch rotation.State { + case services.RotationStateInProgress: + log.WithFields(logrus.Fields{"type": caType}).Infof("Rotation is in progress, current phase: %q.", rotation.Phase) + case services.RotationStateStandby: + log.WithFields(logrus.Fields{"type": caType}).Infof("Rotation has been completed.") + } + } + return nil +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is called by remote trusted cluster and is used to update +// only public keys and certificates of the certificate authority. +func (a *AuthServer) RotateExternalCertAuthority(ca services.CertAuthority) error { + if ca == nil { + return trace.BadParameter("missing certificate authority") + } + // this is just an extra precaution against local admins, + // because this is additionally enforced by RBAC as well + if ca.GetClusterName() == a.clusterName.GetClusterName() { + return trace.BadParameter("can not rotate local certificate authority") + } + + existing, err := a.GetCertAuthority(services.CertAuthID{ + Type: ca.GetType(), + DomainName: ca.GetClusterName(), + }, false) + if err != nil { + return trace.Wrap(err) + } + + updated := existing.Clone() + updated.SetCheckingKeys(ca.GetCheckingKeys()) + updated.SetTLSKeyPairs(ca.GetTLSKeyPairs()) + updated.SetRotation(ca.GetRotation()) + + // use compare and swap to protect from concurrent updates + // by trusted cluster API + if err := a.CompareAndSwapCertAuthority(updated, existing); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// autoRotateCertAuthorities automatically rotates cert authorities, +// does nothing if no rotation parameters were set up +// or it is too early to rotate per schedule +func (a *AuthServer) autoRotateCertAuthorities() error { + clusterName := a.clusterName.GetClusterName() + for _, caType := range []services.CertAuthType{services.HostCA, services.UserCA} { + ca, err := a.GetCertAuthority(services.CertAuthID{ + Type: caType, + DomainName: clusterName, + }, true) + if err != nil { + return trace.Wrap(err) + } + if err := a.autoRotate(ca); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +func (a *AuthServer) autoRotate(ca services.CertAuthority) error { + rotation := ca.GetRotation() + // rotation mode is not manual, nothing to do + if rotation.Mode != services.RotationModeAuto { + return nil + } + // rotation is not in progress, there is nothing to do + if rotation.State != services.RotationStateInProgress { + return nil + } + logger := log.WithFields(logrus.Fields{"type": ca.GetType()}) + var req *rotationReq + switch rotation.Phase { + case services.RotationPhaseUpdateClients: + if rotation.Schedule.UpdateServers.After(a.clock.Now()) { + return nil + } + req = &rotationReq{ + clock: a.clock, + ca: ca, + targetPhase: services.RotationPhaseUpdateServers, + mode: services.RotationModeAuto, + gracePeriod: rotation.GracePeriod.Duration, + schedule: rotation.Schedule, + } + case services.RotationPhaseUpdateServers: + if rotation.Schedule.Standby.After(a.clock.Now()) { + return nil + } + req = &rotationReq{ + clock: a.clock, + ca: ca, + targetPhase: services.RotationPhaseStandby, + mode: services.RotationModeAuto, + gracePeriod: rotation.GracePeriod.Duration, + schedule: rotation.Schedule, + } + default: + return trace.BadParameter("phase is not supported: %q", rotation.Phase) + } + logger.Infof("Setting rotation phase %q", req.targetPhase) + rotated, err := processRotationRequest(*req) + if err != nil { + return trace.Wrap(err) + } + if err := a.CompareAndSwapCertAuthority(rotated, ca); err != nil { + return trace.Wrap(err) + } + logger.Infof("Cert authority rotation request is completed") + return nil +} + +// processRotationRequest processes rotation request FSM-style +// switches Phase and State +func processRotationRequest(req rotationReq) (services.CertAuthority, error) { + rotation := req.ca.GetRotation() + ca := req.ca.Clone() + + switch req.targetPhase { + // this is the first stage of the rotation - new certificate authorities + // are being generated, clients will start using new credentials + // and servers will use existing credentials, but will trust clients + // using new credentials + case services.RotationPhaseUpdateClients: + switch rotation.State { + case services.RotationStateStandby, "": + default: + return nil, trace.BadParameter("can not initate rotation while another is in progress") + } + if err := startNewRotation(req, ca); err != nil { + return nil, trace.Wrap(err) + } + return ca, nil + // update server phase trusts uses the new credentials both for servers + // and clients, but still trusts clients with old credentials + case services.RotationPhaseUpdateServers: + if rotation.Phase != services.RotationPhaseUpdateClients { + return nil, trace.BadParameter( + "can only switch to phase %v from %v, current phase is %v", + services.RotationPhaseUpdateServers, + services.RotationPhaseUpdateClients, + rotation.Phase) + } + // this is simply update of the phase to signal nodes to restart + // and start serving new signatures + rotation.Phase = req.targetPhase + rotation.Mode = req.mode + ca.SetRotation(rotation) + return ca, nil + // rollback moves back both clients and servers to use old credentials + // but will trust new credentials + case services.RotationPhaseRollback: + switch rotation.Phase { + case services.RotationPhaseUpdateClients, services.RotationPhaseUpdateServers: + if err := startRollingBackRotation(ca); err != nil { + return nil, trace.Wrap(err) + } + return ca, nil + } + // this is to complete rotation, moves overall rotation + // to standby, servers will only trust one CA + case services.RotationPhaseStandby: + switch rotation.Phase { + case services.RotationPhaseUpdateServers, services.RotationPhaseRollback: + if err := completeRotation(req.clock, ca); err != nil { + return nil, trace.Wrap(err) + } + return ca, nil + default: + return nil, trace.BadParameter( + "can only switch to phase %v from %v, current phase is %v", + services.RotationPhaseUpdateServers, + services.RotationPhaseUpdateClients, + rotation.Phase) + } + default: + return nil, trace.BadParameter("unsupported phase: %q", req.targetPhase) + } + return nil, trace.BadParameter("internal error") +} + +// startNewRotation starts new rotation and in place updates the certificate +// authority with new CA keys +func startNewRotation(req rotationReq, ca services.CertAuthority) error { + clock := req.clock + gracePeriod := req.gracePeriod + + rotation := ca.GetRotation() + id := uuid.New() + + rotation.Mode = req.mode + rotation.Schedule = req.schedule + + // first part of the function generates credentials + sshPrivPEM, sshPubPEM, err := native.GenerateKeyPair("") + if err != nil { + return trace.Wrap(err) + } + + keyPEM, certPEM, err := tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: ca.GetClusterName(), + Organization: []string{ca.GetClusterName()}, + }, nil, defaults.CATTL) + if err != nil { + return trace.Wrap(err) + } + tlsKeyPair := &services.TLSKeyPair{ + Cert: certPEM, + Key: keyPEM, + } + + // second part of the function rotates the certificate authority + rotation.Started = clock.Now().UTC() + rotation.GracePeriod = services.NewDuration(gracePeriod) + rotation.CurrentID = id + + signingKeys := ca.GetSigningKeys() + checkingKeys := ca.GetCheckingKeys() + keyPairs := ca.GetTLSKeyPairs() + + // drop old certificate authority without keeping it as trusted + if gracePeriod == 0 { + signingKeys = [][]byte{sshPrivPEM} + checkingKeys = [][]byte{sshPubPEM} + keyPairs = []services.TLSKeyPair{*tlsKeyPair} + // in case of force rotation, rotation has been started and completed + // in the same step moving it to standby state + rotation.State = services.RotationStateStandby + } else { + // rotation sets the first key to be the new key + // and keep only public keys/certs for the new CA + signingKeys = [][]byte{sshPrivPEM, signingKeys[0]} + checkingKeys = [][]byte{sshPubPEM, checkingKeys[0]} + oldKeyPair := keyPairs[0] + keyPairs = []services.TLSKeyPair{*tlsKeyPair, oldKeyPair} + rotation.State = services.RotationStateInProgress + rotation.Phase = services.RotationPhaseUpdateClients + } + + ca.SetSigningKeys(signingKeys) + ca.SetCheckingKeys(checkingKeys) + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} + +// startRollingBackRotation starts rolls back rotation to the previous state +func startRollingBackRotation(ca services.CertAuthority) error { + rotation := ca.GetRotation() + + // rollback always sets rotation to manual mode + rotation.Mode = services.RotationModeManual + + // second part of the function rotates the certificate authority + signingKeys := ca.GetSigningKeys() + checkingKeys := ca.GetCheckingKeys() + keyPairs := ca.GetTLSKeyPairs() + + // rotation sets the first key to be the new key + // and keep only public keys/certs for the new CA + signingKeys = [][]byte{signingKeys[1]} + checkingKeys = [][]byte{checkingKeys[1]} + + // here, keep the attempted key pair certificate as trusted + // as during rollback phases, both types of clients may be present in the cluster + keyPairs = []services.TLSKeyPair{keyPairs[1], services.TLSKeyPair{Cert: keyPairs[0].Cert}} + rotation.State = services.RotationStateInProgress + rotation.Phase = services.RotationPhaseRollback + + ca.SetSigningKeys(signingKeys) + ca.SetCheckingKeys(checkingKeys) + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} + +// completeRollingBackRotation completes rollback of the rotation +// sets it to the standby state +func completeRollingBackRotation(clock clockwork.Clock, ca services.CertAuthority) error { + rotation := ca.GetRotation() + + // clean up the state + rotation.Started = time.Time{} + rotation.State = services.RotationStateStandby + rotation.Phase = services.RotationPhaseStandby + rotation.Mode = "" + rotation.Schedule = services.RotationSchedule{} + + keyPairs := ca.GetTLSKeyPairs() + // only keep the original certificate authority as trusted + // and remove all extra + keyPairs = []services.TLSKeyPair{keyPairs[0]} + + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} + +// completeRotation completes certificate authority rotation +func completeRotation(clock clockwork.Clock, ca services.CertAuthority) error { + rotation := ca.GetRotation() + signingKeys := ca.GetSigningKeys() + checkingKeys := ca.GetCheckingKeys() + keyPairs := ca.GetTLSKeyPairs() + + signingKeys = signingKeys[:1] + checkingKeys = checkingKeys[:1] + keyPairs = keyPairs[:1] + + rotation.Started = time.Time{} + rotation.State = services.RotationStateStandby + rotation.Phase = services.RotationPhaseStandby + rotation.LastRotated = clock.Now() + rotation.Mode = "" + rotation.Schedule = services.RotationSchedule{} + + ca.SetSigningKeys(signingKeys) + ca.SetCheckingKeys(checkingKeys) + ca.SetTLSKeyPairs(keyPairs) + ca.SetRotation(rotation) + return nil +} diff --git a/lib/auth/state.go b/lib/auth/state.go new file mode 100644 index 0000000000000..049dc976ca2b2 --- /dev/null +++ b/lib/auth/state.go @@ -0,0 +1,237 @@ +package auth + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/dir" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/gravitational/trace" +) + +// ProcessStorage is a special local process state backend +// helps to manage rotation for certificate authorities +// and process-local state management +type ProcessStorage struct { + b backend.Backend +} + +func NewProcessStorage(path string) (*ProcessStorage, error) { + if path == "" { + return nil, trace.BadParameter("missing parameter state") + } + backend, err := dir.New(backend.Params{"path": path}) + if err != nil { + return nil, trace.Wrap(err) + } + return &ProcessStorage{b: backend}, nil +} + +func (p *ProcessStorage) Close() error { + return p.b.Close() +} + +const ( + // IdentityNameCurrent current is for the identity that is currently used + IdentityCurrent = "current" + // IdentityReplacement replacement is the identity that is replacing current + IdentityReplacement = "replacement" +) + +// stateName is a shared common state resource name used internally +const stateName = "state" + +// GetState reads rotation state +func (p *ProcessStorage) GetState(role teleport.Role) (*StateV2, error) { + data, err := p.b.GetVal([]string{"states", strings.ToLower(role.String())}, stateName) + if err != nil { + return nil, trace.Wrap(err) + } + var res StateV2 + if err := utils.UnmarshalWithSchema(GetStateSchema(), &res, data); err != nil { + return nil, trace.BadParameter(err.Error()) + } + return &res, nil +} + +// CreateState creates process state if it does not exist yet +func (p *ProcessStorage) CreateState(role teleport.Role, state StateV2) error { + if err := state.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + data, err := json.Marshal(state) + if err != nil { + return trace.Wrap(err) + } + err = p.b.CreateVal([]string{"states", strings.ToLower(role.String())}, stateName, data, backend.Forever) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// WriteState writes local cluster state +func (p *ProcessStorage) WriteState(role teleport.Role, state StateV2) error { + if err := state.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + data, err := json.Marshal(state) + if err != nil { + return trace.Wrap(err) + } + err = p.b.UpsertVal([]string{"states", strings.ToLower(role.String())}, stateName, data, backend.Forever) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// ReadIdentity reads identity using identity name and role +func (p *ProcessStorage) ReadIdentity(name string, role teleport.Role) (*Identity, error) { + if name == "" { + return nil, trace.BadParameter("missing parameter name") + } + data, err := p.b.GetVal([]string{"ids", strings.ToLower(role.String())}, name) + if err != nil { + return nil, trace.Wrap(err) + } + var res IdentityV2 + if err := utils.UnmarshalWithSchema(GetIdentitySchema(), &res, data); err != nil { + return nil, trace.BadParameter(err.Error()) + } + return ReadIdentityFromKeyPair(res.Spec.Key, res.Spec.SSHCert, res.Spec.TLSCert, res.Spec.TLSCACerts) +} + +// WriteIdentity writes identity to the local storage +func (p *ProcessStorage) WriteIdentity(name string, id Identity) error { + res := IdentityV2{ + ResourceHeader: services.ResourceHeader{ + Kind: services.KindIdentity, + Version: services.V2, + Metadata: services.Metadata{ + Name: name, + }, + }, + Spec: IdentitySpecV2{ + Key: id.KeyBytes, + SSHCert: id.CertBytes, + TLSCert: id.TLSCertBytes, + TLSCACerts: id.TLSCACertsBytes, + }, + } + if err := res.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + data, err := json.Marshal(res) + if err != nil { + return trace.Wrap(err) + } + return p.b.UpsertVal([]string{"ids", strings.ToLower(id.ID.Role.String())}, name, data, backend.Forever) +} + +// StateV2 is a local disk state +type StateV2 struct { + services.ResourceHeader + Spec StateSpecV2 `json:"spec"` +} + +// CheckAndSetDefaults checks and sets defaults value for state +func (s *StateV2) CheckAndSetDefaults() error { + s.Kind = services.KindState + s.Version = services.V2 + // for state resource name does not matter + if s.Metadata.Name == "" { + s.Metadata.Name = stateName + } + if err := s.Metadata.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + return nil +} + +// StateSpecV2 is a state spec +type StateSpecV2 struct { + Rotation services.Rotation `json:"rotation"` +} + +// IdentityV2 specifies local host identity +type IdentityV2 struct { + services.ResourceHeader + Spec IdentitySpecV2 `json:"spec"` +} + +// CheckAndSetDefaults checks and sets defaults value for identity +func (s *IdentityV2) CheckAndSetDefaults() error { + s.Kind = services.KindIdentity + s.Version = services.V2 + if err := s.Metadata.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if len(s.Spec.Key) == 0 { + return trace.BadParameter("missing parameter Key") + } + if len(s.Spec.SSHCert) == 0 { + return trace.BadParameter("missing parameter SSHCert") + } + if len(s.Spec.TLSCert) == 0 { + return trace.BadParameter("missing parameter TLSCert") + } + if len(s.Spec.TLSCACerts) == 0 { + return trace.BadParameter("missing parameter TLSCACerts") + } + return nil +} + +// IdentitySpecV2 is PEM encoded data with host identity +type IdentitySpecV2 struct { + // Key is a PEM encoded private key + Key []byte `json:"key,omitempty"` + // SSHCert is a PEM encoded SSH host cert + SSHCert []byte `json:"ssh_cert,omitempty"` + // TLSCert is a PEM encoded TLS x509 client certificate + TLSCert []byte `json:"tls_cert,omitempty"` + // TLSCACert is a list of PEM encoded TLS x509 certificate of certificate authority + // associated with auth server services + TLSCACerts [][]byte `json:"tls_ca_certs,omitempty"` +} + +// IdentitySpecV2Schema is a schema for identity spec +const IdentitySpecV2Schema = `{ + "type": "object", + "additionalProperties": false, + "required": ["key", "ssh_cert", "tls_cert", "tls_ca_certs"], + "properties": { + "key": {"type": "string"}, + "ssh_cert": {"type": "string"}, + "tls_cert": {"type": "string"}, + "tls_ca_certs": { + "type": "array", + "items": {"type": "string"} + } + } +}` + +// GetIdentitySchema returns JSON Schema for cert authorities +func GetIdentitySchema() string { + return fmt.Sprintf(services.V2SchemaTemplate, services.MetadataSchema, IdentitySpecV2Schema, services.DefaultDefinitions) +} + +// StateSpecV2Schema is a schema for local server state +const StateSpecV2Schema = `{ + "type": "object", + "additionalProperties": false, + "required": ["rotation"], + "properties": { + "rotation": %v + } +}` + +// GetStateSchema returns JSON Schema for cert authorities +func GetStateSchema() string { + return fmt.Sprintf(services.V2SchemaTemplate, services.MetadataSchema, fmt.Sprintf(StateSpecV2Schema, services.RotationSchema), services.DefaultDefinitions) +} diff --git a/lib/backend/backend.go b/lib/backend/backend.go index 619bfaa03061f..ee6b00ddddc7a 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -47,6 +47,11 @@ type Backend interface { UpsertVal(bucket []string, key string, val []byte, ttl time.Duration) error // GetVal return a value for a given key in the bucket GetVal(path []string, key string) ([]byte, error) + // CompareAndSwapVal compares and swaps values in atomic operation, + // succeeds if prevVal matches the value stored in the database, + // requires prevVal as a non-empty value. Returns trace.CompareFailed + // in case if value did not match. + CompareAndSwapVal(bucket []string, key string, val []byte, prevVal []byte, ttl time.Duration) error // DeleteKey deletes a key in a bucket DeleteKey(bucket []string, key string) error // DeleteBucket deletes the bucket by a given path diff --git a/lib/backend/boltbk/boltbk.go b/lib/backend/boltbk/boltbk.go index 044e0f14f7b26..a0e2a3e975dbd 100644 --- a/lib/backend/boltbk/boltbk.go +++ b/lib/backend/boltbk/boltbk.go @@ -17,6 +17,7 @@ limitations under the License. package boltbk import ( + "bytes" "encoding/json" "path/filepath" "sort" @@ -141,6 +142,51 @@ func (b *BoltBackend) CreateVal(bucket []string, key string, val []byte, ttl tim return trace.Wrap(err) } +// CompareAndSwapVal compares and swap values in atomic operation, +// succeeds if prevVal matches the value stored in the databases, +// requires prevVal as a non-empty value. Returns trace.CompareFailed +// in case if value did not match +func (b *BoltBackend) CompareAndSwapVal(bucket []string, key string, newData []byte, prevData []byte, ttl time.Duration) error { + if len(prevData) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + v := &kv{ + Created: b.clock.Now().UTC(), + Value: newData, + TTL: ttl, + } + newEncodedData, err := json.Marshal(v) + if err != nil { + return trace.Wrap(err) + } + err = b.db.Update(func(tx *bolt.Tx) error { + bkt, err := GetBucket(tx, bucket) + if err != nil { + if trace.IsNotFound(err) { + return trace.CompareFailed("key %q is not found", key) + } + return trace.Wrap(err) + } + currentData := bkt.Get([]byte(key)) + if currentData == nil { + _, err := GetBucket(tx, append(bucket, key)) + if err == nil { + return trace.BadParameter("key %q is a bucket", key) + } + return trace.CompareFailed("%v %v is not found", bucket, key) + } + var currentVal kv + if err := json.Unmarshal(currentData, ¤tVal); err != nil { + return trace.Wrap(err) + } + if bytes.Compare(prevData, currentVal.Value) != 0 { + return trace.CompareFailed("%q is not matching expected value", key) + } + return boltErr(bkt.Put([]byte(key), newEncodedData)) + }) + return trace.Wrap(err) +} + func (b *BoltBackend) upsertVal(path []string, key string, val []byte, ttl time.Duration) error { v := &kv{ Created: b.clock.Now().UTC(), diff --git a/lib/backend/boltbk/boltbk_test.go b/lib/backend/boltbk/boltbk_test.go index 355729c521e4d..e938af0ec5d93 100644 --- a/lib/backend/boltbk/boltbk_test.go +++ b/lib/backend/boltbk/boltbk_test.go @@ -60,6 +60,10 @@ func (s *BoltSuite) TestBasicCRUD(c *C) { s.suite.BasicCRUD(c) } +func (s *BoltSuite) TestCompareAndSwap(c *C) { + s.suite.CompareAndSwap(c) +} + func (s *BoltSuite) TestExpiration(c *C) { s.suite.Expiration(c) } diff --git a/lib/backend/dir/impl.go b/lib/backend/dir/impl.go index 845738643385d..c5c997c1bd5a9 100644 --- a/lib/backend/dir/impl.go +++ b/lib/backend/dir/impl.go @@ -17,6 +17,7 @@ limitations under the License. package dir import ( + "bytes" "io" "io/ioutil" "os" @@ -174,6 +175,57 @@ func (bk *Backend) CreateVal(bucket []string, key string, val []byte, ttl time.D return trace.Wrap(bk.applyTTL(dirPath, key, ttl)) } +// CompareAndSwapVal compares and swap values in atomic operation +func (bk *Backend) CompareAndSwapVal(bucket []string, key string, val []byte, prevVal []byte, ttl time.Duration) error { + if len(prevVal) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + // do not allow keys that start with a dot + if key[0] == reservedPrefix { + return trace.BadParameter("invalid key: '%s'. Key names cannot start with '.'", key) + } + // create the directory: + dirPath := path.Join(bk.RootDir, path.Join(bucket...)) + err := os.MkdirAll(dirPath, defaultDirMode) + if err != nil { + return trace.ConvertSystemError(err) + } + // create the file (AKA "key"): + filename := path.Join(dirPath, key) + f, err := os.OpenFile(filename, os.O_RDWR|os.O_EXCL, defaultFileMode) + if err != nil { + err = trace.ConvertSystemError(err) + if trace.IsNotFound(err) { + return trace.CompareFailed("%v/%v did not match expected value", dirPath, key) + } + return trace.Wrap(err) + } + defer f.Close() + if err := utils.FSWriteLock(f); err != nil { + return trace.Wrap(err) + } + defer utils.FSUnlock(f) + // before writing, make sure the values are equal + oldVal, err := ioutil.ReadAll(f) + if err != nil { + return trace.ConvertSystemError(err) + } + if bytes.Compare(oldVal, prevVal) != 0 { + return trace.CompareFailed("%v/%v did not match expected value", dirPath, key) + } + if _, err := f.Seek(0, 0); err != nil { + return trace.ConvertSystemError(err) + } + if err := f.Truncate(0); err != nil { + return trace.ConvertSystemError(err) + } + n, err := f.Write(val) + if err == nil && n < len(val) { + return trace.Wrap(io.ErrShortWrite) + } + return trace.Wrap(bk.applyTTL(dirPath, key, ttl)) +} + // UpsertVal updates or inserts value with a given TTL into a bucket // ForeverTTL for no TTL func (bk *Backend) UpsertVal(bucket []string, key string, val []byte, ttl time.Duration) error { @@ -280,11 +332,7 @@ func (bk *Backend) DeleteBucket(parent []string, bucket string) error { func removeFiles(dir string) error { d, err := os.Open(dir) if err != nil { - err = trace.ConvertSystemError(err) - if !trace.IsNotFound(err) { - return err - } - return nil + return trace.ConvertSystemError(err) } defer d.Close() names, err := d.Readdirnames(-1) @@ -308,6 +356,10 @@ func removeFiles(dir string) error { if err != nil { return err } + } else if fi.IsDir() { + if err := removeFiles(path); err != nil { + return err + } } } return nil diff --git a/lib/backend/dir/impl_test.go b/lib/backend/dir/impl_test.go index 5248a7cb81a7b..823b16bd5643f 100644 --- a/lib/backend/dir/impl_test.go +++ b/lib/backend/dir/impl_test.go @@ -88,6 +88,14 @@ func (s *Suite) TestConcurrentOperations(c *check.C) { c.Assert(err, check.IsNil) }(i) + go func(cnt int) { + err := s.bk.CompareAndSwapVal(bucket, "key", []byte(value2), []byte(value1), time.Hour) + resultsC <- struct{}{} + if err != nil && !trace.IsCompareFailed(err) { + c.Assert(err, check.IsNil) + } + }(i) + go func(cnt int) { err := s.bk.CreateVal(bucket, "key", []byte(value2), time.Hour) resultsC <- struct{}{} @@ -113,12 +121,14 @@ func (s *Suite) TestConcurrentOperations(c *check.C) { go func(cnt int) { err := s.bk.DeleteBucket([]string{"concurrent"}, "bucket") + if err != nil && !trace.IsNotFound(err) { + c.Assert(err, check.IsNil) + } resultsC <- struct{}{} - c.Assert(err, check.IsNil) }(i) } timeoutC := time.After(3 * time.Second) - for i := 0; i < attempts*4; i++ { + for i := 0; i < attempts*5; i++ { select { case <-resultsC: case <-timeoutC: @@ -127,6 +137,10 @@ func (s *Suite) TestConcurrentOperations(c *check.C) { } } +func (s *Suite) TestCompareAndSwap(c *check.C) { + s.suite.CompareAndSwap(c) +} + func (s *Suite) TestCreateAndRead(c *check.C) { bucket := []string{"one", "two"} diff --git a/lib/backend/dynamo/dynamodbbk.go b/lib/backend/dynamo/dynamodbbk.go index 0b89e95a3dad2..0a679c06c73ce 100644 --- a/lib/backend/dynamo/dynamodbbk.go +++ b/lib/backend/dynamo/dynamodbbk.go @@ -471,6 +471,51 @@ func (b *DynamoDBBackend) UpsertVal(path []string, key string, val []byte, ttl t return b.createKey(fullPath, val, ttl, true) } +// CompareAndSwapVal compares and swap values in atomic operation +func (b *DynamoDBBackend) CompareAndSwapVal(path []string, key string, val []byte, prevVal []byte, ttl time.Duration) error { + if len(prevVal) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + fullPath := b.fullPath(append(path, key)...) + r := record{ + HashKey: hashKey, + FullPath: fullPath, + Value: val, + TTL: ttl, + Timestamp: time.Now().UTC().Unix(), + } + if ttl != backend.Forever { + r.Expires = aws.Int64(b.clock.Now().UTC().Add(ttl).Unix()) + } + av, err := dynamodbattribute.MarshalMap(r) + if err != nil { + return trace.Wrap(err) + } + input := dynamodb.PutItemInput{ + Item: av, + TableName: aws.String(b.Tablename), + } + input.SetConditionExpression("#v = :prev") + input.SetExpressionAttributeNames(map[string]*string{ + "#v": aws.String("Value"), + }) + input.SetExpressionAttributeValues(map[string]*dynamodb.AttributeValue{ + ":prev": &dynamodb.AttributeValue{ + B: prevVal, + }, + }) + _, err = b.svc.PutItem(&input) + err = convertError(err) + if err != nil { + // in this case let's use more specific compare failed error + if trace.IsAlreadyExists(err) { + return trace.CompareFailed(err.Error()) + } + return trace.Wrap(err) + } + return nil +} + const delayBetweenLockAttempts = 100 * time.Millisecond // AcquireLock for a token diff --git a/lib/backend/dynamo/dynamodbbk_test.go b/lib/backend/dynamo/dynamodbbk_test.go index 04153c60ebbe4..11fc2af7eef94 100644 --- a/lib/backend/dynamo/dynamodbbk_test.go +++ b/lib/backend/dynamo/dynamodbbk_test.go @@ -66,6 +66,10 @@ func (s *DynamoDBSuite) TestBasicCRUD(c *C) { s.suite.BasicCRUD(c) } +func (s *DynamoDBSuite) TestCompareAndSwap(c *C) { + s.suite.CompareAndSwap(c) +} + func (s *DynamoDBSuite) TestBatchCRUD(c *C) { s.suite.BatchCRUD(c) } diff --git a/lib/backend/etcdbk/etcd.go b/lib/backend/etcdbk/etcd.go index b2ed02802a507..e41bac2345430 100644 --- a/lib/backend/etcdbk/etcd.go +++ b/lib/backend/etcdbk/etcd.go @@ -166,6 +166,26 @@ func (b *bk) CreateVal(path []string, key string, val []byte, ttl time.Duration) return trace.Wrap(convertErr(err)) } +// CompareAndSwapVal compares and swap values in atomic operation, +// succeeds if prevVal matches the value stored in the databases, +// requires prevVal as a non-empty value. Returns trace.CompareFailed +// in case if value did not match +func (b *bk) CompareAndSwapVal(path []string, key string, val []byte, prevVal []byte, ttl time.Duration) error { + if len(prevVal) == 0 { + return trace.BadParameter("missing prevVal parameter, to atomically create item, use CreateVal method") + } + encodedPrev := base64.StdEncoding.EncodeToString(prevVal) + _, err := b.api.Set( + context.Background(), + b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val), + &client.SetOptions{PrevValue: encodedPrev, PrevExist: client.PrevExist, TTL: ttl}) + err = convertErr(err) + if trace.IsNotFound(err) { + return trace.CompareFailed(err.Error()) + } + return trace.Wrap(err) +} + // maxOptimisticAttempts is the number of attempts optimistic locking const maxOptimisticAttempts = 5 diff --git a/lib/backend/etcdbk/etcd_test.go b/lib/backend/etcdbk/etcd_test.go index 7bb7855a7ae1c..f959ebc4c2e27 100644 --- a/lib/backend/etcdbk/etcd_test.go +++ b/lib/backend/etcdbk/etcd_test.go @@ -103,6 +103,10 @@ func (s *EtcdSuite) TestBasicCRUD(c *C) { s.suite.BasicCRUD(c) } +func (s *EtcdSuite) TestCompareAndSwap(c *C) { + s.suite.CompareAndSwap(c) +} + func (s *EtcdSuite) TestExpiration(c *C) { s.suite.Expiration(c) } diff --git a/lib/backend/test/suite.go b/lib/backend/test/suite.go index 57ff274abe88c..56f5c5b4aad06 100644 --- a/lib/backend/test/suite.go +++ b/lib/backend/test/suite.go @@ -24,6 +24,7 @@ import ( "time" "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/trace" . "gopkg.in/check.v1" @@ -115,6 +116,35 @@ func (s *BackendSuite) BasicCRUD(c *C) { c.Assert(trace.IsNotFound(err), Equals, true, Commentf("%#v", err)) } +// CompareAndSwap tests compare and swap functionality +func (s *BackendSuite) CompareAndSwap(c *C) { + bucket := []string{"test", "cas"} + + // compare and swap on non existing operation will fail + err := s.B.CompareAndSwapVal(bucket, "one", []byte("1"), []byte("2"), backend.Forever) + fixtures.ExpectCompareFailed(c, err) + + err = s.B.CreateVal(bucket, "one", []byte("1"), backend.Forever) + c.Assert(err, IsNil) + + // success CAS! + err = s.B.CompareAndSwapVal(bucket, "one", []byte("2"), []byte("1"), backend.Forever) + c.Assert(err, IsNil) + + val, err := s.B.GetVal(bucket, "one") + c.Assert(err, IsNil) + c.Assert(string(val), Equals, "2") + + // value has been updated - not '1' any more + err = s.B.CompareAndSwapVal(bucket, "one", []byte("3"), []byte("1"), backend.Forever) + fixtures.ExpectCompareFailed(c, err) + + // existing value has not been changed by the failed CAS operation + val, err = s.B.GetVal(bucket, "one") + c.Assert(err, IsNil) + c.Assert(string(val), Equals, "2") +} + // BatchCRUD tests batch CRUD operations if supported by the backend func (s *BackendSuite) BatchCRUD(c *C) { getter, ok := s.B.(backend.ItemsGetter) diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 9b4d0759d3994..29acd28f4ac8b 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -764,7 +764,7 @@ func (k *KeyPair) Identity() (*auth.Identity, error) { } else { certBytes = []byte(k.Cert) } - return auth.ReadIdentityFromKeyPair(keyBytes, certBytes, []byte(k.TLSCert), []byte(k.TLSCACert)) + return auth.ReadIdentityFromKeyPair(keyBytes, certBytes, []byte(k.TLSCert), [][]byte{[]byte(k.TLSCACert)}) } // Authority is a host or user certificate authority that diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 199b9a0c0e4e1..fb6e85bad23e6 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -244,6 +244,9 @@ var ( // ReportingPeriod is a period for reports in logs ReportingPeriod = 5 * time.Minute + + // HighResPollingPeriod is a default high resolution polling period + HighResPollingPeriod = 5 * time.Second ) // Default connection limits, they can be applied separately on any of the Teleport @@ -272,6 +275,10 @@ const ( // CertDuration is a default certificate duration // 12 is default as it' longer than average working day (I hope so) CertDuration = 12 * time.Hour + // RotationGracePeriod is a default rotation period for graceful + // certificate rotations, by default to set to maximum allowed user + // cert duration + RotationGracePeriod = MaxCertDuration ) // list of roles teleport service can run as: diff --git a/lib/fixtures/fixtures.go b/lib/fixtures/fixtures.go index 73cbaa9709d92..d1233b356cf1c 100644 --- a/lib/fixtures/fixtures.go +++ b/lib/fixtures/fixtures.go @@ -11,27 +11,27 @@ import ( // ExpectNotFound expects not found error func ExpectNotFound(c *check.C, err error) { - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected NotFound, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected NotFound, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectBadParameter expects bad parameter error func ExpectBadParameter(c *check.C, err error) { - c.Assert(trace.IsBadParameter(err), check.Equals, true, check.Commentf("expected BadParameter, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsBadParameter(err), check.Equals, true, check.Commentf("expected BadParameter, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectCompareFailed expects compare failed error func ExpectCompareFailed(c *check.C, err error) { - c.Assert(trace.IsCompareFailed(err), check.Equals, true, check.Commentf("expected CompareFailed, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsCompareFailed(err), check.Equals, true, check.Commentf("expected CompareFailed, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectAccessDenied expects error to be access denied func ExpectAccessDenied(c *check.C, err error) { - c.Assert(trace.IsAccessDenied(err), check.Equals, true, check.Commentf("expected AccessDenied, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsAccessDenied(err), check.Equals, true, check.Commentf("expected AccessDenied, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // ExpectAlreadyExists expects already exists error func ExpectAlreadyExists(c *check.C, err error) { - c.Assert(trace.IsAlreadyExists(err), check.Equals, true, check.Commentf("expected AlreadyExists, got %T %#v at %v", err, err, string(debug.Stack()))) + c.Assert(trace.IsAlreadyExists(err), check.Equals, true, check.Commentf("expected AlreadyExists, got %T %v at %v", trace.Unwrap(err), err, string(debug.Stack()))) } // DeepCompare uses gocheck DeepEquals but provides nice diff if things are not equal diff --git a/lib/httplib/httplib.go b/lib/httplib/httplib.go index ee7d76ee33f84..e07c6b1e791d1 100644 --- a/lib/httplib/httplib.go +++ b/lib/httplib/httplib.go @@ -104,9 +104,9 @@ func ReadJSON(r *http.Request, val interface{}) error { func ConvertResponse(re *roundtrip.Response, err error) (*roundtrip.Response, error) { if err != nil { if uerr, ok := err.(*url.Error); ok && uerr != nil && uerr.Err != nil { - return nil, trace.Wrap(uerr.Err) + return nil, trace.ConnectionProblem(uerr.Err, "failed connecting to endpoint") } - return nil, trace.Wrap(err) + return nil, trace.ConvertSystemError(err) } return re, trace.ReadError(re.Code(), re.Bytes()) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index fb8c63d9ece8a..a6797f5d3e6d8 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -76,6 +76,12 @@ type remoteSite struct { // remoteAccessPoint provides access to a cached subset of the Auth Server API of // the remote cluster this site belongs to. remoteAccessPoint auth.AccessPoint + + // remoteCA is a remote certificate authority as recorded by the client + // if the certificate authority changes rotation status compared + // to the last recorded state the reverse tunnel will force reconnect + // to re-create the client with new settings + remoteCA services.CertAuthority } func (s *remoteSite) getRemoteClient() (auth.ClientI, bool, error) { @@ -287,6 +293,24 @@ func (s *remoteSite) GetLastConnected() time.Time { return connInfo.GetLastHeartbeat() } +func (s *remoteSite) compareAndSwapCertAuthority(ca services.CertAuthority) error { + s.Lock() + defer s.Unlock() + + if s.remoteCA == nil { + s.remoteCA = ca + return nil + } + + rotation := s.remoteCA.GetRotation() + if rotation.Matches(ca.GetRotation()) { + s.remoteCA = ca + return nil + } + s.remoteCA = ca + return trace.CompareFailed("remote certificate authority rotation has been updated") +} + func (s *remoteSite) periodicSendDiscoveryRequests() { ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) defer ticker.Stop() @@ -307,68 +331,85 @@ func (s *remoteSite) periodicSendDiscoveryRequests() { } } -// DELETE IN: 2.6.0 -// attemptCertExchange tries to exchange TLS certificates with remote -// clusters that have upgraded to 2.5.0 -func (s *remoteSite) attemptCertExchange() error { - // this logic is explicitly using the local non cached - // client as it has to have write access to the auth server - localCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ +// updateCertAuthorities updates local and remote cert authorities +func (s *remoteSite) updateCertAuthorities() error { + // update main cluster cert authorities on the remote side + // remote side makes sure that only relevant fields + // are updated + hostCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ Type: services.HostCA, DomainName: s.srv.ClusterName, }, false) if err != nil { return trace.Wrap(err) } - re, err := s.remoteClient.ExchangeCerts(auth.ExchangeCertsRequest{ - PublicKey: localCA.GetCheckingKeys()[0], - TLSCert: localCA.GetTLSKeyPairs()[0].Cert, - }) + err = s.remoteClient.RotateExternalCertAuthority(hostCA) + if err != nil { + return trace.Wrap(err) + } + + userCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ + Type: services.UserCA, + DomainName: s.srv.ClusterName, + }, false) + if err != nil { + return trace.Wrap(err) + } + err = s.remoteClient.RotateExternalCertAuthority(userCA) if err != nil { return trace.Wrap(err) } - remoteCA, err := s.localClient.GetCertAuthority(services.CertAuthID{ + + // update remote cluster's host cert authoritiy on a local cluster + // local proxy is authorized to perform this operation only for + // host authorities of remote clusters. + remoteCA, err := s.remoteClient.GetCertAuthority(services.CertAuthID{ Type: services.HostCA, DomainName: s.domainName, }, false) if err != nil { return trace.Wrap(err) } - _, err = s.localClient.ExchangeCerts(auth.ExchangeCertsRequest{ - PublicKey: remoteCA.GetCheckingKeys()[0], - TLSCert: re.TLSCert, - }) - return trace.Wrap(err) -} -// DELETE IN: 2.6.0 -// This logic is only relevant for upgrades from 2.5.0 to 2.6.0 -func (s *remoteSite) periodicAttemptCertExchange() { - ticker := time.NewTicker(defaults.NetworkBackoffDuration) - defer ticker.Stop() - if err := s.attemptCertExchange(); err != nil { - s.Warningf("Attempt at cert exchange failed: %v.", err) - } else { - s.Debugf("Certificate exchange has completed, going to force reconnect.") - s.srv.RemoveSite(s.domainName) - s.Close() - return + if remoteCA.GetClusterName() != s.domainName { + return trace.BadParameter( + "remote cluster sent different cluster name %v instead of expected one %v", + remoteCA.GetClusterName(), s.domainName) } + err = s.localClient.UpsertCertAuthority(remoteCA) + if err != nil { + return trace.Wrap(err) + } + + return s.compareAndSwapCertAuthority(remoteCA) +} +func (s *remoteSite) periodicUpdateCertAuthorities() { + ticker := time.NewTicker(defaults.HighResPollingPeriod) + defer ticker.Stop() for { select { case <-s.ctx.Done(): s.Debugf("Context is closing.") return case <-ticker.C: - err := s.attemptCertExchange() + err := s.updateCertAuthorities() if err != nil { - s.Warningf("Could not perform certificate exchange: %v.", trace.DebugReport(err)) + switch { + case trace.IsNotFound(err): + s.Debugf("Remote cluster %v does not support cert authorities rotation yet.", s.domainName) + case trace.IsCompareFailed(err): + s.Infof("Remote cluster has updated certificate authorities, going to force reconnect.") + s.srv.RemoveSite(s.domainName) + s.Close() + return + case trace.IsConnectionProblem(err): + s.Debugf("Remote cluster %v is offline.", s.domainName) + default: + s.Warningf("Could not perform cert authorities updated: %v.", trace.DebugReport(err)) + } } else { - s.Debugf("Certificate exchange has completed, going to force reconnect.") - s.srv.RemoveSite(s.domainName) - s.Close() - return + s.Debugf("Certificate authorities updated.") } } } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index df29b50441709..c785568c73ff9 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -827,7 +827,7 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { remoteSite.localClient = srv.localAuthClient remoteSite.localAccessPoint = srv.localAccessPoint - clt, isLegacyRemoteCluster, err := remoteSite.getRemoteClient() + clt, _, err := remoteSite.getRemoteClient() if err != nil { return nil, trace.Wrap(err) } @@ -852,11 +852,7 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { remoteSite.certificateCache = certificateCache go remoteSite.periodicSendDiscoveryRequests() - - // if remote cluster is legacy, attempt periodic certificate exchanges - if isLegacyRemoteCluster { - go remoteSite.periodicAttemptCertExchange() - } + go remoteSite.periodicUpdateCertAuthorities() return remoteSite, nil } diff --git a/lib/service/cfg.go b/lib/service/cfg.go index 8844e9d421551..9ef7ea5a9f837 100644 --- a/lib/service/cfg.go +++ b/lib/service/cfg.go @@ -140,6 +140,10 @@ type Config struct { // UploadEventsC is a channel for upload events // used in tests UploadEventsC chan *events.UploadEvent `json:"-"` + + // FileDescriptors is an optional list of file descriptors for the process + // to inherit and use for listeners, used during in process updates + FileDescriptors []FileDescriptor } // ApplyToken assigns a given token to all internal services but only if token diff --git a/lib/service/connect.go b/lib/service/connect.go new file mode 100644 index 0000000000000..814c6803576c9 --- /dev/null +++ b/lib/service/connect.go @@ -0,0 +1,409 @@ +package service + +import ( + "time" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" +) + +// connectToAuthService attempts to login into the auth servers specified in the +// configuration. Returns 'true' if successful +func (process *TeleportProcess) connectToAuthService(role teleport.Role) (*Connector, error) { + connector, err := process.connect(role) + if err != nil { + return nil, trace.Wrap(err) + } + process.addConnector(connector) + return connector, nil +} + +func (process *TeleportProcess) connect(role teleport.Role) (*Connector, error) { + // TODO (klizhentas) migrations should create state here + state, err := process.storage.GetState(role) + if err != nil { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + // no state recorded means that this is the first connect + // process haven't connected yet, so we expect the token to exist + return process.firstTimeConnect(role) + } + + identity, err := process.GetIdentity(role) + if err != nil { + return nil, trace.Wrap(err) + } + + rotation := state.Spec.Rotation + + switch rotation.State { + // rotation is on standby, so just use whatever is current + case "", services.RotationStateStandby: + // admin and auth are a bit special, as it does not need clients + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + }, nil + } + log.Infof("Connecting to the cluster %v with TLS client certificate.", identity.ClusterName) + client, err := newClient(process.Config.AuthServers, identity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{Client: client, ClientIdentity: identity, ServerIdentity: identity}, nil + case services.RotationStateInProgress: + switch rotation.Phase { + case services.RotationPhaseUpdateClients: + // in this phase, clients should use updated credentials, + // while servers should use old credentials to answer auth requests + newIdentity, err := process.storage.ReadIdentity(auth.IdentityReplacement, role) + if err != nil { + return nil, trace.Wrap(err) + } + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: newIdentity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + }, nil + } + client, err := newClient(process.Config.AuthServers, newIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{ + Client: client, + ClientIdentity: newIdentity, + ServerIdentity: identity, + }, nil + case services.RotationPhaseUpdateServers: + // in this phase, servers and clients are using new identity, but the + // identity is still set up to trust the old certificate authority certificates + newIdentity, err := process.storage.ReadIdentity(auth.IdentityReplacement, role) + if err != nil { + return nil, trace.Wrap(err) + } + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: newIdentity, + ServerIdentity: newIdentity, + AuthServer: process.getLocalAuth(), + }, nil + } + client, err := newClient(process.Config.AuthServers, newIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{ + Client: client, + ClientIdentity: newIdentity, + ServerIdentity: newIdentity, + }, nil + case services.RotationPhaseRollback: + // in rollback phase, clients and servers should switch back + // to the old certificate authority-issued credentials, + // but new certificate authority should be trusted + // because not all clients can update at the same time + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + return &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + }, nil + } + client, err := newClient(process.Config.AuthServers, identity) + if err != nil { + return nil, trace.Wrap(err) + } + return &Connector{ + Client: client, + ClientIdentity: identity, + ServerIdentity: identity, + }, nil + default: + return nil, trace.BadParameter("unsupported rotation phase: %q", rotation.Phase) + } + default: + return nil, trace.BadParameter("unsupported rotation state: %q", rotation.State) + } +} + +func (process *TeleportProcess) firstTimeConnect(role teleport.Role) (*Connector, error) { + id := auth.IdentityID{ + Role: role, + HostUUID: process.Config.HostUUID, + NodeName: process.Config.Hostname, + } + additionalPrincipals, err := process.getAdditionalPrincipals(role) + if err != nil { + return nil, trace.Wrap(err) + } + var identity *auth.Identity + if process.getLocalAuth() != nil { + // Auth service is on the same host, no need to go though the invitation + // procedure + log.Debugf("This server has local Auth server started, using it to add role to the cluster.") + identity, err = auth.LocalRegister(id, process.getLocalAuth(), additionalPrincipals) + } else { + // Auth server is remote, so we need a provisioning token + if process.Config.Token == "" { + return nil, trace.BadParameter("%v must join a cluster and needs a provisioning token", role) + } + log.Infof("Joining the cluster with a token %v.", process.Config.Token) + identity, err = auth.Register(process.Config.DataDir, process.Config.Token, id, process.Config.AuthServers, additionalPrincipals) + } + if err != nil { + return nil, trace.Wrap(err) + } + + log.Infof("%v has successfully registered with the cluster.", role) + var connector *Connector + if role == teleport.RoleAdmin || role == teleport.RoleAuth { + connector = &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + AuthServer: process.getLocalAuth(), + } + } else { + client, err := newClient(process.Config.AuthServers, identity) + if err != nil { + return nil, trace.Wrap(err) + } + connector = &Connector{ + ClientIdentity: identity, + ServerIdentity: identity, + Client: client, + } + } + + // sync local rotation state to match remote rotation state + ca, err := connector.GetCertAuthority(services.CertAuthID{ + DomainName: connector.ClientIdentity.ClusterName, + Type: services.HostCA, + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + + err = process.storage.WriteIdentity(auth.IdentityCurrent, *identity) + if err != nil { + log.Warningf("%v failed to write identity: %v", err) + } + + err = process.storage.WriteState(role, auth.StateV2{ + Spec: auth.StateSpecV2{ + Rotation: ca.GetRotation(), + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + log.Infof("%v has successfully wrote credentials and state to disk..", role) + return connector, nil +} + +// periodicSyncRotationState checks rotation state periodically and +// takes action if necessary +func (process *TeleportProcess) periodicSyncRotationState() error { + t := time.NewTicker(defaults.HighResPollingPeriod) + defer t.Stop() + for { + select { + case <-t.C: + needsReload, err := process.syncRotationState() + if err != nil { + log.Warningf("Failed to sync rotation state: %v", trace.DebugReport(err)) + } else if needsReload { + // TODO: set context? + process.BroadcastEvent(Event{Name: TeleportReloadEvent}) + return nil + } + case <-process.Exiting(): + return nil + } + } +} + +// syncRotationState compares cluster rotation state with local services state +// and performs rotation if necessary +func (process *TeleportProcess) syncRotationState() (bool, error) { + var needsReload bool + connectors := process.getConnectors() + for _, conn := range connectors { + reload, err := process.syncServiceRotationState(conn) + if err != nil { + return false, trace.Wrap(err) + } + if reload { + needsReload = true + } + } + return needsReload, nil +} + +// syncServiceRotationState syncs up rotation state for individual service (Auth, Proxy, Node) and +// if necessary, updates credentials. Returns true if the service will need to reload. +func (process *TeleportProcess) syncServiceRotationState(conn *Connector) (bool, error) { + state, err := process.storage.GetState(conn.ClientIdentity.ID.Role) + if err != nil { + return false, trace.Wrap(err) + } + ca, err := conn.GetCertAuthority(services.CertAuthID{ + DomainName: conn.ClientIdentity.ClusterName, + Type: services.HostCA, + }, false) + if err != nil { + return false, trace.Wrap(err) + } + return process.rotate(conn, *state, ca.GetRotation()) +} + +// rotate is called to check if rotation should be triggered locally +func (process *TeleportProcess) rotate(conn *Connector, localState auth.StateV2, remote services.Rotation) (bool, error) { + id := conn.ClientIdentity.ID + local := localState.Spec.Rotation + if local.Matches(remote) { + // nothing to do, local state and rotation state are in sync + return false, nil + } + + additionalPrincipals, err := process.getAdditionalPrincipals(id.Role) + if err != nil { + return false, trace.Wrap(err) + } + + storage := process.storage + + const outOfSync = "%v and cluster rotation state (%v) is out of sync with local (%v). Clear local state and re-register this %v." + + writeStateAndIdentity := func(name string, identity *auth.Identity) error { + err = storage.WriteIdentity(name, *identity) + if err != nil { + return trace.Wrap(err) + } + localState.Spec.Rotation = remote + err = storage.WriteState(id.Role, localState) + if err != nil { + return trace.Wrap(err) + } + return nil + } + + // now, need to evaluate what is exact difference, there are + // several supported scenarios, that this logic should handle + switch remote.State { + case "", services.RotationStateStandby: + switch local.State { + // great, nothing to do, it could happen + // that the old node came up and missed the whole rotation + // rollback cycle, but there is nothing we can do at this point + case "", services.RotationStateStandby: + if len(additionalPrincipals) != 0 && !conn.ServerIdentity.HasPrincipals(additionalPrincipals) { + log.Infof("%v has updated principals to %q, going to request new principals and update") + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = storage.WriteIdentity(auth.IdentityCurrent, *identity) + if err != nil { + return false, trace.Wrap(err) + } + return true, nil + } + return false, nil + // local rotation is in progress, if it has + // just rolled back + case services.RotationStateInProgress: + // rollback phase has completed, all services + // will receive new identities + if local.Phase != services.RotationPhaseRollback && local.CurrentID != remote.CurrentID { + return false, trace.CompareFailed(outOfSync, id.Role, remote, local, id.Role) + } + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityCurrent, identity) + if err != nil { + return false, trace.Wrap(err) + } + return true, nil + default: + return false, trace.BadParameter("unsupported state: %q", localState) + } + case services.RotationStateInProgress: + switch remote.Phase { + case services.RotationPhaseStandby, "": + // nothing to do + return false, nil + case services.RotationPhaseUpdateClients: + // only allow transition in case if local rotation state is standby + // so this server is in the "clean" state + if local.State != services.RotationStateStandby && local.State != "" { + return false, trace.CompareFailed(outOfSync, id.Role, remote, local, id.Role) + } + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityReplacement, identity) + if err != nil { + return false, trace.Wrap(err) + } + // update of the servers and client requires reload of teleport process + return true, nil + case services.RotationPhaseUpdateServers: + // allow transition to this phase only if the previous + // phase was UpdateClients - as this is a happy scenario + // when all phases are traversed in succession + if local.Phase != services.RotationPhaseUpdateClients && local.CurrentID != remote.CurrentID { + return false, trace.CompareFailed(outOfSync, id.Role, remote, local, id.Role) + } + // write replacement identity as a current identity and reload the server + replacement, err := storage.ReadIdentity(auth.IdentityReplacement, id.Role) + if err != nil { + return false, trace.Wrap(err) + } + err = writeStateAndIdentity(auth.IdentityCurrent, replacement) + if err != nil { + return false, trace.Wrap(err) + } + // update of the servers requires reload of teleport process + return true, nil + case services.RotationPhaseRollback: + // allow transition to this phase from any other local phase + // because it will be widely used to recover cluster state to + // the previously valid state + // client will re-register to receive credentials signed by "old" CA + identity, err := conn.ReRegister(additionalPrincipals) + if err != nil { + return false, trace.Wrap(err) + } + // update of the servers requires reload of teleport process + err = writeStateAndIdentity(auth.IdentityCurrent, identity) + if err != nil { + return false, trace.Wrap(err) + } + return true, nil + default: + return false, trace.BadParameter("unsupported phase: %q", remote.Phase) + } + default: + return false, trace.BadParameter("unsupported state: %q", remote.State) + } +} + +func newClient(authServers []utils.NetAddr, identity *auth.Identity) (*auth.Client, error) { + tlsConfig, err := identity.TLSConfig() + if err != nil { + return nil, trace.Wrap(err) + } + return auth.NewTLSClient(authServers, tlsConfig) +} diff --git a/lib/service/service.go b/lib/service/service.go index de2cbda6c19de..c2ac9759ac89e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -113,6 +113,14 @@ const ( // TeleportExitEvent is generated when the Teleport process begins closing // all listening sockets and exiting. TeleportExitEvent = "TeleportExit" + + // TeleportReloadEvent is generated when the Teleport process needs + // to reload its configuration + TeleportReloadEvent = "TeleportReload" + + // TeleportStartEvent is generated when the Teleport process starts + // successfully + TeleportStartEvent = "TeleportStart" ) // RoleConfig is a configuration for a server role (either proxy or node) @@ -128,8 +136,33 @@ type RoleConfig struct { // Connector has all resources process needs to connect // to other parts of the cluster: client and identity type Connector struct { - Identity *auth.Identity - Client *auth.Client + // ClientIdentity is the identity to be used in client connections + ClientIdentity *auth.Identity + // ServerIdentity is the identity to be used in servers + ServerIdentity *auth.Identity + // Client is authenticated client + Client *auth.Client + // AuthServer is auth server, used for connector + // associated with auth server + AuthServer *auth.AuthServer +} + +// ReRegister receives new identity credentials for proxy, node and auth server +// in case of auth server the role is 'TeleportAdmin' and instead of using +// client it uses the local auth server +func (c *Connector) ReRegister(additionalPrincipals []string) (*auth.Identity, error) { + if c.ClientIdentity.ID.Role == teleport.RoleAdmin || c.ClientIdentity.ID.Role == teleport.RoleAuth { + return auth.GenerateIdentity(c.AuthServer, c.ClientIdentity.ID, additionalPrincipals) + } + return auth.ReRegister(c.Client, c.ClientIdentity.ID, additionalPrincipals) +} + +func (c *Connector) GetCertAuthority(id services.CertAuthID, loadPrivateKeys bool) (services.CertAuthority, error) { + if c.ClientIdentity.ID.Role == teleport.RoleAdmin || c.ClientIdentity.ID.Role == teleport.RoleAuth { + return c.AuthServer.GetCertAuthority(id, loadPrivateKeys) + } else { + return c.Client.GetCertAuthority(id, loadPrivateKeys) + } } // TeleportProcess structure holds the state of the Teleport daemon, controlling @@ -149,6 +182,10 @@ type TeleportProcess struct { // identities of this process (credentials to auth sever, basically) Identities map[teleport.Role]*auth.Identity + + // connectors is a list of connected clients and their identities + connectors map[teleport.Role]*Connector + // registeredListeners keeps track of all listeners created by the process // used to pass listeners to child processes during live reload registeredListeners []RegisteredListener @@ -160,6 +197,9 @@ type TeleportProcess struct { // during restart used to collect their status in case if the // child process crashed. forkedPIDs []int + + // storage is a server local storage + storage *auth.ProcessStorage } // GetAuthServer returns the process' auth server @@ -196,22 +236,25 @@ func (process *TeleportProcess) findStaticIdentity(id auth.IdentityID) (*auth.Id return nil, trace.NotFound("identity %v not found", &id) } -// readIdentity reads identity from disk and resets the local state -func (process *TeleportProcess) readIdentity(role teleport.Role) (*auth.Identity, error) { +// getConnectors returns a copy of the identities registered for auth server +func (process *TeleportProcess) getConnectors() []*Connector { process.Lock() defer process.Unlock() - id := auth.IdentityID{ - Role: role, - HostUUID: process.Config.HostUUID, - NodeName: process.Config.Hostname, + out := make([]*Connector, 0, len(process.connectors)) + for role := range process.connectors { + out = append(out, process.connectors[role]) } - identity, err := auth.ReadIdentity(process.Config.DataDir, id) - if err != nil { - return nil, trace.Wrap(err) - } - process.Identities[role] = identity - return identity, nil + return out +} + +// addConnector adds connector to registered connectors list, +// it will overwrite the connector for the same role +func (process *TeleportProcess) addConnector(connector *Connector) { + process.Lock() + defer process.Unlock() + + process.connectors[connector.ClientIdentity.ID.Role] = connector } // GetIdentity returns the process identity (credentials to the auth server) for a given @@ -227,63 +270,111 @@ func (process *TeleportProcess) GetIdentity(role teleport.Role) (i *auth.Identit if found { return i, nil } - + i, err = process.storage.ReadIdentity(auth.IdentityCurrent, role) id := auth.IdentityID{ Role: role, HostUUID: process.Config.HostUUID, NodeName: process.Config.Hostname, } - i, err = auth.ReadIdentity(process.Config.DataDir, id) if err != nil { - if trace.IsNotFound(err) { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + if role == teleport.RoleAdmin { + // for admin identity use local auth server + // because admin identity is requested by auth server + // itself + principals, err := process.getAdditionalPrincipals(role) + if err != nil { + return nil, trace.Wrap(err) + } + i, err = auth.GenerateIdentity(process.localAuth, id, principals) + } else { // try to locate static identity provided in the file i, err = process.findStaticIdentity(id) if err != nil { return nil, trace.Wrap(err) } log.Infof("Found static identity %v in the config file, writing to disk.", &id) - if err = auth.WriteIdentity(process.Config.DataDir, i); err != nil { + if err = process.storage.WriteIdentity(auth.IdentityCurrent, *i); err != nil { return nil, trace.Wrap(err) } - } else { - return nil, trace.Wrap(err) } } process.Identities[role] = i return i, nil } -// connectToAuthService attempts to login into the auth servers specified in the -// configuration. Returns 'true' if successful -func (process *TeleportProcess) connectToAuthService(role teleport.Role, additionalPrincipals []string) (*Connector, error) { - identity, err := process.GetIdentity(role) - if err != nil { - return nil, trace.Wrap(err) +// Process is a interface for processes +type Process interface { + // Closer closes all resources used by the process + io.Closer + // Start starts the process in a non-blocking way + Start() error + // WaitForSignals waits until the process recievies a signal + // and exits + WaitForSignals(context.Context) error + // ExportFileDescriptors exports service listeners + // file descriptors used by the process. + ExportFileDescriptors() ([]FileDescriptor, error) + // Shutdown starts graceful shutdown of the process, + // blocks until all resources are associated + Shutdown(context.Context) +} + +// NewProcess is a function that creates new teleport from config +type NewProcess func(cfg *Config) (Process, error) + +func newTeleportProcess(cfg *Config) (Process, error) { + return NewTeleport(cfg) +} + +// Run starts teleport processes, waits for signals +// and handles internal process signals. +func Run(ctx context.Context, cfg Config, newTeleport NewProcess) error { + if newTeleport == nil { + newTeleport = newTeleportProcess } - tlsConfig, err := identity.TLSConfig() + copyCfg := cfg + srv, err := newTeleport(©Cfg) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err, "Initialization failed") } - log.Infof("Connecting to the cluster %v with TLS client certificate.", identity.ClusterName) - client, err := auth.NewTLSClient(process.Config.AuthServers, tlsConfig) - if err != nil { - return nil, trace.Wrap(err) + if err := srv.Start(); err != nil { + return trace.Wrap(err, "Startup Failed") } - if len(additionalPrincipals) != 0 && !identity.HasPrincipals(additionalPrincipals) { - log.Infof("Identity %v needs principals %v, going to re-register.", identity.ID, additionalPrincipals) - if err := auth.ReRegister(process.Config.DataDir, client, identity.ID, additionalPrincipals); err != nil { - return nil, trace.Wrap(err) - } - if identity, err = process.readIdentity(role); err != nil { - return nil, trace.Wrap(err) - } - tlsConfig, err = identity.TLSConfig() - if err != nil { - return nil, trace.Wrap(err) - } +wait: + err = srv.WaitForSignals(ctx) + if err != ErrTeleportReloading { + return trace.Wrap(err) + } + fileDescriptors, err := srv.ExportFileDescriptors() + if err != nil { + warnOnErr(srv.Close()) + return trace.Wrap(err) } - // success ? we're logged in! - return &Connector{Client: client, Identity: identity}, nil + newCfg := cfg + newCfg.FileDescriptors = fileDescriptors + newSrv, err := newTeleport(&newCfg) + if err != nil { + warnOnErr(srv.Close()) + return trace.Wrap(err, "Reload failed") + } + if err := newSrv.Start(); err != nil { + warnOnErr(srv.Close()) + return trace.Wrap(err, "Startup of a reloaded process failed") + } + timeoutCtx, cancel := context.WithTimeout(ctx, defaults.DefaultIdleConnectionDuration*2) + defer cancel() + // TODO(klizhentas) wait until services are declared as started, before shutting down + // this one, otherwise some requests may fail + srv.Shutdown(timeoutCtx) + if timeoutCtx.Err() == context.DeadlineExceeded { + warnOnErr(srv.Close()) + return trace.Wrap(err, "Failed to shutdown the parent process") + } + srv = newSrv + goto wait } // NewTeleport takes the daemon configuration, instantiates all required services @@ -305,9 +396,11 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } } - importedDescriptors, err := importFileDescriptors() - if err != nil { - return nil, trace.Wrap(err) + if len(cfg.FileDescriptors) == 0 { + cfg.FileDescriptors, err = importFileDescriptors() + if err != nil { + return nil, trace.Wrap(err) + } } // if there's no host uuid initialized yet, try to read one from the @@ -345,12 +438,19 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } } + storage, err := auth.NewProcessStorage(filepath.Join(cfg.DataDir, teleport.ComponentProcess)) + if err != nil { + return nil, trace.Wrap(err) + } + process := &TeleportProcess{ Clock: clockwork.NewRealClock(), Supervisor: NewSupervisor(), Config: cfg, Identities: make(map[teleport.Role]*auth.Identity), - importedDescriptors: importedDescriptors, + connectors: make(map[teleport.Role]*Connector), + importedDescriptors: cfg.FileDescriptors, + storage: storage, } serviceStarted := false @@ -396,6 +496,10 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { warnOnErr(process.closeImportedDescriptors(teleport.ComponentProxy)) } + // TODO: klizhentas heartbeat the current state so tctl get nodes will report it + // sync rotation state periodically + process.RegisterFunc("common.rotate", process.periodicSyncRotationState) + if !serviceStarted { return nil, trace.BadParameter("all services failed to start") } @@ -573,7 +677,7 @@ func (process *TeleportProcess) initAuthService() error { } // first, create the AuthServer - authServer, identity, err := auth.Init(auth.InitConfig{ + authServer, err := auth.Init(auth.InitConfig{ Backend: b, Authority: cfg.Keygen, ClusterConfiguration: cfg.ClusterConfiguration, @@ -602,6 +706,11 @@ func (process *TeleportProcess) initAuthService() error { process.setLocalAuth(authServer) + connector, err := process.connectToAuthService(teleport.RoleAdmin) + if err != nil { + return trace.Wrap(err) + } + // second, create the API Server: it's actually a collection of API servers, // each serving requests for a "role" which is assigned to every connected // client based on their certificate (user, server, admin, etc) @@ -636,7 +745,7 @@ func (process *TeleportProcess) initAuthService() error { }) // Register TLS endpoint of the auth service - tlsConfig, err := identity.TLSConfig() + tlsConfig, err := connector.ServerIdentity.TLSConfig() if err != nil { return trace.Wrap(err) } @@ -688,7 +797,7 @@ func (process *TeleportProcess) initAuthService() error { process.RegisterFunc("auth.heartbeat.broadcast", func() error { // Heart beat auth server presence, this is not the best place for this // logic, consolidate it into auth package later - connector, err := process.connectToAuthService(teleport.RoleAdmin, nil) + connector, err := process.connectToAuthService(teleport.RoleAdmin) if err != nil { return trace.Wrap(err) } @@ -740,8 +849,16 @@ func (process *TeleportProcess) initAuthService() error { defer ticker.Stop() announce: for { + state, err := process.storage.GetState(teleport.RoleAdmin) + if err != nil { + if !trace.IsNotFound(err) { + log.Warningf("Failed to get rotation state: %v.", err) + } + } else { + srv.Spec.Rotation = state.Spec.Rotation + } srv.SetTTL(process, defaults.ServerHeartbeatTTL) - err := authServer.UpsertAuthServer(&srv) + err = authServer.UpsertAuthServer(&srv) if err != nil { log.Warningf("Failed to announce presence: %v.", err) } @@ -825,10 +942,17 @@ func (process *TeleportProcess) newLocalCache(clt auth.ClientI, cacheName []stri }) } +func (process *TeleportProcess) getRotation(role teleport.Role) (*services.Rotation, error) { + state, err := process.storage.GetState(role) + if err != nil { + return nil, trace.Wrap(err) + } + return &state.Spec.Rotation, nil +} + // initSSH initializes the "node" role, i.e. a simple SSH server connected to the auth server. func (process *TeleportProcess) initSSH() error { - process.RegisterWithAuthServer( - process.Config.Token, teleport.RoleNode, SSHIdentityEvent, nil) + process.registerWithAuthServer(teleport.RoleNode, SSHIdentityEvent) eventsC := make(chan Event) process.WaitForEvent(SSHIdentityEvent, eventsC, make(chan struct{})) @@ -878,7 +1002,7 @@ func (process *TeleportProcess) initSSH() error { s, err = regular.New(cfg.SSH.Addr, cfg.Hostname, - []ssh.Signer{conn.Identity.KeySigner}, + []ssh.Signer{conn.ServerIdentity.KeySigner}, authClient, cfg.DataDir, cfg.AdvertiseIP, @@ -894,6 +1018,7 @@ func (process *TeleportProcess) initSSH() error { regular.SetKEXAlgorithms(cfg.KEXAlgorithms), regular.SetMACAlgorithms(cfg.MACAlgorithms), regular.SetPAMConfig(cfg.SSH.PAM), + regular.SetRotationGetter(process.getRotation), ) if err != nil { return trace.Wrap(err) @@ -938,21 +1063,15 @@ func (process *TeleportProcess) initSSH() error { return nil } -// RegisterWithAuthServer uses one time provisioning token obtained earlier +// registerWithAuthServer uses one time provisioning token obtained earlier // from the server to get a pair of SSH keys signed by Auth server host // certificate authority -func (process *TeleportProcess) RegisterWithAuthServer(token string, role teleport.Role, eventName string, additionalPrincipals []string) { - cfg := process.Config - identityID := auth.IdentityID{Role: role, HostUUID: cfg.HostUUID, NodeName: cfg.Hostname} - - // this means the server has not been initialized yet, we are starting - // the registering client that attempts to connect to the auth server - // and provision the keys +func (process *TeleportProcess) registerWithAuthServer(role teleport.Role, eventName string) { var authClient *auth.Client process.RegisterFunc(fmt.Sprintf("register.%v", strings.ToLower(role.String())), func() error { retryTime := defaults.ServerHeartbeatTTL / 3 for { - connector, err := process.connectToAuthService(role, additionalPrincipals) + connector, err := process.connectToAuthService(role) if err == nil { process.BroadcastEvent(Event{Name: eventName, Payload: connector}) authClient = connector.Client @@ -966,27 +1085,6 @@ func (process *TeleportProcess) RegisterWithAuthServer(token string, role telepo if !trace.IsNotFound(err) { return trace.Wrap(err) } - // we haven't connected yet, so we expect the token to exist - if process.getLocalAuth() != nil { - // Auth service is on the same host, no need to go though the invitation - // procedure - log.Debugf("This server has local Auth server started, using it to add role to the cluster.") - err = auth.LocalRegister(cfg.DataDir, identityID, process.getLocalAuth(), additionalPrincipals) - } else { - // Auth server is remote, so we need a provisioning token - if token == "" { - return trace.BadParameter("%v must join a cluster and needs a provisioning token", role) - } - log.Infof("Joining the cluster with a token %v.", token) - err = auth.Register(cfg.DataDir, token, identityID, cfg.AuthServers, additionalPrincipals) - } - if err != nil { - log.Errorf("Failed to join the cluster: %v.", err) - time.Sleep(retryTime) - } else { - log.Infof("%v has successfully registered with the cluster.", role) - continue - } } }) @@ -1111,6 +1209,22 @@ func (process *TeleportProcess) initDiagnosticService() error { return nil } +// getAdditionalPrincipals returns a list of additional principals set up for role +func (process *TeleportProcess) getAdditionalPrincipals(role teleport.Role) ([]string, error) { + var principals []string + if process.Config.Hostname != "" { + principals = append(principals, process.Config.Hostname) + } + if process.Config.Proxy.PublicAddr.Addr != "" { + host, err := utils.Host(process.Config.Proxy.PublicAddr.Addr) + if err != nil { + return nil, trace.Wrap(err) + } + principals = append(principals, host) + } + return principals, nil +} + // initProxy gets called if teleport runs with 'proxy' role enabled. // this means it will do two things: // 1. serve a web UI @@ -1124,17 +1238,7 @@ func (process *TeleportProcess) initProxy() error { return trace.Wrap(err) } } - - var additionalPrincipals []string - if process.Config.Proxy.PublicAddr.Addr != "" { - host, err := utils.Host(process.Config.Proxy.PublicAddr.Addr) - if err != nil { - return trace.Wrap(err) - } - additionalPrincipals = []string{host} - } - - process.RegisterWithAuthServer(process.Config.Token, teleport.RoleProxy, ProxyIdentityEvent, additionalPrincipals) + process.registerWithAuthServer(teleport.RoleProxy, ProxyIdentityEvent) process.RegisterFunc("proxy.init", func() error { eventsC := make(chan Event) process.WaitForEvent(ProxyIdentityEvent, eventsC, make(chan struct{})) @@ -1271,7 +1375,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } - tlsConfig, err := conn.Identity.TLSConfig() + clientTLSConfig, err := conn.ClientIdentity.TLSConfig() if err != nil { return trace.Wrap(err) } @@ -1283,11 +1387,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // Register reverse tunnel agents pool agentPool, err := reversetunnel.NewAgentPool(reversetunnel.AgentPoolConfig{ - HostUUID: conn.Identity.ID.HostUUID, + HostUUID: conn.ServerIdentity.ID.HostUUID, Client: conn.Client, AccessPoint: accessPoint, - HostSigners: []ssh.Signer{conn.Identity.KeySigner}, - Cluster: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], + HostSigners: []ssh.Signer{conn.ServerIdentity.KeySigner}, + Cluster: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority], }) if err != nil { return trace.Wrap(err) @@ -1304,17 +1408,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { tsrv, err = reversetunnel.NewServer( reversetunnel.Config{ ID: process.Config.HostUUID, - ClusterName: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], - ClientTLS: tlsConfig, + ClusterName: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority], + ClientTLS: clientTLSConfig, Listener: listeners.reverseTunnel, - HostSigners: []ssh.Signer{conn.Identity.KeySigner}, + HostSigners: []ssh.Signer{conn.ServerIdentity.KeySigner}, LocalAuthClient: conn.Client, LocalAccessPoint: accessPoint, NewCachingAccessPoint: process.newLocalCache, Limiter: reverseTunnelLimiter, DirectClusters: []reversetunnel.DirectCluster{ { - Name: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], + Name: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority], Client: conn.Client, }, }, @@ -1392,7 +1496,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } sshProxy, err := regular.New(cfg.Proxy.SSHAddr, cfg.Hostname, - []ssh.Signer{conn.Identity.KeySigner}, + []ssh.Signer{conn.ServerIdentity.KeySigner}, accessPoint, cfg.DataDir, nil, @@ -1405,6 +1509,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { regular.SetKEXAlgorithms(cfg.KEXAlgorithms), regular.SetMACAlgorithms(cfg.MACAlgorithms), regular.SetNamespace(defaults.Namespace), + regular.SetRotationGetter(process.getRotation), ) if err != nil { return trace.Wrap(err) @@ -1532,12 +1637,17 @@ func (process *TeleportProcess) Close() error { process.Config.Keygen.Close() + var errors []error localAuth := process.getLocalAuth() if localAuth != nil { - return trace.Wrap(process.localAuth.Close()) + errors = append(errors, process.localAuth.Close()) } - return nil + if process.storage != nil { + errors = append(errors, process.storage.Close()) + } + + return trace.NewAggregate(errors...) } func validateConfig(cfg *Config) error { diff --git a/lib/service/signals.go b/lib/service/signals.go index bbcc23590ef8e..fc7d6f00ea7ca 100644 --- a/lib/service/signals.go +++ b/lib/service/signals.go @@ -124,6 +124,9 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { default: log.Infof("Ignoring %q.", signal) } + case <-process.Reloading(): + log.Infof("Exiting signal handler: process has started internal reload.") + return ErrTeleportReloading case <-ctx.Done(): process.Close() process.Wait() @@ -133,6 +136,8 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { } } +var ErrTeleportReloading = trace.CompareFailed("teleport process is reloading") + func (process *TeleportProcess) writeToSignalPipe(message string) error { signalPipe, err := process.importSignalPipe() if err != nil { @@ -248,8 +253,8 @@ func (process *TeleportProcess) createListener(listenerType, address string) (ne return listener, nil } -// exportFileDescriptors exports file descriptors to be passed to child process -func (process *TeleportProcess) exportFileDescriptors() ([]FileDescriptor, error) { +// ExportFileDescriptors exports file descriptors to be passed to child process +func (process *TeleportProcess) ExportFileDescriptors() ([]FileDescriptor, error) { var out []FileDescriptor process.Lock() defer process.Unlock() @@ -402,7 +407,7 @@ func (process *TeleportProcess) forkChild() error { log.Info("Forking child.") - listenerFiles, err := process.exportFileDescriptors() + listenerFiles, err := process.ExportFileDescriptors() if err != nil { return trace.Wrap(err) } diff --git a/lib/service/supervisor.go b/lib/service/supervisor.go index fe469759d43c6..9af92357be543 100644 --- a/lib/service/supervisor.go +++ b/lib/service/supervisor.go @@ -62,8 +62,16 @@ type Supervisor interface { // WaitForEvent waits for event to be broadcasted, if the event // was already broadcasted, payloadC will receive current event immediately - // CLose 'cancelC' channel to force WaitForEvent to return prematurely + // Close 'cancelC' channel to force WaitForEvent to return prematurely WaitForEvent(name string, eventC chan Event, cancelC chan struct{}) + + // Exiting channel will be closed when + // TeleportExitEvent will be broadcasted by any caller + Exiting() <-chan struct{} + + // Reloading channel will be closed when + // TeleportReloadEvent will be broadcasted by any caller + Reloading() <-chan struct{} } type LocalSupervisor struct { @@ -75,13 +83,25 @@ type LocalSupervisor struct { events map[string]Event eventsC chan Event eventWaiters map[string][]*waiter + closeContext context.Context signalClose context.CancelFunc + + // exitContext is closed when someone emits Exit event + exitContext context.Context + signalExit context.CancelFunc + + reloadContext context.Context + signalReload context.CancelFunc } // NewSupervisor returns new instance of initialized supervisor func NewSupervisor() Supervisor { closeContext, cancel := context.WithCancel(context.TODO()) + + exitContext, signalExit := context.WithCancel(context.TODO()) + reloadContext, signalReload := context.WithCancel(context.TODO()) + srv := &LocalSupervisor{ services: []Service{}, wg: &sync.WaitGroup{}, @@ -90,6 +110,12 @@ func NewSupervisor() Supervisor { eventWaiters: make(map[string][]*waiter), closeContext: closeContext, signalClose: cancel, + + exitContext: exitContext, + signalExit: signalExit, + + reloadContext: reloadContext, + signalReload: signalReload, } go srv.fanOut() return srv @@ -201,9 +227,25 @@ func (s *LocalSupervisor) Run() error { return s.Wait() } +func (s *LocalSupervisor) Exiting() <-chan struct{} { + return s.exitContext.Done() +} + +func (s *LocalSupervisor) Reloading() <-chan struct{} { + return s.reloadContext.Done() +} + func (s *LocalSupervisor) BroadcastEvent(event Event) { s.Lock() defer s.Unlock() + + switch event.Name { + case TeleportExitEvent: + s.signalExit() + case TeleportReloadEvent: + s.signalReload() + } + s.events[event.Name] = event log.WithFields(logrus.Fields{"event": event.String()}).Debugf("Broadcasting event.") diff --git a/lib/services/authority.go b/lib/services/authority.go index 9defe4fc40a0b..3b003def6c95c 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -169,6 +169,8 @@ type CertAuthority interface { CheckAndSetDefaults() error // SetSigningKeys sets signing keys SetSigningKeys([][]byte) error + // SetCheckingKeys sets signing keys + SetCheckingKeys([][]byte) error // AddRole adds a role to ca role list AddRole(name string) // Checkers returns public keys that can be used to check cert authorities @@ -187,6 +189,12 @@ type CertAuthority interface { SetTLSKeyPairs(keyPairs []TLSKeyPair) // GetTLSKeyPairs returns first PEM encoded TLS cert GetTLSKeyPairs() []TLSKeyPair + // GetRotation returns rotation infor + GetRotation() Rotation + // SetRotation sets rotation state + SetRotation(Rotation) + // Clone returns shallow copy of the cert authority object + Clone() CertAuthority } // CertPoolFromCertAuthorities returns certificate pools from TLS certificates @@ -296,6 +304,32 @@ type CertAuthorityV2 struct { rawObject interface{} } +// Clone returns shallow copy of the cert authority +func (c *CertAuthorityV2) Clone() CertAuthority { + out := *c + out.rawObject = nil + out.Spec.CheckingKeys = utils.CopyByteSlices(c.Spec.CheckingKeys) + out.Spec.SigningKeys = utils.CopyByteSlices(c.Spec.SigningKeys) + for i, kp := range c.Spec.TLSKeyPairs { + out.Spec.TLSKeyPairs[i] = TLSKeyPair{ + Key: utils.CopyByteSlice(kp.Key), + Cert: utils.CopyByteSlice(kp.Cert), + } + } + out.Spec.Roles = utils.CopyStrings(c.Spec.Roles) + return &out +} + +// GetRotation returns rotation infor +func (c *CertAuthorityV2) GetRotation() Rotation { + return c.Spec.Rotation +} + +// SetRotation sets rotation state +func (c *CertAuthorityV2) SetRotation(r Rotation) { + c.Spec.Rotation = r +} + // TLSCA returns TLS certificate authority func (c *CertAuthorityV2) TLSCA() (*tlsca.CertAuthority, error) { if len(c.Spec.TLSKeyPairs) == 0 { @@ -375,6 +409,12 @@ func (ca *CertAuthorityV2) SetSigningKeys(keys [][]byte) error { return nil } +// SetCheckingKeys sets SSH public keys +func (ca *CertAuthorityV2) SetCheckingKeys(keys [][]byte) error { + ca.Spec.CheckingKeys = keys + return nil +} + // GetID returns certificate authority ID - // combined type and name func (ca *CertAuthorityV2) GetID() CertAuthID { @@ -520,6 +560,193 @@ func (ca *CertAuthorityV2) CheckAndSetDefaults() error { return nil } +const ( + // RotationStateStandby is initial status of the rotation - + // nothing is being rotated + RotationStateStandby = "standby" + // RotationStateInProgress specifies that rotation is in progress + RotationStateInProgress = "in_progress" + // RotationPhaseStandby is initial phase of the rotation + // means no operations have started + RotationPhaseStandby = "standby" + // RotationPhaseUpdateClients is a phase of the rotation + // when client credentials will have to be updated and reloaded + // but servers will use and respond with old credentials + // becasue clients have no idea about new credentials at first + RotationPhaseUpdateClients = "update_clients" + // RotationPhaseUpdateServers is a phase of the rotation + // when servers will have to reload and should start serving + // TLS and SSH certificates signed by newly issued CA + RotationPhaseUpdateServers = "update_servers" + // RotationPhaseRollback means that rotation is rolling + // back to the old certificate authority + RotationPhaseRollback = "rollback" + // RotationModeManual is a manual rotation mode when all phases + // are set automatically + RotationModeManual = "manual" + // RotationModeAuto is set to go through all phases by the schedule + RotationModeAuto = "auto" +) + +// RotatePhases lists all supported rotation phases, +// used to show help +var RotatePhases = []string{ + RotationPhaseStandby, + RotationPhaseUpdateClients, + RotationPhaseUpdateServers, + RotationPhaseRollback, +} + +// Rotation is a status of the rotation of the certificate authority +type Rotation struct { + // State could be one of "init", means not active or "in_progress" + State string `json:"state,omitempty"` + // Phase is a current rotation phase + Phase string `json:"phase,omitempty"` + // Mode sets manual or automatic rotation mode + Mode string `json:"mode,omitempty"` + // CurrentID is the ID of the rotation operation + // to differentiate between rotation attempts + CurrentID string `json:"current_id"` + // Started is set to the time when rotation has been started + // in case if the state of the rotation is "in_progress" + Started time.Time `json:"started,omitempty"` + // LastRotation is a time of last rotation + // GracePeriod is a period during which old and new CA + // are valid for checking purposes, but only new CA is issuing certificates + GracePeriod Duration `json:"grace_period,omitempty"` + // LastRotated specifies the last time of the completed rotation + LastRotated time.Time `json:"last_rotated,omitempty"` + // Schedule is a rotation schedule - used in + // automatic mode based on the provided grace period to switch + // beetween phases. + Schedule RotationSchedule `json:"schedule,omitempty"` +} + +// Matches returns true if this state rotation matches +// external rotation - state, phase and rotation ID should all match, +// notice that matches is not Equals as it does not require +// all fields to be the same +func (s *Rotation) Matches(rotation Rotation) bool { + return s.CurrentID == rotation.CurrentID && s.State == rotation.State && s.Phase == rotation.Phase +} + +// LastRotatedDescription returns human friendly descriptoin +func (r *Rotation) LastRotatedDescription() string { + if r.LastRotated.IsZero() { + return "never updated" + } + return fmt.Sprintf("last rotated %v", r.LastRotated.Format(teleport.HumanDateFormatSeconds)) +} + +// PhaseDescription returns human friendly description of a current rotation phase +func (r *Rotation) PhaseDescription() string { + switch r.Phase { + case RotationPhaseStandby, "": + return "on standby" + case RotationPhaseUpdateClients: + return "rotating clients" + case RotationPhaseUpdateServers: + return "rotating servers" + case RotationPhaseRollback: + return "rolling back" + default: + return fmt.Sprintf("unknown phase: %q", r.Phase) + } +} + +// String returns user friendly information about certificate authority +func (r *Rotation) String() string { + switch r.State { + case "", RotationStateStandby: + if r.LastRotated.IsZero() { + return "never updated" + } + return fmt.Sprintf("rotated %v", r.LastRotated.Format(teleport.HumanDateFormatSeconds)) + case RotationStateInProgress: + return fmt.Sprintf("%v (mode: %v, started: %v, ending: %v)", + r.PhaseDescription(), + r.Mode, + r.Started.Format(teleport.HumanDateFormatSeconds), + r.Started.Add(r.GracePeriod.Duration).Format(teleport.HumanDateFormatSeconds), + ) + default: + return "unknown" + } +} + +// CheckAndSetDefaults checks and sets default rotation parameters +func (r *Rotation) CheckAndSetDefaults(clock clockwork.Clock) error { + switch r.Phase { + case "", RotationPhaseRollback, RotationPhaseUpdateClients, RotationPhaseUpdateServers: + default: + return trace.BadParameter("unsupported phase: %q", r.Phase) + } + switch r.Mode { + case "", RotationModeAuto, RotationModeManual: + default: + return trace.BadParameter("unsupported mode: %q", r.Mode) + } + switch r.State { + case "": + r.State = RotationStateStandby + case RotationStateStandby: + case RotationStateInProgress: + if r.CurrentID == "" { + return trace.BadParameter("set 'current_id' parameter for in progress rotation") + } + if r.Started.IsZero() { + return trace.BadParameter("set 'started' parameter for in progress rotation") + } + default: + return trace.BadParameter( + "unsupported rotation 'state': %q, supported states are: %q, %q", + r.State, RotationStateStandby, RotationStateInProgress) + } + return nil +} + +// GenerateSchedule generates schedule based on the time period, splitting +// two phases scheduled time +func GenerateSchedule(clock clockwork.Clock, gracePeriod time.Duration) (*RotationSchedule, error) { + if gracePeriod == 0 { + return nil, trace.BadParameter("empty grace period") + } + return &RotationSchedule{ + UpdateServers: clock.Now().UTC().Add(gracePeriod / 2).UTC(), + Standby: clock.Now().UTC().Add(gracePeriod).UTC(), + }, nil +} + +// RotationSchedule is a rotation schedule setting time switches +// for different phases +type RotationSchedule struct { + // UpdateServers specifies time to switch to update servers + UpdateServers time.Time `json:"update_servers,omitempty"` + // Standby specifies time to switch to standby phase + Standby time.Time `json:"standby,omitempty"` +} + +// CheckAndSetDefaults checks and sets default values on rotation schedule +func (s *RotationSchedule) CheckAndSetDefaults(clock clockwork.Clock) error { + if s.UpdateServers.IsZero() { + return trace.BadParameter("phase %q has no time switch scheduled", RotationPhaseUpdateServers) + } + if s.Standby.IsZero() { + return trace.BadParameter("phase %q has no time switch scheduled", RotationPhaseStandby) + } + if s.Standby.Before(s.UpdateServers) { + return trace.BadParameter("phase %q can not be scheduled before %q", RotationPhaseStandby, RotationPhaseUpdateServers) + } + if s.UpdateServers.Before(clock.Now()) { + return trace.BadParameter("phase %q can not be scheduled in the past", RotationPhaseUpdateServers) + } + if s.Standby.Before(clock.Now()) { + return trace.BadParameter("phase %q can not be scheduled in the past", RotationPhaseStandby) + } + return nil +} + // CertAuthoritySpecV2 is a host or user certificate authority that // can check and if it has private key stored as well, sign it too type CertAuthoritySpecV2 struct { @@ -540,6 +767,8 @@ type CertAuthoritySpecV2 struct { RoleMap RoleMap `json:"role_map,omitempty"` // TLS is a list of TLS key pairs TLSKeyPairs []TLSKeyPair `json:"tls_key_pairs,omitempty"` + // Rotation is a status of the certificate authority rotation + Rotation Rotation `json:"rotation,omitempty"` } // CertAuthoritySpecV2Schema is JSON schema for cert authority V2 @@ -579,10 +808,32 @@ const CertAuthoritySpecV2Schema = `{ } } }, + "rotation": %v, "role_map": %v } }` +const RotationSchema = `{ + "type": "object", + "additionalProperties": false, + "properties": { + "state": {"type": "string"}, + "phase": {"type": "string"}, + "mode": {"type": "string"}, + "current_id": {"type": "string"}, + "started": {"type": "string"}, + "grace_period": {"type": "string"}, + "last_rotated": {"type": "string"}, + "schedule": { + "type": "object", + "properties": { + "update_servers": {"type": "string"}, + "standby": {"type": "string"} + } + } + } +}` + // CertAuthorityV1 is a host or user certificate authority that // can check and if it has private key stored as well, sign it too type CertAuthorityV1 struct { @@ -677,7 +928,7 @@ type CertAuthorityMarshaler interface { // GetCertAuthoritySchema returns JSON Schema for cert authorities func GetCertAuthoritySchema() string { - return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, fmt.Sprintf(CertAuthoritySpecV2Schema, RoleMapSchema), DefaultDefinitions) + return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, fmt.Sprintf(CertAuthoritySpecV2Schema, RotationSchema, RoleMapSchema), DefaultDefinitions) } type TeleportCertAuthorityMarshaler struct{} diff --git a/lib/services/local/configuration_test.go b/lib/services/local/configuration_test.go index c142681a06dce..5ae2f050a34f6 100644 --- a/lib/services/local/configuration_test.go +++ b/lib/services/local/configuration_test.go @@ -22,7 +22,7 @@ import ( "os" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/boltbk" + "github.com/gravitational/teleport/lib/backend/dir" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -50,7 +50,7 @@ func (s *ClusterConfigurationSuite) SetUpTest(c *check.C) { s.tempDir, err = ioutil.TempDir("", "preference-test-") c.Assert(err, check.IsNil) - s.bk, err = boltbk.New(backend.Params{"path": s.tempDir}) + s.bk, err = dir.New(backend.Params{"path": s.tempDir}) c.Assert(err, check.IsNil) } diff --git a/lib/services/local/presence_test.go b/lib/services/local/presence_test.go index e02c0492946a4..5bb45fd083125 100644 --- a/lib/services/local/presence_test.go +++ b/lib/services/local/presence_test.go @@ -22,7 +22,7 @@ import ( "os" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/boltbk" + "github.com/gravitational/teleport/lib/backend/dir" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -52,7 +52,7 @@ func (s *PresenceSuite) SetUpTest(c *check.C) { s.tempDir, err = ioutil.TempDir("", "trusted-clusters-") c.Assert(err, check.IsNil) - s.bk, err = boltbk.New(backend.Params{"path": s.tempDir}) + s.bk, err = dir.New(backend.Params{"path": s.tempDir}) c.Assert(err, check.IsNil) } diff --git a/lib/services/local/trust.go b/lib/services/local/trust.go index 9e7ddfcebae3b..3482edff13967 100644 --- a/lib/services/local/trust.go +++ b/lib/services/local/trust.go @@ -66,6 +66,32 @@ func (s *CA) UpsertCertAuthority(ca services.CertAuthority) error { return nil } +// CompareAndSwapCertAuthority updates the cert authority value +// if existing value matches existing parameter, +// returns nil if succeeds, trace.CompareFailed otherwise +func (s *CA) CompareAndSwapCertAuthority(new, existing services.CertAuthority) error { + if err := new.Check(); err != nil { + return trace.Wrap(err) + } + newData, err := services.GetCertAuthorityMarshaler().MarshalCertAuthority(new) + if err != nil { + return trace.Wrap(err) + } + existingData, err := services.GetCertAuthorityMarshaler().MarshalCertAuthority(existing) + if err != nil { + return trace.Wrap(err) + } + ttl := backend.TTL(s.Clock(), new.Expiry()) + err = s.CompareAndSwapVal([]string{"authorities", string(new.GetType())}, new.GetName(), newData, existingData, ttl) + if err != nil { + if trace.IsCompareFailed(err) { + return trace.CompareFailed("cluster %v settings have been updated, try again", new.GetName()) + } + return trace.Wrap(err) + } + return nil +} + // DeleteCertAuthority deletes particular certificate authority func (s *CA) DeleteCertAuthority(id services.CertAuthID) error { if err := id.Check(); err != nil { @@ -172,26 +198,6 @@ func setSigningKeys(ca services.CertAuthority, loadSigningKeys bool) { ca.SetTLSKeyPairs(keyPairs) } -// DELETE IN: 2.6.0 -// GetAnyCertAuthority returns activated or deactivated certificate authority -// by given id whether it is activated or not. This method is used in migrations. -func (s *CA) GetAnyCertAuthority(id services.CertAuthID) (services.CertAuthority, error) { - if err := id.Check(); err != nil { - return nil, trace.Wrap(err) - } - data, err := s.GetVal([]string{"authorities", string(id.Type)}, id.DomainName) - if err != nil { - if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - data, err = s.GetVal([]string{"authorities", "deactivated", string(id.Type)}, id.DomainName) - if err != nil { - return nil, trace.Wrap(err) - } - } - return services.GetCertAuthorityMarshaler().UnmarshalCertAuthority(data) -} - // GetCertAuthorities returns a list of authorities of a given type // loadSigningKeys controls whether signing keys should be loaded or not func (s *CA) GetCertAuthorities(caType services.CertAuthType, loadSigningKeys bool) ([]services.CertAuthority, error) { diff --git a/lib/services/local/users.go b/lib/services/local/users.go index b16960bfc32b3..3d3798efdf783 100644 --- a/lib/services/local/users.go +++ b/lib/services/local/users.go @@ -178,7 +178,7 @@ func (s *IdentityService) DeleteUser(user string) error { err := s.DeleteBucket([]string{"web", "users"}, user) if err != nil { if trace.IsNotFound(err) { - return trace.NotFound(fmt.Sprintf("user '%v' is not found", user)) + return trace.NotFound("user %q is not found", user) } } return trace.Wrap(err) diff --git a/lib/services/parser.go b/lib/services/parser.go index c1673a5dc976a..49034d247e12a 100644 --- a/lib/services/parser.go +++ b/lib/services/parser.go @@ -28,6 +28,7 @@ import ( "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "github.com/vulcand/predicate" + "github.com/vulcand/predicate/builder" ) // RuleContext specifies context passed to the @@ -38,18 +39,49 @@ type RuleContext interface { GetIdentifier(fields []string) (interface{}, error) // String returns human friendly representation of a context String() string + // GetResource returns resource if specified in the context, + // if unpecified, returns error + GetResource() (Resource, error) } +var ( + // ResourceNameExpr is identifer that specifies resource name + ResourceNameExpr = builder.Identifier("resource.metadata.name") + // CertAuthorityTypeExpr is a function call that returns + // cert authority type + CertAuthorityTypeExpr = builder.Identifier(`system.catype()`) +) + // NewWhereParser returns standard parser for `where` section in access rules func NewWhereParser(ctx RuleContext) (predicate.Parser, error) { + // get is a function that returns identifier field + // used in cases when there are some reserved words by go return predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: predicate.And, OR: predicate.Or, + NOT: predicate.Not, }, Functions: map[string]interface{}{ "equals": predicate.Equals, "contains": predicate.Contains, + // system.catype is a function that returns cert authority type + // returns empty values for unrecognized values to + // pass static rule checks + "system.catype": func() (interface{}, error) { + resource, err := ctx.GetResource() + if err != nil { + if trace.IsNotFound(err) { + return "", nil + } + return nil, trace.Wrap(err) + } + ca, ok := resource.(CertAuthority) + if !ok { + return "", nil + } + return string(ca.GetType()), nil + }, }, GetIdentifier: ctx.GetIdentifier, GetProperty: predicate.GetStringMapValue, @@ -124,6 +156,15 @@ const ( ResourceIdentifier = "resource" ) +// GetResource returns resource specified in the context, +// returns error if not specified +func (ctx *Context) GetResource() (Resource, error) { + if ctx.Resource == nil { + return nil, trace.NotFound("resource is not set in the context") + } + return ctx.Resource, nil +} + // GetIdentifier returns identifier defined in a context func (ctx *Context) GetIdentifier(fields []string) (interface{}, error) { switch fields[0] { diff --git a/lib/services/resource.go b/lib/services/resource.go index be87f3979f0b8..357ec7bc802a2 100644 --- a/lib/services/resource.go +++ b/lib/services/resource.go @@ -153,6 +153,12 @@ const ( // to proxy KindRemoteCluster = "remote_cluster" + // KindIdenity is local on disk identity resource + KindIdentity = "identity" + + // KindState is local on disk process state + KindState = "state" + // V3 is the third version of resources. V3 = "v3" @@ -182,6 +188,10 @@ const ( // VerbDelete is used to remove an object. VerbDelete = "delete" + + // VerbRotate is used to rotate certificate authorities + // used only internally + VerbRotate = "rotate" ) func collectOptions(opts []MarshalOption) (*MarshalConfig, error) { diff --git a/lib/services/role.go b/lib/services/role.go index 9883636616fae..60788c245a85b 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -1637,7 +1637,7 @@ func (set RoleSet) CheckAccessToRule(ctx RuleContext, namespace string, resource } log.Infof("[RBAC] %s access to %s [namespace %s] denied for %v: no allow rule matched", verb, resource, namespace, set) - return trace.AccessDenied("access denied to perform action '%s' on %s", verb, resource) + return trace.AccessDenied("access denied to perform action %q on %q", verb, resource) } // ProcessNamespace sets default namespace in case if namespace is empty diff --git a/lib/services/server.go b/lib/services/server.go index c83f376ee8735..4aef105113d48 100644 --- a/lib/services/server.go +++ b/lib/services/server.go @@ -249,6 +249,8 @@ type ServerSpecV2 struct { Hostname string `json:"hostname"` // CmdLabels is server dynamic labels CmdLabels map[string]CommandLabelV2 `json:"cmd_labels,omitempty"` + // Rotation specifies server rotatoin status + Rotation Rotation `json:"rotation,omitempty"` } // ServerSpecV2Schema is JSON schema for server @@ -279,7 +281,8 @@ const ServerSpecV2Schema = `{ } } } - } + }, + "rotation": %v } }` @@ -423,7 +426,7 @@ func (c *CommandLabels) SetEnv(v string) error { // GetServerSchema returns role schema with optionally injected // schema for extensions func GetServerSchema() string { - return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, ServerSpecV2Schema, DefaultDefinitions) + return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, fmt.Sprintf(ServerSpecV2Schema, RotationSchema), DefaultDefinitions) } // UnmarshalServerResource unmarshals role from JSON or YAML, diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 0404a5e95221a..095bde88e23b9 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -174,11 +174,11 @@ func (s *ServicesTestSuite) UsersCRUD(c *C) { userSlicesEqual(c, u, []services.User{newUser("user2", nil)}) err = s.WebS.DeleteUser("user1") - c.Assert(trace.IsNotFound(err), Equals, true, Commentf("unexpected %T %#v", err, err)) + fixtures.ExpectNotFound(c, err) // bad username err = s.WebS.UpsertUser(newUser("", nil)) - c.Assert(trace.IsBadParameter(err), Equals, true, Commentf("expected bad parameter error, got %T", err)) + fixtures.ExpectBadParameter(c, err) } func (s *ServicesTestSuite) LoginAttempts(c *C) { @@ -226,6 +226,26 @@ func (s *ServicesTestSuite) CertAuthCRUD(c *C) { err = s.CAS.DeleteCertAuthority(*ca.ID()) c.Assert(err, IsNil) + + // test compare and swap + ca = NewTestCA(services.UserCA, "example.com") + c.Assert(s.CAS.CreateCertAuthority(ca), IsNil) + + clock := clockwork.NewFakeClock() + newCA := *ca + rotation := services.Rotation{ + State: services.RotationStateInProgress, + CurrentID: "id1", + GracePeriod: services.NewDuration(time.Hour), + Started: clock.Now(), + } + newCA.SetRotation(rotation) + + err = s.CAS.CompareAndSwapCertAuthority(&newCA, ca) + + out, err = s.CAS.GetCertAuthority(ca.GetID(), true) + c.Assert(err, IsNil) + fixtures.DeepCompare(c, &newCA, out) } func newServer(kind, name, addr, namespace string) *services.ServerV2 { diff --git a/lib/services/trust.go b/lib/services/trust.go index 7145e995c6160..8d379cc93a0b8 100644 --- a/lib/services/trust.go +++ b/lib/services/trust.go @@ -39,6 +39,11 @@ type Trust interface { // UpsertCertAuthority updates or inserts a new certificate authority UpsertCertAuthority(ca CertAuthority) error + // CompareAndSwapCertAuthority updates the cert authority value + // if existing value matches existing parameter, + // returns nil if succeeds, trace.CompareFailed otherwise + CompareAndSwapCertAuthority(new, existing CertAuthority) error + // DeleteCertAuthority deletes particular certificate authority DeleteCertAuthority(id CertAuthID) error @@ -49,12 +54,6 @@ type Trust interface { // controls if signing keys are loaded GetCertAuthority(id CertAuthID, loadSigningKeys bool) (CertAuthority, error) - // DELETE IN: 2.6.0 - // GetAnyCertAuthority returns activated or deactivated certificate authority - // by given id whether it is activated or not. Signing keys are never loaded. - // This method is used in migrations. - GetAnyCertAuthority(id CertAuthID) (ca CertAuthority, error error) - // GetCertAuthorities returns a list of authorities of a given type // loadSigningKeys controls whether signing keys should be loaded or not GetCertAuthorities(caType CertAuthType, loadSigningKeys bool) ([]CertAuthority, error) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index fd4fa6654bbb8..7839418748a2b 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -66,6 +66,7 @@ type Server struct { srv *sshutils.Server hostSigner ssh.Signer shell string + getRotation RotationGetter authService auth.AccessPoint reg *srv.SessionRegistry sessionServer rsession.Service @@ -219,6 +220,17 @@ func (s *Server) Wait() { s.srv.Wait(context.TODO()) } +// RotationGetter returns rotation state +type RotationGetter func(role teleport.Role) (*services.Rotation, error) + +// SetRotationGetter sets rotation state getter +func SetRotationGetter(getter RotationGetter) ServerOption { + return func(s *Server) error { + s.getRotation = getter + return nil + } +} + // SetShell sets default shell that will be executed for interactive // sessions func SetShell(shell string) ServerOption { @@ -466,7 +478,14 @@ func (s *Server) AdvertiseAddr() string { return net.JoinHostPort(s.getAdvertiseIP().String(), port) } -func (s *Server) getInfo() services.Server { +func (s *Server) getRole() teleport.Role { + if s.proxyMode { + return teleport.RoleProxy + } + return teleport.RoleNode +} + +func (s *Server) getInfo() *services.ServerV2 { return &services.ServerV2{ Kind: services.KindNode, Version: services.V2, @@ -486,6 +505,16 @@ func (s *Server) getInfo() services.Server { // registerServer attempts to register server in the cluster func (s *Server) registerServer() error { server := s.getInfo() + if s.getRotation != nil { + rotation, err := s.getRotation(s.getRole()) + if err != nil { + if !trace.IsNotFound(err) { + log.Warningf("Failed to get rotation state: %v", err) + } + } else { + server.Spec.Rotation = *rotation + } + } server.SetTTL(s.clock, defaults.ServerHeartbeatTTL) if !s.proxyMode { return trace.Wrap(s.authService.UpsertNode(server)) diff --git a/lib/utils/copy.go b/lib/utils/copy.go index 7901e4dd10880..a284220fbce7f 100644 --- a/lib/utils/copy.go +++ b/lib/utils/copy.go @@ -16,6 +16,28 @@ limitations under the License. package utils +// CopyByteSlice returns a copy of the byte slice +func CopyByteSlice(in []byte) []byte { + if in == nil { + return nil + } + out := make([]byte, len(in)) + copy(out, in) + return out +} + +// CopyByteSlices returns a copy of the byte slices +func CopyByteSlices(in [][]byte) [][]byte { + if in == nil { + return nil + } + out := make([][]byte, len(in)) + for i := range in { + out[i] = CopyByteSlice(in[i]) + } + return out +} + // CopyStrings makes a deep copy of the passed in string slice and returns // the copy. func CopyStrings(in []string) []string { diff --git a/tool/tctl/common/auth_command.go b/tool/tctl/common/auth_command.go index 66507ae5e1d57..d5fc142744a36 100644 --- a/tool/tctl/common/auth_command.go +++ b/tool/tctl/common/auth_command.go @@ -17,9 +17,9 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" "github.com/gravitational/kingpin" + "github.com/gravitational/trace" ) // AuthCommand implements `tctl auth` group of commands @@ -38,9 +38,15 @@ type AuthCommand struct { compatVersion string compatibility string + rotateGracePeriod time.Duration + rotateType string + rotateManualMode bool + rotateTargetPhase string + authGenerate *kingpin.CmdClause authExport *kingpin.CmdClause authSign *kingpin.CmdClause + authRotate *kingpin.CmdClause } // Initialize allows TokenCommand to plug itself into the CLI parser @@ -66,6 +72,12 @@ func (a *AuthCommand) Initialize(app *kingpin.Application, config *service.Confi a.authSign.Flag("format", "identity format: 'file' (default) or 'dir'").Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&a.outputFormat)) a.authSign.Flag("ttl", "TTL (time to live) for the generated certificate").Default(fmt.Sprintf("%v", defaults.CertDuration)).DurationVar(&a.genTTL) a.authSign.Flag("compat", "OpenSSH compatibility flag").StringVar(&a.compatibility) + + a.authRotate = auth.Command("rotate", "Rotate certificate authorities in the cluster") + a.authRotate.Flag("grace-period", "Grace period keeps previous certificate authorities signatures valid, if set to 0 will force users to relogin and nodes to re-register.").Default(fmt.Sprintf("%v", defaults.RotationGracePeriod)).DurationVar(&a.rotateGracePeriod) + a.authRotate.Flag("manual", "Activate manual rotation , set rotation phases manually").BoolVar(&a.rotateManualMode) + a.authRotate.Flag("type", "Certificate authority to rotate, rotates both host and user CA by default").StringVar(&a.rotateType) + a.authRotate.Flag("phase", fmt.Sprintf("Target rotation phase to set, used in manual rotation, one of: %v", strings.Join(services.RotatePhases, ", "))).StringVar(&a.rotateTargetPhase) } // TryRun takes the CLI command as an argument (like "auth gen") and executes it @@ -78,7 +90,8 @@ func (a *AuthCommand) TryRun(cmd string, client auth.ClientI) (match bool, err e err = a.ExportAuthorities(client) case a.authSign.FullCommand(): err = a.GenerateAndSignKeys(client) - + case a.authRotate.FullCommand(): + err = a.RotateCertAuthority(client) default: return false, nil } @@ -236,6 +249,30 @@ func (a *AuthCommand) GenerateAndSignKeys(clusterApi auth.ClientI) error { } } +// RotateCertAuthority starts or restarts certificate authority rotation process +func (a *AuthCommand) RotateCertAuthority(client auth.ClientI) error { + req := auth.RotateRequest{ + Type: services.CertAuthType(a.rotateType), + GracePeriod: &a.rotateGracePeriod, + TargetPhase: a.rotateTargetPhase, + } + if a.rotateManualMode { + req.Mode = services.RotationModeManual + } else { + req.Mode = services.RotationModeAuto + } + if err := client.RotateCertAuthority(req); err != nil { + return err + } + if a.rotateTargetPhase != "" { + fmt.Printf("Updated rotation phase to %q. To check status use 'tctl status'\n", a.rotateTargetPhase) + } else { + fmt.Printf("Initiated certificate authority rotation. To check status use 'tctl status'\n") + } + + return nil +} + func (a *AuthCommand) generateHostKeys(clusterApi auth.ClientI) error { // only format=openssh is supported if a.outputFormat != client.IdentityFormatOpenSSH { diff --git a/tool/tctl/common/status_command.go b/tool/tctl/common/status_command.go new file mode 100644 index 0000000000000..4f1d25937bde9 --- /dev/null +++ b/tool/tctl/common/status_command.go @@ -0,0 +1,118 @@ +/* +Copyright 2018 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" + "strings" + + "github.com/gravitational/kingpin" + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/trace" +) + +// StatusCommand implements `tctl token` group of commands +type StatusCommand struct { + config *service.Config + + // CLI clauses (subcommands) + status *kingpin.CmdClause +} + +// Initialize allows StatusCommand to plug itself into the CLI parser +func (c *StatusCommand) Initialize(app *kingpin.Application, config *service.Config) { + c.config = config + c.status = app.Command("status", "Report cluster status") +} + +// TryRun takes the CLI command as an argument (like "nodes ls") and executes it. +func (c *StatusCommand) TryRun(cmd string, client auth.ClientI) (match bool, err error) { + switch cmd { + case c.status.FullCommand(): + err = c.Status(client) + default: + return false, nil + } + return true, trace.Wrap(err) +} + +// onStatus is called to execute "status" +func (c *StatusCommand) Status(client auth.ClientI) error { + clusterNameResource, err := client.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + clusterName := clusterNameResource.GetClusterName() + + hostCAs, err := client.GetCertAuthorities(services.HostCA, false) + if err != nil { + return trace.Wrap(err) + } + + userCAs, err := client.GetCertAuthorities(services.UserCA, false) + if err != nil { + return trace.Wrap(err) + } + + authorities := append(userCAs, hostCAs...) + view := func() string { + table := asciitable.MakeHeadlessTable(2) + table.AddRow([]string{"Cluster", clusterName}) + for _, ca := range authorities { + if ca.GetClusterName() != clusterName { + continue + } + info := fmt.Sprintf("%v CA ", strings.Title(string(ca.GetType()))) + rotation := ca.GetRotation() + if c.config.Debug { + table.AddRow([]string{info, + fmt.Sprintf("%v, update_servers: %v, complete: %v", + rotation.String(), + rotation.Schedule.UpdateServers.Format(teleport.HumanDateFormatSeconds), + rotation.Schedule.Standby.Format(teleport.HumanDateFormatSeconds), + )}) + } else { + table.AddRow([]string{info, rotation.String()}) + } + + } + return table.AsBuffer().String() + } + fmt.Printf(view()) + + // in debug mode, output mode of remote certificate authorities + if c.config.Debug { + view := func() string { + table := asciitable.MakeHeadlessTable(2) + for _, ca := range authorities { + if ca.GetClusterName() == clusterName { + continue + } + info := fmt.Sprintf("Remote %v CA %q", strings.Title(string(ca.GetType())), ca.GetClusterName()) + rotation := ca.GetRotation() + table.AddRow([]string{info, rotation.String()}) + } + return "Remote clusters\n\n" + table.AsBuffer().String() + } + fmt.Printf(view()) + } + return nil +} diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 16acb75d8b6f5..8e1de1d1b2a8a 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -19,6 +19,7 @@ package common import ( "fmt" "os" + "path/filepath" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" @@ -86,7 +87,7 @@ func Run(commands []CLICommand) { "Base64 encoded configuration string").Hidden().Envar(defaults.ConfigEnvar).StringVar(&ccf.ConfigString) // "version" command is always available: - ver := app.Command("version", "Print the version.") + ver := app.Command("version", "Print cluster version") app.HelpFlag.Short('h') // parse CLI commands+flags: @@ -133,7 +134,7 @@ func connectToAuthService(cfg *service.Config) (client auth.ClientI, err error) } } // read the host SSH keys and use them to open an SSH connection to the auth service - i, err := auth.ReadIdentity(cfg.DataDir, auth.IdentityID{Role: teleport.RoleAdmin, HostUUID: cfg.HostUUID}) + i, err := auth.ReadLocalIdentity(filepath.Join(cfg.DataDir, teleport.ComponentProcess), auth.IdentityID{Role: teleport.RoleAdmin, HostUUID: cfg.HostUUID}) if err != nil { // the "admin" identity is not present? this means the tctl is running NOT on the auth server. if trace.IsNotFound(err) { @@ -182,8 +183,9 @@ func applyConfig(ccf *GlobalCLIFlags, cfg *service.Config) error { } // --debug flag if ccf.Debug { + cfg.Debug = ccf.Debug utils.InitLogger(utils.LoggingForCLI, logrus.DebugLevel) - logrus.Debugf("DEBUG loggign enabled") + logrus.Debugf("DEBUG logging enabled") } // read a host UUID for this node diff --git a/tool/tctl/main.go b/tool/tctl/main.go index 45fcc59c4dc8e..5b5974bf1cef3 100644 --- a/tool/tctl/main.go +++ b/tool/tctl/main.go @@ -27,6 +27,7 @@ func main() { &common.TokenCommand{}, &common.AuthCommand{}, &common.ResourceCommand{}, + &common.StatusCommand{}, } common.Run(commands) } diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 4593b96f8502a..0c168dfa5e91d 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -165,16 +165,7 @@ func Run(options Options) (executedCommand string, conf *service.Config) { // OnStart is the handler for "start" CLI command func OnStart(config *service.Config) error { - srv, err := service.NewTeleport(config) - if err != nil { - return trace.Wrap(err, "Initialization failed") - } - - if err := srv.Start(); err != nil { - return trace.Wrap(err, "Startup Failed") - } - - return srv.WaitForSignals(context.TODO()) + return service.Run(context.TODO(), *config, nil) } // onStatus is the handler for "status" CLI command diff --git a/vendor/github.com/cenkalti/backoff/.gitignore b/vendor/github.com/cenkalti/backoff/.gitignore deleted file mode 100644 index 00268614f0456..0000000000000 --- a/vendor/github.com/cenkalti/backoff/.gitignore +++ /dev/null @@ -1,22 +0,0 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe diff --git a/vendor/github.com/cenkalti/backoff/.travis.yml b/vendor/github.com/cenkalti/backoff/.travis.yml deleted file mode 100644 index 1040404bfbc06..0000000000000 --- a/vendor/github.com/cenkalti/backoff/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go -go: - - 1.3.3 - - tip -before_install: - - go get github.com/mattn/goveralls - - go get golang.org/x/tools/cmd/cover -script: - - $HOME/gopath/bin/goveralls -service=travis-ci diff --git a/vendor/github.com/cenkalti/backoff/LICENSE b/vendor/github.com/cenkalti/backoff/LICENSE deleted file mode 100644 index 89b8179965581..0000000000000 --- a/vendor/github.com/cenkalti/backoff/LICENSE +++ /dev/null @@ -1,20 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2014 Cenk Altı - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/cenkalti/backoff/README.md b/vendor/github.com/cenkalti/backoff/README.md deleted file mode 100644 index 13b347fb95179..0000000000000 --- a/vendor/github.com/cenkalti/backoff/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls] - -This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client]. - -[Exponential backoff][exponential backoff wiki] -is an algorithm that uses feedback to multiplicatively decrease the rate of some process, -in order to gradually find an acceptable rate. -The retries exponentially increase and stop increasing when a certain threshold is met. - -## Usage - -See https://godoc.org/github.com/cenkalti/backoff#pkg-examples - -## Contributing - -* I would like to keep this library as small as possible. -* Please don't send a PR without opening an issue and discussing it first. -* If proposed change is not a common use case, I will probably not accept it. - -[godoc]: https://godoc.org/github.com/cenkalti/backoff -[godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png -[travis]: https://travis-ci.org/cenkalti/backoff -[travis image]: https://travis-ci.org/cenkalti/backoff.png?branch=master -[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master -[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master - -[google-http-java-client]: https://github.com/google/google-http-java-client -[exponential backoff wiki]: http://en.wikipedia.org/wiki/Exponential_backoff - -[advanced example]: https://godoc.org/github.com/cenkalti/backoff#example_ diff --git a/vendor/github.com/cenkalti/backoff/backoff.go b/vendor/github.com/cenkalti/backoff/backoff.go deleted file mode 100644 index 2102c5f2de96b..0000000000000 --- a/vendor/github.com/cenkalti/backoff/backoff.go +++ /dev/null @@ -1,66 +0,0 @@ -// Package backoff implements backoff algorithms for retrying operations. -// -// Use Retry function for retrying operations that may fail. -// If Retry does not meet your needs, -// copy/paste the function into your project and modify as you wish. -// -// There is also Ticker type similar to time.Ticker. -// You can use it if you need to work with channels. -// -// See Examples section below for usage examples. -package backoff - -import "time" - -// BackOff is a backoff policy for retrying an operation. -type BackOff interface { - // NextBackOff returns the duration to wait before retrying the operation, - // or backoff.Stop to indicate that no more retries should be made. - // - // Example usage: - // - // duration := backoff.NextBackOff(); - // if (duration == backoff.Stop) { - // // Do not retry operation. - // } else { - // // Sleep for duration and retry operation. - // } - // - NextBackOff() time.Duration - - // Reset to initial state. - Reset() -} - -// Stop indicates that no more retries should be made for use in NextBackOff(). -const Stop time.Duration = -1 - -// ZeroBackOff is a fixed backoff policy whose backoff time is always zero, -// meaning that the operation is retried immediately without waiting, indefinitely. -type ZeroBackOff struct{} - -func (b *ZeroBackOff) Reset() {} - -func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 } - -// StopBackOff is a fixed backoff policy that always returns backoff.Stop for -// NextBackOff(), meaning that the operation should never be retried. -type StopBackOff struct{} - -func (b *StopBackOff) Reset() {} - -func (b *StopBackOff) NextBackOff() time.Duration { return Stop } - -// ConstantBackOff is a backoff policy that always returns the same backoff delay. -// This is in contrast to an exponential backoff policy, -// which returns a delay that grows longer as you call NextBackOff() over and over again. -type ConstantBackOff struct { - Interval time.Duration -} - -func (b *ConstantBackOff) Reset() {} -func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval } - -func NewConstantBackOff(d time.Duration) *ConstantBackOff { - return &ConstantBackOff{Interval: d} -} diff --git a/vendor/github.com/cenkalti/backoff/backoff_test.go b/vendor/github.com/cenkalti/backoff/backoff_test.go deleted file mode 100644 index 91f27c4f19010..0000000000000 --- a/vendor/github.com/cenkalti/backoff/backoff_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package backoff - -import ( - "testing" - "time" -) - -func TestNextBackOffMillis(t *testing.T) { - subtestNextBackOff(t, 0, new(ZeroBackOff)) - subtestNextBackOff(t, Stop, new(StopBackOff)) -} - -func subtestNextBackOff(t *testing.T, expectedValue time.Duration, backOffPolicy BackOff) { - for i := 0; i < 10; i++ { - next := backOffPolicy.NextBackOff() - if next != expectedValue { - t.Errorf("got: %d expected: %d", next, expectedValue) - } - } -} - -func TestConstantBackOff(t *testing.T) { - backoff := NewConstantBackOff(time.Second) - if backoff.NextBackOff() != time.Second { - t.Error("invalid interval") - } -} diff --git a/vendor/github.com/cenkalti/backoff/context.go b/vendor/github.com/cenkalti/backoff/context.go deleted file mode 100644 index 5d157092544fd..0000000000000 --- a/vendor/github.com/cenkalti/backoff/context.go +++ /dev/null @@ -1,60 +0,0 @@ -package backoff - -import ( - "time" - - "golang.org/x/net/context" -) - -// BackOffContext is a backoff policy that stops retrying after the context -// is canceled. -type BackOffContext interface { - BackOff - Context() context.Context -} - -type backOffContext struct { - BackOff - ctx context.Context -} - -// WithContext returns a BackOffContext with context ctx -// -// ctx must not be nil -func WithContext(b BackOff, ctx context.Context) BackOffContext { - if ctx == nil { - panic("nil context") - } - - if b, ok := b.(*backOffContext); ok { - return &backOffContext{ - BackOff: b.BackOff, - ctx: ctx, - } - } - - return &backOffContext{ - BackOff: b, - ctx: ctx, - } -} - -func ensureContext(b BackOff) BackOffContext { - if cb, ok := b.(BackOffContext); ok { - return cb - } - return WithContext(b, context.Background()) -} - -func (b *backOffContext) Context() context.Context { - return b.ctx -} - -func (b *backOffContext) NextBackOff() time.Duration { - select { - case <-b.Context().Done(): - return Stop - default: - return b.BackOff.NextBackOff() - } -} diff --git a/vendor/github.com/cenkalti/backoff/context_test.go b/vendor/github.com/cenkalti/backoff/context_test.go deleted file mode 100644 index 993fa6149db77..0000000000000 --- a/vendor/github.com/cenkalti/backoff/context_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package backoff - -import ( - "testing" - "time" - - "golang.org/x/net/context" -) - -func TestContext(t *testing.T) { - b := NewConstantBackOff(time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cb := WithContext(b, ctx) - - if cb.Context() != ctx { - t.Error("invalid context") - } - - cancel() - - if cb.NextBackOff() != Stop { - t.Error("invalid next back off") - } -} diff --git a/vendor/github.com/cenkalti/backoff/example_test.go b/vendor/github.com/cenkalti/backoff/example_test.go deleted file mode 100644 index d97a8db8e34c3..0000000000000 --- a/vendor/github.com/cenkalti/backoff/example_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package backoff - -import ( - "log" - - "golang.org/x/net/context" -) - -func ExampleRetry() { - // An operation that may fail. - operation := func() error { - return nil // or an error - } - - err := Retry(operation, NewExponentialBackOff()) - if err != nil { - // Handle error. - return - } - - // Operation is successful. -} - -func ExampleRetryContext() { - // A context - ctx := context.Background() - - // An operation that may fail. - operation := func() error { - return nil // or an error - } - - b := WithContext(NewExponentialBackOff(), ctx) - - err := Retry(operation, b) - if err != nil { - // Handle error. - return - } - - // Operation is successful. -} - -func ExampleTicker() { - // An operation that may fail. - operation := func() error { - return nil // or an error - } - - ticker := NewTicker(NewExponentialBackOff()) - - var err error - - // Ticks will continue to arrive when the previous operation is still running, - // so operations that take a while to fail could run in quick succession. - for _ = range ticker.C { - if err = operation(); err != nil { - log.Println(err, "will retry...") - continue - } - - ticker.Stop() - break - } - - if err != nil { - // Operation has failed. - return - } - - // Operation is successful. - return -} diff --git a/vendor/github.com/cenkalti/backoff/exponential.go b/vendor/github.com/cenkalti/backoff/exponential.go deleted file mode 100644 index 9a6addf075016..0000000000000 --- a/vendor/github.com/cenkalti/backoff/exponential.go +++ /dev/null @@ -1,156 +0,0 @@ -package backoff - -import ( - "math/rand" - "time" -) - -/* -ExponentialBackOff is a backoff implementation that increases the backoff -period for each retry attempt using a randomization function that grows exponentially. - -NextBackOff() is calculated using the following formula: - - randomized interval = - RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor]) - -In other words NextBackOff() will range between the randomization factor -percentage below and above the retry interval. - -For example, given the following parameters: - - RetryInterval = 2 - RandomizationFactor = 0.5 - Multiplier = 2 - -the actual backoff period used in the next retry attempt will range between 1 and 3 seconds, -multiplied by the exponential, that is, between 2 and 6 seconds. - -Note: MaxInterval caps the RetryInterval and not the randomized interval. - -If the time elapsed since an ExponentialBackOff instance is created goes past the -MaxElapsedTime, then the method NextBackOff() starts returning backoff.Stop. - -The elapsed time can be reset by calling Reset(). - -Example: Given the following default arguments, for 10 tries the sequence will be, -and assuming we go over the MaxElapsedTime on the 10th try: - - Request # RetryInterval (seconds) Randomized Interval (seconds) - - 1 0.5 [0.25, 0.75] - 2 0.75 [0.375, 1.125] - 3 1.125 [0.562, 1.687] - 4 1.687 [0.8435, 2.53] - 5 2.53 [1.265, 3.795] - 6 3.795 [1.897, 5.692] - 7 5.692 [2.846, 8.538] - 8 8.538 [4.269, 12.807] - 9 12.807 [6.403, 19.210] - 10 19.210 backoff.Stop - -Note: Implementation is not thread-safe. -*/ -type ExponentialBackOff struct { - InitialInterval time.Duration - RandomizationFactor float64 - Multiplier float64 - MaxInterval time.Duration - // After MaxElapsedTime the ExponentialBackOff stops. - // It never stops if MaxElapsedTime == 0. - MaxElapsedTime time.Duration - Clock Clock - - currentInterval time.Duration - startTime time.Time - random *rand.Rand -} - -// Clock is an interface that returns current time for BackOff. -type Clock interface { - Now() time.Time -} - -// Default values for ExponentialBackOff. -const ( - DefaultInitialInterval = 500 * time.Millisecond - DefaultRandomizationFactor = 0.5 - DefaultMultiplier = 1.5 - DefaultMaxInterval = 60 * time.Second - DefaultMaxElapsedTime = 15 * time.Minute -) - -// NewExponentialBackOff creates an instance of ExponentialBackOff using default values. -func NewExponentialBackOff() *ExponentialBackOff { - b := &ExponentialBackOff{ - InitialInterval: DefaultInitialInterval, - RandomizationFactor: DefaultRandomizationFactor, - Multiplier: DefaultMultiplier, - MaxInterval: DefaultMaxInterval, - MaxElapsedTime: DefaultMaxElapsedTime, - Clock: SystemClock, - random: rand.New(rand.NewSource(time.Now().UnixNano())), - } - b.Reset() - return b -} - -type systemClock struct{} - -func (t systemClock) Now() time.Time { - return time.Now() -} - -// SystemClock implements Clock interface that uses time.Now(). -var SystemClock = systemClock{} - -// Reset the interval back to the initial retry interval and restarts the timer. -func (b *ExponentialBackOff) Reset() { - b.currentInterval = b.InitialInterval - b.startTime = b.Clock.Now() -} - -// NextBackOff calculates the next backoff interval using the formula: -// Randomized interval = RetryInterval +/- (RandomizationFactor * RetryInterval) -func (b *ExponentialBackOff) NextBackOff() time.Duration { - // Make sure we have not gone over the maximum elapsed time. - if b.MaxElapsedTime != 0 && b.GetElapsedTime() > b.MaxElapsedTime { - return Stop - } - defer b.incrementCurrentInterval() - if b.random == nil { - b.random = rand.New(rand.NewSource(time.Now().UnixNano())) - } - return getRandomValueFromInterval(b.RandomizationFactor, b.random.Float64(), b.currentInterval) -} - -// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance -// is created and is reset when Reset() is called. -// -// The elapsed time is computed using time.Now().UnixNano(). -func (b *ExponentialBackOff) GetElapsedTime() time.Duration { - return b.Clock.Now().Sub(b.startTime) -} - -// Increments the current interval by multiplying it with the multiplier. -func (b *ExponentialBackOff) incrementCurrentInterval() { - // Check for overflow, if overflow is detected set the current interval to the max interval. - if float64(b.currentInterval) >= float64(b.MaxInterval)/b.Multiplier { - b.currentInterval = b.MaxInterval - } else { - b.currentInterval = time.Duration(float64(b.currentInterval) * b.Multiplier) - } -} - -// Returns a random value from the following interval: -// [randomizationFactor * currentInterval, randomizationFactor * currentInterval]. -func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration { - var delta = randomizationFactor * float64(currentInterval) - var minInterval = float64(currentInterval) - delta - var maxInterval = float64(currentInterval) + delta - - // Get a random value from the range [minInterval, maxInterval]. - // The formula used below has a +1 because if the minInterval is 1 and the maxInterval is 3 then - // we want a 33% chance for selecting either 1, 2 or 3. - return time.Duration(minInterval + (random * (maxInterval - minInterval + 1))) -} diff --git a/vendor/github.com/cenkalti/backoff/exponential_test.go b/vendor/github.com/cenkalti/backoff/exponential_test.go deleted file mode 100644 index 11b95e4f61d43..0000000000000 --- a/vendor/github.com/cenkalti/backoff/exponential_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package backoff - -import ( - "math" - "testing" - "time" -) - -func TestBackOff(t *testing.T) { - var ( - testInitialInterval = 500 * time.Millisecond - testRandomizationFactor = 0.1 - testMultiplier = 2.0 - testMaxInterval = 5 * time.Second - testMaxElapsedTime = 15 * time.Minute - ) - - exp := NewExponentialBackOff() - exp.InitialInterval = testInitialInterval - exp.RandomizationFactor = testRandomizationFactor - exp.Multiplier = testMultiplier - exp.MaxInterval = testMaxInterval - exp.MaxElapsedTime = testMaxElapsedTime - exp.Reset() - - var expectedResults = []time.Duration{500, 1000, 2000, 4000, 5000, 5000, 5000, 5000, 5000, 5000} - for i, d := range expectedResults { - expectedResults[i] = d * time.Millisecond - } - - for _, expected := range expectedResults { - assertEquals(t, expected, exp.currentInterval) - // Assert that the next backoff falls in the expected range. - var minInterval = expected - time.Duration(testRandomizationFactor*float64(expected)) - var maxInterval = expected + time.Duration(testRandomizationFactor*float64(expected)) - var actualInterval = exp.NextBackOff() - if !(minInterval <= actualInterval && actualInterval <= maxInterval) { - t.Error("error") - } - } -} - -func TestGetRandomizedInterval(t *testing.T) { - // 33% chance of being 1. - assertEquals(t, 1, getRandomValueFromInterval(0.5, 0, 2)) - assertEquals(t, 1, getRandomValueFromInterval(0.5, 0.33, 2)) - // 33% chance of being 2. - assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.34, 2)) - assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.66, 2)) - // 33% chance of being 3. - assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.67, 2)) - assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.99, 2)) -} - -type TestClock struct { - i time.Duration - start time.Time -} - -func (c *TestClock) Now() time.Time { - t := c.start.Add(c.i) - c.i += time.Second - return t -} - -func TestGetElapsedTime(t *testing.T) { - var exp = NewExponentialBackOff() - exp.Clock = &TestClock{} - exp.Reset() - - var elapsedTime = exp.GetElapsedTime() - if elapsedTime != time.Second { - t.Errorf("elapsedTime=%d", elapsedTime) - } -} - -func TestMaxElapsedTime(t *testing.T) { - var exp = NewExponentialBackOff() - exp.Clock = &TestClock{start: time.Time{}.Add(10000 * time.Second)} - // Change the currentElapsedTime to be 0 ensuring that the elapsed time will be greater - // than the max elapsed time. - exp.startTime = time.Time{} - assertEquals(t, Stop, exp.NextBackOff()) -} - -func TestBackOffOverflow(t *testing.T) { - var ( - testInitialInterval time.Duration = math.MaxInt64 / 2 - testMaxInterval time.Duration = math.MaxInt64 - testMultiplier = 2.1 - ) - - exp := NewExponentialBackOff() - exp.InitialInterval = testInitialInterval - exp.Multiplier = testMultiplier - exp.MaxInterval = testMaxInterval - exp.Reset() - - exp.NextBackOff() - // Assert that when an overflow is possible the current varerval time.Duration is set to the max varerval time.Duration . - assertEquals(t, testMaxInterval, exp.currentInterval) -} - -func assertEquals(t *testing.T, expected, value time.Duration) { - if expected != value { - t.Errorf("got: %d, expected: %d", value, expected) - } -} diff --git a/vendor/github.com/cenkalti/backoff/retry.go b/vendor/github.com/cenkalti/backoff/retry.go deleted file mode 100644 index 5dbd825b5c8b5..0000000000000 --- a/vendor/github.com/cenkalti/backoff/retry.go +++ /dev/null @@ -1,78 +0,0 @@ -package backoff - -import "time" - -// An Operation is executing by Retry() or RetryNotify(). -// The operation will be retried using a backoff policy if it returns an error. -type Operation func() error - -// Notify is a notify-on-error function. It receives an operation error and -// backoff delay if the operation failed (with an error). -// -// NOTE that if the backoff policy stated to stop retrying, -// the notify function isn't called. -type Notify func(error, time.Duration) - -// Retry the operation o until it does not return error or BackOff stops. -// o is guaranteed to be run at least once. -// It is the caller's responsibility to reset b after Retry returns. -// -// If o returns a *PermanentError, the operation is not retried, and the -// wrapped error is returned. -// -// Retry sleeps the goroutine for the duration returned by BackOff after a -// failed operation returns. -func Retry(o Operation, b BackOff) error { return RetryNotify(o, b, nil) } - -// RetryNotify calls notify function with the error and wait duration -// for each failed attempt before sleep. -func RetryNotify(operation Operation, b BackOff, notify Notify) error { - var err error - var next time.Duration - - cb := ensureContext(b) - - b.Reset() - for { - if err = operation(); err == nil { - return nil - } - - if permanent, ok := err.(*PermanentError); ok { - return permanent.Err - } - - if next = b.NextBackOff(); next == Stop { - return err - } - - if notify != nil { - notify(err, next) - } - - t := time.NewTimer(next) - - select { - case <-cb.Context().Done(): - t.Stop() - return err - case <-t.C: - } - } -} - -// PermanentError signals that the operation should not be retried. -type PermanentError struct { - Err error -} - -func (e *PermanentError) Error() string { - return e.Err.Error() -} - -// Permanent wraps the given err in a *PermanentError. -func Permanent(err error) *PermanentError { - return &PermanentError{ - Err: err, - } -} diff --git a/vendor/github.com/cenkalti/backoff/retry_test.go b/vendor/github.com/cenkalti/backoff/retry_test.go deleted file mode 100644 index 5af288897df6a..0000000000000 --- a/vendor/github.com/cenkalti/backoff/retry_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package backoff - -import ( - "errors" - "fmt" - "log" - "testing" - "time" - - "golang.org/x/net/context" -) - -func TestRetry(t *testing.T) { - const successOn = 3 - var i = 0 - - // This function is successful on "successOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - if i == successOn { - log.Println("OK") - return nil - } - - log.Println("error") - return errors.New("error") - } - - err := Retry(f, NewExponentialBackOff()) - if err != nil { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != successOn { - t.Errorf("invalid number of retries: %d", i) - } -} - -func TestRetryContext(t *testing.T) { - var cancelOn = 3 - var i = 0 - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // This function cancels context on "cancelOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - // cancelling the context in the operation function is not a typical - // use-case, however it allows to get predictable test results. - if i == cancelOn { - cancel() - } - - log.Println("error") - return fmt.Errorf("error (%d)", i) - } - - err := Retry(f, WithContext(NewConstantBackOff(time.Millisecond), ctx)) - if err == nil { - t.Errorf("error is unexpectedly nil") - } - if err.Error() != "error (3)" { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != cancelOn { - t.Errorf("invalid number of retries: %d", i) - } -} - -func TestRetryPermenent(t *testing.T) { - const permanentOn = 3 - var i = 0 - - // This function fails permanently after permanentOn tries - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - if i == permanentOn { - log.Println("permanent error") - return Permanent(errors.New("permanent error")) - } - - log.Println("error") - return errors.New("error") - } - - err := Retry(f, NewExponentialBackOff()) - if err == nil || err.Error() != "permanent error" { - t.Errorf("unexpected error: %s", err) - } - if i != permanentOn { - t.Errorf("invalid number of retries: %d", i) - } -} diff --git a/vendor/github.com/cenkalti/backoff/ticker.go b/vendor/github.com/cenkalti/backoff/ticker.go deleted file mode 100644 index 49a99718d7425..0000000000000 --- a/vendor/github.com/cenkalti/backoff/ticker.go +++ /dev/null @@ -1,81 +0,0 @@ -package backoff - -import ( - "runtime" - "sync" - "time" -) - -// Ticker holds a channel that delivers `ticks' of a clock at times reported by a BackOff. -// -// Ticks will continue to arrive when the previous operation is still running, -// so operations that take a while to fail could run in quick succession. -type Ticker struct { - C <-chan time.Time - c chan time.Time - b BackOffContext - stop chan struct{} - stopOnce sync.Once -} - -// NewTicker returns a new Ticker containing a channel that will send the time at times -// specified by the BackOff argument. Ticker is guaranteed to tick at least once. -// The channel is closed when Stop method is called or BackOff stops. -func NewTicker(b BackOff) *Ticker { - c := make(chan time.Time) - t := &Ticker{ - C: c, - c: c, - b: ensureContext(b), - stop: make(chan struct{}), - } - go t.run() - runtime.SetFinalizer(t, (*Ticker).Stop) - return t -} - -// Stop turns off a ticker. After Stop, no more ticks will be sent. -func (t *Ticker) Stop() { - t.stopOnce.Do(func() { close(t.stop) }) -} - -func (t *Ticker) run() { - c := t.c - defer close(c) - t.b.Reset() - - // Ticker is guaranteed to tick at least once. - afterC := t.send(time.Now()) - - for { - if afterC == nil { - return - } - - select { - case tick := <-afterC: - afterC = t.send(tick) - case <-t.stop: - t.c = nil // Prevent future ticks from being sent to the channel. - return - case <-t.b.Context().Done(): - return - } - } -} - -func (t *Ticker) send(tick time.Time) <-chan time.Time { - select { - case t.c <- tick: - case <-t.stop: - return nil - } - - next := t.b.NextBackOff() - if next == Stop { - t.Stop() - return nil - } - - return time.After(next) -} diff --git a/vendor/github.com/cenkalti/backoff/ticker_test.go b/vendor/github.com/cenkalti/backoff/ticker_test.go deleted file mode 100644 index 085828cca5276..0000000000000 --- a/vendor/github.com/cenkalti/backoff/ticker_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package backoff - -import ( - "errors" - "fmt" - "log" - "testing" - "time" - - "golang.org/x/net/context" -) - -func TestTicker(t *testing.T) { - const successOn = 3 - var i = 0 - - // This function is successful on "successOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - if i == successOn { - log.Println("OK") - return nil - } - - log.Println("error") - return errors.New("error") - } - - b := NewExponentialBackOff() - ticker := NewTicker(b) - - var err error - for _ = range ticker.C { - if err = f(); err != nil { - t.Log(err) - continue - } - - break - } - if err != nil { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != successOn { - t.Errorf("invalid number of retries: %d", i) - } -} - -func TestTickerContext(t *testing.T) { - const cancelOn = 3 - var i = 0 - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // This function cancels context on "cancelOn" calls. - f := func() error { - i++ - log.Printf("function is called %d. time\n", i) - - // cancelling the context in the operation function is not a typical - // use-case, however it allows to get predictable test results. - if i == cancelOn { - cancel() - } - - log.Println("error") - return fmt.Errorf("error (%d)", i) - } - - b := WithContext(NewConstantBackOff(time.Millisecond), ctx) - ticker := NewTicker(b) - - var err error - for _ = range ticker.C { - if err = f(); err != nil { - t.Log(err) - continue - } - - break - } - if err == nil { - t.Errorf("error is unexpectedly nil") - } - if err.Error() != "error (3)" { - t.Errorf("unexpected error: %s", err.Error()) - } - if i != cancelOn { - t.Errorf("invalid number of retries: %d", i) - } -} diff --git a/vendor/github.com/cenkalti/backoff/tries.go b/vendor/github.com/cenkalti/backoff/tries.go deleted file mode 100644 index d2da7308b6aa9..0000000000000 --- a/vendor/github.com/cenkalti/backoff/tries.go +++ /dev/null @@ -1,35 +0,0 @@ -package backoff - -import "time" - -/* -WithMaxTries creates a wrapper around another BackOff, which will -return Stop if NextBackOff() has been called too many times since -the last time Reset() was called - -Note: Implementation is not thread-safe. -*/ -func WithMaxTries(b BackOff, max uint64) BackOff { - return &backOffTries{delegate: b, maxTries: max} -} - -type backOffTries struct { - delegate BackOff - maxTries uint64 - numTries uint64 -} - -func (b *backOffTries) NextBackOff() time.Duration { - if b.maxTries > 0 { - if b.maxTries <= b.numTries { - return Stop - } - b.numTries++ - } - return b.delegate.NextBackOff() -} - -func (b *backOffTries) Reset() { - b.numTries = 0 - b.delegate.Reset() -} diff --git a/vendor/github.com/cenkalti/backoff/tries_test.go b/vendor/github.com/cenkalti/backoff/tries_test.go deleted file mode 100644 index bd1021143ed45..0000000000000 --- a/vendor/github.com/cenkalti/backoff/tries_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package backoff - -import ( - "math/rand" - "testing" - "time" -) - -func TestMaxTriesHappy(t *testing.T) { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - max := 17 + r.Intn(13) - bo := WithMaxTries(&ZeroBackOff{}, uint64(max)) - - // Load up the tries count, but reset should clear the record - for ix := 0; ix < max/2; ix++ { - bo.NextBackOff() - } - bo.Reset() - - // Now fill the tries count all the way up - for ix := 0; ix < max; ix++ { - d := bo.NextBackOff() - if d == Stop { - t.Errorf("returned Stop on try %d", ix) - } - } - - // We have now called the BackOff max number of times, we expect - // the next result to be Stop, even if we try it multiple times - for ix := 0; ix < 7; ix++ { - d := bo.NextBackOff() - if d != Stop { - t.Error("invalid next back off") - } - } - - // Reset makes it all work again - bo.Reset() - d := bo.NextBackOff() - if d == Stop { - t.Error("returned Stop after reset") - } - -} - -func TestMaxTriesZero(t *testing.T) { - // It might not make sense, but its okay to send a zero - bo := WithMaxTries(&ZeroBackOff{}, uint64(0)) - for ix := 0; ix < 11; ix++ { - d := bo.NextBackOff() - if d == Stop { - t.Errorf("returned Stop on try %d", ix) - } - } -} diff --git a/vendor/github.com/vulcand/predicate/builder/builder.go b/vendor/github.com/vulcand/predicate/builder/builder.go new file mode 100644 index 0000000000000..148673966e373 --- /dev/null +++ b/vendor/github.com/vulcand/predicate/builder/builder.go @@ -0,0 +1,169 @@ +/* +Copyright 2014-2018 Vulcand Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +// package builder is used to construct predicate +// expressions using builder functions. +package builder + +import ( + "fmt" + "strings" +) + +// Expr is an expression builder, +// used to create expressions in rules definitions +type Expr interface { + // String serializes expression into format parsed by rules engine + // (golang based syntax) + String() string +} + +// IdentiferExpr is identifer expression +type IdentifierExpr string + +// String serializes identifer expression into format parsed by rules engine +func (i IdentifierExpr) String() string { + return string(i) +} + +// Identifer returns identifier expression +func Identifier(v string) IdentifierExpr { + return IdentifierExpr(v) +} + +// String returns string expression +func String(v string) StringExpr { + return StringExpr(v) +} + +// StringExpr is a string expression +type StringExpr string + +func (s StringExpr) String() string { + return fmt.Sprintf("%q", string(s)) +} + +// StringsExpr is a slice of strings +type StringsExpr []string + +func (s StringsExpr) String() string { + var out []string + for _, val := range s { + out = append(out, fmt.Sprintf("%q", val)) + } + return strings.Join(out, ",") +} + +// Equals returns equals expression +func Equals(left, right Expr) EqualsExpr { + return EqualsExpr{Left: left, Right: right} +} + +// EqualsExpr constructs function expression used in rules specifications +// that checks if one value is equal to another +// e.g. equals("a", "b") where Left is "a" and right is "b" +type EqualsExpr struct { + // Left is a left argument of Equals expression + Left Expr + // Value to check + Right Expr +} + +// String returns function call expression used in rules +func (i EqualsExpr) String() string { + return fmt.Sprintf("equals(%v, %v)", i.Left, i.Right) +} + +// Not returns ! expression +func Not(expr Expr) NotExpr { + return NotExpr{Expr: expr} +} + +// NotExpr constructs function expression used in rules specifications +// that negates the result of the boolean predicate +// e.g. ! equals"a", "b") where Left is "a" and right is "b" +type NotExpr struct { + // Expr is an expression to negate + Expr Expr +} + +// String returns function call expression used in rules +func (n NotExpr) String() string { + return fmt.Sprintf("!%v", n.Expr) +} + +// Contains returns contains function call expression +func Contains(a, b Expr) ContainsExpr { + return ContainsExpr{Left: a, Right: b} +} + +// ContainsExpr constructs function expression used in rules specifications +// that checks if one value contains the other, e.g. +// contains([]string{"a"}, "b") where left is []string{"a"} and right is "b" +type ContainsExpr struct { + // Left is a left argument of Contains expression + Left Expr + // Right is a right argument of Contains expression + Right Expr +} + +// String rturns function call expression used in rules +func (i ContainsExpr) String() string { + return fmt.Sprintf("contains(%v, %v)", i.Left, i.Right) +} + +// And returns && expression +func And(left, right Expr) AndExpr { + return AndExpr{ + Left: left, + Right: right, + } +} + +// AndExpr returns && expression +type AndExpr struct { + // Left is a left argument of && operator expression + Left Expr + // Right is a right argument of && operator expression + Right Expr +} + +// String returns expression text used in rules +func (a AndExpr) String() string { + return fmt.Sprintf("%v && %v", a.Left, a.Right) +} + +// Or returns || expression +func Or(left, right Expr) OrExpr { + return OrExpr{ + Left: left, + Right: right, + } +} + +// OrExpr returns || expression +type OrExpr struct { + // Left is a left argument of || operator expression + Left Expr + // Right is a right argument of || operator expression + Right Expr +} + +// String returns expression text used in rules +func (a OrExpr) String() string { + return fmt.Sprintf("%v || %v", a.Left, a.Right) +} diff --git a/vendor/github.com/vulcand/predicate/lib.go b/vendor/github.com/vulcand/predicate/lib.go index 013e16c5068b5..c313c1bae71a4 100644 --- a/vendor/github.com/vulcand/predicate/lib.go +++ b/vendor/github.com/vulcand/predicate/lib.go @@ -119,6 +119,14 @@ func Or(a, b BoolPredicate) BoolPredicate { } } +// Not is a boolean predicate that calls a boolean predicate +// and returns negated result +func Not(a BoolPredicate) BoolPredicate { + return func() bool { + return !a() + } +} + // GetFieldByTag returns a field from the object based on the tag func GetFieldByTag(ival interface{}, tagName string, fieldNames []string) (interface{}, error) { if len(fieldNames) == 0 { diff --git a/vendor/github.com/vulcand/predicate/parse.go b/vendor/github.com/vulcand/predicate/parse.go index 03518cf164045..b80e7361a4bd7 100644 --- a/vendor/github.com/vulcand/predicate/parse.go +++ b/vendor/github.com/vulcand/predicate/parse.go @@ -59,6 +59,16 @@ func (p *predicateParser) parseNode(node ast.Node) (interface{}, error) { return callFunction(fn, arguments) case *ast.ParenExpr: return p.parseNode(n.X) + case *ast.UnaryExpr: + joinFn, err := p.getJoinFunction(n.Op) + if err != nil { + return nil, err + } + node, err := p.parseNode(n.X) + if err != nil { + return nil, err + } + return callFunction(joinFn, []interface{}{node}) } return nil, trace.BadParameter("unsupported %T", node) } @@ -122,6 +132,20 @@ func (p *predicateParser) evaluateExpr(n ast.Expr) (interface{}, error) { return nil, trace.Wrap(err) } return val, nil + case *ast.CallExpr: + name, err := getIdentifier(l.Fun) + if err != nil { + return nil, err + } + fn, err := p.getFunction(name) + if err != nil { + return nil, err + } + arguments, err := p.evaluateArguments(l.Args) + if err != nil { + return nil, err + } + return callFunction(fn, arguments) default: return nil, trace.BadParameter("%T is not supported", n) } @@ -161,6 +185,8 @@ func (p *predicateParser) joinPredicates(op token.Token, a, b interface{}) (inte func (p *predicateParser) getJoinFunction(op token.Token) (interface{}, error) { var fn interface{} switch op { + case token.NOT: + fn = p.d.Operators.NOT case token.LAND: fn = p.d.Operators.AND case token.LOR: diff --git a/vendor/github.com/vulcand/predicate/parse_test.go b/vendor/github.com/vulcand/predicate/parse_test.go index 24392d41d41eb..dd11b816b76cb 100644 --- a/vendor/github.com/vulcand/predicate/parse_test.go +++ b/vendor/github.com/vulcand/predicate/parse_test.go @@ -30,6 +30,7 @@ func (s *PredicateSuite) getParserWithOpts(c *check.C, getID GetIdentifierFn, ge NEQ: numberNEQ, LE: numberLE, GE: numberGE, + NOT: numberNOT, }, Functions: map[string]interface{}{ "DivisibleBy": divisibleBy, @@ -38,6 +39,12 @@ func (s *PredicateSuite) getParserWithOpts(c *check.C, getID GetIdentifierFn, ge "number.DivisibleBy": divisibleBy, "Equals": Equals, "Contains": Contains, + "fnreturn": func(arg interface{}) (interface{}, error) { + return arg, nil + }, + "fnerr": func(arg interface{}) (interface{}, error) { + return nil, trace.BadParameter("don't like this parameter") + }, }, GetIdentifier: getID, GetProperty: getProperty, @@ -58,6 +65,35 @@ func (s *PredicateSuite) TestSinglePredicate(c *check.C) { c.Assert(fn(3), check.Equals, false) } +func (s *PredicateSuite) TestSinglePredicateNot(c *check.C) { + p := s.getParser(c) + + pr, err := p.Parse("!DivisibleBy(2)") + c.Assert(err, check.IsNil) + c.Assert(pr, check.FitsTypeOf, divisibleBy(2)) + fn := pr.(numberPredicate) + c.Assert(fn(2), check.Equals, false) + c.Assert(fn(3), check.Equals, true) +} + +func (s *PredicateSuite) TestSinglePredicateWithFunc(c *check.C) { + p := s.getParser(c) + + pr, err := p.Parse("DivisibleBy(fnreturn(2))") + c.Assert(err, check.IsNil) + c.Assert(pr, check.FitsTypeOf, divisibleBy(2)) + fn := pr.(numberPredicate) + c.Assert(fn(2), check.Equals, true) + c.Assert(fn(3), check.Equals, false) +} + +func (s *PredicateSuite) TestSinglePredicateWithFuncErr(c *check.C) { + p := s.getParser(c) + + _, err := p.Parse("DivisibleBy(fnerr(2))") + c.Assert(err, check.NotNil) +} + func (s *PredicateSuite) TestModulePredicate(c *check.C) { p := s.getParser(c) @@ -380,7 +416,6 @@ func (s *PredicateSuite) TestUnhappyCases(c *check.C) { "Remainder(banana)", // unsupported argument "Remainder(1, 2)", // unsupported arguments count "Remainder(Len)", // unsupported argument - `Remainder(Len("Ho"))`, // unsupported argument "Bla(1)", // unknown method call "0.2 && Remainder(1)", // unsupported value `Len("Ho") && 0.2`, // unsupported value @@ -405,6 +440,12 @@ func divisibleBy(divisor int) numberPredicate { } } +func numberNOT(a numberPredicate) numberPredicate { + return func(v int) bool { + return !a(v) + } +} + func numberAND(a, b numberPredicate) numberPredicate { return func(v int) bool { return a(v) && b(v) diff --git a/vendor/github.com/vulcand/predicate/predicate.go b/vendor/github.com/vulcand/predicate/predicate.go index 256e701461851..c5aacb4881671 100644 --- a/vendor/github.com/vulcand/predicate/predicate.go +++ b/vendor/github.com/vulcand/predicate/predicate.go @@ -1,3 +1,20 @@ +/* +Copyright 2014-2018 Vulcand Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + /* Predicate package used to create interpreted mini languages with Go syntax - mostly to define various predicates for configuration, e.g. Latency() > 40 || ErrorRate() > 0.5. @@ -76,6 +93,7 @@ type Operators struct { OR interface{} AND interface{} + NOT interface{} } // Parser takes the string with expression and calls the operators and functions.