Description
🐛 Bug
TorchScript supports warnings.warn
syntax, however it doesn't respect the Python defaults (which is print the warning only once per invocation point as pointed out by stack_level
): https://docs.python.org/3/library/warnings.html#the-warnings-filter
It can cause a lot of spam, especially for deployed inference models.
To Reproduce
Sample notebook: https://colab.research.google.com/drive/1Zcm4F3rk58Q8aEBtfxSgHOKGJ9Ib2Uhs?usp=sharing
import torch
def foo(x):
for i in range(10):
x = torch.nn.functional.softmax(x) # generates a warning inside
foo(torch.rand(5))
# warning is printed once
torch.jit.script(foo)(torch.rand(5))
# warning is printed 10 times
Expected behavior
Ideally, we'd emulate Python's default syntax directly, i.e. do the dedup based on the stack trace - the warn
instruction in TorchScript actually does receive stack_level
but ignores it.
As a simple fix though we can pretend that log_level
is 0 - i.e. log only the first invocation of the unique instruction (e.g. can be done by adding a mutable flag to Code instance in Interpreter for each warn instruction.
Note, that TORCH_WARN_ONCE in C++ is not a good fit here - it'd put all warning to the same line in C++ which is not desired.