Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keras with pytorch backend and mps set to default should use an mps generatir in randperm #19436

Closed
ralphrmartin opened this issue Apr 3, 2024 · 13 comments · Fixed by #19618
Closed
Assignees
Labels
backend:torch stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@ralphrmartin
Copy link

Keras with pytorch backend and mps set to default needs to use an mps generator in randperm

The following code

import os
os.environ["KERAS_BACKEND"] = "torch"

import torch as torch

torch.set_default_device('mps')

import keras
import numpy as np
from keras import layers

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(xx_train, yy_train), (xx_test, yy_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = xx_train.astype("float32") / 255
x_test = xx_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = torch.from_numpy(np.expand_dims(xx_train, -1))
x_test = torch.from_numpy(np.expand_dims(xx_test, -1))
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = torch.from_numpy(keras.utils.to_categorical(yy_train, num_classes).astype("float32"))
y_test = torch.from_numpy(keras.utils.to_categorical(yy_test, num_classes).astype("float32"))model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)
batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

produces the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 6
      2 epochs = 15
      4 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
----> 6 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py:631](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py#line=630), in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File ~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
--> 674     index = self._next_index()  # may raise StopIteration
    675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py:621](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py#line=620), in _BaseDataLoaderIter._next_index(self)
    620 def _next_index(self):
--> 621     return next(self._sampler_iter)

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py:287](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py#line=286), in BatchSampler.__iter__(self)
    285 batch = [0] * self.batch_size
    286 idx_in_batch = 0
--> 287 for idx in self.sampler:
    288     batch[idx_in_batch] = idx
    289     idx_in_batch += 1

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py:167](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py#line=166), in RandomSampler.__iter__(self)
    165 else:
    166     for _ in range(self.num_samples // n):
--> 167         yield from torch.randperm(n, generator=generator).tolist()
    168     yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

File ~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)

RuntimeError: Expected a 'mps:0' generator device but found 'cpu'
@SuryanarayanaY
Copy link
Contributor

Hi @ralphrmartin ,

I have tested the code snippet and getting NotImplementedError as per gist.

@ralphrmartin
Copy link
Author

I'm not quite sure who needs to do what here. Is this a matter for the mps team? I'm just an end user trying to use this stuff, and I get the error given in my initial report when running on an Apple Silicon MacBook Pro, with the following versions of packages, using Python 3.12.2

absl-py           2.1.0
appnope           0.1.4
asttokens         2.4.1
comm              0.2.2
contourpy         1.2.1
cycler            0.12.1
debugpy           1.8.1
decorator         5.1.1
executing         2.0.1
filelock          3.13.3
fonttools         4.50.0
fsspec            2024.3.1
h5py              3.10.0
ipykernel         6.29.4
ipython           8.23.0
jedi              0.19.1
Jinja2            3.1.3
jupyter_client    8.6.1
jupyter_core      5.7.2
keras             3.1.1
kiwisolver        1.4.5
markdown-it-py    3.0.0
MarkupSafe        2.1.5
matplotlib        3.8.4
matplotlib-inline 0.1.6
mdurl             0.1.2
ml-dtypes         0.3.2
mpmath            1.3.0
namex             0.0.7
nest-asyncio      1.6.0
networkx          3.2.1
numpy             1.26.4
optree            0.11.0
packaging         24.0
parso             0.8.3
pexpect           4.9.0
pillow            10.3.0
pip               24.0
platformdirs      4.2.0
prompt-toolkit    3.0.43
psutil            5.9.8
ptyprocess        0.7.0
pure-eval         0.2.2
Pygments          2.17.2
pyparsing         3.1.2
python-dateutil   2.9.0.post0
pyzmq             25.1.2
rich              13.7.1
six               1.16.0
stack-data        0.6.3
sympy             1.12
torch             2.2.2
torchvision       0.17.2
tornado           6.4
traitlets         5.14.2
typing_extensions 4.10.0
wcwidth           0.2.13

@M7Saad
Copy link
Contributor

M7Saad commented Apr 5, 2024

Some operations, such as the 'aten::random_' operator, are currently unsupported for the MPS device in the Torch backend. You can find more information about this issue at pytorch/pytorch#77764. As a temporary solution, I recommend setting the environment variable PYTORCH_ENABLE_MPS_FALLBACK. This enables keras to automatically utilize the GPU, you don't need to set the default device in torch.

@SuryanarayanaY
Copy link
Contributor

Hi @ralphrmartin ,

Could you please refer above comment of @M7Saad .Is It seems compatibility issue with Pytorch ?

@ralphrmartin
Copy link
Author

Thank you.

@SuryanarayanaY
Copy link
Contributor

Hi @ralphrmartin ,

Could you please confirm whether this issue is with pytorch compatibility? If so whether we can mark it as resolved ? Thanks!

@ralphrmartin
Copy link
Author

Setting PYTORCH_ENABLE_MPS_FALLBACK 1 prevents the issue, thanks.

@SuryanarayanaY
Copy link
Contributor

@ralphrmartin ,

Thanks for the response. Can we mark this as closed now?

@ralphrmartin
Copy link
Author

I guess so, but maybe the documentation needs updating to prevent other users from tripping over this.

@SuryanarayanaY SuryanarayanaY added the keras-team-review-pending Pending review by a Keras team member. label Apr 22, 2024
@haifeng-jin haifeng-jin self-assigned this Apr 25, 2024
@sachinprasadhs sachinprasadhs added stat:awaiting keras-eng Awaiting response from Keras engineer and removed keras-team-review-pending Pending review by a Keras team member. labels Apr 25, 2024
@grasskin
Copy link
Member

@ralphrmartin Hi Ralph, looking into this more it seems that PYTORCH_ENABLE_MPS_FALLBACK might have been an experimental flag that is no longer needed. Have you run into this flag in pytorch in general? Specifically, I'm seeing no mention of it here: https://pytorch.org/docs/stable/notes/mps.html.

If so we can remove the flag check from

and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "1"

@haifeng-jin haifeng-jin assigned grasskin and unassigned haifeng-jin Apr 25, 2024
@ralphrmartin
Copy link
Author

I am lost at this point. Using

Keras: 3.3.2
Torch: 2.3.0

My original comment holds, that if I dont use
PYTORCH_ENABLE_MPS_FALLBACK to 1
and I do torch.set_default_device('mps') as suggested at
https://pytorch.org/docs/stable/notes/mps.html),
Keras falls over as described in my initial message, failing to use an mps generator in randperm.

If I set
PYTORCH_ENABLE_MPS_FALLBACK to 1
then the mps device seems to be used to some extent, but I get

UserWarning: The operator 'aten::_foreach_mul_.Scalar' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. 

If I dont do torch.set_default_device('mps') , then it appears that the mps device is not used.

So, now what?

@grasskin
Copy link
Member

Looks like mps is stable enough that we can remove the experimental flag, will submit a separate PR. Thank you for flagging this Ralph.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:torch stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants