Skip to content

Commit

Permalink
fix: 🐛 improve device handling
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Dec 19, 2023
1 parent 2d237f0 commit 6b1547f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Examples to specific dataset options can be found in `./options/default_dataset_

Results are calculated with:
- **PLCC without any correction**. Although test time value correction is common in IQA papers, we want to use the original value in our benchmark.
- **Full image single input.** We use multi-patch testing only when it is necessary for the model to work.
- **Full image single input.** We **do not** use multi-patch testing unless necessary.

Basically, we use the largest existing datasets for training, and cross dataset evaluation performance for fair comparison. The following models do not provide official weights, and are retrained by our scripts:

Expand Down
2 changes: 1 addition & 1 deletion options/train/train_laion_aes_ava.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ path:
train:
optim:
type: AdamW
lr: !!float 1e-4
lr: !!float 3e-5

scheduler:
type: CosineAnnealingLR
Expand Down
10 changes: 8 additions & 2 deletions pyiqa/archs/laion_aes_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ class LAIONAes(nn.Module):
Returns:
A tensor representing the predicted image quality scores.
"""
def __init__(self, pretrained=True) -> None:
def __init__(self,
pretrained=True,
pretrained_model_path=None,
) -> None:
super().__init__()

clip_model, _ = clip.load("ViT-L/14")
self.mlp = MLP(clip_model.visual.output_dim)
self.clip_model = [clip_model]
if pretrained:

if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, True, weight_keys='params')
elif pretrained:
load_pretrained_network(self.mlp, default_model_urls["url"])

def forward(self, x):
Expand Down
10 changes: 5 additions & 5 deletions pyiqa/models/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def __init__(
self.net.eval()
set_random_seed(seed)

self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.dummy_param = torch.nn.Parameter(torch.empty(0)).to(self.device)

def forward(self, target, ref=None, **kwargs):
self.device = self.dummy_param.device
device = self.dummy_param.device

with torch.set_grad_enabled(self.as_loss):

if 'fid' in self.metric_name:
output = self.net(target, ref, device=self.device, **kwargs)
output = self.net(target, ref, device=device, **kwargs)
else:
if not torch.is_tensor(target):
target = imread2tensor(target, rgb=True)
Expand All @@ -77,9 +77,9 @@ def forward(self, target, ref=None, **kwargs):

if self.metric_mode == 'FR':
assert ref is not None, 'Please specify reference image for Full Reference metric'
output = self.net(target.to(self.device), ref.to(self.device), **kwargs)
output = self.net(target.to(device), ref.to(device), **kwargs)
elif self.metric_mode == 'NR':
output = self.net(target.to(self.device), **kwargs)
output = self.net(target.to(device), **kwargs)

if self.as_loss:
if isinstance(output, tuple):
Expand Down

0 comments on commit 6b1547f

Please sign in to comment.