Skip to content

Commit

Permalink
save_properties of a filt
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhammonds committed Jan 14, 2021
1 parent 17593b7 commit 6a5695e
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 65 deletions.
68 changes: 19 additions & 49 deletions neurodsp/filt/checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Checker functions for filtering."""

from warnings import warn
import os
import json

import numpy as np

Expand Down Expand Up @@ -89,8 +91,8 @@ def check_filter_definition(pass_type, f_range):
return f_lo, f_hi


def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
transitions=(-20, -3), filt_type=None, verbose=True):
def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range, transitions=(-20, -3),
filt_type=None, verbose=True, save_properties=None):
"""Check a filters properties, including pass band and transition band.
Parameters
Expand All @@ -117,8 +119,12 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
transitions : tuple of (float, float), optional, default: (-20, -3)
Cutoffs, in dB, that define the transition band.
filt_type : str, optional, {'FIR', 'IIR'}
The type of filter being applied.
verbose : bool, optional, default: True
Whether to print out transition and pass bands.
Whether to print out filter properties.
save_properties : str
Path, including file name, to save filter properites to as a json.
Returns
-------
Expand All @@ -138,8 +144,8 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
"""

# Import utility functions inside function to avoid circular imports
from neurodsp.filt.utils import (compute_frequency_response,
compute_pass_band, compute_transition_band)
from neurodsp.filt.utils import (compute_frequency_response, compute_pass_band,
compute_transition_band, gen_filt_report, save_filt_report)

# Initialize variable to keep track if all checks pass
passes = True
Expand Down Expand Up @@ -173,52 +179,16 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
warn('Transition bandwidth is {:.1f} Hz. This is greater than the desired'\
'pass/stop bandwidth of {:.1f} Hz'.format(transition_bw, pass_bw))

# Print out transition bandwidth and pass bandwidth to the user
# Report filter properties
if verbose or save_properties:
filt_report = gen_filt_report(pass_type, filt_type, fs, f_db, db, pass_bw,
transition_bw, f_range, f_range_trans)

if verbose:
print('\n'.join('{} : {}'.format(key, value) for key, value in filt_report.items()))

# Filter type (high-pass, low-pass, band-pass, band-stop, FIR, IIR)
print('Pass Type: {pass_type}'.format(pass_type=pass_type))

# Cutoff frequency (including definition)
cutoff = round(np.min(f_range) + (0.5 * transition_bw), 3)
print('Cutoff (half-amplitude): {cutoff} Hz'.format(cutoff=cutoff))

# Filter order (or length)
print('Filter order: {order}'.format(order=len(f_db)-1))

# Roll-off or transition bandwidth
print('Transition bandwidth: {:.1f} Hz'.format(transition_bw))
print('Pass/stop bandwidth: {:.1f} Hz'.format(pass_bw))

# Passband ripple and stopband attenuation
pb_ripple = np.max(db[:np.where(f_db < f_range_trans[0])[0][-1]])
sb_atten = np.max(db[np.where(f_db > f_range_trans[1])[0][0]:])
print('Passband Ripple: {pb_ripple} db'.format(pb_ripple=pb_ripple))
print('Stopband Attenuation: {sb_atten} db'.format(sb_atten=sb_atten))

# Filter delay (zero-phase, linear-phase, non-linear phase)
if filt_type == 'FIR' and pass_type in ['bandstop', 'lowpass']:
filt_class = 'linear-phase'
elif filt_type == 'FIR' and pass_type in ['bandpass', 'highpass']:
filt_class = 'zero-phase'
elif filt_type == 'IIR':
filt_class = 'non-linear-phase'
else:
filt_class = None

if filt_type is not None:
print('Filter Class: {filt_class}'.format(filt_class=filt_class))

if filt_class == 'linear-phase':
print('Group Delay: {delay}s'.format(delay=(len(f_db)-1) / 2 * fs))
elif filt_class == 'zero-phase':
print('Group Delay: 0s')

# Direction of computation (one-pass forward/reverse, or two-pass forward and reverse)
if filt_type == 'FIR':
print('Direction: one-pass reverse.')
else:
print('Direction: two-pass forward and reverse')
if save_properties is not None:
save_filt_report(save_properties, filt_report)

return passes

Expand Down
17 changes: 9 additions & 8 deletions neurodsp/filt/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
###################################################################################################
###################################################################################################

def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
n_cycles=3, n_seconds=None, remove_edges=True, butterworth_order=None,
print_transitions=False, plot_properties=False, return_filter=False,
verbose=False):
def filter_signal(sig, fs, pass_type, f_range, filter_type='fir', n_cycles=3, n_seconds=None,
remove_edges=True, butterworth_order=None, print_transitions=False,
plot_properties=False, save_properties=None, return_filter=False, verbose=False):
"""Apply a bandpass, bandstop, highpass, or lowpass filter to a neural signal.
Parameters
Expand Down Expand Up @@ -51,6 +50,8 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
If True, print out the transition and pass bandwidths.
plot_properties : bool, optional, default: False
If True, plot the properties of the filter, including frequency response and/or kernel.
save_properties : str, optional, default: None
Path, including file name, to save filter properites to as a json.
return_filter : bool, optional, default: False
If True, return the filter coefficients.
verbose : bool, optional, default: False
Expand All @@ -76,13 +77,13 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',

