Skip to content

Commit

Permalink
22631: Corrected ablated indices and session training indices when tr…
Browse files Browse the repository at this point in the history
…aining with ablation, MAJOR (#409)
  • Loading branch information
fulpm authored Feb 4, 2025
1 parent 33466bb commit 57e6ccc
Show file tree
Hide file tree
Showing 11 changed files with 760 additions and 141 deletions.
2 changes: 1 addition & 1 deletion howso/analysis.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@
(accum (assoc
analyze_warnings
(associate (concat
"It is recomended to use a \"targetless\" analysis of the data for a time-series Trainee. "
"It is recommended to use a \"targetless\" analysis of the data for a time-series Trainee. "
"Please analyze the data once more with no action features specified and the value \"targetless\" "
"specified for the \"targeted_model\" parameter."
))
Expand Down
9 changes: 5 additions & 4 deletions howso/get_cases.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@
#!GetCaseGivenReplaySession
(get (retrieve_from_entity session ".replay_steps") session_index)

;returns assoc with features and cases - a list of lists of all feature values. Retrieves all feature values for cases for
;all (unordered) sessions in the order they were trained within each session. If a session is specified, only that session's
;cases wil be output.
;returns assoc with features and cases - a list of lists of all feature values. Retrieves all feature values for cases in
;all sessions. If a session is specified, only that session's cases will be output. Session and case order is not guaranteed,
;however, the features ".session" and ".session_training_index" may be requested to get the session id and session train order
;for each case respectively.
;{read_only (true) idempotent (true)}
#get_cases
(declare
Expand All @@ -114,7 +115,7 @@
; }
(assoc
;{type "list" values "string"}
;list of features to retrieve.
;list of features to retrieve. Case values will be output given this feature order.
features (list)
;{type "number"}
;set flag to skip decoding feature values into their nominal values for output.
Expand Down
4 changes: 2 additions & 2 deletions howso/get_sessions.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@
(call !ValidateParameters)
(call !Return (assoc
payload
(map
(sort (map
(lambda (retrieve_from_entity (current_value) !internalLabelSessionTrainingIndex))
;list of all cases trained for specified session
(retrieve_from_entity session ".replay_steps")
)
))
))
)

Expand Down
2 changes: 1 addition & 1 deletion howso/return_types.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@
"ablated_indices" {
type "list"
values "number"
description "The session training indices for the ablated cases."
description "The indices of the ablated input cases."
}
"status" {
type ["string" "null"]
Expand Down
46 changes: 31 additions & 15 deletions howso/train.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@
(null
##.replay_steps (list)
##.indices_map (assoc)
;total count of cases trained
##.trained_instance_count 0
;total count of cases observed (incl ablated, trained as weights, etc)
##.total_instance_count 0
##.metadata (assoc)
)
)))
Expand All @@ -81,6 +84,7 @@

(declare (assoc
trained_instance_count (retrieve_from_entity session ".trained_instance_count")
total_instance_count (retrieve_from_entity session ".total_instance_count")
series_cases (if (!= (null) series) (get !seriesStore series))
status_output (null)
message (null)
Expand Down Expand Up @@ -152,6 +156,11 @@
cases cases
))

;capture these cases into the total observed count
(accum_to_entities session (assoc
".total_instance_count" (size cases)
))

(accum_to_entities (assoc !revision 1))

(conclude
Expand Down Expand Up @@ -327,14 +336,7 @@
(if accumulate_weight_feature
cases
;else only accumulate for cases that were actually trained and not ablated
(unzip
cases
(remove
(indices cases)
;change ablated_indices_list into a 0-based list to match indices of cases
(map (lambda (- (current_value) trained_instance_count)) ablated_indices_list)
)
)
(unzip cases (remove (indices cases) ablated_indices_list))
)
))
))
Expand Down Expand Up @@ -503,6 +505,11 @@
(assign_to_entities (assoc !inactiveFeaturesNeedCaching (true) ))
)

;capture these cases into the total observed count
(accum_to_entities session (assoc
".total_instance_count" (size cases)
))

(accum_to_entities (assoc !revision 1))

;return response
Expand All @@ -511,7 +518,7 @@
payload
(assoc
"num_trained" (size new_case_ids)
"ablated_indices" ablated_indices_list
"ablated_indices" (sort ablated_indices_list)
"status" status_output
)
))
Expand Down Expand Up @@ -562,7 +569,12 @@
))
)

(if (!= (size features) (size (first cases)) )
;verify all row sizes match the number of features
(if
(size (filter
(lambda (!= (size features) (size (current_value))))
cases
))
(conclude (conclude
(call !Return (assoc
errors (list "The number of feature names specified does not match the number of feature values given.")
Expand Down Expand Up @@ -711,7 +723,7 @@
(current_value 1)
)
session (get_value session)
session_training_index (+ trained_instance_count (current_index 1))
session_training_index (+ total_instance_count (current_index 1))
))
)
cases
Expand Down Expand Up @@ -823,7 +835,7 @@
(call !AblateCases (assoc
cases (unzip cases (range input_case_index (+ input_case_index batch_size -1)) )
;ensure that starting training index value is updated for each batch
session_training_index (+ trained_instance_count input_case_index)
session_training_index (+ total_instance_count input_case_index)
))
))

Expand Down Expand Up @@ -945,8 +957,8 @@
(map
(lambda
(if (size ts_ablated_indices_map)
(+ session_training_index (get ts_ablated_indices_map (current_value)))
(+ session_training_index (current_value))
(get ts_ablated_indices_map (+ input_case_index (current_value)))
(+ input_case_index (current_value))
)
)
(remove (indices cases) indices_to_train)
Expand Down Expand Up @@ -996,7 +1008,11 @@
feature_values
)
session (get_value session)
session_training_index (+ session_training_index (current_value 1))
session_training_index
(if (size ts_ablated_indices_map)
(+ total_instance_count (get ts_ablated_indices_map (+ input_case_index (current_value 1))))
(+ session_training_index (current_value 1))
)
))
))
indices_to_train
Expand Down
Loading

0 comments on commit 57e6ccc

Please sign in to comment.