diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c178a9ad2..28cebd92a 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -68,7 +68,7 @@ def download_models( custom_path: Optional[torch.serialization.FILE_LIKE] = None, ) -> Optional[str]: if source == "local": - download_path = os.getcwd() + download_path = custom_path if custom_path is not None else os.getcwd() if ( not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload @@ -83,10 +83,20 @@ def download_models( ) return None elif source == "huggingface": - hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) try: - download_path = get_latest_modified_file( - os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") + download_path = ( + get_latest_modified_file( + os.path.join( + os.getenv( + "HF_HOME", os.path.expanduser("~/.cache/huggingface") + ), + "hub/models--2Noise--ChatTTS/snapshots", + ) + ) + if custom_path is None + else get_latest_modified_file( + os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots") + ) ) except: download_path = None @@ -99,16 +109,16 @@ def download_models( download_path = snapshot_download( repo_id="2Noise/ChatTTS", allow_patterns=["*.yaml", "*.json", "*.safetensors"], + cache_dir=custom_path, + force_download=force_redownload, ) except: download_path = None - else: - self.logger.log( - logging.INFO, f"load latest snapshot from cache: {download_path}" - ) - if download_path is None: - self.logger.error("download from huggingface failed.") - return None + else: + self.logger.log( + logging.INFO, + f"load latest snapshot from cache: {download_path}", + ) elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") if not check_all_assets(Path(custom_path), self.sha256_map, update=False): @@ -116,6 +126,10 @@ def download_models( return None download_path = custom_path + if download_path is None: + self.logger.error("Model download failed") + return None + return download_path def load( diff --git a/tools/audio/av.py b/tools/audio/av.py index 333b423d6..cd3a7d66a 100644 --- a/tools/audio/av.py +++ b/tools/audio/av.py @@ -41,11 +41,11 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str): def load_audio( - file: Union[str, BytesIO, Path], - sr: Optional[int] = None, - format: Optional[str] = None, - mono=True, - ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + file: Union[str, BytesIO, Path], + sr: Optional[int] = None, + format: Optional[str] = None, + mono=True, +) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 """ @@ -113,7 +113,7 @@ def frame_iter(container): np.copyto(decoded_audio[..., offset:end_index], frame_data) offset += len(frame_data[0]) - + container.close() # Truncate the array to the actual size @@ -124,4 +124,4 @@ def frame_iter(container): if sr is not None: return decoded_audio - return decoded_audio, rate \ No newline at end of file + return decoded_audio, rate