Skip to content

Commit

Permalink
[branch/v6.2] Backport #7360 (#7574)
Browse files Browse the repository at this point in the history
  • Loading branch information
xacrimon authored Jul 21, 2021
1 parent 8233ec1 commit fe6824e
Show file tree
Hide file tree
Showing 17 changed files with 1,095 additions and 451 deletions.
48 changes: 48 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,54 @@ func (c *Client) DeleteAllNodes(ctx context.Context, namespace string) error {
return trail.FromGRPC(err)
}

// StreamSessionEvents streams audit events from a given session recording.
func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) {
request := &proto.StreamSessionEventsRequest{
SessionID: sessionID,
StartIndex: int32(startIndex),
}

ch := make(chan events.AuditEvent)
e := make(chan error, 1)

stream, err := c.grpc.StreamSessionEvents(ctx, request)
if err != nil {
e <- trace.Wrap(err)
return ch, e
}

go func() {
outer:
for {
oneOf, err := stream.Recv()
if err != nil {
if err != io.EOF {
e <- trace.Wrap(trail.FromGRPC(err))
} else {
close(ch)
}

break outer
}

event, err := events.FromOneOf(*oneOf)
if err != nil {
e <- trace.Wrap(trail.FromGRPC(err))
break outer
}

select {
case ch <- event:
case <-ctx.Done():
e <- trace.Wrap(ctx.Err())
break outer
}
}
}()

return ch, e
}

