Skip to content

Commit

Permalink
[api] Promote Internal::Parameter to Halide::Parameter (halide#7829)
Browse files Browse the repository at this point in the history
* Promote Internal::Parameter to Halide::Parameter (to support Serialization API
refactoring)

* Make raw_buffer(), scalar_address(), and scalar_raw_value() methods
protected.

Make Pipeline and Serializer protected friend classes.

* Add Parameter public interface to python bindings.
Remove old stub internal interface from PyParam.

* Remove blank line at start of function

---------

Co-authored-by: Derek Gerstmann <dgerstmann@adobe.com>
Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
3 people authored and ardier committed Mar 3, 2024
1 parent 1d76828 commit 8be6e2f
Show file tree
Hide file tree
Showing 27 changed files with 224 additions and 109 deletions.
1 change: 1 addition & 0 deletions python_bindings/src/halide/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ set(native_sources
PyLoopLevel.cpp
PyModule.cpp
PyParam.cpp
PyParameter.cpp
PyPipeline.cpp
PyRDom.cpp
PyStage.cpp
Expand Down
2 changes: 1 addition & 1 deletion python_bindings/src/halide/halide_/PyGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ using Halide::Internal::AbstractGenerator;
using Halide::Internal::AbstractGeneratorPtr;
using Halide::Internal::GeneratorFactoryProvider;
using ArgInfo = Halide::Internal::AbstractGenerator::ArgInfo;
using Halide::Parameter;
using Halide::Internal::ArgInfoDirection;
using Halide::Internal::ArgInfoKind;
using Halide::Internal::Parameter;

template<typename T>
std::map<std::string, T> dict_to_map(const py::dict &dict) {
Expand Down
2 changes: 2 additions & 0 deletions python_bindings/src/halide/halide_/PyHalide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "PyLambda.h"
#include "PyModule.h"
#include "PyParam.h"
#include "PyParameter.h"
#include "PyPipeline.h"
#include "PyRDom.h"
#include "PyTarget.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) {
define_lambda(m);
define_operators(m);
define_param(m);
define_parameter(m);
define_image_param(m);
define_type(m);
define_derivative(m);
Expand Down
25 changes: 0 additions & 25 deletions python_bindings/src/halide/halide_/PyParam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
namespace Halide {
namespace PythonBindings {

using Halide::Internal::Parameter;

namespace {

template<typename TYPE>
Expand Down Expand Up @@ -40,29 +38,6 @@ void add_param_methods(py::class_<Param<>> &param_class) {
} // namespace

void define_param(py::module &m) {
// This is a "just-enough" wrapper around Parameter to let us pass it back
// and forth between Py and C++. It deliberately exposes very few methods,
// and we should keep it that way.
auto parameter_class =
py::class_<Parameter>(m, "InternalParameter")
.def(py::init<const Parameter &>(), py::arg("p"))
.def("defined", &Parameter::defined)
.def("type", &Parameter::type)
.def("dimensions", &Parameter::dimensions)
.def("_to_argument", [](const Parameter &p) -> Argument {
return Argument(p.name(),
p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
p.type(),
p.dimensions(),
p.get_argument_estimates());
})
.def("__repr__", [](const Parameter &p) -> std::string {
std::ostringstream o;
// Don't leak any info but the name into the repr string.
o << "<halide.InternalParameter '" << p.name() << "'>";
return o.str();
});

auto param_class =
py::class_<Param<>>(m, "Param")
.def(py::init<Type>(), py::arg("type"))
Expand Down
106 changes: 106 additions & 0 deletions python_bindings/src/halide/halide_/PyParameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include "PyParameter.h"

#include "PyType.h"

namespace Halide {
namespace PythonBindings {

namespace {

template<typename TYPE>
void add_scalar_methods(py::class_<Parameter> &parameter_class) {
parameter_class
.def("scalar", &Parameter::scalar<TYPE>)
.def(
"set_scalar", [](Parameter &parameter, TYPE value) -> void {
parameter.set_scalar<TYPE>(value);
},
py::arg("value"));
}

} // namespace

void define_parameter(py::module &m) {
// Disambiguate some ambigious methods
void (Parameter::*set_scalar_method)(const Type &t, halide_scalar_value_t val) = &Parameter::set_scalar;

auto parameter_class =
py::class_<Parameter>(m, "Parameter")
.def(py::init<>())
.def(py::init<const Parameter &>(), py::arg("p"))
.def(py::init<const Type &, bool, int>())
.def(py::init<const Type &, bool, int, const std::string &>())
.def(py::init<const Type &, bool, int, const std::string &,
const Buffer<void> &, int, const std::vector<BufferConstraint> &,
MemoryType>())
.def(py::init<const Type &, bool, int, const std::string &,
uint64_t, const Expr &, const Expr &, const Expr &, const Expr &>())
.def("_to_argument", [](const Parameter &p) -> Argument {
return Argument(p.name(),
p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
p.type(),
p.dimensions(),
p.get_argument_estimates());
})
.def("__repr__", [](const Parameter &p) -> std::string {
std::ostringstream o;
o << "<halide.Parameter '" << p.name() << "'";
if (!p.defined()) {
o << " (undefined)";
} else {
// TODO: add dimensions to this
o << " type " << halide_type_to_string(p.type());
}
o << ">";
return o.str();
})
.def("type", &Parameter::type)
.def("dimensions", &Parameter::dimensions)
.def("name", &Parameter::name)
.def("is_buffer", &Parameter::is_buffer)
.def("scalar_expr", &Parameter::scalar_expr)
.def("set_scalar", set_scalar_method, py::arg("value_type"), py::arg("value"))
.def("buffer", &Parameter::buffer)
.def("set_buffer", &Parameter::set_buffer, py::arg("buffer"))
.def("same_as", &Parameter::same_as, py::arg("other"))
.def("defined", &Parameter::defined)
.def("set_min_constraint", &Parameter::set_min_constraint, py::arg("dim"), py::arg("expr"))
.def("set_extent_constraint", &Parameter::set_extent_constraint, py::arg("dim"), py::arg("expr"))
.def("set_stride_constraint", &Parameter::set_stride_constraint, py::arg("dim"), py::arg("expr"))
.def("set_min_constraint_estimate", &Parameter::set_min_constraint_estimate, py::arg("dim"), py::arg("expr"))
.def("set_extent_constraint_estimate", &Parameter::set_extent_constraint_estimate, py::arg("dim"), py::arg("expr"))
.def("set_host_alignment", &Parameter::set_host_alignment, py::arg("bytes"))
.def("min_constraint", &Parameter::min_constraint, py::arg("dim"))
.def("extent_constraint", &Parameter::extent_constraint, py::arg("dim"))
.def("stride_constraint", &Parameter::stride_constraint, py::arg("dim"))
.def("min_constraint_estimate", &Parameter::min_constraint_estimate, py::arg("dim"))
.def("extent_constraint_estimate", &Parameter::extent_constraint_estimate, py::arg("dim"))
.def("host_alignment", &Parameter::host_alignment)
.def("buffer_constraints", &Parameter::buffer_constraints)
.def("set_min_value", &Parameter::set_min_value, py::arg("expr"))
.def("min_value", &Parameter::min_value)
.def("set_max_value", &Parameter::set_max_value, py::arg("expr"))
.def("max_value", &Parameter::max_value)
.def("set_estimate", &Parameter::set_estimate, py::arg("expr"))
.def("estimate", &Parameter::estimate)
.def("set_default_value", &Parameter::set_default_value, py::arg("expr"))
.def("default_value", &Parameter::default_value)
.def("get_argument_estimates", &Parameter::get_argument_estimates)
.def("store_in", &Parameter::store_in, py::arg("memory_type"))
.def("memory_type", &Parameter::memory_type);

add_scalar_methods<bool>(parameter_class);
add_scalar_methods<uint8_t>(parameter_class);
add_scalar_methods<uint16_t>(parameter_class);
add_scalar_methods<uint32_t>(parameter_class);
add_scalar_methods<uint64_t>(parameter_class);
add_scalar_methods<int8_t>(parameter_class);
add_scalar_methods<int16_t>(parameter_class);
add_scalar_methods<int32_t>(parameter_class);
add_scalar_methods<int64_t>(parameter_class);
add_scalar_methods<float>(parameter_class);
add_scalar_methods<double>(parameter_class);
}

} // namespace PythonBindings
} // namespace Halide
14 changes: 14 additions & 0 deletions python_bindings/src/halide/halide_/PyParameter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef HALIDE_PYTHON_BINDINGS_PYPARAMETER_H
#define HALIDE_PYTHON_BINDINGS_PYPARAMETER_H

#include "PyHalide.h"

namespace Halide {
namespace PythonBindings {

void define_parameter(py::module &m);

} // namespace PythonBindings
} // namespace Halide

#endif // HALIDE_PYTHON_BINDINGS_PYPARAMETER_H
1 change: 0 additions & 1 deletion python_bindings/stub/PyStubImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ namespace py = pybind11;
namespace Halide {
namespace PythonBindings {

using Parameter = Internal::Parameter;
using ArgInfoKind = Internal::ArgInfoKind;
using ArgInfo = Internal::AbstractGenerator::ArgInfo;
using GeneratorFactory = Internal::GeneratorFactory;
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Internal {

namespace {

Argument to_argument(const Internal::Parameter &param) {
Argument to_argument(const Parameter &param) {
return Argument(param.name(),
param.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
param.type(),
Expand Down
8 changes: 4 additions & 4 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,12 +1365,12 @@ Pipeline Deserializer::deserialize(std::istream &in) {

} // namespace Internal

Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Internal::Parameter> &external_params) {
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
return deserializer.deserialize(filename);
}

Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Internal::Parameter> &external_params) {
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
return deserializer.deserialize(in);
}
Expand All @@ -1381,12 +1381,12 @@ Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Inte

namespace Halide {

Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Internal::Parameter> &external_params) {
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}

Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Internal::Parameter> &external_params) {
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}
Expand Down
4 changes: 2 additions & 2 deletions src/Deserialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace Halide {
* external_params is an optional map, all parameters in the map
* will be treated as external parameters so won't be deserialized.
*/
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Internal::Parameter> &external_params);
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params);

Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Internal::Parameter> &external_params);
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params);

} // namespace Halide

Expand Down
2 changes: 1 addition & 1 deletion src/Dimension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace Halide {
namespace Internal {

Dimension::Dimension(const Internal::Parameter &p, int d, Func f)
Dimension::Dimension(const Parameter &p, int d, Func f)
: param(p), d(d), f(std::move(f)) {
user_assert(param.defined())
<< "Can't access the dimensions of an undefined Parameter\n";
Expand Down
4 changes: 2 additions & 2 deletions src/Dimension.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ class Dimension {
friend class ::Halide::OutputImageParam;

/** Construct a Dimension representing dimension d of some
* Internal::Parameter p. Only friends may construct
* Parameter p. Only friends may construct
* these. */
Dimension(const Internal::Parameter &p, int d, Func f);
Dimension(const Parameter &p, int d, Func f);

Parameter param;
int d;
Expand Down
4 changes: 2 additions & 2 deletions src/ExternFuncArgument.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct ExternFuncArgument {
Internal::FunctionPtr func;
Buffer<> buffer;
Expr expr;
Internal::Parameter image_param;
Parameter image_param;

ExternFuncArgument(Internal::FunctionPtr f)
: arg_type(FuncArg), func(std::move(f)) {
Expand All @@ -44,7 +44,7 @@ struct ExternFuncArgument {
: arg_type(ExprArg), expr(e) {
}

ExternFuncArgument(const Internal::Parameter &p)
ExternFuncArgument(const Parameter &p)
: arg_type(ImageParamArg), image_param(p) {
// Scalar params come in via the Expr constructor.
internal_assert(p.is_buffer());
Expand Down
4 changes: 2 additions & 2 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,7 @@ Stage &Stage::prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from
return *this;
}

Stage &Stage::prefetch(const Internal::Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) {
Stage &Stage::prefetch(const Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) {
definition.schedule().touched() = true;
PrefetchDirective prefetch = {param.name(), at.name(), from.name(), std::move(offset), strategy, param};
definition.schedule().prefetches().push_back(prefetch);
Expand Down Expand Up @@ -2629,7 +2629,7 @@ Func &Func::prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from,
return *this;
}

Func &Func::prefetch(const Internal::Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) {
Func &Func::prefetch(const Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) {
invalidate_cache();
Stage(func, func.definition(), 0).prefetch(param, at, from, std::move(offset), strategy);
return *this;
Expand Down
4 changes: 2 additions & 2 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class Stage {

Stage &prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
Stage &prefetch(const Internal::Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
Stage &prefetch(const Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
template<typename T>
Stage &prefetch(const T &image, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
Expand Down Expand Up @@ -1982,7 +1982,7 @@ class Func {
// @{
Func &prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
Func &prefetch(const Internal::Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
Func &prefetch(const Parameter &param, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
template<typename T>
Func &prefetch(const T &image, const VarOrRVar &at, const VarOrRVar &from, Expr offset = 1,
Expand Down
3 changes: 1 addition & 2 deletions src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
namespace Halide {

struct ExternFuncArgument;
class Parameter;
class Tuple;

class Var;

/** An enum to specify calling convention for extern stages. */
Expand All @@ -31,7 +31,6 @@ enum class NameMangling {
namespace Internal {

struct Call;
class Parameter;

/** A reference-counted handle to Halide's internal representation of
* a function. Similar to a front-end Func object, but with no
Expand Down
2 changes: 1 addition & 1 deletion src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1645,7 +1645,7 @@ bool GeneratorBase::emit_hlpipe(const std::string &hlpipe_file_path) {
debug(1) << "Applying autoscheduler " << asp.name << " to Generator " << name() << " ...\n";
auto_schedule_results = pipeline.apply_autoscheduler(context.target(), asp);
}
std::map<std::string, Internal::Parameter> params; // FIXME: Remove when API allows this to be optional
std::map<std::string, Parameter> params; // FIXME: Remove when API allows this to be optional
serialize_pipeline(pipeline, hlpipe_file_path, params);
return true;
#else
Expand Down
4 changes: 2 additions & 2 deletions src/ImageParam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Halide {

ImageParam::ImageParam(Type t, int d)
: OutputImageParam(
Internal::Parameter(t, true, d, Internal::make_entity_name(this, "Halide:.*:ImageParam", 'p')),
Parameter(t, true, d, Internal::make_entity_name(this, "Halide:.*:ImageParam", 'p')),
Argument::InputBuffer,
Func()) {
// We must call create_func() after the super-ctor has completed.
Expand All @@ -15,7 +15,7 @@ ImageParam::ImageParam(Type t, int d)

ImageParam::ImageParam(Type t, int d, const std::string &n)
: OutputImageParam(
Internal::Parameter(t, true, d, n),
Parameter(t, true, d, n),
Argument::InputBuffer,
Func()) {
// We must call create_func() after the super-ctor has completed.
Expand Down
2 changes: 1 addition & 1 deletion src/ImageParam.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ImageParam : public OutputImageParam {
friend class ::Halide::Internal::GeneratorInput_Buffer;

// Only for use of Generator
ImageParam(const Internal::Parameter &p, Func f)
ImageParam(const Parameter &p, Func f)
: OutputImageParam(p, Argument::InputBuffer, std::move(f)) {
}

Expand Down
6 changes: 3 additions & 3 deletions src/Memoization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class FindParameterDependencies : public IRGraphVisitor {
for (const auto &extern_arg : extern_args) {
if (extern_arg.is_buffer()) {
// Function with an extern definition
record(Halide::Internal::Parameter(extern_arg.buffer.type(), true,
extern_arg.buffer.dimensions(),
extern_arg.buffer.name()));
record(Halide::Parameter(extern_arg.buffer.type(), true,
extern_arg.buffer.dimensions(),
extern_arg.buffer.name()));
} else if (extern_arg.is_image_param()) {
record(extern_arg.image_param);
}
Expand Down
Loading

0 comments on commit 8be6e2f

Please sign in to comment.