From e95585322497a424fb584706eb6e8ab729cb6d4a Mon Sep 17 00:00:00 2001 From: BBC-Esq Date: Mon, 30 Dec 2024 21:10:29 -0500 Subject: [PATCH 01/11] Update core.py --- ChatTTS/core.py | 56 +++++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c178a9ad2..b283048fe 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -61,14 +61,16 @@ def has_loaded(self, use_decoder=False): return not not_finish + # Modified def download_models( self, source: Literal["huggingface", "local", "custom"] = "local", force_redownload=False, custom_path: Optional[torch.serialization.FILE_LIKE] = None, + cache_dir: Optional[str] = None, ) -> Optional[str]: if source == "local": - download_path = os.getcwd() + download_path = cache_dir if cache_dir else os.getcwd() if ( not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload @@ -83,32 +85,40 @@ def download_models( ) return None elif source == "huggingface": - hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) - try: - download_path = get_latest_modified_file( - os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") - ) - except: - download_path = None - if download_path is None or force_redownload: - self.logger.log( - logging.INFO, - f"download from HF: https://huggingface.co/2Noise/ChatTTS", - ) + if cache_dir: try: download_path = snapshot_download( repo_id="2Noise/ChatTTS", allow_patterns=["*.yaml", "*.json", "*.safetensors"], + cache_dir=cache_dir, + force_download=force_redownload ) except: download_path = None else: - self.logger.log( - logging.INFO, f"load latest snapshot from cache: {download_path}" - ) - if download_path is None: - self.logger.error("download from huggingface failed.") - return None + hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) + try: + download_path = get_latest_modified_file( + os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") + ) + except: + download_path = None + if download_path is None or force_redownload: + self.logger.log( + logging.INFO, + f"download from HF: https://huggingface.co/2Noise/ChatTTS", + ) + try: + download_path = snapshot_download( + repo_id="2Noise/ChatTTS", + allow_patterns=["*.yaml", "*.json", "*.safetensors"], + ) + except: + download_path = None + else: + self.logger.log( + logging.INFO, f"load latest snapshot from cache: {download_path}" + ) elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") if not check_all_assets(Path(custom_path), self.sha256_map, update=False): @@ -116,8 +126,13 @@ def download_models( return None download_path = custom_path + if download_path is None: + self.logger.error("Model download failed") + return None + return download_path + # Modified def load( self, source: Literal["huggingface", "local", "custom"] = "local", @@ -129,8 +144,9 @@ def load( use_flash_attn=False, use_vllm=False, experimental: bool = False, + cache_dir: Optional[str] = None, ) -> bool: - download_path = self.download_models(source, force_redownload, custom_path) + download_path = self.download_models(source, force_redownload, custom_path, cache_dir) if download_path is None: return False return self._load( From 272be36d00b3bdf1a9fbb0c8a4f6678a439d0281 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 31 Dec 2024 02:15:53 +0000 Subject: [PATCH 02/11] chore(format): run black on main --- ChatTTS/core.py | 13 +++++++++---- tools/audio/av.py | 14 +++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index b283048fe..4c14b79db 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -91,12 +91,14 @@ def download_models( repo_id="2Noise/ChatTTS", allow_patterns=["*.yaml", "*.json", "*.safetensors"], cache_dir=cache_dir, - force_download=force_redownload + force_download=force_redownload, ) except: download_path = None else: - hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) + hf_home = os.getenv( + "HF_HOME", os.path.expanduser("~/.cache/huggingface") + ) try: download_path = get_latest_modified_file( os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") @@ -117,7 +119,8 @@ def download_models( download_path = None else: self.logger.log( - logging.INFO, f"load latest snapshot from cache: {download_path}" + logging.INFO, + f"load latest snapshot from cache: {download_path}", ) elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") @@ -146,7 +149,9 @@ def load( experimental: bool = False, cache_dir: Optional[str] = None, ) -> bool: - download_path = self.download_models(source, force_redownload, custom_path, cache_dir) + download_path = self.download_models( + source, force_redownload, custom_path, cache_dir + ) if download_path is None: return False return self._load( diff --git a/tools/audio/av.py b/tools/audio/av.py index 333b423d6..cd3a7d66a 100644 --- a/tools/audio/av.py +++ b/tools/audio/av.py @@ -41,11 +41,11 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str): def load_audio( - file: Union[str, BytesIO, Path], - sr: Optional[int] = None, - format: Optional[str] = None, - mono=True, - ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + file: Union[str, BytesIO, Path], + sr: Optional[int] = None, + format: Optional[str] = None, + mono=True, +) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 """ @@ -113,7 +113,7 @@ def frame_iter(container): np.copyto(decoded_audio[..., offset:end_index], frame_data) offset += len(frame_data[0]) - + container.close() # Truncate the array to the actual size @@ -124,4 +124,4 @@ def frame_iter(container): if sr is not None: return decoded_audio - return decoded_audio, rate \ No newline at end of file + return decoded_audio, rate From 2b1bd6db8d2eb38ed186960e1fb2f92e8e416c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:59:52 +0900 Subject: [PATCH 03/11] Update core.py --- ChatTTS/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 4c14b79db..1c164d62a 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -61,7 +61,6 @@ def has_loaded(self, use_decoder=False): return not not_finish - # Modified def download_models( self, source: Literal["huggingface", "local", "custom"] = "local", From 856bfbd230faa6d3284c0ecb21857168a1f5b76c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:07:37 +0900 Subject: [PATCH 04/11] Update av.py --- tools/audio/av.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/audio/av.py b/tools/audio/av.py index cd3a7d66a..e09b0febd 100644 --- a/tools/audio/av.py +++ b/tools/audio/av.py @@ -41,11 +41,11 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str): def load_audio( - file: Union[str, BytesIO, Path], - sr: Optional[int] = None, - format: Optional[str] = None, - mono=True, -) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + file: Union[str, BytesIO, Path], + sr: Optional[int] = None, + format: Optional[str] = None, + mono=True, + ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 """ @@ -113,7 +113,7 @@ def frame_iter(container): np.copyto(decoded_audio[..., offset:end_index], frame_data) offset += len(frame_data[0]) - + container.close() # Truncate the array to the actual size From 21dd6bcce029b2b4fc37e4237c8075f9b27a0af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:07:56 +0900 Subject: [PATCH 05/11] Update av.py From 12318b5c9919b982e5fa96b53316971d7f47034a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 7 Jan 2025 11:08:19 +0000 Subject: [PATCH 06/11] chore(format): run black on dev --- tools/audio/av.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/audio/av.py b/tools/audio/av.py index e09b0febd..cd3a7d66a 100644 --- a/tools/audio/av.py +++ b/tools/audio/av.py @@ -41,11 +41,11 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str): def load_audio( - file: Union[str, BytesIO, Path], - sr: Optional[int] = None, - format: Optional[str] = None, - mono=True, - ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + file: Union[str, BytesIO, Path], + sr: Optional[int] = None, + format: Optional[str] = None, + mono=True, +) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 """ @@ -113,7 +113,7 @@ def frame_iter(container): np.copyto(decoded_audio[..., offset:end_index], frame_data) offset += len(frame_data[0]) - + container.close() # Truncate the array to the actual size From fa2104704412960fc812f964817a12ab35ad3e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:12:21 +0900 Subject: [PATCH 07/11] Update core.py --- ChatTTS/core.py | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 1c164d62a..fcc66ea7f 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -84,7 +84,23 @@ def download_models( ) return None elif source == "huggingface": - if cache_dir: + if cache_dir is None: + hf_home = os.getenv( + "HF_HOME", os.path.expanduser("~/.cache/huggingface") + ) + else: + hf_home = cache_dir + try: + download_path = get_latest_modified_file( + os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") + ) + except: + download_path = None + if download_path is None or force_redownload: + self.logger.log( + logging.INFO, + f"download from HF: https://huggingface.co/2Noise/ChatTTS", + ) try: download_path = snapshot_download( repo_id="2Noise/ChatTTS", @@ -94,28 +110,6 @@ def download_models( ) except: download_path = None - else: - hf_home = os.getenv( - "HF_HOME", os.path.expanduser("~/.cache/huggingface") - ) - try: - download_path = get_latest_modified_file( - os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") - ) - except: - download_path = None - if download_path is None or force_redownload: - self.logger.log( - logging.INFO, - f"download from HF: https://huggingface.co/2Noise/ChatTTS", - ) - try: - download_path = snapshot_download( - repo_id="2Noise/ChatTTS", - allow_patterns=["*.yaml", "*.json", "*.safetensors"], - ) - except: - download_path = None else: self.logger.log( logging.INFO, @@ -134,7 +128,6 @@ def download_models( return download_path - # Modified def load( self, source: Literal["huggingface", "local", "custom"] = "local", From 647133f581b7cf867cf35f4fd306530cec2d8b08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:21:42 +0900 Subject: [PATCH 08/11] Update core.py --- ChatTTS/core.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index fcc66ea7f..1c68e3e0e 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -84,15 +84,13 @@ def download_models( ) return None elif source == "huggingface": - if cache_dir is None: - hf_home = os.getenv( - "HF_HOME", os.path.expanduser("~/.cache/huggingface") - ) - else: - hf_home = cache_dir try: download_path = get_latest_modified_file( - os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") + os.path.join(os.getenv( + "HF_HOME", os.path.expanduser("~/.cache/huggingface") + ), "hub/models--2Noise--ChatTTS/snapshots") + ) if cache_dir is None else get_latest_modified_file( + os.path.join(cache_dir, "models--2Noise--ChatTTS/snapshots") ) except: download_path = None From 8ea0bb59c682b38d8411e91374baf5bcb53e8b82 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 7 Jan 2025 11:22:05 +0000 Subject: [PATCH 09/11] chore(format): run black on dev --- ChatTTS/core.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 1c68e3e0e..0c03cdbaf 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -85,12 +85,19 @@ def download_models( return None elif source == "huggingface": try: - download_path = get_latest_modified_file( - os.path.join(os.getenv( - "HF_HOME", os.path.expanduser("~/.cache/huggingface") - ), "hub/models--2Noise--ChatTTS/snapshots") - ) if cache_dir is None else get_latest_modified_file( - os.path.join(cache_dir, "models--2Noise--ChatTTS/snapshots") + download_path = ( + get_latest_modified_file( + os.path.join( + os.getenv( + "HF_HOME", os.path.expanduser("~/.cache/huggingface") + ), + "hub/models--2Noise--ChatTTS/snapshots", + ) + ) + if cache_dir is None + else get_latest_modified_file( + os.path.join(cache_dir, "models--2Noise--ChatTTS/snapshots") + ) ) except: download_path = None From 467800e6fe291d1d34adc8061dcc2c9cd73b3712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:44:10 +0900 Subject: [PATCH 10/11] Update core.py --- ChatTTS/core.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 0c03cdbaf..46a34be1f 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -66,10 +66,9 @@ def download_models( source: Literal["huggingface", "local", "custom"] = "local", force_redownload=False, custom_path: Optional[torch.serialization.FILE_LIKE] = None, - cache_dir: Optional[str] = None, ) -> Optional[str]: if source == "local": - download_path = cache_dir if cache_dir else os.getcwd() + download_path = custom_path if custom_path is not None else os.getcwd() if ( not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload @@ -94,9 +93,9 @@ def download_models( "hub/models--2Noise--ChatTTS/snapshots", ) ) - if cache_dir is None + if custom_path is None else get_latest_modified_file( - os.path.join(cache_dir, "models--2Noise--ChatTTS/snapshots") + os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots") ) ) except: @@ -110,7 +109,7 @@ def download_models( download_path = snapshot_download( repo_id="2Noise/ChatTTS", allow_patterns=["*.yaml", "*.json", "*.safetensors"], - cache_dir=cache_dir, + cache_dir=custom_path, force_download=force_redownload, ) except: @@ -144,10 +143,9 @@ def load( use_flash_attn=False, use_vllm=False, experimental: bool = False, - cache_dir: Optional[str] = None, ) -> bool: download_path = self.download_models( - source, force_redownload, custom_path, cache_dir + source, force_redownload, custom_path ) if download_path is None: return False From 6e0948bbec9f87bc0fb1a66c7fb18db26931baf9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 7 Jan 2025 11:44:30 +0000 Subject: [PATCH 11/11] chore(format): run black on dev --- ChatTTS/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 46a34be1f..28cebd92a 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -144,9 +144,7 @@ def load( use_vllm=False, experimental: bool = False, ) -> bool: - download_path = self.download_models( - source, force_redownload, custom_path - ) + download_path = self.download_models(source, force_redownload, custom_path) if download_path is None: return False return self._load(