-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA2
] Add flash attention for opt
#26414
Merged
Merged
Changes from 3 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
626276f
added flash attention for opt
susnato 7800457
added to list
susnato 689f599
fix use cache (#3)
younesbelkada 74e9687
style fix
susnato cf923e8
fix text
susnato db8cf07
test fix2
susnato af67e0f
reverted until 689f599
susnato edf1610
torch fx tests are working now!
susnato f34b680
small fix
susnato 90be210
Merge branch 'main' into flash_attn_opt
susnato 5cec2ad
added TODO docstring
susnato 2bde7cf
Merge branch 'main' into flash_attn_opt
susnato db18d65
Merge branch 'main' into flash_attn_opt
susnato adfbb69
changes
susnato 10ab9b3
Merge branch 'main' into flash_attn_opt
susnato 7d4c688
comments and .md file modification
susnato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should not be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to get error for the fx tests -
test_torch_fx
andtest_torch_fx_output_loss
.The error is happening for this line.
The full error for test_torch_fx
tests/models/opt/test_modeling_opt.py F [100%]
========================================================= FAILURES =========================================================
________________________________________________ OPTModelTest.test_torch_fx ________________________________________________
self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>
config = OPTConfig {
"_remove_final_layer_norm": false,
"activation_function": "relu",
"attention_dropout": 0.1,
"bos_t...d": 1,
"transformers_version": "4.34.0.dev0",
"use_cache": true,
"vocab_size": 99,
"word_embed_proj_dim": 16
}
inputs_dict = {'attention_mask': tensor([[True, True, True, True, True, True, True],
[True, True, True, True, True, True, Tr...91, 54, 98, 57, 55, 2],
[74, 79, 56, 51, 93, 26, 2],
[62, 18, 55, 3, 73, 74, 2]], device='cuda:0')}
output_loss = False
tests/test_modeling_common.py:877:
model = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_names = ['input_ids', 'attention_mask'], disable_check = False, tracer_cls = <class 'transformers.utils.fx.HFTracer'>
src/transformers/utils/fx.py:1250:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}
dummy_inputs = None, complete_concrete_args_with_inputs_not_in_dummy_inputs = True
src/transformers/utils/fx.py:1088:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:739:
self = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_ids = Proxy(input_ids), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = None, use_cache = True, output_attentions = False, output_hidden_states = False, return_dict = False
src/transformers/models/opt/modeling_opt.py:1023:
mod = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:717:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
src/transformers/utils/fx.py:987:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:434:
args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:710:
self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward_call = <bound method OPTDecoder.forward of OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions)...out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)>
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194:
self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input_ids = Proxy(view), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = Proxy(decoder_embed_tokens), use_cache = True, output_attentions = False, output_hidden_states = False
return_dict = False
src/transformers/models/opt/modeling_opt.py:872:
self = Proxy(attention_mask), key = 0
src/transformers/utils/fx.py:646:
self = tensor(..., device='meta', size=(14, 11), dtype=torch.int64), element = 0
E NotImplementedError: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/_tensor.py:983: NotImplementedError
During handling of the above exception, another exception occurred:
self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>
tests/test_modeling_common.py:805:
tests/test_modeling_common.py:882: in _create_and_check_torch_fx_tracing
self.fail(f"Couldn't trace module: {e}")
E AssertionError: Couldn't trace module: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
EDIT : The errors are fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I see, in that line can you use instead
torch.isin
? https://pytorch.org/docs/stable/generated/torch.isin.html - suggestion from @michaelbenayoun