if filter_type.lower() == 'fir':
return filter_signal_fir(sig, fs, pass_type, f_range, n_cycles, n_seconds,
remove_edges, print_transitions,
plot_properties, return_filter, verbose=verbose)
remove_edges, print_transitions, plot_properties,
save_properties, return_filter, verbose)
elif filter_type.lower() == 'iir':
_iir_checks(n_seconds, butterworth_order, remove_edges)
return filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order,
print_transitions, plot_properties,
return_filter, verbose=verbose)
print_transitions, plot_properties, save_properties,
return_filter, verbose)
else:
raise ValueError('Filter type not understood.')

Expand Down
9 changes: 6 additions & 3 deletions neurodsp/filt/fir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
###################################################################################################

def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, remove_edges=True,
print_transitions=False, plot_properties=False, return_filter=False,
verbose=False):
print_transitions=False, plot_properties=False, save_properties=None,
return_filter=False, verbose=False, file_path=None, file_name=None):
"""Apply an FIR filter to a signal.
Parameters
Expand Down Expand Up @@ -47,6 +47,8 @@ def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, r
If True, print out the transition and pass bandwidths.
plot_properties : bool, optional, default: False
If True, plot the properties of the filter, including frequency response and/or kernel.
save_properties : str
Path, including file name, to save filter properites to as a json.
return_filter : bool, optional, default: False
If True, return the filter coefficients of the FIR filter.
verbose : bool, optional, default: False
Expand Down Expand Up @@ -82,7 +84,8 @@ def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, r

# Check filter properties: compute transition bandwidth & run checks
check_filter_properties(filter_coefs, 1, fs, pass_type, f_range, filt_type="FIR",
verbose=np.any([print_transitions, verbose]))
verbose=np.any([print_transitions, verbose]),
save_properties=save_properties)

# Remove any NaN on the edges of 'sig'
sig, sig_nans = remove_nans(sig)
Expand Down
8 changes: 6 additions & 2 deletions neurodsp/filt/iir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
###################################################################################################

def filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, print_transitions=False,
plot_properties=False, return_filter=False, verbose=False):
plot_properties=False, save_properties=None, return_filter=False,
verbose=False):
"""Apply an IIR filter to a signal.
Parameters
Expand Down Expand Up @@ -40,6 +41,8 @@ def filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, print_tran
If True, print out the transition and pass bandwidths.
plot_properties : bool, optional, default: False
If True, plot the properties of the filter, including frequency response and/or kernel.
save_properties : str
Path, including file name, to save filter properites to as a json.
return_filter : bool, optional, default: False
If True, return the second order series coefficients of the IIR filter.
verbose : bool, optional, default: False
Expand Down Expand Up @@ -69,7 +72,8 @@ def filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, print_tran

# Check filter properties: compute transition bandwidth & run checks
check_filter_properties(sos, None, fs, pass_type, f_range, filt_type="IIR",
verbose=np.any([print_transitions, verbose]))
verbose=np.any([print_transitions, verbose]),
save_properties=save_properties)

