Skip to content

Commit 165af63

Browse files
Merge branch '4980-get-wsi-at-mpp' of https://github.com/NikolasSchmitz/MONAI into 4980-get-wsi-at-mpp
2 parents 01e60e0 + 1789735 commit 165af63

File tree

4 files changed

+13
-20
lines changed

4 files changed

+13
-20
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<p align="center">
2-
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" width="50%" alt='project-monai'>
2+
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" width="50%" alt='project-monai'>
33
</p>
44

55
**M**edical **O**pen **N**etwork for **AI**

monai/networks/schedulers/ddim.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
117117
)
118118

119119
self.num_inference_steps = num_inference_steps
120-
step_ratio = self.num_train_timesteps // self.num_inference_steps
121-
if self.steps_offset >= step_ratio:
122-
raise ValueError(
123-
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
124-
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
125-
f" the max train timestep."
126-
)
127-
128-
# creates integer timesteps by multiplying by ratio
129-
# casting to int to avoid issues when num_inference_step is power of 3
130-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
131-
self.timesteps = torch.from_numpy(timesteps).to(device)
120+
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121+
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122+
123+
self.timesteps = (
124+
torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device)
125+
.round()
126+
.long()
127+
)
132128
self.timesteps += self.steps_offset
133129

134130
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:

monai/networks/schedulers/ddpm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from __future__ import annotations
3333

34-
import numpy as np
3534
import torch
3635

3736
from monai.utils import StrEnum
@@ -122,11 +121,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
122121
)
123122

124123
self.num_inference_steps = num_inference_steps
125-
step_ratio = self.num_train_timesteps // self.num_inference_steps
126-
# creates integer timesteps by multiplying by ratio
127-
# casting to int to avoid issues when num_inference_step is power of 3
128-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
129-
self.timesteps = torch.from_numpy(timesteps).to(device)
124+
self.timesteps = (
125+
torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
126+
)
130127

131128
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
132129
"""

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ h5py
5151
nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
5252
optuna
5353
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
54-
onnx>=1.13.0
54+
onnx>=1.13.0, <1.19.1
5555
onnxruntime; python_version <= '3.10'
5656
typeguard<3 # https://github.com/microsoft/nni/issues/5457
5757
filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523

0 commit comments

Comments
 (0)