Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Jul 15, 2022
1 parent 4a0aa5d commit 3cab2a8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 21 deletions.
14 changes: 1 addition & 13 deletions python_bindings/src/PyGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,7 @@ void define_generator(py::module &m) {
.def("machine_params", &GeneratorContext::machine_params)
#else
.def(py::init<const Target &>(), py::arg("target"))
.def(py::init([](const Target &target, const py::dict &autoscheduler_params_dict) -> GeneratorContext {
// Manually convert the dict into AutoSchedulerParams:
// we want to allow Python to pass in dicts that have non-string values for some keys;
// PyBind will reject these as a type failure. We'll stringify them here explicitly.
AutoSchedulerParams asp;
for (auto item : autoscheduler_params_dict) {
const std::string name = py::str(item.first).cast<std::string>();
const std::string value = py::str(item.second).cast<std::string>();
asp[name] = value;
}
return GeneratorContext(target, asp);
}),
py::arg("target"), py::arg("autoscheduler_params"))
.def(py::init<const Target &, const AutoschedulerParams &>(), py::arg("target"), py::arg("autoscheduler_params"))
.def("target", &GeneratorContext::target)
.def("autoscheduler_params", &GeneratorContext::autoscheduler_params)
#endif
Expand Down
10 changes: 6 additions & 4 deletions python_bindings/src/builtin_helpers_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def autoscheduler(self):
return self._autoscheduler

def using_autoscheduler(self):
return len(self._autoscheduler) > 0
return bool(self._autoscheduler.name)

def natural_vector_size(self, type: Type) -> int:
return self.target().natural_vector_size(type)
Expand Down Expand Up @@ -638,11 +638,13 @@ def _set_generatorparam_value(self, name: str, value: Any):
old_value = gp._value
new_value = GeneratorParam._parse_value(name, type(old_value), value)
gp._value = new_value
elif name == "autoscheduler":
_check(not self.autoscheduler().name, "The GeneratorParam %s cannot be set more than once" % name)
self.autoscheduler().name = value
elif name.startswith("autoscheduler."):
sub_key = name[14:]
asp = self.autoscheduler()
_check(not sub_key in asp, "The GeneratorParam %s cannot be set more than once" % name)
asp[sub_key] = value
_check(not sub_key in self.autoscheduler().extra, "The GeneratorParam %s cannot be set more than once" % name)
self.autoscheduler().extra[sub_key] = value
else:
self._unhandled_generator_params[name] = value

Expand Down
6 changes: 3 additions & 3 deletions python_bindings/test/generators/bilateral_grid_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import halide as hl

@hl.alias(
bilateral_grid_Adams2019={'autoscheduler.name':'Adams2019'},
bilateral_grid_Mullapudi2016={'autoscheduler.name':'Mullapudi2016'},
bilateral_grid_Li2018={'autoscheduler.name':'Li2018'},
bilateral_grid_Adams2019={'autoscheduler':'Adams2019'},
bilateral_grid_Mullapudi2016={'autoscheduler':'Mullapudi2016'},
bilateral_grid_Li2018={'autoscheduler':'Li2018'},
)
@hl.generator()
class bilateral_grid:
Expand Down
2 changes: 1 addition & 1 deletion src/autoschedulers/mullapudi2016/AutoSchedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2791,7 +2791,7 @@ void Partitioner::generate_group_cpu_schedule(
}

if (can_prove(def_par < arch_params.parallelism)) {
user_warning << "Insufficient parallelism for " << f_handle.name() << ": " << def_par << " < " << arch_params.parallelism << "\n";
user_warning << "Insufficient parallelism for " << f_handle.name() << "\n";
}

// Find the level at which group members will be computed.
Expand Down

0 comments on commit 3cab2a8

Please sign in to comment.