Skip to content

Commit

Permalink
pep8 cleanup and fix tests (#147)
Browse files Browse the repository at this point in the history
* pep8

* change order to prevent rgb error

* edit testing config
  • Loading branch information
ieee8023 authored Dec 12, 2023
1 parent 9ff4638 commit e4ea0fe
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 55 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: ['3.8']
torch-version: [1.10.0, 2.0.0]
os: [ubuntu-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]
python-version: ['3.9']
torch-version: [2.1.1]
os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r requirements.txt
wheel
pytest
pydicom>=2.3.1
pydicom>=2.3.1
2 changes: 1 addition & 1 deletion torchxrayvision/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def ResNetAE(weights=None):
"""A ResNet based autoencoder.
Possible weights for this class include:
- "101-elastic" trained on PadChest, NIH, CheXpert, and MIMIC. From the paper https://arxiv.org/abs/2102.09475
.. code-block:: python
Expand Down
4 changes: 2 additions & 2 deletions torchxrayvision/baseline_models/chexpert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def infer(self, x, tasks):
for task in tasks:

idx = self.task_sequence[task]
#task_prob = probs.detach().cpu().numpy()[idx]
# task_prob = probs.detach().cpu().numpy()[idx]
task_prob = probs[idx]
task2results[task] = task_prob

Expand Down Expand Up @@ -226,7 +226,7 @@ def infer(self, img, tasks):
else:
task2ensemble_results[task].append(individual_task2results[task])

assert all([task in task2ensemble_results for task in tasks]),\
assert all([task in task2ensemble_results for task in tasks]), \
"Not all tasks in task2ensemble_results"

task2results = {}
Expand Down
18 changes: 9 additions & 9 deletions torchxrayvision/baseline_models/riken/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ class AgeModel(nn.Module):
url = {https://www.nature.com/articles/s43856-022-00220-6},
year = {2022}
}
"""

targets: List[str] = ["Age"]
""""""

def __init__(self):

super(AgeModel, self).__init__()

url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/baseline_models_riken_xray_age_every_model_age_senet154_v2_tl_26_ft_7_fp32.pt"

weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
Expand All @@ -81,17 +81,17 @@ def __init__(self):
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
# now between [0,1]

x = self.norm(x)
return self.model(x)

def __repr__(self):
return "riken-age-prediction"
23 changes: 11 additions & 12 deletions torchxrayvision/baseline_models/xinario/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class ViewModel(nn.Module):
"""
The native resolution of the model is 320x320. Images are scaled
automatically.
Expand All @@ -26,7 +26,7 @@ class ViewModel(nn.Module):
pred = model(image)
# tensor([[17.3186, 26.7156]]), grad_fn=<AddmmBackward0>)
model.targets[pred.argmax()]
# Lateral
Expand All @@ -37,13 +37,13 @@ class ViewModel(nn.Module):

targets: List[str] = ['Frontal', 'Lateral']
""""""

def __init__(self):

super(ViewModel, self).__init__()

url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/xinario_chestViewSplit_resnet-50.pt"

weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
Expand All @@ -54,7 +54,6 @@ def __init__(self):
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
xrv.utils.download(url, self.weights_filename_local)


self.model = torchvision.models.resnet.resnet50()
try:
weights = torch.load(self.weights_filename_local)
Expand All @@ -74,17 +73,17 @@ def __init__(self):
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
# now between [0,1]

x = self.norm(x)
return self.model(x)[:,:2] # cut off the rest of the outputs
return self.model(x)[:, :2] # cut off the rest of the outputs

def __repr__(self):
return "xinario-view-prediction"
38 changes: 21 additions & 17 deletions torchxrayvision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Dataset:
metadata file and for some the metadata files are packaged in the library
so only the imgpath needs to be specified.
"""

def __init__(self):
pass

Expand Down Expand Up @@ -262,7 +263,7 @@ def __init__(self, datasets, seed=0, label_concat=False):
print("Could not merge dataframes (.csv not available):", sys.exc_info()[0])

self.csv = self.csv.reset_index(drop=True)

def __setattr__(self, name, value):
if hasattr(self, 'labels'):
# check only if have finished init, otherwise __init__ breaks
Expand Down Expand Up @@ -346,6 +347,7 @@ class SubsetDataset(Dataset):
- of PC_Dataset num_samples=94825 views=['PA', 'AP'] data_aug=None
"""

def __init__(self, dataset, idxs=None):
super(SubsetDataset, self).__init__()
self.dataset = dataset
Expand All @@ -365,7 +367,7 @@ def __setattr__(self, name, value):
# check only if have finished init, otherwise __init__ breaks
if name in ['transform', 'data_aug', 'labels', 'pathologies', 'targets']:
raise NotImplementedError(f'Cannot set {name} on a subset dataset. Set the transforms directly on the dataset object. If it was to be set via this subset dataset it would have to modify the internal dataset which could have unexpected side effects')

object.__setattr__(self, name, value)

def string(self):
Expand Down Expand Up @@ -895,17 +897,17 @@ def __init__(self,
"216840111366964012373310883942009170084120009_00-097-074.png",
"216840111366964012819207061112010315104455352_04-024-184.png",
"216840111366964012819207061112010306085429121_04-020-102.png",
"216840111366964012989926673512011083134050913_00-168-009.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964012373310883942009152114636712_00-102-045.png", # "OSError: image file is truncated"
"216840111366964012819207061112010281134410801_00-129-131.png", # "OSError: image file is truncated"
"216840111366964012487858717522009280135853083_00-075-001.png", # "OSError: image file is truncated"
"216840111366964012989926673512011151082430686_00-157-045.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964013686042548532013208193054515_02-026-007.png", # "OSError: image file is truncated"
"216840111366964013590140476722013058110301622_02-056-111.png", # "OSError: image file is truncated"
"216840111366964013590140476722013043111952381_02-065-198.png", # "OSError: image file is truncated"
"216840111366964013829543166512013353113303615_02-092-190.png", # "OSError: image file is truncated"
"216840111366964013962490064942014134093945580_01-178-104.png", # "OSError: image file is truncated"
]
"216840111366964012989926673512011083134050913_00-168-009.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964012373310883942009152114636712_00-102-045.png", # "OSError: image file is truncated"
"216840111366964012819207061112010281134410801_00-129-131.png", # "OSError: image file is truncated"
"216840111366964012487858717522009280135853083_00-075-001.png", # "OSError: image file is truncated"
"216840111366964012989926673512011151082430686_00-157-045.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964013686042548532013208193054515_02-026-007.png", # "OSError: image file is truncated"
"216840111366964013590140476722013058110301622_02-056-111.png", # "OSError: image file is truncated"
"216840111366964013590140476722013043111952381_02-065-198.png", # "OSError: image file is truncated"
"216840111366964013829543166512013353113303615_02-092-190.png", # "OSError: image file is truncated"
"216840111366964013962490064942014134093945580_01-178-104.png", # "OSError: image file is truncated"
]
self.csv = self.csv[~self.csv["ImageID"].isin(missing)]

if unique_patients:
Expand All @@ -920,7 +922,7 @@ def __init__(self,
mask = self.csv["Labels"].str.contains(pathology.lower())
if pathology in mapping:
for syn in mapping[pathology]:
#print("mapping", syn)
# print("mapping", syn)
mask |= self.csv["Labels"].str.contains(syn.lower())
labels.append(mask.values)
self.labels = np.asarray(labels).T
Expand Down Expand Up @@ -1094,7 +1096,7 @@ def __getitem__(self, idx):
sample["lab"] = self.labels[idx]

imgid = self.csv['Path'].iloc[idx]
#clean up path in csv so the user can specify the path
# clean up path in csv so the user can specify the path
imgid = imgid.replace("CheXpert-v1.0-small/", "").replace("CheXpert-v1.0/", "")
img_path = os.path.join(self.imgpath, imgid)
img = imread(img_path)
Expand Down Expand Up @@ -1344,7 +1346,7 @@ def __init__(self, imgpath,
mask = self.csv["labels_automatic"].str.contains(pathology.lower())
if pathology in mapping:
for syn in mapping[pathology]:
#print("mapping", syn)
# print("mapping", syn)
mask |= self.csv["labels_automatic"].str.contains(syn.lower())
labels.append(mask.values)

Expand Down Expand Up @@ -1994,7 +1996,7 @@ def __init__(self,
transform=None,
data_aug=None,
seed=0
):
):
super(ObjectCXR_Dataset, self).__init__()

np.random.seed(seed) # Reset the seed so all runs are the same.
Expand Down Expand Up @@ -2053,6 +2055,7 @@ def __call__(self, x):

class XRayResizer(object):
"""Resize an image to a specific size"""

def __init__(self, size: int, engine="skimage"):
self.size = size
self.engine = engine
Expand All @@ -2076,6 +2079,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray:

class XRayCenterCrop(object):
"""Perform a center crop on the long dimension of the input image"""

def crop_center(self, img: np.ndarray) -> np.ndarray:
_, y, x = img.shape
crop_size = np.min([y, x])
Expand Down
6 changes: 4 additions & 2 deletions torchxrayvision/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
}

# Just created for documentation


class Model:
"""The library is composed of core and baseline classifiers. Core
classifiers are trained specifically for this library and baseline
Expand Down Expand Up @@ -132,6 +134,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
pass


class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
Expand Down Expand Up @@ -190,7 +193,7 @@ class DenseNet(nn.Module):
:param weights: Specify a weight name to load pre-trained weights
:param op_threshs: Specify a weight name to load pre-trained weights
:param apply_sigmoid: Apply a sigmoid
"""

targets: List[str] = [
Expand Down Expand Up @@ -379,7 +382,6 @@ class ResNet(nn.Module):
]
""""""


def __init__(self, weights: str = None, apply_sigmoid: bool = False):
super(ResNet, self).__init__()

Expand Down
17 changes: 9 additions & 8 deletions torchxrayvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def load_image(fname: str):

return img

def read_xray_dcm(path:PathLike, voi_lut:bool=False, fix_monochrome:bool=True)->ndarray:

def read_xray_dcm(path: PathLike, voi_lut: bool = False, fix_monochrome: bool = True) -> ndarray:
"""read a dicom-like file and convert to numpy array
Args:
Expand All @@ -98,26 +99,26 @@ def read_xray_dcm(path:PathLike, voi_lut:bool=False, fix_monochrome:bool=True)->

# get the pixel array
ds = pydicom.dcmread(path, force=True)
data = ds.pixel_array

# we have not tested RGB, YBR_FULL, or YBR_FULL_422 yet.
if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']:
if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']:
raise NotImplementedError(f'PhotometricInterpretation `{ds.PhotometricInterpretation}` is not yet supported.')
# get the max possible pixel value from DCM header
max_possible_pixel_val = (2**ds.BitsStored - 1)

data = ds.pixel_array

# LUT for human friendly view
if voi_lut:
data = pydicom.pixel_data_handlers.util.apply_voi_lut(data, ds, index=0)


# `MONOCHROME1` have an inverted view; Bones are black; background is white
# https://web.archive.org/web/20150920230923/http://www.mccauslandcenter.sc.edu/mricro/dicom/index.html
if fix_monochrome and ds.PhotometricInterpretation == "MONOCHROME1":
warnings.warn(f"Coverting MONOCHROME1 to MONOCHROME2 interpretation for file: {path}. Can be avoided by setting `fix_monochrome=False`")
data = max_possible_pixel_val - data

# normalize data to [-1024, 1024]
# normalize data to [-1024, 1024]
data = normalize(data, max_possible_pixel_val)
return data

Expand All @@ -129,13 +130,13 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4,
batch_size=threads,
num_workers=threads,
)

preds = []
with torch.inference_mode():
for i, batch in enumerate(tqdm(dl)):
output = model(batch["img"].to(device))

output = output.detach().cpu().numpy()
preds.append(output)

return np.concatenate(preds)

0 comments on commit e4ea0fe

Please sign in to comment.