forked from QueensGambit/CrazyAra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstateobj.cpp
70 lines (59 loc) · 2.64 KB
/
stateobj.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/*
CrazyAra, a deep learning chess variant engine
Copyright (C) 2018 Johannes Czech, Moritz Willig, Alena Beyer
Copyright (C) 2019-2020 Johannes Czech
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
/*
* @file: stateobj.cpp
* Created on 17.07.2020
* @author: queensgambit
*/
#include "stateobj.h"
#include "constants.h"
void get_probs_of_move_list(const size_t batchIdx, const float* policyProb, const std::vector<Action>& legalMoves, bool mirrorPolicy, bool normalize, DynamicVector<double> &policyProbSmall, bool selectPolicyFromPlane)
{
size_t vectorIdx;
for (size_t mvIdx = 0; mvIdx < legalMoves.size(); ++mvIdx) {
if (mirrorPolicy) {
// find the according index in the vector
vectorIdx = StateConstants::action_to_index<normal,mirrored>(legalMoves[mvIdx]);
} else {
// use the non-mirrored look-up table instead
vectorIdx = StateConstants::action_to_index<normal,notMirrored>(legalMoves[mvIdx]);
}
// set the right prob value
// accessing the data on the raw floating point vector is faster
// than calling policyProb.At(batchIdx, vectorIdx)
if (selectPolicyFromPlane) {
assert(vectorIdx < StateConstants::NB_LABELS_POLICY_MAP());
policyProbSmall[mvIdx] = policyProb[batchIdx*StateConstants::NB_LABELS_POLICY_MAP()+vectorIdx];
} else {
assert(vectorIdx < StateConstants::NB_LABELS());
policyProbSmall[mvIdx] = policyProb[batchIdx*StateConstants::NB_LABELS()+vectorIdx];
}
}
if (normalize) {
policyProbSmall = softmax(policyProbSmall);
}
}
const float* get_policy_data_batch(const size_t batchIdx, const float* probOutputs, bool isPolicyMap)
{
if (isPolicyMap) {
return probOutputs + batchIdx*StateConstants::NB_LABELS_POLICY_MAP();
}
return probOutputs + batchIdx*StateConstants::NB_LABELS();
}
const float* get_auxiliary_data_batch(const size_t batchIdx, const float* auxiliaryOutputs)
{
return auxiliaryOutputs + batchIdx*StateConstants::NB_AUXILIARY_OUTPUTS();
}