Skip to content

Commit

Permalink
added cuda support for pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Gabteni committed May 14, 2024
1 parent ee424d8 commit 8803979
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 44 deletions.
8 changes: 8 additions & 0 deletions app/translation/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device for translation selected: {device}")

if device.type == 'cuda':
print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
print(f"CUDA Device Count: {torch.cuda.device_count()}")
else:
print("CUDA is not available. Using CPU instead.")

model_name = "facebook/m2m100_1.2B" # You can also use larger models for better accuracy

def translate_to_german(text, src_lang):
Expand Down
102 changes: 61 additions & 41 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages = [{ include = "app" }]

[[tool.poetry.source]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cpu"
url = "https://download.pytorch.org/whl/cu117"
priority = "explicit"

[tool.poetry.dependencies]
Expand All @@ -33,8 +33,8 @@ torch = [
{markers = "sys_platform == 'linux' and platform_machine == 'arm64'", url="https://download.pytorch.org/whl/cpu/torch-1.13.1-cp310-none-macosx_11_0_arm64.whl"},
{markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", url = "https://download.pytorch.org/whl/cpu/torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl"},
{markers = "sys_platform == 'linux' and platform_machine == 'aarch64'", url="https://download.pytorch.org/whl/torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl"},
{markers = "sys_platform == 'linux' and platform_machine == 'x86_64'", url="https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp310-cp310-linux_x86_64.whl"},
{markers = "sys_platform == 'win' and platform_machine == 'amd64'", url="https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp310-cp310-win_amd64.whl"},
{markers = "sys_platform == 'linux' and platform_machine == 'x86_64'", url="https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp310-cp310-linux_x86_64.whl"},
{markers = "sys_platform == 'win' and platform_machine == 'amd64'", url="https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp310-cp310-win_amd64.whl"},
]
transformers = "^4.40.2"
sentencepiece = "^0.2.0"
Expand Down

0 comments on commit 8803979

Please sign in to comment.