Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch20 lib #78

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions dockerfiles/Dockerfile_cuda121_torch2.1
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ RUN pip install \
torchvision \
torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

RUN TORCH_VERSION=`(python -c "import torch;print(torch.__version__)")` &&
pip install \
torch_geometric \
torch_scatter \
torch_sparse \
torch_cluster \
torch_spline_conv --extra-index-url https://data.pyg.org/whl/torch-${TORCH_VERSION}.html

RUN GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') && \
pip install GDAL==$GDAL_VERSION --no-binary=gdal

Expand Down
2 changes: 1 addition & 1 deletion src/cultionet/augment/augmenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class Augmenters:
each labeled parcel in `y`.

aug_args: Additional keyword arguments passed to the
`torch_geometric.data.Data` object.
`Data` object.

Example:
>>> augmenters = Augmenters(augmentations=['tswarp'])
Expand Down
2 changes: 1 addition & 1 deletion src/cultionet/utils/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def transform(self, batch: Data) -> Data:
r"""Normalizes data by the Dynamic World log method or by z-scores.

Args:
batch (Data): A `torch_geometric` data object.
batch (Data): A tensor data object.
data_means (Tensor): The data feature-wise means.
data_stds (Tensor): The data feature-wise standard deviations.

Expand Down
2 changes: 1 addition & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import joblib
import pytorch_lightning as pl
import torch
from torch_geometric.data import Data

import cultionet
from cultionet.data.data import Data
from cultionet.data.datasets import EdgeDataset
from cultionet.enums import AttentionTypes, ModelTypes, ResBlockTypes
from cultionet.model import CultionetParams
Expand Down