forked from mir-aidj/all-in-one
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
executable file
·95 lines (83 loc) · 3.57 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from typing import List
from cog import BasePredictor, BaseModel, Input, Path
import torch
import allin1
import os
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def predict(
self,
music_input: Path = Input(
description="An audio file input to analyze.",
default=None,
),
visualize: bool = Input(
description="Save visualizations",
default=False,
),
sonify: bool = Input(
description="Save sonifications",
default=False,
),
activ: bool = Input(
description="Save frame-level raw activations from sigmoid and softmax",
default=False,
),
embed: bool = Input(
description="Save frame-level embeddings",
default=False,
),
model: str = Input(
description="Name of the pretrained model to use",
default="harmonix-all",
choices=["harmonix-all", "harmonix-fold0", "harmonix-fold1", "harmonix-fold2", "harmonix-fold3", "harmonix-fold4", "harmonix-fold5", "harmonix-fold6", "harmonix-fold7"]
),
include_activations: bool = Input(
description="Whether to include activations in the analysis results or not.",
default=False
),
include_embeddings: bool = Input(
description="Whether to include embeddings in the analysis results or not.",
default=False,
),
) -> List[Path]:
if not music_input:
raise ValueError("Must provide `music_input`.")
if os.path.isdir('demix'):
import shutil
shutil.rmtree('demix')
if os.path.isdir('spec'):
import shutil
shutil.rmtree('spec')
if os.path.isdir('output'):
import shutil
shutil.rmtree('output')
# Music Structure Analysis
music_input_analysis = allin1.analyze(paths=music_input, out_dir='output', visualize=visualize, sonify=sonify, model=model, device=self.device, include_activations=include_activations, include_embeddings=include_embeddings)
output_dir = []
for dirpath, dirnames, filenames in os.walk("output"):
for filename in [f for f in filenames if f.rsplit('.', 1)[-1] == "json"]:
json_dir = os.path.join(dirpath, filename)
output_dir.append(Path(json_dir))
if visualize:
for dirpath, dirnames, filenames in os.walk("viz"):
for filename in [f for f in filenames if f.rsplit('.', 1)[-1] == "pdf"]:
visualization_dir = os.path.join(dirpath, filename)
import fitz
doc = fitz.open(str(visualization_dir))
for i, page in enumerate(doc):
img = page.get_pixmap()
img_dir = str(visualization_dir).rsplit('.',1)[0]+'.png'
img.save(img_dir)
break
output_dir.append(Path(img_dir))
if sonify:
for dirpath, dirnames, filenames in os.walk("sonif"):
for filename in [f for f in filenames if f.rsplit('.', 1)[-1] == "mp3"]:
sonification_dir = os.path.join(dirpath, filename)
output_dir.append(Path(sonification_dir))
return output_dir