Skip to content

Commit

Permalink
black update
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvm9 committed Nov 29, 2024
1 parent 1b2be85 commit 2db6598
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions pynapple/process/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,34 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
"""
if isinstance(group, nap.TsdFrame):
newgroup = group.restrict(ep)

if tuning_curves.shape[1] != newgroup.shape[1]:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(tuning_curves.columns.values == np.array(newgroup.columns)):
raise RuntimeError("Different indices for tuning curves and group keys")

count = group

elif isinstance(group, nap.TsGroup):
newgroup = group.restrict(ep)

if tuning_curves.shape[1] != len(newgroup):
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
raise RuntimeError("Different indices for tuning curves and group keys")

# Bin spikes
count = newgroup.count(bin_size, ep, time_units)

elif isinstance(group, dict):
newgroup = nap.TsGroup(group, time_support=ep)
count = newgroup.count(bin_size, ep, time_units)

else:
raise RuntimeError("Unknown format for group")

# Occupancy
if feature is None:
occupancy = np.ones(tuning_curves.shape[0])
Expand Down Expand Up @@ -172,34 +172,38 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
if type(group) is nap.TsdFrame:
newgroup = group.restrict(ep)
numcells = newgroup.shape[1]

if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.columns)):

if not np.all(
np.array(list(tuning_curves.keys())) == np.array(newgroup.columns)
):
raise RuntimeError("Different indices for tuning curves and group keys")

count = group

elif type(group) is nap.TsGroup:
newgroup = group.restrict(ep)
numcells = len(newgroup)

if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())):
if not np.all(
np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())
):
raise RuntimeError("Different indices for tuning curves and group keys")

count = newgroup.count(bin_size, ep, time_units)

elif type(group) is dict:
newgroup = nap.TsGroup(group, time_support=ep)
count = newgroup.count(bin_size, ep, time_units)

else:
raise RuntimeError("Unknown format for group")

indexes = list(tuning_curves.keys())

# Occupancy
Expand Down

0 comments on commit 2db6598

Please sign in to comment.