// SearchEvents allows searching for events with a full pagination support.
func (c *Client) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) {
request := &proto.GetEventsRequest{
Expand Down
1,169 changes: 720 additions & 449 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions api/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,15 @@ message IsMFARequiredRequest {
}
}

// StreamSessionEventsRequest is a request containing needed data to fetch a session recording.
message StreamSessionEventsRequest {
// SessionID is the ID for a given session in an UUIDv4 format.
string SessionID = 1;
// StartIndex is the index of the event to resume the stream after.
// A StartIndex of 0 creates a new stream.
int32 StartIndex = 2;
}

// NodeLogin specifies an SSH node and OS login.
message NodeLogin {
// Node can be node's hostname or UUID.
Expand Down Expand Up @@ -1130,4 +1139,7 @@ service AuthService {
rpc GetAuthPreference(google.protobuf.Empty) returns (types.AuthPreferenceV2);
// SetAuthPreference sets cluster auth preference.
rpc SetAuthPreference(types.AuthPreferenceV2) returns (google.protobuf.Empty);

// StreamSessionEvents streams audit events from a given session recording.
rpc StreamSessionEvents(StreamSessionEventsRequest) returns (stream events.OneOf);
}
2 changes: 1 addition & 1 deletion e
Submodule e updated from ac31c6 to 878ae1
88 changes: 88 additions & 0 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"os/exec"
"os/user"
"path/filepath"
"reflect"
"regexp"
"runtime/pprof"
"strconv"
Expand All @@ -47,6 +48,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/testauthority"
Expand Down Expand Up @@ -5567,3 +5569,89 @@ func TestTraitsPropagation(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "hello leaf", strings.TrimSpace(outputLeaf))
}

// TestSessionStreaming tests streaming events from session recordings.
func TestSessionStreaming(t *testing.T) {
ctx := context.Background()
sessionID := session.ID(uuid.New())
log := testlog.FailureOnly(t)
privateKey, publicKey, err := testauthority.New().GenerateKeyPair("")
require.NoError(t, err)

teleport := NewInstance(InstanceConfig{
ClusterName: Site,
HostID: "00000000-0000-0000-0000-000000000000",
NodeName: Host,
Ports: ports.PopIntSlice(6),
Priv: privateKey,
Pub: publicKey,
log: log,
})

if err := teleport.Create(nil, true, nil); err != nil {
t.Fatalf("Unexpected response from Create: %v", err)
}

if err := teleport.Start(); err != nil {
t.Fatalf("Unexpected response from Start: %v", err)
}

defer teleport.StopAll()

api := teleport.GetSiteAPI(Site)
uploadStream, err := api.CreateAuditStream(ctx, sessionID)
require.Nil(t, err)

generatedSession := events.GenerateTestSession(events.SessionParams{
PrintEvents: 100,
SessionID: string(sessionID),
ServerID: "00000000-0000-0000-0000-000000000000",
})

for _, event := range generatedSession {
err := uploadStream.EmitAuditEvent(ctx, event)
require.NoError(t, err)
}

err = uploadStream.Complete(ctx)
require.Nil(t, err)
start := time.Now()

// retry in case of error
outer:
for time.Since(start) < time.Minute*5 {
time.Sleep(time.Second * 5)

receivedSession := make([]apievents.AuditEvent, 0)
sessionPlayback, e := api.StreamSessionEvents(ctx, sessionID, 0)

inner:
for {
select {
case event, more := <-sessionPlayback:
if !more {
break inner
}

receivedSession = append(receivedSession, event)
case <-ctx.Done():
require.Nil(t, ctx.Err())
case err := <-e:
require.Nil(t, err)
case <-time.After(time.Minute * 5):
t.FailNow()
}
}

for i := range generatedSession {
receivedSession[i].SetClusterName("")
if !reflect.DeepEqual(generatedSession[i], receivedSession[i]) {
continue outer
}
}

return
}

t.FailNow()
}
13 changes: 13 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,19 @@ func (a *ServerWithRoles) SearchSessionEvents(fromUTC, toUTC time.Time, limit in
return events, lastKey, nil
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan events.AuditEvent, chan error) {
if err := a.action(defaults.Namespace, types.KindSession, types.VerbList); err != nil {
c, e := make(chan events.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}

return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
}

// NewAdminAuthServer returns auth server authorized as admin,
// used for auth server cached access
func NewAdminAuthServer(authServer *Server, sessions session.Service, alog events.IAuditLog) (ClientI, error) {
Expand Down
7 changes: 7 additions & 0 deletions lib/auth/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,13 @@ func (c *Client) GetSessionEvents(namespace string, sid session.ID, afterN int,
return retval, nil
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (c *Client) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan events.AuditEvent, chan error) {
return c.APIClient.StreamSessionEvents(ctx, string(sessionID), startIndex)
}

// GetNamespaces returns a list of namespaces
func (c *Client) GetNamespaces() ([]services.Namespace, error) {
out, err := c.Get(c.Endpoint("namespaces"), url.Values{})
Expand Down
32 changes: 32 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2537,6 +2537,38 @@ func (g *GRPCServer) authenticate(ctx context.Context) (*grpcContext, error) {
}, nil
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (g *GRPCServer) StreamSessionEvents(req *proto.StreamSessionEventsRequest, stream proto.AuthService_StreamSessionEventsServer) error {
auth, err := g.authenticate(stream.Context())
if err != nil {
return trace.Wrap(err)
}

c, e := auth.ServerWithRoles.StreamSessionEvents(stream.Context(), session.ID(req.SessionID), int64(req.StartIndex))

for {
select {
case event, more := <-c:
if !more {
return nil
}

oneOf, err := apievents.ToOneOf(event)
if err != nil {
return trail.ToGRPC(trace.Wrap(err))
}

if err := stream.Send(oneOf); err != nil {
return trail.ToGRPC(trace.Wrap(err))
}
case err := <-e:
return trail.ToGRPC(trace.Wrap(err))
}
}
}

// GetEvents searches for events on the backend and sends them back in a response.
func (g *GRPCServer) GetEvents(ctx context.Context, req *proto.GetEventsRequest) (*proto.Events, error) {
auth, err := g.authenticate(ctx)
Expand Down
5 changes: 5 additions & 0 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,11 @@ type IAuditLog interface {
// WaitForDelivery waits for resources to be released and outstanding requests to
// complete after calling Close method
WaitForDelivery(context.Context) error

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error)
}

// EventFields instance is attached to every logged event
Expand Down
87 changes: 87 additions & 0 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,86 @@ func (l *AuditLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, orde
return l.localLog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey)
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
l.log.Debugf("StreamSessionEvents(%v)", sessionID)
e := make(chan error, 1)
c := make(chan apievents.AuditEvent)

tarballPath := filepath.Join(l.playbackDir, string(sessionID)+".stream.tar")
downloadCtx, cancel := l.createOrGetDownload(tarballPath)

// Wait until another in progress download finishes and use it's tarball.
if cancel == nil {
l.log.Debugf("Another download is in progress for %v, waiting until it gets completed.", sessionID)
select {
case <-downloadCtx.Done():
case <-l.ctx.Done():
e <- trace.BadParameter("audit log is closing, aborting the download")
return c, e
}
}
defer cancel()
rawSession, err := os.OpenFile(tarballPath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0640)
if err != nil {
e <- trace.Wrap(err)
return c, e
}

start := time.Now()
if err := l.UploadHandler.Download(l.ctx, sessionID, rawSession); err != nil {
// remove partially downloaded tarball
if rmErr := os.Remove(tarballPath); rmErr != nil {
l.log.WithError(rmErr).Warningf("Failed to remove file %v.", tarballPath)
}

e <- trace.Wrap(err)
return c, e
}

l.log.WithField("duration", time.Since(start)).Debugf("Downloaded %v to %v.", sessionID, tarballPath)
_, err = rawSession.Seek(0, 0)
if err != nil {
e <- trace.Wrap(err)
return c, e
}

if err != nil {
e <- trace.Wrap(err)
return c, e
}

protoReader := NewProtoReader(rawSession)

go func() {
for {
if ctx.Err() != nil {
e <- trace.Wrap(ctx.Err())
break
}

event, err := protoReader.Read(ctx)
if err != nil {
if err != io.EOF {
e <- trace.Wrap(err)
} else {
close(c)
}

break
}

if event.GetIndex() >= startIndex {
c <- event
}
}
}()

return c, e
}

// getLocalLog returns the local (file based) audit log.
func (l *AuditLog) getLocalLog() IAuditLog {
l.RLock()
Expand Down Expand Up @@ -1231,3 +1311,10 @@ func (a *closedLogger) WaitForDelivery(context.Context) error {
func (a *closedLogger) Close() error {
return trace.NotImplemented(loggerClosedMessage)
}

func (a *closedLogger) StreamSessionEvents(_ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.NotImplemented(loggerClosedMessage)

return c, e
}
6 changes: 5 additions & 1 deletion lib/events/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ func (d *DiscardAuditLog) SearchEvents(fromUTC, toUTC time.Time, namespace strin
func (d *DiscardAuditLog) SearchSessionEvents(fromUTC time.Time, toUTC time.Time, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) {
return make([]apievents.AuditEvent, 0), "", nil
}

func (d *DiscardAuditLog) UploadSessionRecording(SessionRecording) error {
return nil
}
func (d *DiscardAuditLog) EmitAuditEvent(ctx context.Context, event AuditEvent) error {
return nil
}
func (d *DiscardAuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
close(c)
return c, e
}
9 changes: 9 additions & 0 deletions lib/events/dynamoevents/dynamoevents.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,3 +1436,12 @@ func convertError(err error) error {
return err
}
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise it is simply closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (l *Log) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.NotImplemented("not implemented")
return c, e
}
Loading

0 comments on commit fe6824e

Please sign in to comment.