Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

18891: Adds end-to-end surprisal as distance to all distance opcodes and queries, MINOR #54

Merged
merged 26 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fc9bb21
18891: Adds base ability to change how surprisal is conditioned for c…
howsohazard Jan 7, 2024
390ca6d
18891: Adds check for nonparameterized cyclic
howsohazard Jan 7, 2024
cf7af18
18891: Progress
howsohazard Jan 8, 2024
a78b2d7
18891: Minor comment cleanup
howsohazard Jan 8, 2024
112b1f9
18891: Adds todo
howsohazard Jan 8, 2024
673708d
18891: Updates to - 1 nats instead of 1.5 for surprisal
howsohazard Jan 8, 2024
5509727
18891: Undoes change from -1.5 to -1 nats for continuous
howsohazard Jan 8, 2024
1cc2609
18891: Renames for clarity
howsohazard Jan 11, 2024
38ac26e
18891: Implements surprisal space nominals
howsohazard Jan 11, 2024
a5d92b4
18891: Adds todos
howsohazard Jan 11, 2024
cbce672
18891: Progress
howsohazard Jan 11, 2024
e430716
18891: Clean up
howsohazard Jan 11, 2024
c1f43df
18891: More cleanup
howsohazard Jan 11, 2024
2ac2cdf
18891: More cleanup
howsohazard Jan 11, 2024
561dd7a
18891: API signature cleanup
howsohazard Jan 11, 2024
0d73705
18891: More progress
howsohazard Jan 11, 2024
8e77bdc
18891: Makes surprisal space ubiquitous
howsohazard Jan 12, 2024
861c636
18891: Fixes bugs, adds tests
howsohazard Jan 12, 2024
a51aea9
18891: Updates tests
howsohazard Jan 12, 2024
f56d244
18891: Removes unused variable
howsohazard Jan 12, 2024
ab813e0
Merge branch 'main' into 18891-surprisal-flexibility
howsohazard Jan 13, 2024
0cf3c03
Merge branch 'main' into 18891-surprisal-flexibility
howsohazard Jan 15, 2024
b43065a
Merge branch 'main' into 18891-surprisal-flexibility
howsohazard Jan 15, 2024
e7eb20f
Merge branch 'main' into 18891-surprisal-flexibility
howsohazard Jan 18, 2024
63c7cd7
Merge branch 'main' into 18891-surprisal-flexibility
howsohazard Jan 18, 2024
8a0c2ef
Merge branch 'main' into 18891-surprisal-flexibility
howsohazard Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/language.js

Large diffs are not rendered by default.

286 changes: 167 additions & 119 deletions src/Amalgam/GeneralizedDistance.h

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/Amalgam/SeparableBoxFilterDataStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,13 +1011,13 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G
auto value_found = column->stringIdValueToIndices.find(value.stringID);
if(value_found != end(column->stringIdValueToIndices))
{
double term = dist_params.ComputeDistanceTermNonNominalExactMatch(query_feature_index, high_accuracy);
double term = dist_params.ComputeDistanceTermContinuousExactMatch(query_feature_index, high_accuracy);
AccumulatePartialSums(*(value_found->second), query_feature_index, term);
}
}

//the next closest string will have an edit distance of 1
return dist_params.ComputeDistanceTermNonNominalNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
return dist_params.ComputeDistanceTermContinuousNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
}
else if(effective_feature_type == GeneralizedDistance::EFDT_CONTINUOUS_CODE)
{
Expand All @@ -1035,7 +1035,7 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G
}

//next most similar code must be at least a distance of 1 edit away
return dist_params.ComputeDistanceTermNonNominalNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
return dist_params.ComputeDistanceTermContinuousNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
}
//else feature_type == FDT_CONTINUOUS_NUMERIC or FDT_CONTINUOUS_UNIVERSALLY_NUMERIC

Expand All @@ -1052,9 +1052,9 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G

double term = 0.0;
if(exact_index_found)
term = dist_params.ComputeDistanceTermNonNominalExactMatch(query_feature_index, high_accuracy);
term = dist_params.ComputeDistanceTermContinuousExactMatch(query_feature_index, high_accuracy);
else
term = dist_params.ComputeDistanceTermNonNominalNonNullRegular(
term = dist_params.ComputeDistanceTermContinuousNonNullRegular(
value.number - column->sortedNumberValueEntries[value_index]->value.number, query_feature_index, high_accuracy);

size_t num_entities_computed = AccumulatePartialSums(column->sortedNumberValueEntries[value_index]->indicesWithValue, query_feature_index, term);
Expand Down Expand Up @@ -1203,7 +1203,7 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G
break;
}

