You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
when running through the tutorial.ipynb configuring device='mps' as follows
if useGPU and torch.cuda.is_available():
device = 'cuda:0'
torch.cuda.empty_cache()
if useGPU and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
# Notify:
print("Device selected: %s" % device)
the cell:
thisName = hParamsAggGNN['name']
#\\\ Architecture
thisArchit = archit.AggregationGNN(# Linear
hParamsAggGNN['F'],
hParamsAggGNN['K'],
hParamsAggGNN['bias'],
# Nonlinearity
hParamsAggGNN['sigma'],
# Pooling
hParamsAggGNN['rho'],
hParamsAggGNN['alpha'],
# MLP in the end
hParamsAggGNN['dimLayersMLP'],
# Structure
G.S/np.max(np.diag(G.E)), # Normalize the adjacency matrix
order = hParamsAggGNN['order'],
maxN = hParamsAggGNN['Nmax'],
nNodes = hParamsAggGNN['nNodes'])
#\\\ Optimizer
thisOptim = optim.Adam(thisArchit.parameters(), lr = learningRate, betas = (beta1,beta2))
#\\\ Model
AggGNN = model.Model(thisArchit,
lossFunction(),
thisOptim,
trainer,
evaluator,
device,
thisName,
saveDir)
#\\\ Add model to the dictionary
modelsGNN[thisName] = AggGNN
raises the error: TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
I believe MPS (apple M1 GPU) is not currently compatible with float64 and every tensor needs to be converted to float32. I'm wondering if an precision option could be added to the setup?
The text was updated successfully, but these errors were encountered:
when running through the
tutorial.ipynb
configuringdevice='mps'
as followsthe cell:
raises the error:
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
I believe MPS (apple M1 GPU) is not currently compatible with float64 and every tensor needs to be converted to float32. I'm wondering if an precision option could be added to the setup?
The text was updated successfully, but these errors were encountered: