Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge GenDB and SchemaHelper; use GenDB in pclean binary #212

Merged
merged 13 commits into from
Oct 2, 2024
Merged
2 changes: 1 addition & 1 deletion cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ cc_library(
":irm",
":observations",
"//distributions:crp",
"//pclean:get_joint_relations",
"//pclean:io",
"//pclean:schema",
"//pclean:schema_helper",
],
)

Expand Down
6 changes: 5 additions & 1 deletion cxx/clean_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <cstdlib>
#include <random>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -159,7 +160,10 @@ class CleanRelation : public Relation<T> {
}

std::vector<int> get_cluster_assignment(const T_items& items) const {
assert(items.size() == domains.size());
if (items.size() != domains.size()) {
printf("Warning: for relation %s, items.size=%ld and domains.size()=%ld\n", name.c_str(), items.size(), domains.size());
std::exit(1);
}
std::vector<int> z(domains.size());
for (int i = 0; i < std::ssize(domains); ++i) {
z[i] = domains[i]->get_cluster_assignment(items[i]);
Expand Down
4 changes: 3 additions & 1 deletion cxx/distributions/stringcat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// See LICENSE.txt

#include <algorithm>
#include <cstdlib>
#include <cassert>
#include <limits>
#include "distributions/stringcat.hh"
Expand All @@ -10,7 +11,8 @@
int StringCat::string_to_index(const std::string& s) const {
auto it = std::find(strings.begin(), strings.end(), s);
if (it == strings.end()) {
assert(false);
printf("String %s not in StringCat's list of strings\n", s.c_str());
std::exit(1);
}
return it - strings.begin();
}
Expand Down
238 changes: 229 additions & 9 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
#include "hirm.hh"
#include "irm.hh"
#include "observations.hh"
#include "pclean/get_joint_relations.hh"
#include "pclean/schema.hh"
#include "pclean/schema_helper.hh"

GenDB::GenDB(std::mt19937* prng, const PCleanSchema& schema_,
bool _only_final_emissions, bool _record_class_is_clean)
: schema(schema_),
schema_helper(schema_, _only_final_emissions, _record_class_is_clean) {
std::map<std::string, std::vector<std::string>>
annotated_domains_for_relation;
T_schema hirm_schema =
schema_helper.make_hirm_schema(&annotated_domains_for_relation);
: schema(schema_), only_final_emissions(_only_final_emissions),
record_class_is_clean(_record_class_is_clean) {
// Note that the domains cache must be populated before the reference
// indices.
compute_domains_cache();
compute_reference_indices_cache();

T_schema hirm_schema = make_hirm_schema();
hirm = new HIRM(hirm_schema, prng);

for (const auto& [class_name, unused_class] : schema.classes) {
Expand Down Expand Up @@ -171,6 +173,7 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
const std::string& class_name,
int class_item, bool new_rows_have_unique_entities) {
T_items items;
assert(schema.classes.contains(class_name));
PCleanClass c = schema.classes.at(class_name);

for (const auto& [name, var] : c.vars) {
Expand Down Expand Up @@ -203,7 +206,7 @@ void GenDB::get_relation_items(const std::string& rel_name, const int ind,
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name));
items[ind] = class_item;
auto& ref_indices = schema_helper.relation_reference_indices;
auto& ref_indices = relation_reference_indices;
if (ref_indices.contains(rel_name)) {
if (ref_indices.at(rel_name).contains(ind)) {
for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) {
Expand Down Expand Up @@ -238,7 +241,7 @@ GenDB::unincorporate_reference(const std::string& class_name,
std::vector<size_t> domain_inds;
for (size_t i = 0; i < domains.size(); ++i) {
if (domains[i] == class_name &&
schema_helper.relation_reference_indices.at(rel_name).at(i).contains(
relation_reference_indices.at(rel_name).at(i).contains(
ref_field)) {
domain_inds.push_back(i);
}
Expand Down Expand Up @@ -392,3 +395,220 @@ void GenDB::incorporate_reference_relation(
}

GenDB::~GenDB() { delete hirm; }

void GenDB::compute_domains_cache() {
for (const auto& c : schema.classes) {
if (!domains.contains(c.first)) {
compute_domains_for(c.first);
}
}
}

void GenDB::compute_reference_indices_cache() {
for (const auto& c : schema.classes) {
if (!class_reference_indices.contains(c.first)) {
compute_reference_indices_for(c.first);
}
}
}

void GenDB::compute_domains_for(const std::string& name) {
std::vector<std::string> ds;
assert(schema.classes.contains(name));
PCleanClass c = schema.classes.at(name);

for (const auto& v : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.second.spec))) {
if (!domains.contains(cv->class_name)) {
compute_domains_for(cv->class_name);
}
for (const std::string& s : domains[cv->class_name]) {
ds.push_back(s);
}
}
}

// Put the "primary" domain last, so that it survives reordering.
ds.push_back(name);

domains[name] = ds;
}

void GenDB::compute_reference_indices_for(
const std::string& name) {
std::vector<std::string> ds;
int total_offset = 0;
assert(schema.classes.contains(name));
PCleanClass c = schema.classes.at(name);

// Recursively maps the indices of class "name" (and ancestors) in relation
// items to the names and indices (in items) of their parents (reference
// fields).
std::map<int, std::map<std::string, int>> ref_indices;

// Temporarily stores reference fields and indices for class "name";
std::map<std::string, int> class_ref_indices;
for (const auto& v : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.second.spec))) {
if (!class_reference_indices.contains(cv->class_name)) {
compute_reference_indices_for(cv->class_name);
}
// Indices for foreign-key domains are generated by adding an offset
// to their indices in the respective class.
const int offset = total_offset;
total_offset += domains.at(cv->class_name).size();
class_ref_indices[v.first] = total_offset - 1;
std::map<std::string, int> child_class_indices;
if (class_reference_indices.contains(cv->class_name)) {
for (const auto& [ind, ref] :
class_reference_indices.at(cv->class_name)) {
std::map<std::string, int> class_ref_indices;
for (const auto& [field_name, ref_ind] : ref) {
child_class_indices[field_name] = ref_ind + offset;
}
ref_indices[ind + offset] = child_class_indices;
}
}
}
}

// Do not store a `class_reference_indices` entry for classes
// with no reference fields.
if (class_ref_indices.size() > 0) {
ref_indices[total_offset] = class_ref_indices;
class_reference_indices[name] = ref_indices;
}
}

void GenDB::make_relations_for_queryfield(
const QueryField& f, const PCleanClass& record_class, T_schema* tschema) {

// First, find all the vars and classes specified in f.class_path.
std::vector<std::string> var_names;
std::vector<std::string> class_names;
PCleanVariable last_var;
PCleanClass last_class = record_class;
class_names.push_back(record_class.name);
for (size_t i = 0; i < f.class_path.size(); ++i) {
const PCleanVariable& v = last_class.vars[f.class_path[i]];
last_var = v;
var_names.push_back(v.name);
if (i < f.class_path.size() - 1) {
class_names.push_back(std::get<ClassVar>(v.spec).class_name);
last_class = schema.classes.at(class_names.back());
}
}
// Remove the last var_name because it isn't used in making the path_prefix.
var_names.pop_back();

// Get the base relation from the last class and variable name.
std::string base_relation_name = class_names.back() + ":" + last_var.name;

// Handle queries of the record class specially.
if (f.class_path.size() == 1) {
if (record_class_is_clean) {
// Just rename the existing clean relation and set it to be observed.
T_clean_relation cr =
std::get<T_clean_relation>(tschema->at(base_relation_name));
cr.is_observed = true;
(*tschema)[f.name] = cr;
tschema->erase(base_relation_name);
} else {
T_noisy_relation tnr =
get_emission_relation(std::get<ScalarVar>(last_var.spec),
domains[record_class.name], base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
// If the record class is the only class in the schema, there will be
// no entries in `relation_reference_indices`.
if (class_reference_indices.contains(record_class.name)) {
relation_reference_indices[f.name] =
class_reference_indices.at(record_class.name);
}
}
return;
}

// Handle only_final_emissions == true.
if (only_final_emissions) {
std::vector<std::string> noisy_domains = domains[class_names.back()];
for (int i = class_names.size() - 2; i >= 0; --i) {
noisy_domains.push_back(class_names[i]);
relation_reference_indices[f.name][noisy_domains.size() - 1]
[var_names[i]] = noisy_domains.size() - 2;
}
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec), noisy_domains, base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
// If the record class is the only class in the schema, there will be
// no entries in `relation_reference_indices`.
if (relation_reference_indices.contains(base_relation_name)) {
relation_reference_indices[f.name] =
relation_reference_indices.at(base_relation_name);
}
return;
}

// Handle only_final_emissions == false.
std::string& previous_relation = base_relation_name;
std::vector<std::string> current_domains = domains[class_names.back()];
std::map<int, std::map<std::string, int>> ref_indices;
for (int i = f.class_path.size() - 2; i >= 0; --i) {
current_domains.push_back(class_names[i]);
ref_indices[current_domains.size() - 1][var_names[i]] =
current_domains.size() - 2;
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec), current_domains, previous_relation);
std::string rel_name;
if (i == 0) {
rel_name = f.name;
tnr.is_observed = true;
} else {
// Intermediate emissions have a name of the form
// "[Observing Class]::[QueryFieldName]"
rel_name = class_names[i] + "::" + f.name;
tnr.is_observed = false;
}
(*tschema)[rel_name] = tnr;
// Since noisy relations have the leftmost domains in common with their base
// relations, they share the reference indices with their base relations as
// well.
if (relation_reference_indices.contains(previous_relation)) {
relation_reference_indices[rel_name] =
relation_reference_indices.at(previous_relation);
}
relation_reference_indices[rel_name].merge(ref_indices);
previous_relation = rel_name;
}
}

T_schema GenDB::make_hirm_schema() {
T_schema tschema;

// For every scalar variable, make a clean relation with the name
// "[ClassName]:[VariableName]".
for (const auto& c : schema.classes) {
for (const auto& v : c.second.vars) {
std::string rel_name = c.first + ':' + v.first;
if (const ScalarVar* dv = std::get_if<ScalarVar>(&(v.second.spec))) {
tschema[rel_name] = get_distribution_relation(*dv, domains[c.first]);
if (class_reference_indices.contains(c.first)) {
relation_reference_indices[rel_name] =
class_reference_indices.at(c.first);
}
}
}
}

// For every query field, make one or more relations by walking up
// the class_path. At least one of those relations will have name equal
// to the name of the QueryField.
const PCleanClass record_class = schema.classes.at(schema.query.record_class);
for (const auto& [unused_name, f] : schema.query.fields) {
make_relations_for_queryfield(f, record_class, &tschema);
}

return tschema;
}

Loading
Loading