# Remove any NaN on the edges of 'sig'
sig, sig_nans = remove_nans(sig)
Expand Down
105 changes: 105 additions & 0 deletions neurodsp/filt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Utility functions for filtering."""

import os
import json

import numpy as np
from scipy.signal import freqz, sosfreqz

Expand Down Expand Up @@ -253,3 +256,105 @@ def remove_filter_edges(sig, filt_len):
sig[-n_rmv:] = np.nan

return sig


def gen_filt_report(pass_type, filt_type, fs, f_db, db, pass_bw,
transition_bw, f_range, f_range_trans):
"""Create a filter report.
Parameters
----------
pass_type : {'bandpass', 'bandstop', 'lowpass', 'highpass'}
Which type of filter was applied.
filt_type : str, {'FIR', 'IIR'}
The type of filter being applied.
fs : float
Sampling rate, in Hz.
f_db : 1d array
Frequency vector corresponding to attenuation decibels, in Hz.
db : 1d array
Degree of attenuation for each frequency specified in `f_db`, in dB.
pass_bw : float
The pass bandwidth of the filter.
transition_band : float
The transition bandwidth of the filter.
f_range : tuple of (float, float) or float
Cutoff frequency(ies) used for filter, specified as f_lo & f_hi.
f_range_trans : tuple of (float, float)
The lower and upper frequencies of the transition band.
Returns
-------
filt_report : dict
A dicionary of filter parameter keys and corresponding values.
"""
filt_report = {}

# Filter type (high-pass, low-pass, band-pass, band-stop, FIR, IIR)
filt_report['Pass Type'] = '{pass_type}'.format(pass_type=pass_type)

# Cutoff frequenc(ies) (including definition)
filt_report['Cutoff (half-amplitude)'] = '{cutoff} Hz'.format(cutoff=f_range)

# Filter order (or length)
filt_report['Filter order'] = '{order}'.format(order=len(f_db)-1)

# Roll-off or transition bandwidth
filt_report['Transition bandwidth'] = '{:.1f} Hz'.format(transition_bw)
filt_report['Pass/stop bandwidth'] = '{:.1f} Hz'.format(pass_bw)

# Passband ripple and stopband attenuation
pb_ripple = np.max(db[:np.where(f_db < f_range_trans[0])[0][-1]])
sb_atten = np.max(db[np.where(f_db > f_range_trans[1])[0][0]:])
filt_report['Passband Ripple'] = '{pb_ripple} db'.format(pb_ripple=pb_ripple)
filt_report['Stopband Attenuation'] = '{sb_atten} db'.format(sb_atten=sb_atten)

# Filter delay (zero-phase, linear-phase, non-linear phase)
filt_report['Filter Type'] = filt_type

if filt_type == 'FIR' and pass_type in ['bandstop', 'lowpass']:

filt_report['Filter Class'] = '{filt_class}'.format(filt_class='linear-phase')
filt_report['Group Delay'] = '{delay}s'.format(delay=(len(f_db)-1) / 2 * fs)

elif filt_type == 'FIR' and pass_type in ['bandpass', 'highpass']:

filt_report['Filter Class'] = '{filt_class}'.format(filt_class='zero-phase')
filt_report['Group Delay'] = '0s'

elif filt_type == 'IIR':

# Group delay isn't reported for IIR since it varies from sample to sample
filt_report['Filter Class'] = '{filt_class}'.format(filt_class='non-linear-phase')

# Direction of computation (one-pass forward/reverse, or two-pass forward and reverse)
if filt_type == 'FIR':
filt_report['Direction'] = 'one-pass reverse'
else:
filt_report['Direction'] = 'two-pass forward and reverse'

return filt_report


def save_filt_report(save_properties, filt_report):
"""Save filter properties as a json file.
Parameters
----------
save_properties : str
Path, including file name, to save filter properites to as a json.
filt_report : dict
Contains filter report info.
"""

# Ensure parents exists
if not os.path.isdir(os.path.dirname(save_properties)):
raise ValueError("Unable to save properties. Parent directory does not exist.")

# Enforce file extension
if not save_properties.endswith('.json'):
save_properties = save_properties + '.json'

# Save
with open(save_properties, 'w') as file_path:
json.dump(filt_report, file_path)
7 changes: 6 additions & 1 deletion neurodsp/tests/filt/test_checks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for filter check functions."""

import tempfile
from pytest import raises

from neurodsp.tests.settings import FS
Expand Down Expand Up @@ -46,7 +47,7 @@ def test_check_filter_definition():
def test_check_filter_properties():

filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12))

passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12))
assert passes is True

Expand All @@ -58,6 +59,10 @@ def test_check_filter_properties():
passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12))
assert passes is False

temp_path = tempfile.NamedTemporaryFile()
check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12),
verbose=True, save_properties=temp_path.name)
temp_path.close()

def test_check_filter_length():

Expand Down
38 changes: 36 additions & 2 deletions neurodsp/tests/filt/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Tests for filter utilities."""

from pytest import raises
from neurodsp.tests.settings import FS
import tempfile
from pytest import raises, mark, param

import numpy as np

from neurodsp.tests.settings import FS
from neurodsp.filt.utils import *
from neurodsp.filt.fir import design_fir_filter, compute_filter_length

Expand Down Expand Up @@ -53,3 +56,34 @@ def test_remove_filter_edges():
assert np.all(np.isnan(dropped_sig[:n_rmv]))
assert np.all(np.isnan(dropped_sig[-n_rmv:]))
assert np.all(~np.isnan(dropped_sig[n_rmv:-n_rmv]))


@mark.parametrize("pass_type", ['bandpass', 'bandstop', 'lowpass', 'highpass'])
@mark.parametrize("filt_type", ['IIR', 'FIR'])
def test_gen_filt_report(pass_type, filt_type):

fs = 1000
f_db = np.arange(0, 50)
db = np.random.rand(50)
pass_bw = 10
transition_bw = 4
f_range = (10, 40)
f_range_trans = (40, 44)

report = gen_filt_report(pass_type, filt_type, fs, f_db, db, pass_bw,
transition_bw, f_range, f_range_trans)

assert pass_type in report.values()
assert filt_type in report.values()


@mark.parametrize("dir_exists", [True, param(False, marks=mark.xfail)])
def test_save_filt_report(dir_exists):

filt_report = {'Pass Type': 'bandpass', 'Cutoff (half-amplitude)': 50}
temp_path = tempfile.NamedTemporaryFile()
if not dir_exists:
save_filt_report('/bad/path/', filt_report)
else:
save_filt_report(temp_path.name, filt_report)
temp_path.close()

0 comments on commit 6a5695e

Please sign in to comment.