Skip to content

Commit

Permalink
Merge pull request #6 from charmbracelet/ctx-race
Browse files Browse the repository at this point in the history
Ctx race
  • Loading branch information
aymanbagabas authored Aug 22, 2023
2 parents efe1ff2 + a30642c commit 1a051f8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 12 deletions.
51 changes: 40 additions & 11 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"net"
"sync"
"time"

gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -92,10 +93,14 @@ type Context interface {
}

type sshContext struct {
context.Context
*sync.RWMutex
ctx context.Context
mtx *sync.RWMutex
}

var _ context.Context = &sshContext{}

var _ sync.Locker = &sshContext{}

func newContext(srv *Server) (*sshContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background())
ctx := &sshContext{innerCtx, &sync.RWMutex{}}
Expand All @@ -120,21 +125,45 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.RWMutex.Lock()
defer ctx.RWMutex.Unlock()
ctx.Context = context.WithValue(ctx.Context, key, value)
ctx.mtx.Lock()
defer ctx.mtx.Unlock()
ctx.ctx = context.WithValue(ctx.ctx, key, value)
}

func (ctx *sshContext) Value(key interface{}) interface{} {
ctx.RWMutex.RLock()
defer ctx.RWMutex.RUnlock()
return ctx.Context.Value(key)
ctx.mtx.RLock()
defer ctx.mtx.RUnlock()
return ctx.ctx.Value(key)
}

func (ctx *sshContext) Done() <-chan struct{} {
ctx.RWMutex.RLock()
defer ctx.RWMutex.RUnlock()
return ctx.Context.Done()
ctx.mtx.RLock()
defer ctx.mtx.RUnlock()
return ctx.ctx.Done()
}

// Deadline implements context.Context.
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) {
ctx.mtx.RLock()
defer ctx.mtx.RUnlock()
return ctx.ctx.Deadline()
}

// Err implements context.Context.
func (ctx *sshContext) Err() error {
ctx.mtx.RLock()
defer ctx.mtx.RUnlock()
return ctx.ctx.Err()
}

// Lock implements sync.Locker.
func (ctx *sshContext) Lock() {
ctx.mtx.Lock()
}

// Unlock implements sync.Locker.
func (ctx *sshContext) Unlock() {
ctx.mtx.Unlock()
}

func (ctx *sshContext) User() string {
Expand Down
41 changes: 40 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ssh

import "testing"
import (
"testing"
"time"
)

func TestSetPermissions(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -69,3 +72,39 @@ func TestRaceRWIssue160(t *testing.T) {
t.Fatal(err)
}
}

// Taken from https://github.com/gliderlabs/ssh/pull/211/commits/02f9d573009f8c13755b6b90fa14a4f549b17b22
func TestSetValueConcurrency(t *testing.T) {
ctx, cancel := newContext(nil)
defer cancel()

go func() {
for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue
_, _ = ctx.Deadline()
_ = ctx.Err()
_ = ctx.Value("foo")
select {
case <-ctx.Done():
break
default:
}
}
}()
ctx.SetValue("bar", -1) // a context value which never changes
now := time.Now()
var cnt int64
go func() {
for time.Since(now) < 100*time.Millisecond {
cnt++
ctx.SetValue("foo", cnt) // a context value which changes a lot
}
cancel()
}()
<-ctx.Done()
if ctx.Value("foo") != cnt {
t.Fatal("context.Value(foo) doesn't match latest SetValue")
}
if ctx.Value("bar") != -1 {
t.Fatal("context.Value(bar) doesn't match latest SetValue")
}
}

0 comments on commit 1a051f8

Please sign in to comment.