Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime OpenCL device selection, allow seed=0 and fix allow-undefined user_header bug #954

Merged
merged 11 commits into from
Dec 8, 2020
2 changes: 1 addition & 1 deletion make/program
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ CXXFLAGS_PROGRAM += -include-pch $(STAN)src/stan/model/model_header$(STAN_FLAGS)
$(STAN_TARGETS) examples/bernoulli/bernoulli$(EXE) $(patsubst %.stan,%$(EXE),$(wildcard src/test/test-models/*.stan)) : %$(EXE) : $(STAN)src/stan/model/model_header$(STAN_FLAGS).hpp.gch
endif

ifneq ($(findstring allow_undefined,$(STANCFLAGS)),)
ifneq ($(findstring allow_undefined,$(STANCFLAGS))$(findstring allow-undefined,$(STANCFLAGS)),)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor fix for #953

$(STAN_TARGETS) examples/bernoulli/bernoulli$(EXE) $(patsubst %.stan,%$(EXE),$(wildcard src/test/test-models/*.stan)) : CXXFLAGS_PROGRAM += -include $(USER_HEADER)
endif

Expand Down
22 changes: 22 additions & 0 deletions src/cmdstan/arguments/arg_opencl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef CMDSTAN_ARGUMENTS_ARG_OPENCL_HPP
#define CMDSTAN_ARGUMENTS_ARG_OPENCL_HPP

#include <cmdstan/arguments/arg_opencl_device.hpp>
#include <cmdstan/arguments/arg_opencl_platform.hpp>
#include <cmdstan/arguments/categorical_argument.hpp>

namespace cmdstan {

class arg_opencl : public categorical_argument {
public:
arg_opencl() {
_name = "opencl";
_description = "OpenCL options";

_subarguments.push_back(new arg_opencl_device());
_subarguments.push_back(new arg_opencl_platform());
}
};

} // namespace cmdstan
#endif
26 changes: 26 additions & 0 deletions src/cmdstan/arguments/arg_opencl_device.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef CMDSTAN_ARGUMENTS_ARG_OPENCL_DEVICE_HPP
#define CMDSTAN_ARGUMENTS_ARG_OPENCL_DEVICE_HPP

#include <cmdstan/arguments/singleton_argument.hpp>

namespace cmdstan {

class arg_opencl_device : public int_argument {
public:
arg_opencl_device() : int_argument() {
_name = "device";
_description = "ID of the OpenCL device to use";
_validity = "device >= 0 or -1 to use the compile-time device ID";
_default = "-1";
_default_value = -1;
_constrained = true;
_good_value = 1;
_bad_value = -1.0;
_value = _default_value;
}

bool is_valid(int value) { return value >= 0 || value == _default_value; }
};

} // namespace cmdstan
#endif
26 changes: 26 additions & 0 deletions src/cmdstan/arguments/arg_opencl_platform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef CMDSTAN_ARGUMENTS_ARG_OPENCL_PLATFORM_HPP
#define CMDSTAN_ARGUMENTS_ARG_OPENCL_PLATFORM_HPP

#include <cmdstan/arguments/singleton_argument.hpp>

namespace cmdstan {

class arg_opencl_platform : public int_argument {
public:
arg_opencl_platform() : int_argument() {
_name = "platform";
_description = "ID of the OpenCL platform to use";
_validity = "platform >= 0 or -1 to use the compile-time platform ID";
_default = "-1";
_default_value = -1;
_constrained = true;
_good_value = 1;
_bad_value = -1.0;
_value = _default_value;
}

bool is_valid(int value) { return value >= 0 || value == _default_value; }
};

} // namespace cmdstan
#endif
4 changes: 2 additions & 2 deletions src/cmdstan/arguments/arg_seed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class arg_seed : public int_argument {
arg_seed() : int_argument() {
_name = "seed";
_description = "Random number generator seed";
_validity = "integer > 0 or -1 to generate seed from system time";
_validity = "integer >= 0 or -1 to generate seed from system time";
_default = "-1";
_default_value = -1;
_constrained = true;
Expand All @@ -25,7 +25,7 @@ class arg_seed : public int_argument {
.total_milliseconds();
}

bool is_valid(int value) { return value > 0 || value == _default_value; }
bool is_valid(int value) { return value >= 0 || value == _default_value; }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file is a fix for #941


unsigned int random_value() {
if (_value == _default_value) {
Expand Down
7 changes: 6 additions & 1 deletion src/cmdstan/arguments/argument_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ class argument_parser {

if (!good_arg) {
err(cat_name + " is either mistyped or misplaced.");

#ifndef STAN_OPENCL
if (cat_name == "opencl") {
err("Re-compile the model with STAN_OPENCL to use OpenCL CmdStan "
"arguments.");
}
#endif
std::vector<std::string> valid_paths;

for (size_t i = 0; i < _arguments.size(); ++i) {
Expand Down
24 changes: 24 additions & 0 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cmdstan/arguments/arg_init.hpp>
#include <cmdstan/arguments/arg_output.hpp>
#include <cmdstan/arguments/arg_random.hpp>
#include <cmdstan/arguments/arg_opencl.hpp>
#include <cmdstan/arguments/argument_parser.hpp>
#include <cmdstan/io/json/json_data.hpp>
#include <cmdstan/write_model_compile_info.hpp>
Expand Down Expand Up @@ -116,6 +117,9 @@ int command(int argc, const char *argv[]) {
valid_arguments.push_back(new arg_init());
valid_arguments.push_back(new arg_random());
valid_arguments.push_back(new arg_output());
#ifdef STAN_OPENCL
valid_arguments.push_back(new arg_opencl());
#endif
argument_parser parser(valid_arguments);
int err_code = parser.parse_args(argc, argv, info, err);
if (err_code != 0) {
Expand All @@ -129,6 +133,26 @@ int command(int argc, const char *argv[]) {
= dynamic_cast<arg_seed *>(parser.arg("random")->arg("seed"));
unsigned int random_seed = random_arg->random_value();

#ifdef STAN_OPENCL
int_argument *opencl_device_id
= dynamic_cast<int_argument *>(parser.arg("opencl")->arg("device"));
int_argument *opencl_platform_id
= dynamic_cast<int_argument *>(parser.arg("opencl")->arg("platform"));

// Either both device and platform are set or neither in which case we default
// to compile-time constants
if ((opencl_device_id->is_default() && !opencl_platform_id->is_default())
|| (!opencl_device_id->is_default()
&& opencl_platform_id->is_default())) {
std::cerr << "Please set both device and platform OpenCL IDs." << std::endl;
return err_code;
} else if (!opencl_device_id->is_default()
&& !opencl_platform_id->is_default()) {
stan::math::opencl_context.select_device(opencl_platform_id->value(),
opencl_device_id->value());
}
#endif

parser.print(info);
write_parallel_info(info);
write_opencl_device(info);
Expand Down
4 changes: 2 additions & 2 deletions src/cmdstan/write_opencl_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ void write_opencl_device(stan::callbacks::writer &writer) {
&& (stan::math::opencl_context.device().size() > 0)) {
std::stringstream msg_opencl_platform;
msg_opencl_platform
<< "opencl_platform = "
<< "opencl_platform_name = "
<< stan::math::opencl_context.platform()[0].getInfo<CL_PLATFORM_NAME>();
writer(msg_opencl_platform.str());
std::stringstream msg_opencl_device;
msg_opencl_device
<< "opencl_device = "
<< "opencl_device_name = "
<< stan::math::opencl_context.device()[0].getInfo<CL_DEVICE_NAME>();
writer(msg_opencl_device.str());
}
Expand Down
21 changes: 1 addition & 20 deletions src/test/interface/command_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,25 +477,6 @@ TEST(StanUiCommand, random_seed_fail_1) {
model_path.push_back("test-models");
model_path.push_back("transformed_data_rng_test");

std::string command
= convert_model_path(model_path)
+ " sample num_samples=10 num_warmup=10 init=0 " + " random seed=0 "
+ " data file=src/test/test-models/transformed_data_rng_test.init.R"
+ " output refresh=0 file=test/output.csv";
std::string cmd_output = run_command(command).output;
run_command_output out = run_command(command);
EXPECT_EQ(1, count_matches(expected_message, out.body));
}

TEST(StanUiCommand, random_seed_fail_2) {
std::string expected_message = "is not a valid value for \"seed\"";

std::vector<std::string> model_path;
model_path.push_back("src");
model_path.push_back("test");
model_path.push_back("test-models");
model_path.push_back("transformed_data_rng_test");

std::string command
= convert_model_path(model_path)
+ " sample num_samples=10 num_warmup=10 init=0 " + " random seed=-2 "
Expand All @@ -506,7 +487,7 @@ TEST(StanUiCommand, random_seed_fail_2) {
EXPECT_EQ(1, count_matches(expected_message, out.body));
}

TEST(StanUiCommand, random_seed_fail_3) {
TEST(StanUiCommand, random_seed_fail_2) {
std::string expected_message = "is not a valid value for \"seed\"";

std::vector<std::string> model_path;
Expand Down