Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish inference_gendb. #220

Merged
merged 9 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ GenDB::GenDB(std::mt19937* prng, const PCleanSchema& schema_,

for (const auto& [class_name, unused_class] : schema.classes) {
domain_crps[class_name] = CRP();
reference_values[class_name];
}
for (const auto& [rel_name, trel] : hirm->schema) {
const std::vector<std::string>& domains =
Expand Down Expand Up @@ -99,16 +100,15 @@ T_items GenDB::sample_entities_relation(
const std::string& ref_class =
std::get<ClassVar>(schema.classes.at(class_name).vars.at(ref_field).spec)
.class_name;
std::tuple<std::string, std::string, int> ref_key = {class_name, ref_field,
class_item};
if (!reference_values.contains(ref_key)) {
sample_and_incorporate_reference(prng, ref_key, ref_class,
std::pair<std::string, int> ref_key = {ref_field, class_item};
if (!reference_values.at(class_name).contains(ref_key)) {
sample_and_incorporate_reference(prng, class_name, ref_key, ref_class,
new_rows_have_unique_entities);
}
T_items items =
sample_entities_relation(
prng, ref_class, ++class_path_start, class_path_end,
reference_values.at(ref_key), new_rows_have_unique_entities);
reference_values.at(class_name).at(ref_key), new_rows_have_unique_entities);
// The order of the items corresponds to the order of the relation's domains,
// with the class (domain) corresponding to the primary key placed last on the
// list.
Expand All @@ -126,10 +126,10 @@ int GenDB::get_reference_id(const std::string& class_name,
}

void GenDB::sample_and_incorporate_reference(
std::mt19937* prng,
const std::tuple<std::string, std::string, int>& ref_key,
std::mt19937* prng, const std::string& class_name,
const std::pair<std::string, int>& ref_key,
const std::string& ref_class, bool new_rows_have_unique_entities) {
auto [class_name, ref_field, class_item] = ref_key;
auto [ref_field, class_item] = ref_key;
int new_val;
if (new_rows_have_unique_entities) {
new_val = domain_crps[ref_class].max_table() + 1;
Expand All @@ -140,7 +140,7 @@ void GenDB::sample_and_incorporate_reference(
// Generate a unique ID for the sample and incorporate it into the
// domain CRP.
int new_id = get_reference_id(class_name, ref_field, class_item);
reference_values[ref_key] = new_val;
reference_values.at(class_name)[ref_key] = new_val;
domain_crps[ref_class].incorporate(new_id, new_val);
}

Expand Down Expand Up @@ -186,15 +186,14 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
if (const ClassVar* cv = std::get_if<ClassVar>(&(var.spec))) {
// If the reference field isn't populated, sample a value from a CRP and
// add it to reference_values.
std::tuple<std::string, std::string, int> ref_key = {class_name, name,
class_item};
if (!reference_values.contains(ref_key)) {
std::pair<std::string, int> ref_key = {name, class_item};
if (!reference_values.at(class_name).contains(ref_key)) {
assert(prng != nullptr);
sample_and_incorporate_reference(
prng, ref_key, cv->class_name, new_rows_have_unique_entities);
prng, class_name, ref_key, cv->class_name, new_rows_have_unique_entities);
}
T_items ref_items = sample_class_ancestors(
prng, cv->class_name, reference_values.at(ref_key),
prng, cv->class_name, reference_values.at(class_name).at(ref_key),
new_rows_have_unique_entities);
items.insert(items.end(), ref_items.begin(), ref_items.end());
}
Expand All @@ -218,7 +217,7 @@ void GenDB::get_relation_items(const std::string& rel_name, const int ind,
if (ref_indices.at(rel_name).contains(ind)) {
for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) {
int rf_item =
reference_values.at({domains.at(ind), rf_name, class_item});
reference_values.at(domains.at(ind)).at({rf_name, class_item});
get_relation_items(rel_name, rf_ind, rf_item, items);
}
}
Expand Down Expand Up @@ -284,10 +283,11 @@ double GenDB::unincorporate_reference(
// Check if any entities need to be removed from IRM domain clusters (after
// they've been unincorporated from relations) and compute the change in logp.
double logp_domain_cluster = 0.;
int ref_val = reference_values.at({class_name, ref_field, class_item});
int ref_val = reference_values.at(class_name).at({ref_field, class_item});
for (auto& [rel_name, inds] : domain_inds) {
for (int d_ind : inds) {
int r_ind = relation_reference_indices.at(rel_name).at(d_ind).at(ref_field);
int r_ind = relation_reference_indices.at(rel_name).at(d_ind).at(
ref_field);
logp_domain_cluster += unincorporate_from_domain_cluster_relation(
rel_name, ref_val, r_ind, unincorporated_from_domains);
}
Expand Down Expand Up @@ -338,11 +338,11 @@ GenDB::update_reference_items(
H_items>>& stored_values,
const std::string& class_name, const std::string& ref_field,
const int class_item, const int new_ref_val) {
int old_ref_val = reference_values.at({class_name, ref_field, class_item});
int old_ref_val = reference_values.at(class_name).at({ref_field, class_item});

// Temporarily associate class_name.ref_field at index class_item with the new
// value.
reference_values[{class_name, ref_field, class_item}] = new_ref_val;
reference_values.at(class_name)[{ref_field, class_item}] = new_ref_val;

std::map<std::string,
std::unordered_map<T_items, ObservationVariant, H_items>>
Expand All @@ -356,7 +356,7 @@ GenDB::update_reference_items(
}
}
// Return reference_values to its original state.
reference_values[{class_name, ref_field, class_item}] = old_ref_val;
reference_values.at(class_name)[{ref_field, class_item}] = old_ref_val;
return new_stored_values;
}

Expand Down Expand Up @@ -435,7 +435,7 @@ double GenDB::unincorporate_from_domain_cluster_relation(
if (relation_reference_indices.contains(r) &&
relation_reference_indices.at(r).contains(ind)) {
for (auto [name, r_ind] : relation_reference_indices.at(r).at(ind)) {
int ref_item = reference_values.at({ref_class, name, item});
int ref_item = reference_values.at(ref_class).at({name, item});
logp_adj += unincorporate_from_domain_cluster_relation(r, ref_item, r_ind,
unincorporated);
}
Expand All @@ -461,7 +461,7 @@ double GenDB::unincorporate_from_entity_cluster(
double logp_adj = 0.;

int ref_id = get_reference_id(class_name, ref_field, class_item);
int ref_item = reference_values.at({class_name, ref_field, class_item});
int ref_item = reference_values.at(class_name).at({ref_field, class_item});

const std::string& ref_class =
std::get<ClassVar>(schema.classes.at(class_name).vars.at(ref_field).spec)
Expand Down Expand Up @@ -551,7 +551,7 @@ double GenDB::unincorporate_singleton(
double logp_refclass = 0.;

std::mt19937* prng = nullptr; // unused
int ref_val = reference_values.at({class_name, ref_field, class_item});
int ref_val = reference_values.at(class_name).at({ref_field, class_item});
T_items base_items = sample_class_ancestors(prng, ref_class, ref_val, false);
logp_refclass +=
unincorporate_from_entity_cluster(class_name, ref_field, class_item,
Expand Down Expand Up @@ -581,9 +581,13 @@ void GenDB::transition_reference(std::mt19937* prng,
const std::string& ref_class =
std::get<ClassVar>(schema.classes.at(class_name).vars.at(ref_field).spec)
.class_name;
int init_refval = reference_values.at({class_name, ref_field, class_item});
int init_refval = reference_values.at(class_name).at({ref_field, class_item});
std::map<int, double> crp_dist =
domain_crps[ref_class].tables_weights_gibbs(init_refval);
domain_crps.at(ref_class).tables_weights_gibbs(init_refval);
if (crp_dist.size() == 1) {
// Can only re-incorporate into the same table.
return;
}

// For each relation, get the indices (in the items vector) of the reference
// value being transitioned.
Expand Down Expand Up @@ -645,7 +649,7 @@ void GenDB::transition_reference(std::mt19937* prng,
entities[i] = table;
logps[i] += log(n_customers);

reference_values.at({class_name, ref_field, class_item}) = table;
reference_values.at(class_name).at({ref_field, class_item}) = table;

if (table == init_refval) {
logps[i++] += logp_current;
Expand Down Expand Up @@ -733,7 +737,7 @@ void GenDB::reincorporate_new_refval(
}

// Update reference_values.
reference_values.at({class_name, ref_field, class_item}) = new_refval;
reference_values.at(class_name).at({ref_field, class_item}) = new_refval;

// Check if the singleton was selected.
bool is_singleton = !domain_crps.at(ref_class).tables.contains(new_refval);
Expand All @@ -751,7 +755,8 @@ void GenDB::reincorporate_new_refval(
} else {
// Remove the singleton from reference_values if it was not selected.
for (auto [k, v] : unincorporated_from_entity_crps) {
reference_values.erase(k);
auto [ref_class, field, item] = k;
reference_values.at(ref_class).erase({field, item});
}
}
int ref_id = get_reference_id(class_name, ref_field, class_item);
Expand All @@ -764,6 +769,22 @@ void GenDB::reincorporate_new_refval(
hirm->cleanup_relation_clusters();
}

void GenDB::transition_reference_class_and_ancestors(
std::mt19937* prng, const std::string& class_name) {
PCleanClass c = schema.classes.at(class_name);

for (const auto& [name, var] : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(var.spec))) {
transition_reference_class_and_ancestors(prng, cv->class_name);
}
}

for (auto [k, v] : reference_values.at(class_name)) {
auto [ref_field, class_item] = k;
transition_reference(prng, class_name, ref_field, class_item);
}
}

GenDB::~GenDB() { delete hirm; }

void GenDB::compute_domains_cache() {
Expand Down
35 changes: 20 additions & 15 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

class GenDB {
public:

GenDB(std::mt19937* prng, const PCleanSchema& schema,
bool _only_final_emissions = false, bool _record_class_is_clean = true);

Expand Down Expand Up @@ -41,22 +40,23 @@ class GenDB {
// Samples a reference value and stores it in reference_values and the
// relevant domain CRP.
void sample_and_incorporate_reference(
std::mt19937* prng,
const std::tuple<std::string, std::string, int>& ref_key,
const std::string& ref_class, bool new_rows_have_unique_entities);

std::mt19937* prng, const std::string& class_name,
const std::pair<std::string, int>& ref_key, const std::string& ref_class,
bool new_rows_have_unique_entities);

// Samples a set of entities in the domains of the relation corresponding to
// class_path.
T_items sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end,
int class_item, bool new_rows_have_unique_entities);
std::vector<std::string>::const_iterator class_path_end, int class_item,
bool new_rows_have_unique_entities);

// Sample items from a class' ancestors (recursive reference fields).
T_items sample_class_ancestors(
std::mt19937* prng, const std::string& class_name, int class_item,
bool new_rows_have_unique_entities);
T_items sample_class_ancestors(std::mt19937* prng,
const std::string& class_name, int class_item,
bool new_rows_have_unique_entities);

// Populates "items" with entities by walking the DAG of reference indices,
// starting with "ind".
Expand Down Expand Up @@ -179,6 +179,10 @@ class GenDB {
void transition_reference(std::mt19937* prng, const std::string& class_name,
const std::string& ref_field, const int class_item);

// Transitions all reference fields and rows for a class and its ancestors.
void transition_reference_class_and_ancestors(std::mt19937* prng,
const std::string& class_name);

~GenDB();

// Disable copying.
Expand Down Expand Up @@ -211,16 +215,18 @@ class GenDB {

// Make the relations associated with QueryField f and put them into
// schema.
void make_relations_for_queryfield(
const QueryField& f, const PCleanClass& record_class, T_schema* schema);
void make_relations_for_queryfield(const QueryField& f,
const PCleanClass& record_class,
T_schema* schema);

// Member variables
const PCleanSchema& schema;

// This data structure contains entity sets and linkages. Semantics are
// map<tuple<class_name, reference_field_name, class_primary_key> ref_val>>,
// where primary_key and ref_val are (integer) entity IDs.
std::map<std::tuple<std::string, std::string, int>, int> reference_values;
// map<class_name, map<pair<reference_field_name, class_primary_key>
// ref_val>>, where primary_key and ref_val are (integer) entity IDs.
std::map<std::string, std::map<std::pair<std::string, int>, int>>
reference_values;

HIRM* hirm; // Owned by the GenDB instance.

Expand All @@ -233,7 +239,6 @@ class GenDB {
bool record_class_is_clean;
std::map<std::string, std::vector<std::string>> domains;


// Maps class names to relations corresponding to attributes of the class.
std::map<std::string, std::vector<std::string>> class_to_relations;

Expand Down
Loading