Skip to content

Commit

Permalink
fix breaking dnn input formatting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Jan 2, 2025
1 parent aec881f commit a59ecf7
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions neuro_py/ensemble/decoding/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random

from re import X
from typing import List, Tuple, Dict, Optional, Any

import sklearn.preprocessing
Expand Down Expand Up @@ -202,8 +203,11 @@ def zscore_trial_segs(
Normalized train data, normalized rest features, and normalization parameters.
"""
is_2D = train[0].ndim == 1
concat_train = train if is_2D else np.concatenate(train)
train_mean = normparams['X_train_mean'] if normparams is not None else bn.nanmean(concat_train, axis=0)
concat_train = train if is_2D else np.concatenate(train).astype(float)
train_mean = (
normparams['X_train_mean'] if normparams is not None
else bn.nanmean(concat_train, axis=0)
)
train_std = normparams['X_train_std'] if normparams is not None else bn.nanstd(concat_train, axis=0)

train_notnan_cols = train_std != 0
Expand All @@ -213,15 +217,17 @@ def zscore_trial_segs(
# if train is not jagged, it gets converted completely to object
# np.ndarray. Hence, cannot exclusively use normed_train.loc
if isinstance(normed_train, pd.DataFrame):
normed_train = normed_train.loc
normed_train[:, train_nan_cols] = 0
normed_train.loc[:, train_nan_cols] = 0
else:
normed_train[:, train_nan_cols] = 0
else:
normed_train = np.empty_like(train)
for i, nsvstseg in enumerate(train):
zscored = np.divide(nsvstseg-train_mean, train_std, where=train_notnan_cols)
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
zscored.loc[:, train_nan_cols] = 0
else:
zscored[:, train_nan_cols] = 0
normed_train[i] = zscored

normed_rest_feats = []
Expand All @@ -230,16 +236,18 @@ def zscore_trial_segs(
if is_2D:
normed_feats = np.divide(feats-train_mean, train_std, where=train_notnan_cols)
if isinstance(normed_feats, pd.DataFrame):
normed_feats = normed_feats.loc
normed_feats[:, train_nan_cols] = 0
normed_feats.loc[:, train_nan_cols] = 0
else:
normed_feats[:, train_nan_cols] = 0
normed_rest_feats.append(normed_feats)
else:
normed_feats = np.empty_like(feats)
for i, trialSegROI in enumerate(feats):
zscored = np.divide(feats[i]-train_mean, train_std, where=train_notnan_cols)
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
zscored.loc[:, train_nan_cols] = 0
else:
zscored[:, train_nan_cols] = 0
normed_feats[i] = zscored
normed_rest_feats.append(normed_feats)

Expand Down Expand Up @@ -351,6 +359,13 @@ def minibatchify(
"""
g_seed = torch.Generator()
g_seed.manual_seed(seed)
if Xtrain.ndim == 2: # handle object arrays
Xtrain = Xtrain.astype(np.float32)
Xval = Xval.astype(np.float32)
Xtest = Xtest.astype(np.float32)
ytrain = ytrain.astype(np.float32)
yval = yval.astype(np.float32)
ytest = ytest.astype(np.float32)
train = torch.utils.data.TensorDataset(
torch.from_numpy(Xtrain).type(torch.float32),
torch.from_numpy(ytrain).type(torch.float32))
Expand Down

1 comment on commit a59ecf7

@ryanharvey1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have linting errors on this commit

Please sign in to comment.