Skip to content

Commit

Permalink
slight modularization
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 committed Apr 22, 2024
1 parent 0bd9475 commit 35fec7b
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/csrc/include/internal/rl/off_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RLOffPolicySystem {
RLOffPolicySystem(const RLOffPolicySystem&) = delete;

// empty constructor:
RLOffPolicySystem() : train_step_count_(0) {}
RLOffPolicySystem(int model_device, int rb_device);

// some important functions which have to be implemented by the base class
virtual void updateReplayBuffer(torch::Tensor, torch::Tensor, torch::Tensor, float, bool) = 0;
Expand All @@ -80,6 +80,7 @@ class RLOffPolicySystem {
virtual std::shared_ptr<ModelState> getSystemState_() = 0;
virtual std::shared_ptr<Comm> getSystemComm_() = 0;
size_t train_step_count_;
torch::Device model_device_, rb_device_;
};

// Declaration of external global variables
Expand Down
5 changes: 3 additions & 2 deletions src/csrc/include/internal/rl/on_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class RLOnPolicySystem {
// disable copy constructor
RLOnPolicySystem(const RLOnPolicySystem&) = delete;

// empty constructor:
RLOnPolicySystem() : train_step_count_(0) {}
// default constructor:
RLOnPolicySystem(int model_device, int rb_device);

// some important functions which have to be implemented by the base class
virtual void updateRolloutBuffer(torch::Tensor, torch::Tensor, float, float, float, bool) = 0;
Expand All @@ -81,6 +81,7 @@ class RLOnPolicySystem {
virtual std::shared_ptr<ModelState> getSystemState_() = 0;
virtual std::shared_ptr<Comm> getSystemComm_() = 0;
size_t train_step_count_;
torch::Device model_device_, rb_device_;
};

// Declaration of external global variables
Expand Down
4 changes: 0 additions & 4 deletions src/csrc/include/internal/rl/on_policy/ppo.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,6 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_this<R

std::shared_ptr<Comm> getSystemComm_();

// device
torch::Device model_device_;
torch::Device rb_device_;

// models
ACPolicyPack p_model_;
ModelPack q_model_;
Expand Down
1 change: 1 addition & 0 deletions src/csrc/rl/off_policy/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ RLOffPolicySystem::RLOffPolicySystem(int model_device, int rb_device) : train_st

} // namespace off_policy

} // namespace rl
} // namespace torchfort

torchfort_result_t torchfort_rl_off_policy_create_system(const char* name, const char* config_fname,
Expand Down
8 changes: 8 additions & 0 deletions src/csrc/rl/on_policy/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ namespace rl {
namespace on_policy {
// Global variables
std::unordered_map<std::string, std::shared_ptr<RLOnPolicySystem>> registry;

// default constructor:
RLOnPolicySystem::RLOnPolicySystem(int model_device, int rb_device) : train_step_count_(0), model_device_(get_device(model_device)), rb_device_(get_device(rb_device)) {
if ( !(torchfort::rl::validate_devices(model_device, rb_device)) ) {
THROW_INVALID_USAGE("The parameters model_device and rb_device have to specify the same GPU or one has to specify a GPU and the other the CPU.");
}
}

} // namespace on_policy

} // namespace rl
Expand Down
2 changes: 1 addition & 1 deletion src/csrc/rl/on_policy/ppo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace on_policy {

PPOSystem::PPOSystem(const char* name, const YAML::Node& system_node,
int model_device, int rb_device)
: model_device_(get_device(model_device)), rb_device_(get_device(rb_device)) {
: RLOnPolicySystem(model_device, rb_device) {

// get basic parameters first
auto algo_node = system_node["algorithm"];
Expand Down

0 comments on commit 35fec7b

Please sign in to comment.