Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge GenDB and SchemaHelper; use GenDB in pclean binary #212

Merged
merged 13 commits into from
Oct 2, 2024
Merged
2 changes: 1 addition & 1 deletion cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ cc_library(
":irm",
":observations",
"//distributions:crp",
"//pclean:get_joint_relations",
"//pclean:io",
"//pclean:schema",
"//pclean:schema_helper",
],
)

Expand Down
238 changes: 229 additions & 9 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
#include "hirm.hh"
#include "irm.hh"
#include "observations.hh"
#include "pclean/get_joint_relations.hh"
#include "pclean/schema.hh"
#include "pclean/schema_helper.hh"

GenDB::GenDB(std::mt19937* prng, const PCleanSchema& schema_,
bool _only_final_emissions, bool _record_class_is_clean)
: schema(schema_),
schema_helper(schema_, _only_final_emissions, _record_class_is_clean) {
std::map<std::string, std::vector<std::string>>
annotated_domains_for_relation;
T_schema hirm_schema =
schema_helper.make_hirm_schema(&annotated_domains_for_relation);
: schema(schema_), only_final_emissions(_only_final_emissions),
record_class_is_clean(_record_class_is_clean) {
// Note that the domains cache must be populated before the reference
// indices.
compute_domains_cache();
compute_reference_indices_cache();

T_schema hirm_schema = make_hirm_schema();
hirm = new HIRM(hirm_schema, prng);

for (const auto& [class_name, unused_class] : schema.classes) {
Expand Down Expand Up @@ -152,6 +154,7 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
const std::string& class_name,
int class_item) {
T_items items;
assert(schema.classes.contains(class_name));
PCleanClass c = schema.classes.at(class_name);

for (const auto& [name, var] : c.vars) {
Expand Down Expand Up @@ -182,7 +185,7 @@ void GenDB::get_relation_items(const std::string& rel_name, const int ind,
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name));
items[ind] = class_item;
auto& ref_indices = schema_helper.relation_reference_indices;
auto& ref_indices = relation_reference_indices;
if (ref_indices.contains(rel_name)) {
if (ref_indices.at(rel_name).contains(ind)) {
for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) {
Expand Down Expand Up @@ -217,7 +220,7 @@ GenDB::unincorporate_reference(const std::string& class_name,
std::vector<size_t> domain_inds;
for (size_t i = 0; i < domains.size(); ++i) {
if (domains[i] == class_name &&
schema_helper.relation_reference_indices.at(rel_name).at(i).contains(
relation_reference_indices.at(rel_name).at(i).contains(
ref_field)) {
domain_inds.push_back(i);
}
Expand Down Expand Up @@ -371,3 +374,220 @@ void GenDB::incorporate_reference_relation(
}

GenDB::~GenDB() { delete hirm; }

void GenDB::compute_domains_cache() {
for (const auto& c : schema.classes) {
if (!domains.contains(c.first)) {
compute_domains_for(c.first);
}
}
}

void GenDB::compute_reference_indices_cache() {
for (const auto& c : schema.classes) {
if (!class_reference_indices.contains(c.first)) {
compute_reference_indices_for(c.first);
}
}
}

void GenDB::compute_domains_for(const std::string& name) {
std::vector<std::string> ds;
assert(schema.classes.contains(name));
PCleanClass c = schema.classes.at(name);

for (const auto& v : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.second.spec))) {
if (!domains.contains(cv->class_name)) {
compute_domains_for(cv->class_name);
}
for (const std::string& s : domains[cv->class_name]) {
ds.push_back(s);
}
}
}

// Put the "primary" domain last, so that it survives reordering.
ds.push_back(name);

domains[name] = ds;
}

void GenDB::compute_reference_indices_for(
const std::string& name) {
std::vector<std::string> ds;
int total_offset = 0;
assert(schema.classes.contains(name));
PCleanClass c = schema.classes.at(name);

// Recursively maps the indices of class "name" (and ancestors) in relation
// items to the names and indices (in items) of their parents (reference
// fields).
std::map<int, std::map<std::string, int>> ref_indices;

// Temporarily stores reference fields and indices for class "name";
std::map<std::string, int> class_ref_indices;
for (const auto& v : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.second.spec))) {
if (!class_reference_indices.contains(cv->class_name)) {
compute_reference_indices_for(cv->class_name);
}
// Indices for foreign-key domains are generated by adding an offset
// to their indices in the respective class.
const int offset = total_offset;
total_offset += domains.at(cv->class_name).size();
class_ref_indices[v.first] = total_offset - 1;
std::map<std::string, int> child_class_indices;
if (class_reference_indices.contains(cv->class_name)) {
for (const auto& [ind, ref] :
class_reference_indices.at(cv->class_name)) {
std::map<std::string, int> class_ref_indices;
for (const auto& [field_name, ref_ind] : ref) {
child_class_indices[field_name] = ref_ind + offset;
}
ref_indices[ind + offset] = child_class_indices;
}
}
}
}

// Do not store a `class_reference_indices` entry for classes
// with no reference fields.
if (class_ref_indices.size() > 0) {
ref_indices[total_offset] = class_ref_indices;
class_reference_indices[name] = ref_indices;
}
}

void GenDB::make_relations_for_queryfield(
const QueryField& f, const PCleanClass& record_class, T_schema* tschema) {

// First, find all the vars and classes specified in f.class_path.
std::vector<std::string> var_names;
std::vector<std::string> class_names;
PCleanVariable last_var;
PCleanClass last_class = record_class;
class_names.push_back(record_class.name);
for (size_t i = 0; i < f.class_path.size(); ++i) {
const PCleanVariable& v = last_class.vars[f.class_path[i]];
last_var = v;
var_names.push_back(v.name);
if (i < f.class_path.size() - 1) {
class_names.push_back(std::get<ClassVar>(v.spec).class_name);
last_class = schema.classes.at(class_names.back());
}
}
// Remove the last var_name because it isn't used in making the path_prefix.
var_names.pop_back();

// Get the base relation from the last class and variable name.
std::string base_relation_name = class_names.back() + ":" + last_var.name;

// Handle queries of the record class specially.
if (f.class_path.size() == 1) {
if (record_class_is_clean) {
// Just rename the existing clean relation and set it to be observed.
T_clean_relation cr =
std::get<T_clean_relation>(tschema->at(base_relation_name));
cr.is_observed = true;
(*tschema)[f.name] = cr;
tschema->erase(base_relation_name);
} else {
T_noisy_relation tnr =
get_emission_relation(std::get<ScalarVar>(last_var.spec),
domains[record_class.name], base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
// If the record class is the only class in the schema, there will be
// no entries in `relation_reference_indices`.
if (class_reference_indices.contains(record_class.name)) {
relation_reference_indices[f.name] =
class_reference_indices.at(record_class.name);
}
}
return;
}

// Handle only_final_emissions == true.
if (only_final_emissions) {
std::vector<std::string> noisy_domains = domains[class_names.back()];
for (int i = class_names.size() - 2; i >= 0; --i) {
noisy_domains.push_back(class_names[i]);
relation_reference_indices[f.name][noisy_domains.size() - 1]
[var_names[i]] = noisy_domains.size() - 2;
}
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec), noisy_domains, base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
// If the record class is the only class in the schema, there will be
// no entries in `relation_reference_indices`.
if (relation_reference_indices.contains(base_relation_name)) {
relation_reference_indices[f.name] =
relation_reference_indices.at(base_relation_name);
}
return;
}

// Handle only_final_emissions == false.
std::string& previous_relation = base_relation_name;
std::vector<std::string> current_domains = domains[class_names.back()];
std::map<int, std::map<std::string, int>> ref_indices;
for (int i = f.class_path.size() - 2; i >= 0; --i) {
current_domains.push_back(class_names[i]);
ref_indices[current_domains.size() - 1][var_names[i]] =
current_domains.size() - 2;
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec), current_domains, previous_relation);
std::string rel_name;
if (i == 0) {
rel_name = f.name;
tnr.is_observed = true;
} else {
// Intermediate emissions have a name of the form
// "[Observing Class]::[QueryFieldName]"
rel_name = class_names[i] + "::" + f.name;
tnr.is_observed = false;
}
(*tschema)[rel_name] = tnr;
// Since noisy relations have the leftmost domains in common with their base
// relations, they share the reference indices with their base relations as
// well.
if (relation_reference_indices.contains(previous_relation)) {
relation_reference_indices[rel_name] =
relation_reference_indices.at(previous_relation);
}
relation_reference_indices[rel_name].merge(ref_indices);
previous_relation = rel_name;
}
}

