Skip to content

Commit

Permalink
Fix 2D ahead/behind
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Jan 23, 2024
1 parent 0347062 commit ede41e7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
19 changes: 15 additions & 4 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,14 @@ def get_firing_rate(cls, key, time, multiunit=False):
)

def get_ahead_behind_distance(self):
# TODO: allow specification of specific time interval
# TODO: allow specification of track graph
# TODO: Handle decode intervals, store in table

classifier = self.load_model()
results = self.load_results()
posterior = results.acausal_posterior.unstack("state_bins").sum("state")

# TODO: Handle intervals, store in table

if classifier.environments[0].track_graph is not None:
linear_position_info = self.load_linear_position_info(
self.fetch1("KEY")
Expand Down Expand Up @@ -496,9 +498,18 @@ def get_ahead_behind_distance(self):
position_info = self.load_position_info(self.fetch1("KEY"))
map_position = analysis.maximum_a_posteriori_estimate(posterior)

orientation_name = (
"orientation"
if "orientation" in position_info.columns
else "head_orientation"
)
position_variable_names = (
PositionGroup & self.fetch1("KEY")
).fetch1("position_variables")

return analysis.get_ahead_behind_distance2D(
position_info[["position_x", "position_y"]].to_numpy(),
position_info["orientation"].to_numpy(),
position_info[position_variable_names].to_numpy(),
position_info[orientation_name].to_numpy(),
map_position,
classifier.environments[0].track_graphDD,
)
19 changes: 15 additions & 4 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,14 @@ def spike_times_sorted_by_place_field_peak(self, time_slice=None):
]

def get_ahead_behind_distance(self):
# TODO: allow specification of specific time interval
# TODO: allow specification of track graph
# TODO: Handle decode intervals, store in table

classifier = self.load_model()
results = self.load_results()
posterior = results.acausal_posterior.unstack("state_bins").sum("state")

# TODO: Handle intervals, store in table

if classifier.environments[0].track_graph is not None:
linear_position_info = self.load_linear_position_info(
self.fetch1("KEY")
Expand Down Expand Up @@ -515,9 +517,18 @@ def get_ahead_behind_distance(self):
position_info = self.load_position_info(self.fetch1("KEY"))
map_position = analysis.maximum_a_posteriori_estimate(posterior)

orientation_name = (
"orientation"
if "orientation" in position_info.columns
else "head_orientation"
)
position_variable_names = (
PositionGroup & self.fetch1("KEY")
).fetch1("position_variables")

return analysis.get_ahead_behind_distance2D(
position_info[["position_x", "position_y"]].to_numpy(),
position_info["orientation"].to_numpy(),
position_info[position_variable_names].to_numpy(),
position_info[orientation_name].to_numpy(),
map_position,
classifier.environments[0].track_graphDD,
)

0 comments on commit ede41e7

Please sign in to comment.