Skip to content

Commit

Permalink
Add is_sham_change column.
Browse files Browse the repository at this point in the history
Move trials_id and simplify calculation.
Add active calculation.
Add loading trails to presentations table creation code.
Add unittests for new stimulus calculations.
Change test for trials_id to account for gaps in trials.
  • Loading branch information
morriscb committed Aug 21, 2023
1 parent a3aa0cf commit b950704
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 59 deletions.
1 change: 1 addition & 0 deletions allensdk/brain_observatory/behavior/behavior_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def from_json(
path=session_data["stim_table_file"],
behavior_session_id=session_data["behavior_session_id"],
exclude_columns=stimulus_presentation_exclude_columns,
trials=trials
),
templates=Templates.from_stimulus_file(
stimulus_file=stimulus_file_lookup.behavior_stimulus_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Trials,
)
from allensdk.brain_observatory.behavior.stimulus_processing import (
add_active_flag,
compute_is_sham_change,
compute_trials_id_for_stimulus,
fix_omitted_end_frame,
get_flashes_since_change,
Expand Down Expand Up @@ -58,22 +60,63 @@ def __init__(
columns_to_rename: Optional[Dict[str, str]] = None,
column_list: Optional[List[str]] = None,
sort_columns: bool = True,
trials: Optional[Trials] = None
):
"""
Parameters
----------
presentations: The stimulus presentations table
columns_to_rename: Optional dict mapping
old column name -> new column name
presentations : pandas.DataFrame
The stimulus presentations table
columns_to_rename : Optional dict mapping
Mapping to rename columns. old column name -> new column name
column_list: Optional list of columns to include.
This will reorder the columns.
sort_columns: Whether to sort the columns by name
sort_columns: bool
Whether to sort the columns by name
trials : Optional Trials object.
allensdk Trials object for the same session as the presentations
table.
"""
if columns_to_rename is not None:
presentations = presentations.rename(columns=columns_to_rename)
if column_list is not None:
presentations = presentations[column_list]
presentations = enforce_df_int_typing(
presentations,
[
"flashes_since_change",
"image_index",
"movie_frame_index",
"repeat",
"stimulus_index",
],
)
presentations = presentations.reset_index(drop=True)
presentations.index = pd.Index(
range(presentations.shape[0]),
name="stimulus_presentations_id",
dtype="int",
)
if trials is not None:
if "active" not in presentations.columns:
# Add column marking where the mouse is engaged in active,
# trained behavior.
presentations = add_active_flag(
presentations, trials.data
)
if "trials_id" not in presentations.columns:
# Add trials_id to presentations df to allow for joining of the
# two tables.
presentations['trials_id'] = compute_trials_id_for_stimulus(
presentations, trials.data
)
if "is_sham_change" not in presentations.columns:
# Mark changes in active and replay stimulus that are
# #sham-changes
presentations = compute_is_sham_change(
presentations, trials.data
)
if sort_columns:
presentations = enforce_df_column_order(
presentations,
Expand All @@ -97,23 +140,6 @@ def __init__(
"trials_id",
],
)
presentations = presentations.reset_index(drop=True)
presentations = enforce_df_int_typing(
presentations,
[
"flashes_since_change",
"image_index",
"movie_frame_index",
"movie_repeat",
"stimulus_index",
],
)
presentations.index = pd.Index(
range(presentations.shape[0]),
name="stimulus_presentations_id",
dtype="int",
)

super().__init__(name="presentations", value=presentations)

def to_nwb(
Expand Down Expand Up @@ -181,6 +207,7 @@ def from_nwb(
cls,
nwbfile: NWBFile,
add_is_change: bool = True,
add_trials_dependent_values: bool = True,
column_list: Optional[List[str]] = None,
) -> "Presentations":
"""
Expand Down Expand Up @@ -240,7 +267,13 @@ def from_nwb(
table["flashes_since_change"] = get_flashes_since_change(
stimulus_presentations=table
)
return Presentations(presentations=table, column_list=column_list)
trials = None
if add_trials_dependent_values and nwbfile.trials is not None:
trials = Trials.from_nwb(nwbfile)

return Presentations(presentations=table,
column_list=column_list,
trials=trials)

@classmethod
def from_stimulus_file(
Expand Down Expand Up @@ -394,14 +427,10 @@ def from_stimulus_file(
stim_pres_df, stimulus_file.session_type, project_code.value
)

# Add trials_id to presentations df to allow for joining of the two
# tables.
stim_pres_df["trials_id"] = compute_trials_id_for_stimulus(
stim_pres_df, trials.data
)

return Presentations(
presentations=stim_pres_df, column_list=column_list
presentations=stim_pres_df,
column_list=column_list,
trials=trials
)

@classmethod
Expand All @@ -412,6 +441,7 @@ def from_path(
exclude_columns: Optional[List[str]] = None,
columns_to_rename: Optional[Dict[str, str]] = None,
sort_columns: bool = True,
trials: Optional[Trials] = None
) -> "Presentations":
"""
Reads the table directly from a precomputed csv
Expand Down Expand Up @@ -446,6 +476,7 @@ def from_path(
presentations=df,
columns_to_rename=columns_to_rename,
sort_columns=sort_columns,
trials=trials,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pandas as pd
from pynwb import NWBFile

from allensdk.core.dataframe_utils import (
enforce_df_int_typing
)
from allensdk.brain_observatory import dict_to_indexed_array
from allensdk.brain_observatory.behavior.data_files import (
BehaviorStimulusFile, SyncFile)
Expand Down Expand Up @@ -50,6 +53,7 @@ def __init__(
"""
trials = trials.rename(columns={'stimulus_change': 'is_change'})
super().__init__(name='trials', value=None, is_value_self=True)
trials = enforce_df_int_typing(trials, ["change_frame"])

self._trials = trials
self._response_window_start = response_window_start
Expand Down
141 changes: 114 additions & 27 deletions allensdk/brain_observatory/behavior/stimulus_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,44 @@ def get_flashes_since_change(
return flashes_since_change


def add_active_flag(
stim_pres_table: pd.DataFrame,
trials: pd.DataFrame
) -> pd.DataFrame:
"""Mark the active stimuli by lining up the stimulus times with the
trials times.
Parameters
----------
stim_pres_table : pandas.DataFrame
Stimulus table to add active column to.
trials : pandas.DataFrame
Trials table to align with the stimulus table.
Returns
-------
stimulus_table : pandas.DataFrame
Copy of ``stim_pres_table`` with added acive column.
"""
if "active" in stim_pres_table.columns:
return stim_pres_table
else:
active = pd.Series(
data=np.zeros(len(stim_pres_table), dtype=bool),
index=stim_pres_table.index,
name="active",
)
stim_mask = (stim_pres_table.start_time > trials.start_time.min()) & (
stim_pres_table.start_time < trials.stop_time.max()
) & (~stim_pres_table.image_name.isna())
active[stim_mask] = True
stim_pres_table['active'] = active
return stim_pres_table


def compute_trials_id_for_stimulus(
stim_pres_table: pd.DataFrame, trials_table: pd.DataFrame
stim_pres_table: pd.DataFrame,
trials_table: pd.DataFrame
) -> pd.Series:
"""Add an id to allow for merging of the stimulus presentations
table with the trials table.
Expand All @@ -692,47 +728,37 @@ def compute_trials_id_for_stimulus(
passive stimulus/replay blocks that contain the same image ordering and
length.
"""
stim_pres_sorted = stim_pres_table.sort_values("start_time")
trials_sorted = trials_table.sort_values("start_time")
# Create a placeholder for the trials_id.
trials_ids = pd.Series(
data=np.full(len(stim_pres_sorted), INT_NULL, dtype=int),
index=stim_pres_sorted.index,
data=np.full(len(stim_pres_table), INT_NULL, dtype=int),
index=stim_pres_table.index,
name="trials_id",
).astype('int')
# Return an empty trials_id if the stimulus block is not available.
if "stimulus_block" not in stim_pres_sorted.columns:
# Return input frame if the stimulus_block or active is not available.
if "stimulus_block" not in stim_pres_table.columns \
or "active" not in stim_pres_table.columns:
return trials_ids

if "active" in stim_pres_sorted.columns:
has_active = True
active_sorted = stim_pres_sorted.active
else:
has_active = False
active_sorted = pd.Series(
data=np.zeros(len(stim_pres_sorted), dtype=bool),
index=stim_pres_sorted.index,
name="active",
)
active_sorted = stim_pres_table.active

# Find stimulus blocks that start within a trial. Copy the trial_id
# into our new trials_ids series.
for idx, trial in trials_sorted.iterrows():
stim_mask = (stim_pres_sorted.start_time > trial.start_time) & (
stim_pres_sorted.start_time < trial.stop_time
) & (~stim_pres_sorted.image_name.isna())
# into our new trials_ids series. For some sessions there are gaps in
# between one trial's end and the next's stop time so we account for this
# by only using the max time for all trials as the limit.
max_trials_stop = trials_table.stop_time.max()
for idx, trial in trials_table.iterrows():
stim_mask = (stim_pres_table.start_time > trial.start_time) & (
stim_pres_table.start_time < max_trials_stop
) & (~stim_pres_table.image_name.isna())
trials_ids[stim_mask] = idx
if not has_active:
active_sorted[stim_mask] = True

# The code below finds all stimulus blocks that contain images/trials
# and attempts to detect blocks that are identical to copy the associated
# trials_ids into those blocks. In the parlance of the data this is
# copying the active stimulus block data into the passive stimulus block.

# Get the block ids for the behavior trial presentations
stim_blocks = stim_pres_sorted.stimulus_block
stim_image_names = stim_pres_sorted.image_name
stim_blocks = stim_pres_table.stimulus_block
stim_image_names = stim_pres_table.image_name
active_stim_blocks = stim_blocks[active_sorted].unique()
# Find passive blocks that show images for potential copying of the active
# into a passive stimulus block.
Expand Down Expand Up @@ -835,3 +861,64 @@ def produce_stimulus_block_names(
] = vbo_map[block_id]

return stim_df


def compute_is_sham_change(
stim_df: pd.DataFrame,
trials: pd.DataFrame
) -> pd.DataFrame:
"""Add is_sham_change to stimulus presentation table.
Parameters
----------
stim_df : pandas.DataFrame
Stimulus presentations table to add is_sham_change to.
trials : pandas.DataFrame
Trials data frame to pull info from to create
Returns
-------
stimulus_presentations : pandas.DataFrame
Input ``stim_df`` DataFrame with the is_sham_change column added.
"""
if "trials_id" not in stim_df.columns \
or "active" not in stim_df.columns \
or "stimulus_block" not in stim_df.columns:
return stim_df
stim_trials = stim_df.merge(trials,
left_on='trials_id',
right_index=True,
how='left')
catch_frames = stim_trials[
stim_trials['catch'].fillna(False)]['change_frame'].unique()

stim_df['is_sham_change'] = False
catch_flashes = stim_df[
stim_df['start_frame'].isin(catch_frames)].index.values
stim_df.loc[catch_flashes, 'is_sham_change'] = True

stim_blocks = stim_df.stimulus_block
stim_image_names = stim_df.image_name
active_stim_blocks = stim_blocks[stim_df.active].unique()
# Find passive blocks that show images for potential copying of the active
# into a passive stimulus block.
passive_stim_blocks = stim_blocks[
np.logical_and(~stim_df.active, ~stim_image_names.isna())
].unique()

# Copy the trials_id into the passive block if it exists.
if len(passive_stim_blocks) > 0:
for active_stim_block in active_stim_blocks:
active_block_mask = stim_blocks == active_stim_block
active_images = stim_image_names[active_block_mask].values
for passive_stim_block in passive_stim_blocks:
passive_block_mask = stim_blocks == passive_stim_block
if np.array_equal(
active_images,
stim_image_names[passive_block_mask].values
):
stim_df.loc[passive_block_mask, 'is_sham_change'] = \
stim_df[active_block_mask][
'is_sham_change'].values

return stim_df.sort_index()
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def data(self):
{
"start_time": [300.0, 330.0, 360.0],
"stop_time": [330.0, 360.0, 360.0],
"catch": [False, True, False],
"change_frame": [-99, 99, -99]
}
)

Expand Down
Loading

0 comments on commit b950704

Please sign in to comment.