Skip to content

Commit d150d3e

Browse files
gmagogsfmfacebook-github-bot
authored andcommittedOct 2, 2020
Make sure each warnings.warn only executes once inside TorchScript. (pytorch#45382)
Summary: * Add a pass at end of runCleanupPasses to annotate `aten::warn` so that each has its unique id * Enhanced interpreter so that it tracks which `aten::warn` has been executed before and skip them * Improved insertInstruction so that it correctly checks for overflow Fixes pytorch#45108 Pull Request resolved: pytorch#45382 Reviewed By: mrshenli Differential Revision: D24060677 Pulled By: gmagogsfm fbshipit-source-id: 9221bc55b9ce36b374bdf614da3fe47496b481c1
1 parent 73e9daa commit d150d3e

File tree

9 files changed

+275
-15
lines changed

9 files changed

+275
-15
lines changed
 

‎aten/src/ATen/core/interned_strings.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ namespace c10 {
360360
_(attr, scope) \
361361
_(attr, keepdims) \
362362
_(attr, cache_id) \
363-
_(attr, new_axis)
363+
_(attr, new_axis) \
364+
_(attr, warn_id)
364365
#else
365366
#define FORALL_NS_SYMBOLS(_) \
366367
_(namespaces, prim) \

‎test/jit/test_warn.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import os
2+
import sys
3+
import io
4+
5+
import torch
6+
import warnings
7+
from contextlib import redirect_stderr
8+
from torch.testing import FileCheck
9+
10+
# Make the helper files in test/ importable
11+
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12+
sys.path.append(pytorch_test_dir)
13+
from torch.testing._internal.jit_utils import JitTestCase
14+
15+
if __name__ == '__main__':
16+
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
17+
"\tpython test/test_jit.py TESTNAME\n\n"
18+
"instead.")
19+
20+
21+
class TestWarn(JitTestCase):
22+
def test_warn(self):
23+
@torch.jit.script
24+
def fn():
25+
warnings.warn("I am warning you")
26+
27+
f = io.StringIO()
28+
with redirect_stderr(f):
29+
fn()
30+
31+
FileCheck() \
32+
.check_count(
33+
str="UserWarning: I am warning you",
34+
count=1,
35+
exactly=True) \
36+
.run(f.getvalue())
37+
38+
def test_warn_only_once(self):
39+
@torch.jit.script
40+
def fn():
41+
for _ in range(10):
42+
warnings.warn("I am warning you")
43+
44+
f = io.StringIO()
45+
with redirect_stderr(f):
46+
fn()
47+
48+
FileCheck() \
49+
.check_count(
50+
str="UserWarning: I am warning you",
51+
count=1,
52+
exactly=True) \
53+
.run(f.getvalue())
54+
55+
def test_warn_only_once_in_loop_func(self):
56+
def w():
57+
warnings.warn("I am warning you")
58+
59+
@torch.jit.script
60+
def fn():
61+
for _ in range(10):
62+
w()
63+
64+
f = io.StringIO()
65+
with redirect_stderr(f):
66+
fn()
67+
68+
FileCheck() \
69+
.check_count(
70+
str="UserWarning: I am warning you",
71+
count=1,
72+
exactly=True) \
73+
.run(f.getvalue())
74+
75+
def test_warn_once_per_func(self):
76+
def w1():
77+
warnings.warn("I am warning you")
78+
79+
def w2():
80+
warnings.warn("I am warning you")
81+
82+
@torch.jit.script
83+
def fn():
84+
w1()
85+
w2()
86+
87+
f = io.StringIO()
88+
with redirect_stderr(f):
89+
fn()
90+
91+
FileCheck() \
92+
.check_count(
93+
str="UserWarning: I am warning you",
94+
count=2,
95+
exactly=True) \
96+
.run(f.getvalue())
97+
98+
def test_warn_once_per_func_in_loop(self):
99+
def w1():
100+
warnings.warn("I am warning you")
101+
102+
def w2():
103+
warnings.warn("I am warning you")
104+
105+
@torch.jit.script
106+
def fn():
107+
for _ in range(10):
108+
w1()
109+
w2()
110+
111+
f = io.StringIO()
112+
with redirect_stderr(f):
113+
fn()
114+
115+
FileCheck() \
116+
.check_count(
117+
str="UserWarning: I am warning you",
118+
count=2,
119+
exactly=True) \
120+
.run(f.getvalue())
121+
122+
def test_warn_multiple_calls_multiple_warnings(self):
123+
@torch.jit.script
124+
def fn():
125+
warnings.warn("I am warning you")
126+
127+
f = io.StringIO()
128+
with redirect_stderr(f):
129+
fn()
130+
fn()
131+
132+
FileCheck() \
133+
.check_count(
134+
str="UserWarning: I am warning you",
135+
count=2,
136+
exactly=True) \
137+
.run(f.getvalue())
138+
139+
def test_warn_multiple_calls_same_func_diff_stack(self):
140+
def warn(caller: str):
141+
warnings.warn("I am warning you from " + caller)
142+
143+
@torch.jit.script
144+
def foo():
145+
warn("foo")
146+
147+
@torch.jit.script
148+
def bar():
149+
warn("bar")
150+
151+
f = io.StringIO()
152+
with redirect_stderr(f):
153+
foo()
154+
bar()
155+
156+
FileCheck() \
157+
.check_count(
158+
str="UserWarning: I am warning you from foo",
159+
count=1,
160+
exactly=True) \
161+
.check_count(
162+
str="UserWarning: I am warning you from bar",
163+
count=1,
164+
exactly=True) \
165+
.run(f.getvalue())

‎test/test_jit.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jit.test_enum import TestEnum # noqa: F401
3333
from jit.test_profiler import TestProfiler # noqa: F401
3434
from jit.test_slice import TestSlice # noqa: F401
35+
from jit.test_warn import TestWarn # noqa: F401
3536

3637
# Torch
3738
from torch import Tensor

‎tools/build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ core_sources_full = [
148148
"torch/csrc/jit/ir/scope.cpp",
149149
"torch/csrc/jit/ir/subgraph_matcher.cpp",
150150
"torch/csrc/jit/jit_log.cpp",
151+
"torch/csrc/jit/passes/annotate_warns.cpp",
151152
"torch/csrc/jit/passes/bailout_graph.cpp",
152153
"torch/csrc/jit/passes/batch_mm.cpp",
153154
"torch/csrc/jit/passes/canonicalize.cpp",

‎torch/csrc/jit/frontend/ir_emitter.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <torch/csrc/jit/frontend/schema_matching.h>
99
#include <torch/csrc/jit/frontend/script_type_parser.h>
1010
#include <torch/csrc/jit/ir/ir.h>
11+
#include <torch/csrc/jit/passes/annotate_warns.h>
1112
#include <torch/csrc/jit/passes/canonicalize.h>
1213
#include <torch/csrc/jit/passes/constant_pooling.h>
1314
#include <torch/csrc/jit/passes/constant_propagation.h>
@@ -4090,6 +4091,10 @@ void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
40904091

40914092
// For jitter
40924093
CanonicalizeOutputs(to_clean);
4094+
4095+
// Annotate aten::warns so that each has its unique ID. This enables us to
4096+
// mimic Python behavior of only emitting each warning only once.
4097+
AnnotateWarns(to_clean);
40934098
}
40944099

40954100
// we consider _N where N is a number, to be a non-meaningful name
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <torch/csrc/jit/passes/annotate_warns.h>
2+
3+
#include <atomic>
4+
5+
namespace torch {
6+
namespace jit {
7+
8+
void AnnotateWarns(Block* b) {
9+
static std::atomic<int64_t> idx(0);
10+
for (Node* n : b->nodes()) {
11+
for (Block* child_b : n->blocks()) {
12+
AnnotateWarns(child_b);
13+
}
14+
15+
if (n->kind() != aten::warn) {
16+
continue;
17+
}
18+
19+
n->i_(attr::warn_id, idx);
20+
idx++;
21+
}
22+
}
23+
24+
void AnnotateWarns(const std::shared_ptr<Graph>& graph) {
25+
AnnotateWarns(graph->block());
26+
}
27+
28+
} // namespace jit
29+
} // namespace torch
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/ir/ir.h>
4+
5+
namespace torch {
6+
namespace jit {
7+
8+
TORCH_API void AnnotateWarns(const std::shared_ptr<Graph>& graph);
9+
10+
} // namespace jit
11+
} // namespace torch

‎torch/csrc/jit/runtime/instruction.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace jit {
5252
_(ISINSTANCE, "TI") /* check object is one of types[X:X+N] */ \
5353
_(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */ \
5454
_(FORK, "CN") /* launch a thread to run code entry x with N inputs */ \
55-
_(WARN, "") /* emit a warning with line information */ \
55+
_(WARN, "I") /* emit a warning with line information */ \
5656
_(ENTER, "EN") /* enter scope of a contextmanager */ \
5757
_(EXIT, "EX") /* exit the last entered contextmanager */
5858

‎torch/csrc/jit/runtime/interpreter.cpp

+60-13
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,16 @@ struct TLSCurrentInterpreterGuard {
428428
InterpreterStateImpl* prev_state_;
429429
};
430430

431+
template <class Ttarget, class Tsource>
432+
Ttarget safe_narrow_cast(Tsource v) {
433+
Ttarget res = static_cast<Ttarget>(v);
434+
// Casting it back to check whether it overflew.
435+
if (static_cast<Tsource>(res) != v) {
436+
throw std::runtime_error("safe_narrow_cast<>() failed due to overflow");
437+
}
438+
return res;
439+
}
440+
431441
struct CodeImpl {
432442
friend struct InterpreterState;
433443
std::vector<Instruction> instructions_;
@@ -535,7 +545,10 @@ struct CodeImpl {
535545
}
536546

537547
void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) {
538-
instructions_.emplace_back(op, X, N);
548+
instructions_.emplace_back(
549+
op,
550+
safe_narrow_cast<int32_t, int64_t>(X),
551+
safe_narrow_cast<int16_t, int64_t>(N));
539552
instructions_source_.emplace_back(current_node_);
540553

541554
// check that we didn't accidentally emit nodes out of topological order
@@ -873,7 +886,11 @@ struct CodeImpl {
873886

874887
void emitWarn(Node* node) {
875888
emitLoadInputs(node->inputs());
876-
insertInstruction(WARN);
889+
int32_t idx = -1;
890+
if (node->hasAttribute(attr::warn_id)) {
891+
idx = static_cast<int32_t>(node->i(attr::warn_id));
892+
}
893+
insertInstruction(WARN, idx);
877894
}
878895

879896
void emitEnter(Node* node) {
@@ -1017,6 +1034,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
10171034
}
10181035

10191036
private:
1037+
struct WarnedNodes {
1038+
public:
1039+
// Inserts idx into warned_nodes_, returns a boolean indicates whether
1040+
// insertion actually happened (idx wasn't originally in the set).
1041+
bool insert(int32_t idx) {
1042+
std::unique_lock<std::mutex> lock(mutex_);
1043+
return warned_nodes_.insert(idx).second;
1044+
}
1045+
1046+
private:
1047+
std::mutex mutex_;
1048+
std::unordered_set<int32_t> warned_nodes_;
1049+
};
1050+
1051+
WarnedNodes warned_nodes_;
1052+
10201053
// if we need to suspend, where do we reset the stack?
10211054
// answer: to where it was when we were called, not
10221055
// including any inputs to this function
@@ -1487,21 +1520,35 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
14871520
++frame.pc;
14881521
} break;
14891522
case WARN: {
1490-
Node* node = frame.function->instructions_source_.at(frame.pc);
1523+
// Keeps track of which WARN instruction has been executed before,
1524+
// we only want to execute each WARN once to match default Python
1525+
// warning behavior.
1526+
bool need_warn = true;
1527+
if (inst.X != -1) {
1528+
need_warn = warned_nodes_.insert(inst.X);
1529+
}
1530+
1531+
Node* node =
1532+
frames.back().function->instructions_source_.at(frame.pc);
14911533
auto range = node->sourceRange().source();
14921534
if (range->filename()) {
1493-
auto line = range->starting_line_no() +
1494-
range->lineno_for_offset(node->sourceRange().start());
14951535
drop(stack, 1);
1496-
c10::SourceLocation location{
1497-
"", range->filename()->c_str(), uint32_t(line)};
1498-
// Sends the warning to the warning handler with the
1499-
// "verbatim" flag. This flag ensures the warning handler
1500-
// will print the exception as configured.
1501-
c10::Warning::warn(
1502-
location, pop(stack).toStringRef(), /*verbatim=*/true);
1536+
const auto msg = pop(stack).toStringRef();
1537+
if (need_warn) {
1538+
auto line = range->starting_line_no() +
1539+
range->lineno_for_offset(node->sourceRange().start());
1540+
c10::SourceLocation location{
1541+
"", range->filename()->c_str(), uint32_t(line)};
1542+
// Sends the warning to the warning handler with the
1543+
// "verbatim" flag. This flag ensures the warning handler
1544+
// will print the exception as configured.
1545+
c10::Warning::warn(location, msg, /*verbatim=*/true);
1546+
}
15031547
} else {
1504-
TORCH_WARN(pop(stack).toStringRef());
1548+
const auto msg = pop(stack).toStringRef();
1549+
if (need_warn) {
1550+
TORCH_WARN(msg);
1551+
}
15051552
}
15061553
++frame.pc;
15071554
} break;

0 commit comments

Comments
 (0)
Please sign in to comment.