Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fast bandpass filter #29

Merged
merged 2 commits into from
Feb 9, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions mne_realtime/rt_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 19 13:30:38 2021
teonbrooks marked this conversation as resolved.
Show resolved Hide resolved

@author: Timon Merk
"""

import numpy as np
import mne

teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
def calc_band_filters(f_ranges, sample_rate, filter_len="1000ms", l_trans_bandwidth=4, h_trans_bandwidth=4):
teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
""""Calculate bandpass filters with adjustable length for given frequency ranges.
This function returns for the given frequency band ranges the filter coefficients with length "filter_len".
Thus the filters can be sequentially used for band power estimation.
Parameters
----------
f_ranges : list of lists
frequency ranges.
sample_rate : float
sampling frequency.
filter_len : str, optional
length of the filter. Human readable (e.g."1000ms" or "1s"). Default is "1000ms"
l_trans_bandwidth : int/float, optional
teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
Length of the lower transition band. The default is 4.
h_trans_bandwidth : int/float, optional
Length of the higher transition band. The default is 4.
Returns
-------
filter_bank : array
teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
filter coefficients stored in array of shape (n_franges, filter_len (in samples))
"""
filter_list = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
filter_list = []
filter_list = list()

for a, f_range in enumerate(f_ranges):
teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
h = mne.filter.create_filter(None, sample_rate, l_freq=f_range[0], h_freq=f_range[1], fir_design='firwin',
l_trans_bandwidth=l_trans_bandwidth, h_trans_bandwidth=h_trans_bandwidth,
filter_length=filter_len)
filter_list.append(h)
filter_bank = np.vstack(filter_list)
return filter_bank

teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
def apply_filter(dat_, filter_bank, fs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use data to be clear. usually names with underscores imply that they have been modified (usually in an object like in sklearn)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fs is missing in your parameter list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

teonbrooks marked this conversation as resolved.
Show resolved Hide resolved
"""Apply previously calculated (bandpass) filters to data.
Parameters
----------
dat_ : array (n_samples, ) or (n_channels, n_samples)
segment of data.
filter_bank : array
output of calc_band_filters.
Returns
-------
filtered : array
(n_chan, n_fbands, filter_len) array conatining the filtered signal
at each freq band, where n_fbands is the number of filter bands used to decompose the signal
"""
if dat_.ndim == 1:
filtered = np.zeros((1, filter_bank.shape[0], fs))
for filt in range(filter_bank.shape[0]):
filtered[0, filt, :] = np.convolve(filter_bank[filt,:], dat_)[int(fs-fs/2):int(fs+fs/2)]
elif dat_.ndim == 2:
filtered = np.zeros((dat_.shape[0], filter_bank.shape[0], fs))
for chan in range(dat_.shape[0]):
for filt in range(filter_bank.shape[0]):
filtered[chan, filt, :] = np.convolve(filter_bank[filt, :], \
dat_[chan,:])[int(fs-fs/2):int(fs+fs/2)] # mode="full"
return filtered
teonbrooks marked this conversation as resolved.
Show resolved Hide resolved