term = dist_params.ComputeDistanceTermNonNominalNonNullRegular(next_closest_diff, query_feature_index, high_accuracy);
term = dist_params.ComputeDistanceTermContinuousNonNullRegular(next_closest_diff, query_feature_index, high_accuracy);
num_entities_computed += AccumulatePartialSums(
column->sortedNumberValueEntries[next_closest_index]->indicesWithValue, query_feature_index, term);

Expand Down
10 changes: 5 additions & 5 deletions src/Amalgam/SeparableBoxFilterDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SeparableBoxFilterDataStore
{
double max_diff = columnData[absolute_feature_index]->GetMaxDifferenceTermFromValue(
dist_params.featureParams[query_feature_index], value_type, value);
return dist_params.ComputeDistanceTermNonNominalNonNullRegular(max_diff, query_feature_index, high_accuracy);
return dist_params.ComputeDistanceTermContinuousNonNullRegular(max_diff, query_feature_index, high_accuracy);
}

//gets the matrix cell index for the specified index
Expand Down Expand Up @@ -737,7 +737,7 @@ class SeparableBoxFilterDataStore
case GeneralizedDistance::EFDT_CONTINUOUS_UNIVERSALLY_NUMERIC:
{
const size_t column_index = target_label_indices[query_feature_index];
return dist_params.ComputeDistanceTermNonNominalNonCyclicOneNonNullRegular(
return dist_params.ComputeDistanceTermContinuousNonCyclicOneNonNullRegular(
target_values[query_feature_index].number - GetValue(entity_index, column_index).number,
query_feature_index, high_accuracy);
}
Expand All @@ -754,7 +754,7 @@ class SeparableBoxFilterDataStore
const size_t column_index = target_label_indices[query_feature_index];
auto &column_data = columnData[column_index];
if(column_data->numberIndices.contains(entity_index))
return dist_params.ComputeDistanceTermNonNominalNonCyclicOneNonNullRegular(
return dist_params.ComputeDistanceTermContinuousNonCyclicOneNonNullRegular(
target_values[query_feature_index].number - GetValue(entity_index, column_index).number,
query_feature_index, high_accuracy);
else
Expand All @@ -766,7 +766,7 @@ class SeparableBoxFilterDataStore
const size_t column_index = target_label_indices[query_feature_index];
auto &column_data = columnData[column_index];
if(column_data->numberIndices.contains(entity_index))
return dist_params.ComputeDistanceTermNonNominalOneNonNullRegular(
return dist_params.ComputeDistanceTermContinuousOneNonNullRegular(
target_values[query_feature_index].number - GetValue(entity_index, column_index).number,
query_feature_index, high_accuracy);
else
Expand Down Expand Up @@ -922,7 +922,7 @@ class SeparableBoxFilterDataStore
else
effective_feature_type = GeneralizedDistance::EFDT_CONTINUOUS_NUMERIC_PRECOMPUTED;

dist_params.ComputeAndStoreInternedNumberValuesAndDistanceTerms(query_feature_index, position_value_numeric, &column_data->internedNumberIndexToNumberValue);
dist_params.ComputeAndStoreInternedNumberValuesAndDistanceTerms(position_value_numeric, query_feature_index, &column_data->internedNumberIndexToNumberValue);
}
else
{
Expand Down
23 changes: 15 additions & 8 deletions src/Amalgam/amlg_code/full_test.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,15 @@
;should print 3ish
(print "35 " (generalized_distance (list 1 1) (list "continuous_code" "nominal_string") (list 0 5) (null) 1 (list (list 1.5 2 3 4 5) "s") (list (list 1 2 3) "s") ) "\n")

;surprisal
;should both be 0
(print "36 " (generalized_distance (list 1 1) (list "continuous_numeric" "continuous_numeric") (null) (list 0.5 0.5) 1 (list 1 1) (list 1 1) (null) (true) ) "\n" )
(print "37 " (generalized_distance (list 1 1) (list "nominal_numeric" "nominal_numeric") (null) (list 0.5 0.5) 1 (list 1 1) (list 1 1) (null) (true) ) "\n" )

;surprisal
(print "38 " (generalized_distance (list 1 1) (list "continuous_numeric" "continuous_numeric") (null) (list 0.5 0.5) 1 (list 1 1) (list 2 2) (null) (true) ) "\n" )
(print "39 " (generalized_distance (list 1 1) (list "nominal_numeric" "nominal_numeric") (list 2 2) (list 0.25 0.25) 1 (list 1 1) (list 2 2) (null) (true) ) "\n" )

(print "--entropy--\n")
(print (entropy (list 0.5 0.5)) "\n")
(print (entropy (list 0.5 0.5) (list 0.25 0.75) -1 1) "\n")
Expand Down Expand Up @@ -3889,7 +3898,6 @@

;should be:
;(list "vert0" "vert1" "vert2" "vert3")
;(list 0.049787068367863944 0.049787068367863944 0.01831563888873418 0.006737946999085467)
(print "probabilities: "
(compute_on_contained_entities "SurprisalTransformContainer" (list
(query_nearest_generalized_distance
Expand All @@ -3899,9 +3907,9 @@
(null) ; context_weights
(list "continuous_numeric") ; types
(null) ; attributes
(null) ; context_deviations
(list 0.25) ; context_deviations
1 ; p_parameter
"surprisal_to_prob" ; dwe = 1 means return computed distance to each case
"surprisal_to_prob" ; distance transform
(null) ; weight
(rand)
(null)
Expand All @@ -3913,7 +3921,6 @@

;should be
;(list "vert0" "vert2" "vert3" "vert1")
;(list 0.09709538455906153 0.01831563888873418 0.006737946999085467 0)
(print "weighted probabilities: "
(compute_on_contained_entities "SurprisalTransformContainer" (list
(query_nearest_generalized_distance
Expand All @@ -3923,9 +3930,9 @@
(null) ; context_weights
(list "continuous_numeric") ; types
(null) ; attributes
(null) ; context_deviations
(list 0.25) ; context_deviations
1 ; p_parameter
"surprisal_to_prob" ; dwe = 1 means return computed distance to each case
"surprisal_to_prob" ; distance transform
"weight" ; weight
(rand)
(null)
Expand All @@ -3941,12 +3948,12 @@

;should be approx 2.123
(print "surprisal contribution: " (compute_on_contained_entities "SurprisalTransformContainer" (list
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (null) 1 "surprisal_to_prob" (null) "fixed_seed" (null) "precise")
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (list 0.25) 1 "surprisal_to_prob" (null) "fixed_seed" (null) "precise")
)))

;should be approx 2.123
(print "weighted surprisal contribution: " (compute_on_contained_entities "SurprisalTransformContainer" (list
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (null) 1 "surprisal_to_prob" "weight" "fixed_seed" (null) "precise")
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (list 0.25) 1 "surprisal_to_prob" "weight" "fixed_seed" (null) "precise")
)))

(print "--concurrency tests--\n")
Expand Down
26 changes: 1 addition & 25 deletions src/Amalgam/amlg_code/test.amlg
Original file line number Diff line number Diff line change
@@ -1,28 +1,4 @@
(seq
(create_entities "BoxConvictionTestContainer" (null) )
(print "17 " (generalized_distance (null) (list "nominal_numeric") (list 1) (null) 1 (list 1 2 3) (list 10 2 4) ) "\n")

(create_entities (list "BoxConvictionTestContainer" "vert0") (lambda
(null ##x 0 ##y 0 ##weight 2)
) )

(create_entities (list "BoxConvictionTestContainer" "vert1") (lambda
(null ##x 0 ##y 1 ##weight 1)
) )

(create_entities (list "BoxConvictionTestContainer" "vert2") (lambda
(null ##x 1 ##y 0 ##weight 1)
) )

(create_entities (list "BoxConvictionTestContainer" "vert3") (lambda
(null ##x 2 ##y 1 ##weight 1)
) )

;should print:
;dc: (list
;(list "vert0" "vert1" "vert2" "vert3")
;(list 1 1 1 1.4142135623730951)
;)
(print "dc: " (compute_on_contained_entities "BoxConvictionTestContainer" (list
(compute_entity_distance_contributions 1 (list "x" "y") (list "vert3") (null) (null) (null) (null) 2.0 -1 (null) "fixed_seed" (null) "recompute_precise" (true))
)))
)
4 changes: 2 additions & 2 deletions src/Amalgam/entity/EntityQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain
}

//transform distances as appropriate
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(transformSuprisalToProb,
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(distParams.computeSurprisal,
distanceWeightExponent, weightLabel != StringInternPool::NOT_A_STRING_ID,
[this](Entity *e, double &weight_value) { return e->GetValueAtLabelAsNumber(weightLabel, weight_value); });

Expand Down Expand Up @@ -775,7 +775,7 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain
entity_values.push_back(DistanceReferencePair<Entity *>(GetConditionDistanceMeasure(matching_entities[i], high_accuracy), matching_entities[i]));

//transform distances as appropriate
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(transformSuprisalToProb,
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(distParams.computeSurprisal,
distanceWeightExponent, weightLabel != StringInternPool::NOT_A_STRING_ID,
[this](Entity *e, double &weight_value) { return e->GetValueAtLabelAsNumber(weightLabel, weight_value); });

Expand Down
3 changes: 0 additions & 3 deletions src/Amalgam/entity/EntityQueries.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ class EntityQueryCondition
//only applicable when transformSuprisalToProb is false
double distanceWeightExponent;

//if true, the values will be transformed from surprisal to probability; if false, will perform a distance transform
bool transformSuprisalToProb;

//if ENT_QUERY_SELECT has a start offset
bool hasStartOffset;

Expand Down
4 changes: 2 additions & 2 deletions src/Amalgam/entity/EntityQueryBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,15 @@ namespace EntityQueryBuilder
cur_condition->distParams.pValue = p_value;

//value transforms for whatever is measured as "distance"
cur_condition->transformSuprisalToProb = false;
cur_condition->distanceWeightExponent = 1.0;
cur_condition->distParams.computeSurprisal = false;
if(ocn.size() > DISTANCE_VALUE_TRANSFORM)
{
EvaluableNode *dwe_param = ocn[DISTANCE_VALUE_TRANSFORM];
if(!EvaluableNode::IsNull(dwe_param))
{
if(dwe_param->GetType() == ENT_STRING && dwe_param->GetStringIDReference() == ENBISI_surprisal_to_prob)
cur_condition->transformSuprisalToProb = true;
cur_condition->distParams.computeSurprisal = true;
else //try to convert to number
cur_condition->distanceWeightExponent = EvaluableNode::ToNumber(dwe_param, 1.0);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Amalgam/entity/EntityQueryCaches.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void EntityQueryCaches::GetMatchingEntities(EntityQueryCondition *cond, BitArray
weight_column = sbfds.GetColumnIndexFromLabelId(cond->weightLabel);

auto get_weight = sbfds.GetNumberValueFromEntityIndexFunction(weight_column);
EntityQueriesStatistics::DistanceTransform<size_t> distance_transform(cond->transformSuprisalToProb,
EntityQueriesStatistics::DistanceTransform<size_t> distance_transform(cond->distParams.computeSurprisal,
cond->distanceWeightExponent, use_entity_weights, get_weight);

//if first, need to populate with all entities
Expand Down
8 changes: 6 additions & 2 deletions src/Amalgam/interpreter/InterpreterOpcodesMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,9 @@ EvaluableNodeReference Interpreter::InterpretNode_ENT_GENERALIZED_DISTANCE(Evalu

//get value_names if applicable
std::vector<StringInternPool::StringID> value_names;
if(ocn.size() > 8)
if(ocn.size() > 7)
{
EvaluableNodeReference value_names_node = InterpretNodeForImmediateUse(ocn[8]);
EvaluableNodeReference value_names_node = InterpretNodeForImmediateUse(ocn[7]);
if(!EvaluableNode::IsNull(value_names_node))
{
//extract the names for each value into value_names
Expand All @@ -1034,6 +1034,10 @@ EvaluableNodeReference Interpreter::InterpretNode_ENT_GENERALIZED_DISTANCE(Evalu
evaluableNodeManager->FreeNodeTreeIfPossible(value_names_node);
}

dist_params.computeSurprisal = false;
if(ocn.size() > 8)
dist_params.computeSurprisal = InterpretNodeIntoBoolValue(ocn[8], false);

//get the origin and destination
std::vector<EvaluableNodeImmediateValue> location;
std::vector<EvaluableNodeImmediateValueType> location_types;
Expand Down
Loading