Skip to content

Commit

Permalink
Added config management
Browse files Browse the repository at this point in the history
  • Loading branch information
mike-gimelfarb committed Oct 30, 2024
1 parent 1dc19c2 commit 54aa649
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 15 deletions.
72 changes: 71 additions & 1 deletion pyRDDLGym_gurobi/core/planner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from ast import literal_eval
import configparser
import os
import sys
from typing import Any, Dict, List, Tuple, Optional

Kwargs = Dict[str, Any]

import gurobipy
from gurobipy import GRB

Expand All @@ -11,6 +17,70 @@

UNBOUNDED = (-GRB.INFINITY, +GRB.INFINITY)

# ***********************************************************************
# CONFIG FILE MANAGEMENT
#
# - read config files from file path
# - extract experiment settings
# - instantiate planner
#
# ***********************************************************************


def _parse_config_file(path: str):
if not os.path.isfile(path):
raise FileNotFoundError(f'File {path} does not exist.')
config = configparser.RawConfigParser()
config.optionxform = str
config.read(path)
args = {k: literal_eval(v)
for section in config.sections()
for (k, v) in config.items(section)}
return config, args


def _parse_config_string(value: str):
config = configparser.RawConfigParser()
config.optionxform = str
config.read_string(value)
args = {k: literal_eval(v)
for section in config.sections()
for (k, v) in config.items(section)}
return config, args


def _getattr_any(packages, item):
for package in packages:
loaded = getattr(package, item, None)
if loaded is not None:
return loaded
return None


def _load_config(config, args):
gurobi_args = {k: args[k] for (k, _) in config.items('Gurobi')}
compiler_args = {k: args[k] for (k, _) in config.items('Optimizer')}

# policy class
plan_method = compiler_args.pop('method')
plan_kwargs = compiler_args.pop('method_kwargs', {})
compiler_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
compiler_args['model_params'] = gurobi_args

return compiler_args


def load_config(path: str) -> Kwargs:
'''Loads a config file at the specified file path.'''
config, args = _parse_config_file(path)
return _load_config(config, args)


def load_config_from_string(value: str) -> Kwargs:
'''Loads config file contents specified explicitly as a string value.'''
config, args = _parse_config_string(value)
return _load_config(config, args)


# ***********************************************************************
# ALL VERSIONS OF GUROBI PLANS
Expand Down Expand Up @@ -237,7 +307,7 @@ def params(self, compiled: GurobiRDDLCompiler,
lb_name = f'lb__{action}__{k}'
ub_name = f'ub__{action}__{k}'
if values is None:
lb, ub = self.state_bounds[states[0]]
lb, ub = self.state_bounds.get(states[0], UNBOUNDED)
var_bounds = UNBOUNDED if is_linear else (lb - 1, ub + 1)
lb_var = compiled._add_var(model, vtype, *var_bounds)
ub_var = compiled._add_var(model, vtype, *var_bounds)
Expand Down
9 changes: 9 additions & 0 deletions pyRDDLGym_gurobi/examples/default.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[Gurobi]
NonConvex=2
OutputFlag=0

[Optimizer]
method='GurobiPiecewisePolicy'
method_kwargs={}
rollout_horizon=5
verbose=1
28 changes: 14 additions & 14 deletions pyRDDLGym_gurobi/examples/run_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,35 @@
<instance> is the instance number
<horizon> is a positive integer representing the lookahead horizon
'''
import os
import sys

import pyRDDLGym
from pyRDDLGym_gurobi.core.planner import (
GurobiStraightLinePlan, GurobiOnlineController
)
from pyRDDLGym_gurobi.core.planner import GurobiOnlineController, load_config


def main(domain, instance, horizon):
def main(domain, instance):

# create the environment
env = pyRDDLGym.make(domain, instance, enforce_action_constraints=True)

# create the controller
controller = GurobiOnlineController(rddl=env.model,
plan=GurobiStraightLinePlan(),
rollout_horizon=horizon,
model_params={'NonConvex': 2, 'OutputFlag': 1})
# load the config
abs_path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(abs_path, 'default.cfg')
controller_kwargs = load_config(config_path)

# create the controller
controller = GurobiOnlineController(rddl=env.model, **controller_kwargs)
controller.evaluate(env, verbose=True, render=True)

env.close()


if __name__ == "__main__":
args = sys.argv[1:]
if len(args) < 3:
print('python run_plan.py <domain> <instance> <horizon>')
if len(args) < 2:
print('python run_plan.py <domain> <instance>')
exit(1)
domain, instance, horizon = args[:3]
horizon = int(horizon)
main(domain, instance, horizon)
domain, instance = args[:2]
main(domain, instance)

0 comments on commit 54aa649

Please sign in to comment.