Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analysis fix #144

Merged
merged 4 commits into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions keypoint_moseq/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_syllable_names(project_dir, model_name, syllable_ixs):
if os.path.exists(syll_info_path):
syll_info_df = pd.read_csv(syll_info_path, index_col=False).fillna("")

for ix in syllable_ixs:
if len(syll_info_df[syll_info_df.syllable == ix].label.values[0]) > 0:
labels[ix] = (
f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
)
for ix in syllable_ixs:
if len(syll_info_df[syll_info_df.syllable == ix].label.values[0]) > 0:
labels[ix] = (
f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
)
names = [labels[ix] for ix in syllable_ixs]
return names

Expand Down Expand Up @@ -826,7 +826,6 @@ def run_kruskal(
syllable_data = grouped_data.drop(["group", "name"], axis=1).values

N_m, N_s = syllable_data.shape

# Run KW and return H-stats
h_all, real_ranks, X_ties = run_manual_KW_test(
df_usage=df_only_stats,
Expand Down Expand Up @@ -1116,22 +1115,23 @@ def plot_syll_stats_with_sem(

# get significant syllables
sig_sylls = None
if groups is None:
groups = stats_df["group"].unique()

if plot_sig and len(stats_df["group"].unique()) > 1:
# run kruskal wallis and dunn's test
_, _, sig_pairs = run_kruskal(stats_df, statistic=stat, thresh=thresh)
# plot significant syllables for control and experimental group
if ctrl_group is not None and exp_group is not None:
# plot significant syllables for control and experimental group when user specify something
if ctrl_group in groups and exp_group in groups:
# check if the group pair is in the sig pairs dict
if (ctrl_group, exp_group) in sig_pairs.keys():
sig_sylls = sig_pairs.get((ctrl_group, exp_group))
# flip the order of the groups
else:
sig_sylls = sig_pairs.get((exp_group, ctrl_group))
else:
print(
"No control or experimental group specified. Not plotting significant syllables."
)
# plot everything if no group pair is specified
sig_sylls = sig_pairs

xlabel = f"Syllables sorted by {stat}"
if order == "diff":
Expand Down Expand Up @@ -1172,14 +1172,38 @@ def plot_syll_stats_with_sem(

# if a list of significant syllables is given, mark the syllables above the x-axis
if sig_sylls is not None:
markings = []
for s in sig_sylls:
if s in ordering:
markings.append(np.where(ordering == s)[0])
init_y = -0.05
# plot all sig syllables when no reasonable control and experimental group is specified
if isinstance(sig_sylls, dict):
for key in sig_sylls.keys():
markings = []
for s in sig_sylls[key]:
markings.append(np.where(ordering == s)[0])
if len(markings) > 0:
markings = np.concatenate(markings)
plt.scatter(
markings, [init_y] * len(markings), color="r", marker="*"
)
plt.text(
plt.xlim()[1],
init_y,
f"{key[0]} vs. {key[1]} - Total {len(sig_sylls[key])} S.S.",
)
init_y += -0.05
else:
print("No significant syllables found.")
else:
markings = []
for s in sig_sylls:
if s in ordering:
markings.append(np.where(ordering == s)[0])
else:
continue
if len(markings) > 0:
markings = np.concatenate(markings)
plt.scatter(markings, [-0.05] * len(markings), color="r", marker="*")
else:
continue
markings = np.concatenate(markings)
plt.scatter(markings, [-0.05] * len(markings), color="r", marker="*")
print("No significant syllables found.")

# manually define a new patch
patch = Line2D(
Expand Down Expand Up @@ -1423,10 +1447,15 @@ def visualize_transition_bigram(
whether to show just syllable indexes (False) or syllable indexes and
names (True)
"""

# syllable info path
syll_info_path = os.path.join(project_dir, model_name, "syll_info.csv")
# initialize syllable names
syll_names = [f"{ix}" for ix in syll_include]

if show_syllable_names:
syll_names = get_syllable_names(project_dir, model_name, syll_include)
else:
syll_names = [f"{ix}" for ix in syll_include]
if os.path.exists(syll_info_path):
syll_names = get_syllable_names(project_dir, model_name, syll_include)

# infer max_syllables
max_syllables = trans_mats[0].shape[0]
Expand Down