Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Extract a new class for entity frequenty tracking #389

Merged
merged 2 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ List<String> jacocoExclusions = [
'com.amazon.opendistroforelasticsearch.ad.transport.SearchAnomalyDetectorInfoTransportAction*',

// TODO: hc caused coverage to drop
'com.amazon.opendistroforelasticsearch.ad.NodeStateManager',
'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices',
'com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler',
'com.amazon.opendistroforelasticsearch.ad.util.ThrowingSupplierWrapper',
Expand Down
Binary file added docs/entity-priority.pdf
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,21 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Comparator;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.amazon.opendistroforelasticsearch.ad.ExpiringState;
import com.amazon.opendistroforelasticsearch.ad.MaintenanceState;
import com.amazon.opendistroforelasticsearch.ad.MemoryTracker;
import com.amazon.opendistroforelasticsearch.ad.MemoryTracker.Origin;
import com.amazon.opendistroforelasticsearch.ad.annotation.Generated;
import com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao;
import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelState;
Expand All @@ -64,90 +59,22 @@
public class CacheBuffer implements ExpiringState, MaintenanceState {
private static final Logger LOG = LogManager.getLogger(CacheBuffer.class);

static class PriorityNode {
private String key;
private float priority;

PriorityNode(String key, float priority) {
this.priority = priority;
this.key = key;
}

@Generated
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
if (obj instanceof PriorityNode) {
PriorityNode other = (PriorityNode) obj;

EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(key, other.key);
return equalsBuilder.isEquals();
}
return false;
}

@Generated
@Override
public int hashCode() {
return new HashCodeBuilder().append(key).toHashCode();
}

@Generated
@Override
public String toString() {
ToStringBuilder builder = new ToStringBuilder(this);
builder.append("key", key);
builder.append("priority", priority);
return builder.toString();
}
}

static class PriorityNodeComparator implements Comparator<PriorityNode> {

@Override
public int compare(PriorityNode priority, PriorityNode priority2) {
int equality = priority.key.compareTo(priority2.key);
if (equality == 0) {
// this is consistent with PriorityNode's equals method
return 0;
}
// if not equal, first check priority
int cmp = Float.compare(priority.priority, priority2.priority);
if (cmp == 0) {
// if priority is equal, use lexicographical order of key
cmp = equality;
}
return cmp;
}
}
// max entities to track per detector
private final int MAX_TRACKING_ENTITIES = 1000000;

private final int minimumCapacity;
// key -> Priority node
private final ConcurrentHashMap<String, PriorityNode> key2Priority;
private final ConcurrentSkipListSet<PriorityNode> priorityList;
// key -> value
private final ConcurrentHashMap<String, ModelState<EntityModel>> items;
// when detector is created.  Can be reset.  Unit: seconds
private long landmarkSecs;
// length of seconds in one interval.  Used to compute elapsed periods
// since the detector has been enabled.
private long intervalSecs;
// memory consumption per entity
private final long memoryConsumptionPerEntity;
private final MemoryTracker memoryTracker;
private final Clock clock;
private final CheckpointDao checkpointDao;
private final Duration modelTtl;
private final String detectorId;
private Instant lastUsedTime;
private final int DECAY_CONSTANT;
private final long reservedBytes;
private final PriorityTracker priorityTracker;
private final Clock clock;

public CacheBuffer(
int minimumCapacity,
Expand All @@ -163,20 +90,20 @@ public CacheBuffer(
throw new IllegalArgumentException("minimum capacity should be larger than 0");
}
this.minimumCapacity = minimumCapacity;
this.key2Priority = new ConcurrentHashMap<>();
this.priorityList = new ConcurrentSkipListSet<>(new PriorityNodeComparator());

this.items = new ConcurrentHashMap<>();
this.landmarkSecs = clock.instant().getEpochSecond();
this.intervalSecs = intervalSecs;

this.memoryConsumptionPerEntity = memoryConsumptionPerEntity;
this.memoryTracker = memoryTracker;
this.clock = clock;

this.checkpointDao = checkpointDao;
this.modelTtl = modelTtl;
this.detectorId = detectorId;
this.lastUsedTime = clock.instant();
this.DECAY_CONSTANT = 3;

this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity;
this.clock = clock;
this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES);
}

/**
Expand All @@ -186,50 +113,13 @@ public CacheBuffer(
* @param entityModelId model Id
*/
private void update(String entityModelId) {
PriorityNode node = key2Priority.computeIfAbsent(entityModelId, k -> new PriorityNode(entityModelId, 0f));
// reposition this node
this.priorityList.remove(node);
node.priority = getUpdatedPriority(node.priority);
this.priorityList.add(node);
priorityTracker.updatePriority(entityModelId);

Instant now = clock.instant();
items.get(entityModelId).setLastUsedTime(now);
lastUsedTime = now;
}

public float getUpdatedPriority(float oldPriority) {
long increment = computeWeightedCountIncrement();
// if overflowed, we take the short cut from now on
oldPriority += Math.log(1 + Math.exp(increment - oldPriority));
// if overflow happens, using \log(g(t_k-L)) instead.
if (oldPriority == Float.POSITIVE_INFINITY) {
oldPriority = increment;
}
return oldPriority;
}

/**
* Compute periods relative to landmark and the weighted count increment using 0.125n.
* Multiply by 0.125 is implemented using right shift for efficiency.
* @return the weighted count increment used in the priority update step.
*/
private long computeWeightedCountIncrement() {
long periods = (clock.instant().getEpochSecond() - landmarkSecs) / intervalSecs;
return periods >> DECAY_CONSTANT;
}

