Skip to content

Commit

Permalink
way too many states, needs revising...
Browse files Browse the repository at this point in the history
  • Loading branch information
r3w0p committed Jan 4, 2025
1 parent 88ee9d9 commit 7a24ee3
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 45 deletions.
71 changes: 48 additions & 23 deletions src/caravan/core/training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,27 @@ void populate_action_space(ActionSpace *as) {
}
}
}

// Total: 240 + 48 + 6 + 5 = 299
}

void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, TrainConfig &tc, std::mt19937 &gen) {
GameState gs;
GameCommand command;
uint32_t num_moves = 0;

std::string action;
uint16_t action_index;
GameCommand command;
float action_value;

std::uniform_int_distribution<uint16_t> dist_action(0, SIZE_ACTION_SPACE - 1);
GameState last_gs_abc;
std::string last_action_abc;

GameState last_gs_def;
std::string last_action_def;

//std::uniform_int_distribution<uint16_t> dist_action(0, SIZE_ACTION_SPACE - 1);
std::uniform_real_distribution<float> dist_explore(0, 1);
bool explore = dist_explore(gen) < tc.explore;

// Play until winner
while (game->get_winner() == NO_PLAYER) {
Expand All @@ -152,18 +162,22 @@ void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, Train
action_pool.push_back(action_space[i]);
}

// Determine whether to explore for next move
bool explore = dist_explore(gen) < tc.explore;

// Find a valid action
while (true) {
if (explore) {
// If exploring, fetch a random action from the action pool
std::uniform_int_distribution<uint16_t> dist_pool(0, action_pool.size() - 1);
action_index = dist_pool(gen);
action = action_pool[action_index];
action_value = q_table[gs][action];

} else {
// Otherwise, pick the optimal action from the q-table
action_index = 0;
float action_value = q_table[gs][action_pool[action_index]];
action_value = q_table[gs][action_pool[action_index]];

for (uint16_t i_action = 1; i_action < action_pool.size(); i_action++) {
// Change pick if next action has greater value
Expand All @@ -188,33 +202,44 @@ void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, Train
action_pool.erase(action_pool.begin() + action_index);
}

printf("[%s] %s (%hu, %llu)\n",
pturn == PLAYER_ABC ? "ABC" : "DEF",
action.c_str(),
action_index,
action_pool.size());
if (action_value > 0)
printf("[%s] %s (i=%hu, v=%.2f)\n", pturn == PLAYER_ABC ? "ABC" : "DEF", action.c_str(), action_index, action_value);

// Perform action
// (Exceptions intentionally not handled)
game->play_option(&command);

/*
// Measure reward (1 = win, -1 = loss, 0 = neither)
uint16_t reward;
// Update q-table
if (num_moves >= 2) {
GameState last_gs = pturn == PLAYER_ABC ? last_gs_abc : last_gs_def;
std::string last_action = pturn == PLAYER_ABC ? last_action_abc : last_action_def;

if (game->get_winner() != NO_PLAYER) {
if (game->get_winner() == pturn) {
q_table[gs][action] = 1;
} else {
q_table[gs][action] = -1;
}
}

q_table[last_gs][last_action] = q_table[last_gs][last_action] + tc.learning * (tc.discount * q_table[gs][action] - q_table[last_gs][last_action]);
/*
if (game->get_winner() != NO_PLAYER) {
//printf("%f\n", q_table[gs][action]);
printf("%f\n", q_table[last_gs][last_action]);
}
*/
}

if (game->get_winner() == pturn) {
reward = 1;
} else if (game->get_winner() == popp) {
reward = -1;
// Log last move
if (pturn == PLAYER_ABC) {
last_gs_abc = gs;
last_action_abc = action;
} else {
reward = 0;
last_gs_def = gs;
last_action_def = action;
}

// TODO update q_table
// float q_value_former = q_table[gs][action];
// GameState gs_new;
// get_game_state(&gs_new, game, pturn);
// if a winner: +1 for winning player, -1 for losing player
*/
num_moves += 1;
}
}
11 changes: 3 additions & 8 deletions src/caravan/model/game.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,9 @@ int8_t Game::compare_bids(CaravanName cvname1, CaravanName cvname2) {
CaravanName Game::winning_bid(CaravanName cvname1, CaravanName cvname2) {
int8_t bidcomp = compare_bids(cvname1, cvname2);

if (bidcomp < 0) {
return cvname1;
} else if (bidcomp > 0) {
return cvname2;
} else {
return NO_CARAVAN;
}
if (bidcomp < 0) return cvname1;
if (bidcomp > 0) return cvname2;
return NO_CARAVAN;
}

bool Game::has_sold(CaravanName cvname) {
Expand Down Expand Up @@ -313,7 +309,6 @@ bool Game::option_play(GameCommand *command, bool check_only) {
c_hand = player_turn->get_from_hand_at(command->pos_hand);
} catch (CaravanGameException &e) {
if (check_only) return false;

throw;
}

Expand Down
28 changes: 14 additions & 14 deletions src/caravan/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

int main(int argc, char *argv[]) {
// Game and config
std::unique_ptr<Game> game;
std::unique_ptr<Game> game = nullptr;
GameConfig gc;
TrainConfig tc;
uint8_t rand_first;
Expand All @@ -33,8 +33,8 @@ int main(int argc, char *argv[]) {

// Training parameters TODO user-defined arguments
float discount = 0.95;
float learning = 0.75;
uint32_t episode_max = 10;
float learning = 0.7;
uint32_t episode_max = 1000000;

// Game config uses largest deck with most samples and balance to
// maximise chance of encountering every player hand combination.
Expand All @@ -54,7 +54,10 @@ int main(int argc, char *argv[]) {
};

for(; tc.episode <= tc.episode_max; tc.episode++) {
printf("Episode %d\n", tc.episode);
if (tc.episode % 100 == 0) {
printf("Episode %d\n", tc.episode);
printf("- states: %llu\n", q_table.size());
}

// Random first player
rand_first = dist_first_player(gen);
Expand All @@ -64,26 +67,23 @@ int main(int argc, char *argv[]) {
// Set training parameters
tc.discount = discount;

// TODO tc.explore =
// static_cast<float>(tc.episode_max - (tc.episode - 1)) /
// static_cast<float>(tc.episode_max);
tc.explore = 1.0;
tc.explore =
static_cast<float>(tc.episode_max - (tc.episode - 1)) /
static_cast<float>(tc.episode_max);
//tc.explore = 1.0;

tc.learning = learning;

// Start a new game
game = std::make_unique<Game>(&gc);
game.reset(new Game(&gc));

// Train on game until completion
train_on_game(game.get(), q_table, action_space, tc, gen);

printf("Winner: %s\n", game->get_winner() == PLAYER_ABC ? "ABC" : "DEF");

// Close game
game.reset();
//printf("Winner: %s\n", game->get_winner() == PLAYER_ABC ? "ABC" : "DEF");

//std::this_thread::sleep_for(std::chrono::milliseconds(1000));
printf("\n");
//printf("\n");
}

} catch (CaravanException &e) {
Expand Down

0 comments on commit 7a24ee3

Please sign in to comment.