Skip to content

Commit

Permalink
Merge pull request #292 from Bam4d/mypy_support
Browse files Browse the repository at this point in the history
Mypy support
  • Loading branch information
Bam4d authored Oct 18, 2023
2 parents 2ea37c2 + 2600b01 commit 1a5c3a5
Show file tree
Hide file tree
Showing 67 changed files with 2,221 additions and 3,046 deletions.
111 changes: 53 additions & 58 deletions bindings/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#include <pybind11/stl.h>
#include <spdlog/spdlog.h>

#include "wrapper/GriddlyLoaderWrapper.cpp"
#include "wrapper/GDYWrapper.cpp"
#include "wrapper/GDY.cpp"
#include "wrapper/GDYLoader.cpp"
#include "wrapper/NumpyWrapper.cpp"

namespace py = pybind11;
Expand All @@ -22,98 +22,93 @@ PYBIND11_MODULE(python_griddly, m) {

spdlog::debug("Python Griddly module loaded!");

py::class_<Py_GriddlyLoaderWrapper, std::shared_ptr<Py_GriddlyLoaderWrapper>> gdy_reader(m, "GDYReader");
py::class_<Py_GDYLoader, std::shared_ptr<Py_GDYLoader>> gdy_reader(m, "GDYLoader");
gdy_reader.def(py::init<std::string, std::string, std::string>());
gdy_reader.def("load", &Py_GriddlyLoaderWrapper::loadGDYFile);
gdy_reader.def("load_string", &Py_GriddlyLoaderWrapper::loadGDYString);

py::class_<Py_GDYWrapper, std::shared_ptr<Py_GDYWrapper>> gdy(m, "GDY");
gdy.def("set_max_steps", &Py_GDYWrapper::setMaxSteps);
gdy.def("get_player_count", &Py_GDYWrapper::getPlayerCount);
gdy.def("get_action_names", &Py_GDYWrapper::getExternalActionNames);
gdy.def("get_action_input_mappings", &Py_GDYWrapper::getActionInputMappings);
gdy.def("get_avatar_object", &Py_GDYWrapper::getAvatarObject);
gdy.def("create_game", &Py_GDYWrapper::createGame);
gdy.def("get_level_count", &Py_GDYWrapper::getLevelCount);
gdy.def("get_observer_type", &Py_GDYWrapper::getObserverType);


py::class_<Py_GameWrapper, std::shared_ptr<Py_GameWrapper>> game_process(m, "GameProcess");

gdy_reader.def("load", &Py_GDYLoader::loadGDYFile);
gdy_reader.def("load_string", &Py_GDYLoader::loadGDYString);

py::class_<Py_GDY, std::shared_ptr<Py_GDY>> gdy(m, "GDY");
gdy.def("set_max_steps", &Py_GDY::setMaxSteps);
gdy.def("get_player_count", &Py_GDY::getPlayerCount);
gdy.def("get_action_names", &Py_GDY::getExternalActionNames);
gdy.def("get_action_input_mappings", &Py_GDY::getActionInputMappings);
gdy.def("get_avatar_object", &Py_GDY::getAvatarObject);
gdy.def("create_game", &Py_GDY::createGame);
gdy.def("get_level_count", &Py_GDY::getLevelCount);
gdy.def("get_observer_type", &Py_GDY::getObserverType);

py::class_<Py_GameProcess, std::shared_ptr<Py_GameProcess>> game_process(m, "GameProcess");

// Register a player to the game
game_process.def("register_player", &Py_GameWrapper::registerPlayer);
game_process.def("register_player", &Py_GameProcess::registerPlayer);

// Initialize the game or reset the game state
game_process.def("init", &Py_GameWrapper::init);
game_process.def("reset", &Py_GameWrapper::reset);
game_process.def("init", &Py_GameProcess::init);
game_process.def("reset", &Py_GameProcess::reset);

// Generic step function for multiple players and multiple actions per step
game_process.def("step_parallel", &Py_GameWrapper::stepParallel);
game_process.def("step_parallel", &Py_GameProcess::stepParallel);

// Set the current map of the game (should be followed by reset or init)
game_process.def("load_level", &Py_GameWrapper::loadLevel);
game_process.def("load_level_string", &Py_GameWrapper::loadLevelString);
game_process.def("load_level", &Py_GameProcess::loadLevel);
game_process.def("load_level_string", &Py_GameProcess::loadLevelString);

// Get available actions for objects in the current game
game_process.def("get_available_actions", &Py_GameWrapper::getAvailableActionNames);
game_process.def("get_available_action_ids", &Py_GameWrapper::getAvailableActionIds);
game_process.def("build_valid_action_trees", &Py_GameWrapper::buildValidActionTrees);
game_process.def("get_available_actions", &Py_GameProcess::getAvailableActionNames);
game_process.def("get_available_action_ids", &Py_GameProcess::getAvailableActionIds);
game_process.def("build_valid_action_trees", &Py_GameProcess::buildValidActionTrees);

// Width and height of the game grid
game_process.def("get_width", &Py_GameWrapper::getWidth);
game_process.def("get_height", &Py_GameWrapper::getHeight);
// Width and height of the game grid
game_process.def("get_width", &Py_GameProcess::getWidth);
game_process.def("get_height", &Py_GameProcess::getHeight);

// Tile Size (only used in some observers)
game_process.def("get_tile_size", &Py_GameWrapper::getTileSize);
game_process.def("get_tile_size", &Py_GameProcess::getTileSize);

// Observation shapes
game_process.def("get_global_observation_description", &Py_GameWrapper::getGlobalObservationDescription);
game_process.def("get_global_observation_description", &Py_GameProcess::getGlobalObservationDescription);

// Tile size of the global observer
game_process.def("observe", &Py_GameWrapper::observe);
// Enable the history collection mode
game_process.def("enable_history", &Py_GameWrapper::enableHistory);
game_process.def("observe", &Py_GameProcess::observe);

// Enable the history collection mode
game_process.def("enable_history", &Py_GameProcess::enableHistory);

// Create a copy of the game in its current state
game_process.def("clone", &Py_GameWrapper::clone);
game_process.def("clone", &Py_GameProcess::clone);

// Get a dictionary containing the objects in the environment and their variable values
game_process.def("get_state", &Py_GameWrapper::getState);
game_process.def("get_state", &Py_GameProcess::getState);

// Load the state from a state object
game_process.def("load_state", &Py_GameWrapper::loadState);
game_process.def("load_state", &Py_GameProcess::loadState);

// Get a specific variable value
game_process.def("get_global_variable", &Py_GameWrapper::getGlobalVariables);
game_process.def("get_global_variable", &Py_GameProcess::getGlobalVariables);

// Get list of possible object names, ordered by ID
game_process.def("get_object_names", &Py_GameWrapper::getObjectNames);
game_process.def("get_object_names", &Py_GameProcess::getObjectNames);

// Get list of possible variable names, ordered by ID
game_process.def("get_object_variable_names", &Py_GameWrapper::getObjectVariableNames);
game_process.def("get_object_variable_names", &Py_GameProcess::getObjectVariableNames);

// Get a mapping of objects to their variable names
game_process.def("get_object_variable_map", &Py_GameWrapper::getObjectVariableMap);
game_process.def("get_object_variable_map", &Py_GameProcess::getObjectVariableMap);

// Get a list of the global variable names
game_process.def("get_global_variable_names", &Py_GameWrapper::getGlobalVariableNames);
game_process.def("get_global_variable_names", &Py_GameProcess::getGlobalVariableNames);

// Get a list of the events that have happened in the game up to this point
game_process.def("get_history", &Py_GameWrapper::getHistory, py::arg("purge")=true);

// Release resources for vulkan stuff
game_process.def("release", &Py_GameWrapper::release);

game_process.def("seed", &Py_GameWrapper::seedRandomGenerator);
game_process.def("get_history", &Py_GameProcess::getHistory, py::arg("purge") = true);

// Release resources for vulkan stuff
game_process.def("release", &Py_GameProcess::release);

py::class_<Py_StepPlayerWrapper, std::shared_ptr<Py_StepPlayerWrapper>> player(m, "Player");
player.def("step", &Py_StepPlayerWrapper::stepSingle);
player.def("step_multi", &Py_StepPlayerWrapper::stepMulti);
player.def("observe", &Py_StepPlayerWrapper::observe);
player.def("get_observation_description", &Py_StepPlayerWrapper::getObservationDescription);
game_process.def("seed", &Py_GameProcess::seedRandomGenerator);

py::class_<Py_Player, std::shared_ptr<Py_Player>> player(m, "Player");
player.def("observe", &Py_Player::observe);
player.def("get_observation_description", &Py_Player::getObservationDescription);

py::enum_<ObserverType> observer_type(m, "ObserverType");
observer_type.value("SPRITE_2D", ObserverType::SPRITE_2D);
Expand Down
12 changes: 6 additions & 6 deletions bindings/wrapper/GDYWrapper.cpp → bindings/wrapper/GDY.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
#include "../../src/Griddly/Core/GDY/GDYFactory.hpp"
#include "../../src/Griddly/Core/Grid.hpp"
#include "../../src/Griddly/Core/TurnBasedGameProcess.hpp"
#include "GameWrapper.cpp"
#include "StepPlayerWrapper.cpp"
#include "GameProcess.cpp"
#include "Player.cpp"

namespace griddly {

class Py_GDYWrapper {
class Py_GDY {
public:
Py_GDYWrapper(std::shared_ptr<GDYFactory> gdyFactory)
Py_GDY(std::shared_ptr<GDYFactory> gdyFactory)
: gdyFactory_(gdyFactory) {
}

Expand Down Expand Up @@ -83,8 +83,8 @@ class Py_GDYWrapper {
return py_actionInputsDefinitions;
}

std::shared_ptr<Py_GameWrapper> createGame(std::string globalObserverName) {
return std::make_shared<Py_GameWrapper>(Py_GameWrapper(globalObserverName, gdyFactory_));
std::shared_ptr<Py_GameProcess> createGame(std::string globalObserverName) {
return std::make_shared<Py_GameProcess>(Py_GameProcess(globalObserverName, gdyFactory_));
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@
#include "../../src/Griddly/Core/GDY/Objects/ObjectGenerator.hpp"
#include "../../src/Griddly/Core/GDY/TerminationGenerator.hpp"
#include "../../src/Griddly/Core/Grid.hpp"
#include "GDYWrapper.cpp"
#include "GDY.cpp"

namespace griddly {


class Py_GriddlyLoaderWrapper {
class Py_GDYLoader {
public:
Py_GriddlyLoaderWrapper(std::string gdyPath, std::string imagePath, std::string shaderPath)
Py_GDYLoader(std::string gdyPath, std::string imagePath, std::string shaderPath)
: resourceConfig_({gdyPath, imagePath, shaderPath}) {
}

std::shared_ptr<Py_GDYWrapper> loadGDYFile(std::string filename) {
std::shared_ptr<Py_GDY> loadGDYFile(std::string filename) {
auto objectGenerator = std::make_shared<ObjectGenerator>(ObjectGenerator());
auto terminationGenerator = std::make_shared<TerminationGenerator>(TerminationGenerator());
auto gdyFactory = std::make_shared<GDYFactory>(GDYFactory(objectGenerator, terminationGenerator, resourceConfig_));
gdyFactory->initializeFromFile(filename);
return std::make_shared<Py_GDYWrapper>(Py_GDYWrapper(gdyFactory));
return std::make_shared<Py_GDY>(Py_GDY(gdyFactory));
}

std::shared_ptr<Py_GDYWrapper> loadGDYString(std::string string) {
std::shared_ptr<Py_GDY> loadGDYString(std::string string) {
auto objectGenerator = std::make_shared<ObjectGenerator>(ObjectGenerator());
auto terminationGenerator = std::make_shared<TerminationGenerator>(TerminationGenerator());
auto gdyFactory = std::make_shared<GDYFactory>(GDYFactory(objectGenerator, terminationGenerator, resourceConfig_));
std::istringstream s(string);
gdyFactory->parseFromStream(s);
return std::make_shared<Py_GDYWrapper>(Py_GDYWrapper(gdyFactory));
return std::make_shared<Py_GDY>(Py_GDY(gdyFactory));
}

private:
Expand Down
Loading

0 comments on commit 1a5c3a5

Please sign in to comment.