Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ilic-mezza authored Dec 16, 2023
1 parent 1432c68 commit 6a536ee
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions parcnet/example_parcnet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main():
num_valid_nn_packets=num_valid_nn_packets,
model_checkpoint=model_checkpoint,
xfade_len_in=xfade_len_in,
device='cpu'
)

# ----------- Load the reference audio file ----------- #
Expand Down
4 changes: 2 additions & 2 deletions parcnet/parcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
num_valid_nn_packets: int,
model_checkpoint: str,
xfade_len_in: int,
device: str = 'cpu',
):

self.packet_dim = packet_dim
Expand All @@ -40,8 +41,7 @@ def __init__(self,
self.ar_model = ARModel(ar_order, ar_diagonal_load)

# Load the pretrained neural network
self.neural_net = HybridModel.load_from_checkpoint(checkpoint_path=model_checkpoint, channels=1, lite=True)

self.neural_net = HybridModel.load_from_checkpoint(model_checkpoint, channels=1, lite=True).to(device)

def __call__(self, input_signal: np.ndarray, trace: np.ndarray, **kwargs) -> np.ndarray:
self.neural_net.eval()
Expand Down

0 comments on commit 6a536ee

Please sign in to comment.