Skip to content

Commit

Permalink
Merge pull request #3311 from stan-dev/bugfix/3301-stan-csv-reader
Browse files Browse the repository at this point in the history
Bugfix/3301 stan csv reader
  • Loading branch information
mitzimorris authored Oct 3, 2024
2 parents 6773989 + 24cfc7e commit 51dbaa4
Show file tree
Hide file tree
Showing 7 changed files with 4,392 additions and 72 deletions.
115 changes: 58 additions & 57 deletions src/stan/io/stan_csv_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ inline void prettify_stan_csv_name(std::string& variable) {
}
}

// FIXME: should consolidate with the options from
// the command line in stan::lang
struct stan_csv_metadata {
int stan_version_major;
int stan_version_minor;
Expand All @@ -47,6 +45,7 @@ struct stan_csv_metadata {
bool save_warmup;
size_t thin;
bool append_samples;
std::string method;
std::string algorithm;
std::string engine;
int max_depth;
Expand All @@ -64,8 +63,9 @@ struct stan_csv_metadata {
num_samples(0),
num_warmup(0),
save_warmup(false),
thin(0),
thin(1),
append_samples(false),
method(""),
algorithm(""),
engine(""),
max_depth(10) {}
Expand Down Expand Up @@ -101,13 +101,12 @@ class stan_csv_reader {
stan_csv_reader() {}
~stan_csv_reader() {}

static bool read_metadata(std::istream& in, stan_csv_metadata& metadata,
std::ostream* out) {
static void read_metadata(std::istream& in, stan_csv_metadata& metadata) {
std::stringstream ss;
std::string line;

if (in.peek() != '#')
return false;
return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << '\n';
Expand Down Expand Up @@ -161,9 +160,15 @@ class stan_csv_reader {
metadata.model = value;
} else if (name.compare("num_samples") == 0) {
std::stringstream(value) >> metadata.num_samples;
} else if (name.compare("output_samples") == 0) { // ADVI config name
std::stringstream(value) >> metadata.num_samples;
} else if (name.compare("num_warmup") == 0) {
std::stringstream(value) >> metadata.num_warmup;
} else if (name.compare("save_warmup") == 0) {
// cmdstan args can be "true" and "false", was "1", "0"
if (value.compare("true") == 0) {
value = "1";
}
std::stringstream(value) >> metadata.save_warmup;
} else if (name.compare("thin") == 0) {
std::stringstream(value) >> metadata.thin;
Expand All @@ -177,6 +182,8 @@ class stan_csv_reader {
metadata.random_seed = false;
} else if (name.compare("append_samples") == 0) {
std::stringstream(value) >> metadata.append_samples;
} else if (name.compare("method") == 0) {
metadata.method = value;
} else if (name.compare("algorithm") == 0) {
metadata.algorithm = value;
} else if (name.compare("engine") == 0) {
Expand All @@ -185,14 +192,10 @@ class stan_csv_reader {
std::stringstream(value) >> metadata.max_depth;
}
}
if (ss.good() == true)
return false;

return true;
} // read_metadata

static bool read_header(std::istream& in, std::vector<std::string>& header,
std::ostream* out, bool prettify_name = true) {
bool prettify_name = true) {
std::string line;

if (!std::isalpha(in.peek()))
Expand All @@ -216,81 +219,71 @@ class stan_csv_reader {
return true;
}

static bool read_adaptation(std::istream& in, stan_csv_adaptation& adaptation,
std::ostream* out) {
static void read_adaptation(std::istream& in,
stan_csv_adaptation& adaptation) {
std::stringstream ss;
std::string line;
int lines = 0;

if (in.peek() != '#' || in.good() == false)
return false;

return;
while (in.peek() == '#') {
std::getline(in, line);
ss << line << std::endl;
lines++;
}
ss.seekg(std::ios_base::beg);
if (lines < 2)
return;

if (lines < 4)
return false;

char comment; // Buffer for comment indicator, #
std::getline(ss, line); // comment adaptation terminated

// Skip first two lines
std::getline(ss, line);

// Stepsize
std::getline(ss, line, '=');
// parse stepsize
std::getline(ss, line, '='); // stepsize
boost::trim(line);
ss >> adaptation.step_size;
if (lines == 2) // ADVI reports stepsize, no metric
return;

// Metric parameters
std::getline(ss, line);
std::getline(ss, line);
std::getline(ss, line);
std::getline(ss, line); // consume end of stepsize line
std::getline(ss, line); // comment elements of mass matrix
std::getline(ss, line); // diagonal metric or row 1 of dense metric

int rows = lines - 3;
int cols = std::count(line.begin(), line.end(), ',') + 1;
adaptation.metric.resize(rows, cols);
char comment; // Buffer for comment indicator, #

// parse metric, row by row, element by element
for (int row = 0; row < rows; row++) {
std::stringstream line_ss;
line_ss.str(line);
line_ss >> comment;

for (int col = 0; col < cols; col++) {
std::string token;
std::getline(line_ss, token, ',');
boost::trim(token);
std::stringstream(token) >> adaptation.metric(row, col);
}
std::getline(ss, line); // Read in next line
std::getline(ss, line);
}

if (ss.good())
return false;
else
return true;
}

static bool read_samples(std::istream& in, Eigen::MatrixXd& samples,
stan_csv_timing& timing, std::ostream* out) {
stan_csv_timing& timing) {
std::stringstream ss;
std::string line;

int rows = 0;
int cols = -1;

if (in.peek() == '#' || in.good() == false)
return false;
return false; // need at least one data row

while (in.good()) {
bool comment_line = (in.peek() == '#');
bool empty_line = (in.peek() == '\n');

std::getline(in, line);

if (empty_line)
continue;
if (!line.length())
Expand All @@ -316,11 +309,10 @@ class stan_csv_reader {
if (cols == -1) {
cols = current_cols;
} else if (cols != current_cols) {
if (out)
*out << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1
<< std::endl;
return false;
std::stringstream msg;
msg << "Error: expected " << cols << " columns, but found "
<< current_cols << " instead for row " << rows + 1;
throw std::invalid_argument(msg.str());
}
rows++;
}
Expand Down Expand Up @@ -348,36 +340,45 @@ class stan_csv_reader {
/**
* Parses the file.
*
* Throws exception if contents can't be parsed into header + data rows.
*
* Emits warning message
*
* @param[in] in input stream to parse
* @param[out] out output stream to send messages
*/
static stan_csv parse(std::istream& in, std::ostream* out) {
stan_csv data;
std::string line;

if (!read_metadata(in, data.metadata, out)) {
if (out)
*out << "Warning: non-fatal error reading metadata" << std::endl;
read_metadata(in, data.metadata);
if (!read_header(in, data.header)) {
throw std::invalid_argument("Error: no column names found in csv file");
}

if (!read_header(in, data.header, out)) {
if (out)
*out << "Error: error reading header" << std::endl;
throw std::invalid_argument("Error with header of input file in parse");
// skip warmup draws, if any
if (data.metadata.algorithm != "fixed_param" && data.metadata.num_warmup > 0
&& data.metadata.save_warmup) {
while (in.peek() != '#') {
std::getline(in, line);
}
}

if (!read_adaptation(in, data.adaptation, out)) {
if (out)
*out << "Warning: non-fatal error reading adaptation data" << std::endl;
if (data.metadata.algorithm != "fixed_param") {
read_adaptation(in, data.adaptation);
}

data.timing.warmup = 0;
data.timing.sampling = 0;

if (!read_samples(in, data.samples, data.timing, out)) {
if (out)
*out << "Warning: non-fatal error reading samples" << std::endl;
if (data.metadata.method == "variational") {
std::getline(in, line); // discard variational estimate
}

if (!read_samples(in, data.samples, data.timing)) {
if (out)
*out << "Unable to parse sample" << std::endl;
}
return data;
}
};
Expand Down
Loading

0 comments on commit 51dbaa4

Please sign in to comment.