diff --git a/core/circuitbreaker/circuit_breaker.go b/core/circuitbreaker/circuit_breaker.go index 976aade0b..531b968f5 100644 --- a/core/circuitbreaker/circuit_breaker.go +++ b/core/circuitbreaker/circuit_breaker.go @@ -2,7 +2,6 @@ package circuitbreaker import ( "sync/atomic" - "unsafe" "github.com/alibaba/sentinel-golang/core/base" sbase "github.com/alibaba/sentinel-golang/core/stat/base" @@ -33,6 +32,13 @@ const ( Open ) +func newState() *State { + var state State + state = Closed + + return &state +} + func (s *State) String() string { switch s.get() { case Closed: @@ -47,21 +53,15 @@ func (s *State) String() string { } func (s *State) get() State { - statePtr := (*int32)(unsafe.Pointer(s)) - return State(atomic.LoadInt32(statePtr)) + return State(atomic.LoadInt32((*int32)(s))) } func (s *State) set(update State) { - statePtr := (*int32)(unsafe.Pointer(s)) - newState := int32(update) - atomic.StoreInt32(statePtr, newState) + atomic.StoreInt32((*int32)(s), int32(update)) } -func (s *State) casState(expect State, update State) bool { - statePtr := (*int32)(unsafe.Pointer(s)) - oldState := int32(expect) - newState := int32(update) - return atomic.CompareAndSwapInt32(statePtr, oldState, newState) +func (s *State) cas(expect State, update State) bool { + return atomic.CompareAndSwapInt32((*int32)(s), int32(expect), int32(update)) } // StateChangeListener listens on the circuit breaker state change event @@ -132,7 +132,7 @@ func (b *circuitBreakerBase) updateNextRetryTimestamp() { // fromClosedToOpen updates circuit breaker state machine from closed to open. // Return true only if current goroutine successfully accomplished the transformation. func (b *circuitBreakerBase) fromClosedToOpen(snapshot interface{}) bool { - if b.state.casState(Closed, Open) { + if b.state.cas(Closed, Open) { b.updateNextRetryTimestamp() for _, listener := range stateChangeListeners { listener.OnTransformToOpen(Closed, *b.rule, snapshot) @@ -145,7 +145,7 @@ func (b *circuitBreakerBase) fromClosedToOpen(snapshot interface{}) bool { // fromOpenToHalfOpen updates circuit breaker state machine from open to half-open. // Return true only if current goroutine successfully accomplished the transformation. func (b *circuitBreakerBase) fromOpenToHalfOpen(ctx *base.EntryContext) bool { - if b.state.casState(Open, HalfOpen) { + if b.state.cas(Open, HalfOpen) { for _, listener := range stateChangeListeners { listener.OnTransformToHalfOpen(Open, *b.rule) } @@ -158,7 +158,7 @@ func (b *circuitBreakerBase) fromOpenToHalfOpen(ctx *base.EntryContext) bool { // if the current circuit breaker performs the probe through this entry, but the entry was blocked, // this hook will guarantee current circuit breaker state machine will rollback to Open from Half-Open entry.WhenExit(func(entry *base.SentinelEntry, ctx *base.EntryContext) error { - if ctx.IsBlocked() && b.state.casState(HalfOpen, Open) { + if ctx.IsBlocked() && b.state.cas(HalfOpen, Open) { for _, listener := range stateChangeListeners { listener.OnTransformToOpen(HalfOpen, *b.rule, 1.0) } @@ -174,7 +174,7 @@ func (b *circuitBreakerBase) fromOpenToHalfOpen(ctx *base.EntryContext) bool { // fromHalfOpenToOpen updates circuit breaker state machine from half-open to open. // Return true only if current goroutine successfully accomplished the transformation. func (b *circuitBreakerBase) fromHalfOpenToOpen(snapshot interface{}) bool { - if b.state.casState(HalfOpen, Open) { + if b.state.cas(HalfOpen, Open) { b.updateNextRetryTimestamp() for _, listener := range stateChangeListeners { listener.OnTransformToOpen(HalfOpen, *b.rule, snapshot) @@ -187,7 +187,7 @@ func (b *circuitBreakerBase) fromHalfOpenToOpen(snapshot interface{}) bool { // fromHalfOpenToOpen updates circuit breaker state machine from half-open to closed // Return true only if current goroutine successfully accomplished the transformation. func (b *circuitBreakerBase) fromHalfOpenToClosed() bool { - if b.state.casState(HalfOpen, Closed) { + if b.state.cas(HalfOpen, Closed) { for _, listener := range stateChangeListeners { listener.OnTransformToClosed(HalfOpen, *b.rule) } @@ -206,14 +206,12 @@ type slowRtCircuitBreaker struct { } func newSlowRtCircuitBreakerWithStat(r *Rule, stat *slowRequestLeapArray) *slowRtCircuitBreaker { - status := new(State) - status.set(Closed) return &slowRtCircuitBreaker{ circuitBreakerBase: circuitBreakerBase{ rule: r, retryTimeoutMs: r.RetryTimeoutMs, nextRetryTimestampMs: 0, - state: status, + state: newState(), }, stat: stat, maxAllowedRt: r.MaxAllowedRtMs, @@ -392,15 +390,12 @@ type errorRatioCircuitBreaker struct { } func newErrorRatioCircuitBreakerWithStat(r *Rule, stat *errorCounterLeapArray) *errorRatioCircuitBreaker { - status := new(State) - status.set(Closed) - return &errorRatioCircuitBreaker{ circuitBreakerBase: circuitBreakerBase{ rule: r, retryTimeoutMs: r.RetryTimeoutMs, nextRetryTimestampMs: 0, - state: status, + state: newState(), }, minRequestAmount: r.MinRequestAmount, errorRatioThreshold: r.Threshold, @@ -572,15 +567,12 @@ type errorCountCircuitBreaker struct { } func newErrorCountCircuitBreakerWithStat(r *Rule, stat *errorCounterLeapArray) *errorCountCircuitBreaker { - status := new(State) - status.set(Closed) - return &errorCountCircuitBreaker{ circuitBreakerBase: circuitBreakerBase{ rule: r, retryTimeoutMs: r.RetryTimeoutMs, nextRetryTimestampMs: 0, - state: status, + state: newState(), }, minRequestAmount: r.MinRequestAmount, errorCountThreshold: uint64(r.Threshold), diff --git a/core/circuitbreaker/circuit_breaker_test.go b/core/circuitbreaker/circuit_breaker_test.go index 0fd879599..92d0489e9 100644 --- a/core/circuitbreaker/circuit_breaker_test.go +++ b/core/circuitbreaker/circuit_breaker_test.go @@ -64,7 +64,7 @@ func (s *StateChangeListenerMock) OnTransformToHalfOpen(prev State, rule Rule) { func TestStatus(t *testing.T) { t.Run("get_set", func(t *testing.T) { - status := new(State) + status := newState() assert.True(t, status.get() == Closed) status.set(Open) @@ -72,13 +72,13 @@ func TestStatus(t *testing.T) { }) t.Run("cas", func(t *testing.T) { - status := new(State) + status := newState() assert.True(t, status.get() == Closed) - assert.True(t, status.casState(Closed, Open)) - assert.True(t, !status.casState(Closed, Open)) + assert.True(t, status.cas(Closed, Open)) + assert.True(t, !status.cas(Closed, Open)) status.set(HalfOpen) - assert.True(t, status.casState(HalfOpen, Open)) + assert.True(t, status.cas(HalfOpen, Open)) }) }