@@ -428,6 +428,16 @@ struct TLSCurrentInterpreterGuard {
428
428
InterpreterStateImpl* prev_state_;
429
429
};
430
430
431
+ template <class Ttarget , class Tsource >
432
+ Ttarget safe_narrow_cast (Tsource v) {
433
+ Ttarget res = static_cast <Ttarget>(v);
434
+ // Casting it back to check whether it overflew.
435
+ if (static_cast <Tsource>(res) != v) {
436
+ throw std::runtime_error (" safe_narrow_cast<>() failed due to overflow" );
437
+ }
438
+ return res;
439
+ }
440
+
431
441
struct CodeImpl {
432
442
friend struct InterpreterState ;
433
443
std::vector<Instruction> instructions_;
@@ -535,7 +545,10 @@ struct CodeImpl {
535
545
}
536
546
537
547
void insertInstruction (OpCode op, int64_t X = 0 , uint64_t N = 0 ) {
538
- instructions_.emplace_back (op, X, N);
548
+ instructions_.emplace_back (
549
+ op,
550
+ safe_narrow_cast<int32_t , int64_t >(X),
551
+ safe_narrow_cast<int16_t , int64_t >(N));
539
552
instructions_source_.emplace_back (current_node_);
540
553
541
554
// check that we didn't accidentally emit nodes out of topological order
@@ -873,7 +886,11 @@ struct CodeImpl {
873
886
874
887
void emitWarn (Node* node) {
875
888
emitLoadInputs (node->inputs ());
876
- insertInstruction (WARN);
889
+ int32_t idx = -1 ;
890
+ if (node->hasAttribute (attr::warn_id)) {
891
+ idx = static_cast <int32_t >(node->i (attr::warn_id));
892
+ }
893
+ insertInstruction (WARN, idx);
877
894
}
878
895
879
896
void emitEnter (Node* node) {
@@ -1017,6 +1034,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
1017
1034
}
1018
1035
1019
1036
private:
1037
+ struct WarnedNodes {
1038
+ public:
1039
+ // Inserts idx into warned_nodes_, returns a boolean indicates whether
1040
+ // insertion actually happened (idx wasn't originally in the set).
1041
+ bool insert (int32_t idx) {
1042
+ std::unique_lock<std::mutex> lock (mutex_);
1043
+ return warned_nodes_.insert (idx).second ;
1044
+ }
1045
+
1046
+ private:
1047
+ std::mutex mutex_;
1048
+ std::unordered_set<int32_t > warned_nodes_;
1049
+ };
1050
+
1051
+ WarnedNodes warned_nodes_;
1052
+
1020
1053
// if we need to suspend, where do we reset the stack?
1021
1054
// answer: to where it was when we were called, not
1022
1055
// including any inputs to this function
@@ -1487,21 +1520,35 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
1487
1520
++frame.pc ;
1488
1521
} break ;
1489
1522
case WARN: {
1490
- Node* node = frame.function ->instructions_source_ .at (frame.pc );
1523
+ // Keeps track of which WARN instruction has been executed before,
1524
+ // we only want to execute each WARN once to match default Python
1525
+ // warning behavior.
1526
+ bool need_warn = true ;
1527
+ if (inst.X != -1 ) {
1528
+ need_warn = warned_nodes_.insert (inst.X );
1529
+ }
1530
+
1531
+ Node* node =
1532
+ frames.back ().function ->instructions_source_ .at (frame.pc );
1491
1533
auto range = node->sourceRange ().source ();
1492
1534
if (range->filename ()) {
1493
- auto line = range->starting_line_no () +
1494
- range->lineno_for_offset (node->sourceRange ().start ());
1495
1535
drop (stack, 1 );
1496
- c10::SourceLocation location{
1497
- " " , range->filename ()->c_str (), uint32_t (line)};
1498
- // Sends the warning to the warning handler with the
1499
- // "verbatim" flag. This flag ensures the warning handler
1500
- // will print the exception as configured.
1501
- c10::Warning::warn (
1502
- location, pop (stack).toStringRef (), /* verbatim=*/ true );
1536
+ const auto msg = pop (stack).toStringRef ();
1537
+ if (need_warn) {
1538
+ auto line = range->starting_line_no () +
1539
+ range->lineno_for_offset (node->sourceRange ().start ());
1540
+ c10::SourceLocation location{
1541
+ " " , range->filename ()->c_str (), uint32_t (line)};
1542
+ // Sends the warning to the warning handler with the
1543
+ // "verbatim" flag. This flag ensures the warning handler
1544
+ // will print the exception as configured.
1545
+ c10::Warning::warn (location, msg, /* verbatim=*/ true );
1546
+ }
1503
1547
} else {
1504
- TORCH_WARN (pop (stack).toStringRef ());
1548
+ const auto msg = pop (stack).toStringRef ();
1549
+ if (need_warn) {
1550
+ TORCH_WARN (msg);
1551
+ }
1505
1552
}
1506
1553
++frame.pc ;
1507
1554
} break ;
0 commit comments