Skip to content

Commit

Permalink
df: Add option to run only 1st stage of model via commandline
Browse files Browse the repository at this point in the history
  • Loading branch information
Rikorose committed Jun 12, 2023
1 parent 1775426 commit 80a88f5
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion DeepFilterNet/df/enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def main(args):
log_level=args.log_level,
config_allow_defaults=True,
epoch=args.epoch,
mask_only=args.no_df_stage,
)
suffix = suffix if args.suffix else None
if args.output_dir is None:
Expand Down Expand Up @@ -105,6 +106,7 @@ def init_df(
config_allow_defaults: bool = False,
epoch: Union[str, int, None] = "best",
default_model: str = DEFAULT_MODEL,
mask_only: bool = False,
) -> Tuple[nn.Module, DF, str]:
"""Initializes and loads config, model and deep filtering state.
Expand Down Expand Up @@ -161,7 +163,9 @@ def init_df(
load_cp = epoch is not None and not (isinstance(epoch, str) and epoch.lower() == "none")
if not load_cp:
checkpoint_dir = None
mask_only = config("mask_only", cast=bool, section="train", default=False, save=False)
mask_only = mask_only or config(
"mask_only", cast=bool, section="train", default=False, save=False
)
model, epoch = load_model_cp(checkpoint_dir, df_state, epoch=epoch, mask_only=mask_only)
if (epoch is None or epoch == 0) and load_cp:
logger.error("Could not find a checkpoint")
Expand Down Expand Up @@ -361,6 +365,7 @@ def run():
dest="suffix",
help="Don't add the model suffix to the enhanced audio files",
)
parser.add_argument("--no-df-stage", action="store_true")
args = parser.parse_args()
main(args)

Expand Down

0 comments on commit 80a88f5

Please sign in to comment.