-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add experimental MPS support #74
base: main
Are you sure you want to change the base?
Conversation
awesome! A few little things likely to need clean up
|
Ok, I added a quick fix using the suggested env variable. Let me know if it works, if you have time ? Thanks a lot |
|
I guess need to globally set |
if _is_mps_available(torch): | ||
from os import environ | ||
|
||
environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It ought to be set globally as soon as MPS is available... Maybe if I move it; anyway this is not ideal, I don't want you to lose too much time in tests. Does server 6 have MPS ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does, but let's focus on the other items first (unifying naming to wnet3d, etc :).
@MMathisLab Okay, I tried adding it directly in workers; if this does not work let's save it for when I have an MPS-capable device to test on |
Updating 3a0dd11..a1077bc
Fast-forward
.github/workflows/test_and_deploy.yml | 1 -
napari_cellseg3d/code_models/model_framework.py | 3 ---
napari_cellseg3d/code_models/worker_inference.py | 5 +++++
napari_cellseg3d/code_models/worker_training.py | 5 +++++
setup.cfg | 4 ++--
5 files changed, 12 insertions(+), 6 deletions(-)
(napari_cellseg3d_m1) mackenzie@mackenzies-macbook-air CellSeg3d % napari
14:49:56 INFO pydensecrf not installed, CRF post-processing will not be available. Please install by running : pip install pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=masterThis is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step.
14:49:56 INFO wandb not installed, wandb config will not be taken into account
14:50:05 INFO Starting...
14:50:05 INFO ********************
14:50:05 INFO Worker started at 14:50:05
14:50:05 INFO Saving results to : /Users/mackenzie/cellseg3d/inference
14:50:05 INFO Worker is running...
14:50:05 INFO Number of threads has been set to 1 for macOS
14:50:05 INFO MODEL DIMS : 64
14:50:05 INFO Model name : WNet
14:50:05 INFO Instantiating model...
14:50:06 INFO ********************
14:50:06 INFO Loading weights...
14:50:06 INFO Weight file wnet_latest.pth already exists, skipping download
14:50:06 INFO Weights status : None
14:50:06 INFO Done
14:50:06 INFO --------------------
14:50:06 INFO Parameters summary :
14:50:06 INFO Model is : WNet
14:50:06 INFO Window inference is enabled
14:50:06 INFO Window size is 64
14:50:06 INFO Window overlap is 0.25
14:50:06 INFO Dataset loaded on mps device
2024-05-03 14:50:06,431 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'function', transform is not lazy
2024-05-03 14:50:06,431 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'ToTensor', transform is not lazy
14:50:06 INFO --------------------
14:50:06 INFO Loading layer
2024-05-03 14:50:06,443 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'EnsureType', transform is not lazy
14:50:06 INFO Done
14:50:06 INFO ----------
14:50:06 INFO Inference started on layer...
14:50:06 ERROR The operator 'aten::max_pool3d_with_indices' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
Traceback (most recent call last):
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/worker_inference.py", line 408, in model_output
outputs = sliding_window_inference(
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/monai/inferers/utils.py", line 229, in sliding_window_inference
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/worker_inference.py", line 370, in model_output_wrapper
result = model(inputs)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/models/model_WNet.py", line 46, in forward
return super().forward(norm_x)
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/models/wnet/model.py", line 39, in forward
return self.encoder(x)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/models/wnet/model.py", line 133, in forward
c1 = self.conv1(self.max_pool(in_b))
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/modules/pooling.py", line 241, in forward
return F.max_pool3d(input, self.kernel_size, self.stride,
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/_jit_internal.py", line 497, in fn
return if_false(*args, **kwargs)
File "/Users/mackenzie/anaconda3/envs/napari_cellseg3d_m1/lib/python3.9/site-packages/torch/nn/functional.py", line 882, in _max_pool3d
return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
NotImplementedError: The operator 'aten::max_pool3d_with_indices' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
14:50:06 ERROR local variable 'outputs' referenced before assignment
Traceback (most recent call last):
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/worker_inference.py", line 427, in model_output
logger.debug(f"Inference output shape: {outputs.shape}")
UnboundLocalError: local variable 'outputs' referenced before assignment
14:50:06 ERROR 'NoneType' object has no attribute 'shape'
Traceback (most recent call last):
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/worker_inference.py", line 987, in inference
yield self.inference_on_layer(
File "/Users/mackenzie/Documents/CellSeg3d/napari_cellseg3d/code_models/worker_inference.py", line 770, in inference_on_layer
logger.debug(f"Inference on layer result shape : {out.shape}")
AttributeError: 'NoneType' object has no attribute 'shape'
14:50:06 INFO
Worker finished at 14:50:06
14:50:06 INFO ******************** |
No description provided.