T_schema GenDB::make_hirm_schema() {
T_schema tschema;

// For every scalar variable, make a clean relation with the name
// "[ClassName]:[VariableName]".
for (const auto& c : schema.classes) {
for (const auto& v : c.second.vars) {
std::string rel_name = c.first + ':' + v.first;
if (const ScalarVar* dv = std::get_if<ScalarVar>(&(v.second.spec))) {
tschema[rel_name] = get_distribution_relation(*dv, domains[c.first]);
if (class_reference_indices.contains(c.first)) {
relation_reference_indices[rel_name] =
class_reference_indices.at(c.first);
}
}
}
}

// For every query field, make one or more relations by walking up
// the class_path. At least one of those relations will have name equal
// to the name of the QueryField.
const PCleanClass record_class = schema.classes.at(schema.query.record_class);
for (const auto& [unused_name, f] : schema.query.fields) {
make_relations_for_queryfield(f, record_class, &tschema);
}

return tschema;
}

39 changes: 34 additions & 5 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@
#include "hirm.hh"
#include "observations.hh"
#include "pclean/schema.hh"
#include "pclean/schema_helper.hh"

class GenDB {
public:
const PCleanSchema& schema;

// TODO(emilyaf): Merge PCleanSchemaHelper and GenDB.
PCleanSchemaHelper schema_helper;

// This data structure contains entity sets and linkages. Semantics are
// map<tuple<class_name, reference_field_name, class_primary_key> ref_val>>,
// where primary_key and ref_val are (integer) entity IDs.
Expand Down Expand Up @@ -101,6 +97,9 @@ class GenDB {
const std::string& class_name, const std::string& ref_field,
const int class_item, const int new_ref_val);

// Translate the PCleanSchema into an HIRM T_schema.
T_schema make_hirm_schema();

// Incorporates the items and values from stored_values (generally an output
// of update_reference_items).
void incorporate_reference(
Expand All @@ -125,4 +124,34 @@ class GenDB {
// Disable copying.
GenDB& operator=(const GenDB&) = delete;
GenDB(const GenDB&) = delete;
};

// The rest of these methods are conceptually private, but actually
// public for testing.

ThomasColthurst marked this conversation as resolved.
Show resolved Hide resolved
void compute_domains_cache();

void compute_domains_for(const std::string& name);

void compute_reference_indices_cache();

void compute_reference_indices_for(const std::string& name);

void make_relations_for_queryfield(
const QueryField& f, const PCleanClass& c, T_schema* schema);

bool only_final_emissions;
bool record_class_is_clean;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment for "domains"?

std::map<std::string, std::vector<std::string>> domains;

// Map keys are relation name, item index of a class, and reference field
// name. The values in the inner map are the item index of the reference
// class. (See tests for more intuition.)
std::map<std::string, std::map<int, std::map<std::string, int>>>
relation_reference_indices;

// Map keys are class name, item index of a class, and reference field
// name. The values in the inner map are the item index of the reference
// class. (See tests for more intuition.)
std::map<std::string, std::map<int, std::map<std::string, int>>>
class_reference_indices;
};
Loading
Loading