Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch classic search to Backend interface. #2109

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 24 additions & 34 deletions src/engine_classic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,16 @@ void EngineClassic::UpdateFromUciOptions() {
const auto network_configuration =
NetworkFactory::BackendConfiguration(options_);
if (network_configuration_ != network_configuration) {
network_ = NetworkFactory::LoadNetwork(options_);
backend_ =
CreateMemCache(BackendManager::Get()->CreateFromParams(options_),
options_.Get<int>(SharedBackendParams::kNNCacheSizeId));
network_configuration_ = network_configuration;
} else {
// If network is not changed, cache size still may have changed.
backend_->SetCacheCapacity(
options_.Get<int>(SharedBackendParams::kNNCacheSizeId));
}

// Cache size.
cache_.SetCapacity(options_.Get<int>(SharedBackendParams::kNNCacheSizeId));

// Check whether we can update the move timer in "Go".
strict_uci_timing_ = options_.Get<bool>(kStrictUciTiming);
}
Expand All @@ -178,12 +181,12 @@ void EngineClassic::NewGame() {
// newgame and goes straight into go.
ResetMoveTimer();
SharedLock lock(busy_mutex_);
cache_.Clear();
search_.reset();
tree_.reset();
CreateFreshTimeManager();
current_position_ = {ChessBoard::kStartposFen, {}};
UpdateFromUciOptions();
backend_->ClearCache();
}

