diff --git a/make/program b/make/program index ae1cd8a3c2..852c60a3ef 100644 --- a/make/program +++ b/make/program @@ -29,7 +29,7 @@ CXXFLAGS_PROGRAM += -include-pch $(STAN)src/stan/model/model_header$(STAN_FLAGS) $(STAN_TARGETS) examples/bernoulli/bernoulli$(EXE) $(patsubst %.stan,%$(EXE),$(wildcard src/test/test-models/*.stan)) : %$(EXE) : $(STAN)src/stan/model/model_header$(STAN_FLAGS).hpp.gch endif -ifneq ($(findstring allow_undefined,$(STANCFLAGS)),) +ifneq ($(findstring allow_undefined,$(STANCFLAGS))$(findstring allow-undefined,$(STANCFLAGS)),) $(STAN_TARGETS) examples/bernoulli/bernoulli$(EXE) $(patsubst %.stan,%$(EXE),$(wildcard src/test/test-models/*.stan)) : CXXFLAGS_PROGRAM += -include $(USER_HEADER) endif diff --git a/src/cmdstan/arguments/arg_opencl.hpp b/src/cmdstan/arguments/arg_opencl.hpp new file mode 100644 index 0000000000..df64f100e4 --- /dev/null +++ b/src/cmdstan/arguments/arg_opencl.hpp @@ -0,0 +1,22 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_OPENCL_HPP +#define CMDSTAN_ARGUMENTS_ARG_OPENCL_HPP + +#include +#include +#include + +namespace cmdstan { + +class arg_opencl : public categorical_argument { + public: + arg_opencl() { + _name = "opencl"; + _description = "OpenCL options"; + + _subarguments.push_back(new arg_opencl_device()); + _subarguments.push_back(new arg_opencl_platform()); + } +}; + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/arguments/arg_opencl_device.hpp b/src/cmdstan/arguments/arg_opencl_device.hpp new file mode 100644 index 0000000000..14094e00d3 --- /dev/null +++ b/src/cmdstan/arguments/arg_opencl_device.hpp @@ -0,0 +1,26 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_OPENCL_DEVICE_HPP +#define CMDSTAN_ARGUMENTS_ARG_OPENCL_DEVICE_HPP + +#include + +namespace cmdstan { + +class arg_opencl_device : public int_argument { + public: + arg_opencl_device() : int_argument() { + _name = "device"; + _description = "ID of the OpenCL device to use"; + _validity = "device >= 0 or -1 to use the compile-time device ID"; + _default = "-1"; + _default_value = -1; + _constrained = true; + _good_value = 1; + _bad_value = -1.0; + _value = _default_value; + } + + bool is_valid(int value) { return value >= 0 || value == _default_value; } +}; + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/arguments/arg_opencl_platform.hpp b/src/cmdstan/arguments/arg_opencl_platform.hpp new file mode 100644 index 0000000000..8de4c86049 --- /dev/null +++ b/src/cmdstan/arguments/arg_opencl_platform.hpp @@ -0,0 +1,26 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_OPENCL_PLATFORM_HPP +#define CMDSTAN_ARGUMENTS_ARG_OPENCL_PLATFORM_HPP + +#include + +namespace cmdstan { + +class arg_opencl_platform : public int_argument { + public: + arg_opencl_platform() : int_argument() { + _name = "platform"; + _description = "ID of the OpenCL platform to use"; + _validity = "platform >= 0 or -1 to use the compile-time platform ID"; + _default = "-1"; + _default_value = -1; + _constrained = true; + _good_value = 1; + _bad_value = -1.0; + _value = _default_value; + } + + bool is_valid(int value) { return value >= 0 || value == _default_value; } +}; + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/arguments/arg_seed.hpp b/src/cmdstan/arguments/arg_seed.hpp index 541e04c372..72164ef8c7 100644 --- a/src/cmdstan/arguments/arg_seed.hpp +++ b/src/cmdstan/arguments/arg_seed.hpp @@ -12,7 +12,7 @@ class arg_seed : public int_argument { arg_seed() : int_argument() { _name = "seed"; _description = "Random number generator seed"; - _validity = "integer > 0 or -1 to generate seed from system time"; + _validity = "integer >= 0 or -1 to generate seed from system time"; _default = "-1"; _default_value = -1; _constrained = true; @@ -25,7 +25,7 @@ class arg_seed : public int_argument { .total_milliseconds(); } - bool is_valid(int value) { return value > 0 || value == _default_value; } + bool is_valid(int value) { return value >= 0 || value == _default_value; } unsigned int random_value() { if (_value == _default_value) { diff --git a/src/cmdstan/arguments/argument_parser.hpp b/src/cmdstan/arguments/argument_parser.hpp index 36d64e8ef8..3a85d6f304 100644 --- a/src/cmdstan/arguments/argument_parser.hpp +++ b/src/cmdstan/arguments/argument_parser.hpp @@ -95,7 +95,12 @@ class argument_parser { if (!good_arg) { err(cat_name + " is either mistyped or misplaced."); - +#ifndef STAN_OPENCL + if (cat_name == "opencl") { + err("Re-compile the model with STAN_OPENCL to use OpenCL CmdStan " + "arguments."); + } +#endif std::vector valid_paths; for (size_t i = 0; i < _arguments.size(); ++i) { diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index bd527ca953..99bee95c4b 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -116,6 +117,9 @@ int command(int argc, const char *argv[]) { valid_arguments.push_back(new arg_init()); valid_arguments.push_back(new arg_random()); valid_arguments.push_back(new arg_output()); +#ifdef STAN_OPENCL + valid_arguments.push_back(new arg_opencl()); +#endif argument_parser parser(valid_arguments); int err_code = parser.parse_args(argc, argv, info, err); if (err_code != 0) { @@ -129,6 +133,26 @@ int command(int argc, const char *argv[]) { = dynamic_cast(parser.arg("random")->arg("seed")); unsigned int random_seed = random_arg->random_value(); +#ifdef STAN_OPENCL + int_argument *opencl_device_id + = dynamic_cast(parser.arg("opencl")->arg("device")); + int_argument *opencl_platform_id + = dynamic_cast(parser.arg("opencl")->arg("platform")); + + // Either both device and platform are set or neither in which case we default + // to compile-time constants + if ((opencl_device_id->is_default() && !opencl_platform_id->is_default()) + || (!opencl_device_id->is_default() + && opencl_platform_id->is_default())) { + std::cerr << "Please set both device and platform OpenCL IDs." << std::endl; + return err_code; + } else if (!opencl_device_id->is_default() + && !opencl_platform_id->is_default()) { + stan::math::opencl_context.select_device(opencl_platform_id->value(), + opencl_device_id->value()); + } +#endif + parser.print(info); write_parallel_info(info); write_opencl_device(info); diff --git a/src/cmdstan/write_opencl_device.hpp b/src/cmdstan/write_opencl_device.hpp index 58b35cbc3c..89cb744574 100644 --- a/src/cmdstan/write_opencl_device.hpp +++ b/src/cmdstan/write_opencl_device.hpp @@ -15,12 +15,12 @@ void write_opencl_device(stan::callbacks::writer &writer) { && (stan::math::opencl_context.device().size() > 0)) { std::stringstream msg_opencl_platform; msg_opencl_platform - << "opencl_platform = " + << "opencl_platform_name = " << stan::math::opencl_context.platform()[0].getInfo(); writer(msg_opencl_platform.str()); std::stringstream msg_opencl_device; msg_opencl_device - << "opencl_device = " + << "opencl_device_name = " << stan::math::opencl_context.device()[0].getInfo(); writer(msg_opencl_device.str()); } diff --git a/src/test/interface/command_test.cpp b/src/test/interface/command_test.cpp index 25e57c0444..7336de5413 100644 --- a/src/test/interface/command_test.cpp +++ b/src/test/interface/command_test.cpp @@ -477,25 +477,6 @@ TEST(StanUiCommand, random_seed_fail_1) { model_path.push_back("test-models"); model_path.push_back("transformed_data_rng_test"); - std::string command - = convert_model_path(model_path) - + " sample num_samples=10 num_warmup=10 init=0 " + " random seed=0 " - + " data file=src/test/test-models/transformed_data_rng_test.init.R" - + " output refresh=0 file=test/output.csv"; - std::string cmd_output = run_command(command).output; - run_command_output out = run_command(command); - EXPECT_EQ(1, count_matches(expected_message, out.body)); -} - -TEST(StanUiCommand, random_seed_fail_2) { - std::string expected_message = "is not a valid value for \"seed\""; - - std::vector model_path; - model_path.push_back("src"); - model_path.push_back("test"); - model_path.push_back("test-models"); - model_path.push_back("transformed_data_rng_test"); - std::string command = convert_model_path(model_path) + " sample num_samples=10 num_warmup=10 init=0 " + " random seed=-2 " @@ -506,7 +487,7 @@ TEST(StanUiCommand, random_seed_fail_2) { EXPECT_EQ(1, count_matches(expected_message, out.body)); } -TEST(StanUiCommand, random_seed_fail_3) { +TEST(StanUiCommand, random_seed_fail_2) { std::string expected_message = "is not a valid value for \"seed\""; std::vector model_path;