Skip to content

Commit

Permalink
Add npt types to sdr.plot functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Jul 13, 2023
1 parent 949036c commit d019209
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
23 changes: 12 additions & 11 deletions src/sdr/plot/_filter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
A module containing filter-related plotting functions.
"""
from typing import Optional
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import scipy.signal
from typing_extensions import Literal

Expand All @@ -13,7 +14,7 @@


@export
def impulse_response(b: np.ndarray, a: np.ndarray = 1, N: Optional[int] = None, **kwargs):
def impulse_response(b: npt.ArrayLike, a: npt.ArrayLike = 1, N: int | None = None, **kwargs):
r"""
Plots the impulse response $h[n]$ of a filter.
Expand Down Expand Up @@ -75,7 +76,7 @@ def impulse_response(b: np.ndarray, a: np.ndarray = 1, N: Optional[int] = None,


@export
def step_response(b: np.ndarray, a: np.ndarray = 1, N: Optional[int] = None, **kwargs):
def step_response(b: npt.ArrayLike, a: npt.ArrayLike = 1, N: int | None = None, **kwargs):
r"""
Plots the step response $s[n]$ of a filter.
Expand Down Expand Up @@ -136,7 +137,7 @@ def step_response(b: np.ndarray, a: np.ndarray = 1, N: Optional[int] = None, **k


@export
def zeros_poles(b: np.ndarray, a: np.ndarray = 1, **kwargs):
def zeros_poles(b: npt.ArrayLike, a: npt.ArrayLike = 1, **kwargs):
r"""
Plots the zeros and poles of the filter.
Expand Down Expand Up @@ -178,8 +179,8 @@ def zeros_poles(b: np.ndarray, a: np.ndarray = 1, **kwargs):

@export
def frequency_response(
b: np.ndarray,
a: np.ndarray = 1,
b: npt.ArrayLike,
a: npt.ArrayLike = 1,
sample_rate: float = 1.0,
N: int = 1024,
x_axis: Literal["one-sided", "two-sided", "log"] = "two-sided",
Expand Down Expand Up @@ -237,8 +238,8 @@ def frequency_response(

@export
def group_delay(
b: np.ndarray,
a: np.ndarray = 1,
b: npt.ArrayLike,
a: npt.ArrayLike = 1,
sample_rate: float = 1.0,
N: int = 1024,
x_axis: Literal["one-sided", "two-sided", "log"] = "two-sided",
Expand Down Expand Up @@ -296,10 +297,10 @@ def group_delay(

@export
def filter( # pylint: disable=redefined-builtin
b: np.ndarray,
a: np.ndarray = 1,
b: npt.ArrayLike,
a: npt.ArrayLike = 1,
sample_rate: float = 1.0,
N_time: Optional[int] = None,
N_time: int | None = None,
N_freq: int = 1024,
x_axis: Literal["one-sided", "two-sided", "log"] = "two-sided",
decades: int = 4,
Expand Down
2 changes: 2 additions & 0 deletions src/sdr/plot/_rc_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
A module containing :obj:`sdr`'s default matplotlib rcParams.
"""
from __future__ import annotations

import matplotlib.pyplot as plt

from .._helper import export
Expand Down
5 changes: 4 additions & 1 deletion src/sdr/plot/_time_domain.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""
A module containing time-domain plotting functions.
"""
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt

from .._helper import export
from ._rc_params import RC_PARAMS


@export
def time_domain(x: np.ndarray, sample_rate: float = 1.0, **kwargs):
def time_domain(x: npt.ArrayLike, sample_rate: float = 1.0, **kwargs):
"""
Plots a time-domain signal $x[n]$.
Expand Down

0 comments on commit d019209

Please sign in to comment.