void EngineClassic::SetPosition(const std::string& fen,
Expand Down Expand Up @@ -265,52 +268,39 @@ class PonderResponseTransformer : public TransformingUciResponder {
std::string ponder_move_;
};

void ValueOnlyGo(classic::NodeTree* tree, Network* network,
const OptionsDict& options,
void ValueOnlyGo(classic::NodeTree* tree, Backend* backend,
std::unique_ptr<UciResponder> responder) {
auto input_format = network->GetCapabilities().input_format;

const auto& board = tree->GetPositionHistory().Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();
tree->GetCurrentHead()->CreateEdges(legal_moves);
PositionHistory history = tree->GetPositionHistory();
std::vector<InputPlanes> planes;
std::vector<float> comp_q;
comp_q.reserve(legal_moves.size());
auto comp = backend->CreateComputation();
for (auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
if (history.ComputeGameResult() == GameResult::UNDECIDED) {
planes.emplace_back(EncodePositionForNN(
input_format, history, 8, FillEmptyHistory::FEN_ONLY, nullptr));
comp_q.emplace_back();
comp->AddInput(
EvalPosition{
.pos = history.GetPositions(),
.legal_moves = {},
},
EvalResultPtr{.q = &comp_q.back()});
}
history.Pop();
}

std::vector<float> comp_q;
int batch_size = options.Get<int>(classic::SearchParams::kMiniBatchSizeId);
if (batch_size == 0) batch_size = network->GetMiniBatchSize();

for (size_t i = 0; i < planes.size(); i += batch_size) {
auto comp = network->NewComputation();
for (int j = 0; j < batch_size; j++) {
comp->AddInput(std::move(planes[i + j]));
if (i + j + 1 == planes.size()) break;
}
comp->ComputeBlocking();

for (int j = 0; j < batch_size; j++) comp_q.push_back(comp->GetQVal(j));
}

Move best;
int comp_idx = 0;
float max_q = std::numeric_limits<float>::lowest();
for (auto edge : tree->GetCurrentHead()->Edges()) {
for (size_t comp_idx = 0; auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
auto result = history.ComputeGameResult();
float q = -1;
if (result == GameResult::UNDECIDED) {
// NN eval is for side to move perspective - so if its good, its bad for
// us.
q = -comp_q[comp_idx];
comp_idx++;
q = -comp_q[comp_idx++];
} else if (result == GameResult::DRAW) {
q = 0;
} else {
Expand Down Expand Up @@ -375,7 +365,7 @@ void EngineClassic::Go(const GoParams& params) {
responder = std::make_unique<MovesLeftResponseFilter>(std::move(responder));
}
if (options_.Get<bool>(kValueOnly)) {
ValueOnlyGo(tree_.get(), network_.get(), options_, std::move(responder));
ValueOnlyGo(tree_.get(), backend_.get(), std::move(responder));
return;
}

Expand All @@ -385,10 +375,10 @@ void EngineClassic::Go(const GoParams& params) {

auto stopper = time_manager_->GetStopper(params, *tree_.get());
search_ = std::make_unique<classic::Search>(
*tree_, network_.get(), std::move(responder),
*tree_, backend_.get(), std::move(responder),
StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()),
*move_start_time_, std::move(stopper), params.infinite, params.ponder,
options_, &cache_, syzygy_tb_.get());
options_, syzygy_tb_.get());

LOGFILE << "Timer started at "
<< FormatTime(SteadyClockToSystemClock(*move_start_time_));
Expand Down
5 changes: 2 additions & 3 deletions src/engine_classic.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "engine_loop.h"
#include "neural/cache.h"
#include "neural/factory.h"
#include "neural/network.h"
#include "neural/memcache.h"
#include "search/classic/search.h"
#include "syzygy/syzygy.h"
#include "utils/mutex.h"
Expand Down Expand Up @@ -94,8 +94,7 @@ class EngineClassic : public EngineControllerBase {
std::unique_ptr<classic::Search> search_;
std::unique_ptr<classic::NodeTree> tree_;
std::unique_ptr<SyzygyTablebase> syzygy_tb_;
std::unique_ptr<Network> network_;
NNCache cache_;
std::unique_ptr<CachingBackend> backend_;

// Store current TB and network settings to track when they change so that
// they are reloaded.
Expand Down
16 changes: 10 additions & 6 deletions src/neural/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,22 @@ struct BackendAttributes {
int maximum_batch_size;
};

struct EvalResultPtr {
float* q = nullptr;
float* d = nullptr;
float* m = nullptr;
std::span<float> p = {};
};

struct EvalResult {
float q;
float d;
float m;
std::vector<float> p;
};

struct EvalResultPtr {
float* q = nullptr;
float* d = nullptr;
float* m = nullptr;
std::span<float> p;
EvalResultPtr AsPtr() {
return EvalResultPtr{.q = &q, .d = &d, .m = &m, .p = p};
}
};

struct EvalPosition {
Expand Down
69 changes: 43 additions & 26 deletions src/neural/memcache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "neural/memcache.h"

#include "neural/cache.h"
#include "utils/atomic_vector.h"
#include "utils/smallarray.h"

namespace lczero {
Expand All @@ -54,7 +55,7 @@ void CachedValueToEvalResult(const CachedValue& cv, const EvalResultPtr& ptr) {
std::copy(cv.p.get(), cv.p.get() + ptr.p.size(), ptr.p.begin());
}

class MemCache : public Backend {
class MemCache : public CachingBackend {
public:
MemCache(std::unique_ptr<Backend> wrapped, size_t capacity)
: wrapped_backend_(std::move(wrapped)),
Expand All @@ -67,6 +68,11 @@ class MemCache : public Backend {
std::unique_ptr<BackendComputation> CreateComputation() override;
std::optional<EvalResult> GetCachedEvaluation(const EvalPosition&) override;

void ClearCache() override { cache_.Clear(); }
void SetCacheCapacity(size_t capacity) override {
cache_.SetCapacity(capacity);
}

private:
std::unique_ptr<Backend> wrapped_backend_;
HashKeyedCache<CachedValue> cache_;
Expand All @@ -79,50 +85,56 @@ class MemCacheComputation : public BackendComputation {
MemCacheComputation(std::unique_ptr<BackendComputation> wrapped_computation,
MemCache* memcache)
: wrapped_computation_(std::move(wrapped_computation)),
memcache_(memcache) {
keys_.reserve(memcache_->max_batch_size_);
values_.reserve(memcache_->max_batch_size_);
result_ptrs_.reserve(memcache_->max_batch_size_);
}
memcache_(memcache),
entries_(memcache->max_batch_size_) {}

private:
size_t UsedBatchSize() const override {
return wrapped_computation_->UsedBatchSize();
}
virtual AddInputResult AddInput(const EvalPosition& pos,
EvalResultPtr result) override {
assert(pos.legal_moves.size() == result.p.size() || result.p.empty());
const uint64_t hash = ComputeEvalPositionHash(pos);
{
HashKeyedCacheLock<CachedValue> lock(&memcache_->cache_, hash);
if (lock.holds_value()) {
// Sometimes search queries NN without passing the legal moves. It is
// still cached in this case, but in subsequent queries we only return it
// legal moves are not passed again.
if (lock.holds_value() && (pos.legal_moves.empty() || lock->p)) {
CachedValueToEvalResult(**lock, result);
return AddInputResult::FETCHED_IMMEDIATELY;
}
}
keys_.push_back(hash);
auto value = std::make_unique<CachedValue>();
value->p.reset(new float[result.p.size()]);
result_ptrs_.push_back(result);
size_t entry_idx = entries_.emplace_back(
Entry{hash, std::make_unique<CachedValue>(), result});
auto& value = entries_[entry_idx].value;
value->p.reset(pos.legal_moves.empty() ? nullptr
: new float[pos.legal_moves.size()]);
return wrapped_computation_->AddInput(
pos, EvalResultPtr{&value->q,
&value->d,
&value->m,
{value->p.get(), pos.legal_moves.size()}});
pos, EvalResultPtr{&value->q, &value->d, &value->m,
value->p ? std::span<float>{value->p.get(),
pos.legal_moves.size()}
: std::span<float>{}});
}

virtual void ComputeBlocking() override {
wrapped_computation_->ComputeBlocking();
for (size_t i = 0; i < keys_.size(); ++i) {
CachedValueToEvalResult(*values_[i], result_ptrs_[i]);
memcache_->cache_.Insert(keys_[i], std::move(values_[i]));
for (auto& entry : entries_) {
CachedValueToEvalResult(*entry.value, entry.result_ptr);
memcache_->cache_.Insert(entry.key, std::move(entry.value));
}
}

struct Entry {
uint64_t key;
std::unique_ptr<CachedValue> value;
EvalResultPtr result_ptr;
};

std::unique_ptr<BackendComputation> wrapped_computation_;
std::vector<uint64_t> keys_;
std::vector<std::unique_ptr<CachedValue>> values_;
std::vector<EvalResultPtr> result_ptrs_;
MemCache* memcache_;
AtomicVector<Entry> entries_;
};

std::unique_ptr<BackendComputation> MemCache::CreateComputation() {
Expand All @@ -133,20 +145,25 @@ std::optional<EvalResult> MemCache::GetCachedEvaluation(
const EvalPosition& pos) {
const uint64_t hash = ComputeEvalPositionHash(pos);
HashKeyedCacheLock<CachedValue> lock(&cache_, hash);
if (!lock.holds_value()) return std::nullopt;
if (!lock.holds_value() || (!pos.legal_moves.empty() && !lock->p)) {
return std::nullopt;
}
EvalResult result;
result.d = lock->d;
result.q = lock->q;
result.m = lock->m;
std::copy(lock->p.get(), lock->p.get() + pos.legal_moves.size(),
result.p.begin());
if (lock->p) {
result.p.reserve(pos.legal_moves.size());
std::copy(lock->p.get(), lock->p.get() + pos.legal_moves.size(),
std::back_inserter(result.p));
}
return result;
}

} // namespace

std::unique_ptr<Backend> CreateMemCache(std::unique_ptr<Backend> wrapped,
size_t capacity) {
std::unique_ptr<CachingBackend> CreateMemCache(std::unique_ptr<Backend> wrapped,
size_t capacity) {
return std::make_unique<MemCache>(std::move(wrapped), capacity);
}

Expand Down
12 changes: 10 additions & 2 deletions src/neural/memcache.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@

namespace lczero {

class CachingBackend : public Backend {
public:
// Clears the cache.
virtual void ClearCache() = 0;
// Sets the maximum number of items in the cache.
virtual void SetCacheCapacity(size_t capacity) = 0;
};

// Creates a caching backend wrapper, which returns values immediately if they
// are found, and forwards the request to the wrapped backend otherwise (and
// caches the result).
std::unique_ptr<Backend> CreateMemCache(std::unique_ptr<Backend> parent,
size_t capacity);
std::unique_ptr<CachingBackend> CreateMemCache(std::unique_ptr<Backend> parent,
size_t capacity);

} // namespace lczero
Loading