/**
* Compute the weighted total count by considering landmark
* \log(C)=\log(\sum_{i=1}^{n} (g(t_i-L)/g(t-L)))=\log(\sum_{i=1}^{n} (g(t_i-L))-\log(g(t-L))
* @return the minimum priority entity's ID and priority
*/
public Entry<String, Float> getMinimumPriority() {
PriorityNode smallest = priorityList.first();
long periods = (clock.instant().getEpochSecond() - landmarkSecs) / intervalSecs;
float detectorWeight = periods >> DECAY_CONSTANT;
return new SimpleImmutableEntry<>(smallest.key, smallest.priority - detectorWeight);
}

/**
* Insert the model state associated with a model Id to the cache
* @param entityModelId the model Id
Expand Down Expand Up @@ -257,9 +147,7 @@ public void put(String entityModelId, ModelState<EntityModel> value) {
private void put(String entityModelId, ModelState<EntityModel> value, float priority) {
ModelState<EntityModel> contentNode = items.get(entityModelId);
if (contentNode == null) {
PriorityNode node = new PriorityNode(entityModelId, priority);
key2Priority.put(entityModelId, node);
priorityList.add(node);
priorityTracker.addPriority(entityModelId, priority);
items.put(entityModelId, value);
Instant now = clock.instant();
value.setLastUsedTime(now);
Expand Down Expand Up @@ -319,9 +207,9 @@ public ModelState<EntityModel> remove() {
// The removed one loses references and soon GC will collect it.
// We have memory tracking correction to fix incorrect memory usage record.
// put: not a problem as it is unlikely we are removing and putting the same thing
PriorityNode smallest = priorityList.first();
if (smallest != null) {
return remove(smallest.key);
Optional<String> key = priorityTracker.getMinimumPriorityEntityId();
if (key.isPresent()) {
return remove(key.get());
}
return null;
}
Expand All @@ -334,12 +222,11 @@ public ModelState<EntityModel> remove() {
* is no associated ModelState for the key
*/
public ModelState<EntityModel> remove(String keyToRemove) {
// remove if the key matches; priority does not matter
priorityList.remove(new PriorityNode(keyToRemove, 0));
priorityTracker.removePriority(keyToRemove);

// if shared cache is empty, we are using reserved memory
boolean reserved = sharedCacheEmpty();

key2Priority.remove(keyToRemove);
ModelState<EntityModel> valueRemoved = items.remove(keyToRemove);

if (valueRemoved != null) {
Expand Down Expand Up @@ -382,15 +269,17 @@ public long getMemoryConsumptionPerEntity() {

/**
*
* If the cache is not full, check if some other items can replace internal entities.
* If the cache is not full, check if some other items can replace internal entities
* within the same detector.
*
* @param priority another entity's priority
* @return whether one entity can be replaced by another entity with a certain priority
*/
public boolean canReplace(float priority) {
public boolean canReplaceWithinDetector(float priority) {
if (items.isEmpty()) {
return false;
}
Entry<String, Float> minPriorityItem = getMinimumPriority();
Entry<String, Float> minPriorityItem = priorityTracker.getMinimumPriority();
return minPriorityItem != null && priority > minPriorityItem.getValue();
}

Expand All @@ -415,15 +304,6 @@ public void maintenance() {
ModelState<EntityModel> modelState = entry.getValue();
Instant now = clock.instant();

// we can have ConcurrentModificationException when serializing
// and updating rcf model at the same time. To prevent this,
// we need to have a deep copy of models or have a lock. Both
// options are costly.
// As we are gonna retry serializing either when the entity is
// evicted out of cache or during the next maintenance period,
// don't do anything when the exception happens.
checkpointDao.write(modelState, entityModelId);

if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) {
// race conditions can happen between the put and one of the following operations:
// remove: not a problem as all of the data structures are concurrent.
Expand All @@ -433,7 +313,17 @@ public void maintenance() {
// We have memory tracking correction to fix incorrect memory usage record.
// put: not a problem as we are unlikely to maintain an entry that's not
// already in the cache
// remove method saves checkpoint as well
remove(entityModelId);
} else {
// we can have ConcurrentModificationException when serializing
// and updating rcf model at the same time. To prevent this,
// we need to have a deep copy of models or have a lock. Both
// options are costly.
// As we are gonna retry serializing either when the entity is
// evicted out of cache or during the next maintenance period,
// don't do anything when the exception happens.
checkpointDao.write(modelState, entityModelId);
}
} catch (Exception e) {
LOG.warn("Failed to finish maintenance for model id " + entityModelId, e);
Expand Down Expand Up @@ -471,14 +361,6 @@ public long getLastUsedTime(String entityModelId) {
return -1;
}

/**
*
* @return Get the model of highest priority entity
*/
public Optional<String> getHighestPriorityEntityModelId() {
return Optional.of(priorityList).map(list -> list.last()).map(node -> node.key);
}

/**
*
* @param entityModelId entity Id
Expand All @@ -501,8 +383,7 @@ public void clear() {
memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.MULTI_ENTITY_DETECTOR);
}
items.clear();
key2Priority.clear();
priorityList.clear();
priorityTracker.clearPriority();
}

/**
Expand Down Expand Up @@ -561,4 +442,8 @@ public String getDetectorId() {
public List<ModelState<?>> getAllModels() {
return items.values().stream().collect(Collectors.toList());
}

public PriorityTracker getPriorityTracker() {
return priorityTracker;
}
}
Loading