Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Circuit breaker #10

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,14 @@ public ADCircuitBreakerService init() {

return this;
}

public Boolean isOpen() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion. This can also be written as stream().anyMatch(CircuitBreaker::isOpen) to be shorter and clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, will rewrite it in next pr. Thanks Lai!

for (CircuitBreaker breaker : breakers.values()) {
if (breaker.isOpen()) {
return true;
}
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ public class CommonErrorMessages {
public static final String NO_CHECKPOINT_ERR_MSG = "No checkpoints found for model id ";
public static final String MEMORY_LIMIT_EXCEEDED_ERR_MSG = "AD models memory usage exceeds our limit.";
public static final String FEATURE_NOT_AVAILABLE_ERR_MSG = "No Feature in current detection window.";
public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = "AD memory circuit is broken.";
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.time.Instant;

import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService;
import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing;
import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.ClientException;
Expand Down Expand Up @@ -120,13 +121,14 @@ public class AnomalyResultTransportAction extends HandledTransportAction<ActionR
private final IndexNameExpressionResolver indexNameExpressionResolver;
private final ThreadPool threadPool;
private final BackoffPolicy resultSavingBackoffPolicy;
private final ADCircuitBreakerService adCircuitBreakerService;

@Inject
public AnomalyResultTransportAction(ActionFilters actionFilters, TransportService transportService, Client client,
Settings settings, ADStateManager manager, ColdStartRunner eventExecutor,
AnomalyDetectionIndices anomalyDetectionIndices, FeatureManager featureManager, ModelManager modelManager,
HashRing hashRing, ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver,
ThreadPool threadPool) {
ThreadPool threadPool, ADCircuitBreakerService adCircuitBreakerService) {
super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new);
this.transportService = transportService;
this.client = client;
Expand All @@ -145,6 +147,7 @@ public AnomalyResultTransportAction(ActionFilters actionFilters, TransportServic
this.threadPool = threadPool;
this.resultSavingBackoffPolicy = BackoffPolicy.exponentialBackoff(AnomalyDetectorSettings.BACKOFF_INITIAL_DELAY.get(settings),
AnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF.get(settings));
this.adCircuitBreakerService = adCircuitBreakerService;
}

private List<FeatureData> getFeatureData(double[] currentFeature, AnomalyDetector detector) {
Expand Down Expand Up @@ -207,6 +210,11 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<
AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest);
String adID = request.getAdID();

if (adCircuitBreakerService.isOpen()) {
listener.onFailure(new LimitExceededException(adID, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG));
return;
}

try {
Optional<AnomalyDetector> detector = stateManager.getAnomalyDetector(adID);
if (!detector.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

package com.amazon.opendistroforelasticsearch.ad.transport;

import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;
import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult;
import org.apache.logging.log4j.LogManager;
Expand All @@ -30,17 +33,24 @@ public class RCFResultTransportAction extends HandledTransportAction<RCFResultRe

private static final Logger LOG = LogManager.getLogger(RCFResultTransportAction.class);
private ModelManager manager;
private ADCircuitBreakerService adCircuitBreakerService;

@Inject
public RCFResultTransportAction(ActionFilters actionFilters, TransportService transportService,
ModelManager manager) {
ModelManager manager, ADCircuitBreakerService adCircuitBreakerService) {
super(RCFResultAction.NAME, transportService, actionFilters, RCFResultRequest::new);
this.manager = manager;
this.adCircuitBreakerService = adCircuitBreakerService;
}

@Override
protected void doExecute(Task task, RCFResultRequest request, ActionListener<RCFResultResponse> listener) {

if (adCircuitBreakerService.isOpen()) {
listener.onFailure(new LimitExceededException(request.getAdID(), CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG));
return;
}

try {
LOG.info("Serve rcf request for {}", request.getModelID());
RcfResult result = manager.getRcfResult(request.getAdID(), request.getModelID(), request.getFeatures());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
package com.amazon.opendistroforelasticsearch.ad.breaker;

import org.elasticsearch.monitor.jvm.JvmService;
import org.elasticsearch.monitor.jvm.JvmStats;
import org.junit.Before;
import org.junit.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.when;

public class ADCircuitBreakerServiceTests {

Expand All @@ -35,6 +38,12 @@ public class ADCircuitBreakerServiceTests {
@Mock
JvmService jvmService;

@Mock
JvmStats jvmStats;

@Mock
JvmStats.Mem mem;

@Before
public void setup() {
MockitoAnnotations.initMocks(this);
Expand Down Expand Up @@ -88,4 +97,23 @@ public void testInit() {
assertThat(adCircuitBreakerService.init(), is(notNullValue()));
}

@Test
public void testIsOpen() {
when(jvmService.stats()).thenReturn(jvmStats);
when(jvmStats.getMem()).thenReturn(mem);
when(mem.getHeapUsedPercent()).thenReturn((short)50);

adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService));
assertThat(adCircuitBreakerService.isOpen(), equalTo(false));
}

@Test
public void testIsOpen1() {
when(jvmService.stats()).thenReturn(jvmStats);
when(jvmStats.getMem()).thenReturn(mem);
when(mem.getHeapUsedPercent()).thenReturn((short)90);

adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService));
assertThat(adCircuitBreakerService.isOpen(), equalTo(true));
}
}
Loading