Skip to content

Commit

Permalink
Fix negative position in InlinePropagator (csarofeen#1813)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 11, 2022
1 parent 10a996c commit de6b7ca
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::unordered_set<TensorView*> getPropagationSubgraph(
void ComputeAt::runAt(
TensorView* producer,
TensorView* consumer,
unsigned int consumer_position,
int64_t consumer_position,
ComputeAtMode mode) {
FUSER_PERF_SCOPE("ComputeAt::runAt");

Expand All @@ -176,7 +176,7 @@ void ComputeAt::runAt(
" are not in the same fusion.");

if (mode == ComputeAtMode::MostInlined) {
consumer_position = consumer->nDims();
consumer_position = -1;
}

FusionGuard fg(producer->fusion());
Expand Down Expand Up @@ -205,7 +205,7 @@ void ComputeAt::runAt(
void ComputeAt::runWith(
TensorView* producer,
TensorView* consumer,
unsigned int producer_position,
int64_t producer_position,
ComputeAtMode mode) {
FUSER_PERF_SCOPE("ComputeAt::runWith");

Expand All @@ -218,7 +218,7 @@ void ComputeAt::runWith(
" are not in the same fusion.");

if (mode == ComputeAtMode::MostInlined) {
producer_position = producer->nDims();
producer_position = -1;
}

FusionGuard fg(producer->fusion());
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/compute_at.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ struct ComputeAt {
static void runAt(
TensorView* producer,
TensorView* consumer,
unsigned int consumer_position,
int64_t consumer_position,
ComputeAtMode mode = ComputeAtMode::Standard);

// Runs the compute with pass making consumer look like producer, computing
// producer relative to consumer
static void runWith(
TensorView* producer,
TensorView* consumer,
unsigned int producer_position,
int64_t producer_position,
ComputeAtMode mode = ComputeAtMode::Standard);
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ InlinePropagator::InlinePropagator(
: max_pos_calc(mode),
selected_(std::move(selected)),
reference_(reference),
reference_pos_(reference_pos),
mode_(mode) {
if (reference_pos < 0) {
reference_pos += int64_t(reference->nDims()) + 1;
Expand All @@ -192,6 +191,7 @@ InlinePropagator::InlinePropagator(
" and <= ",
reference->nDims(),
".");
reference_pos_ = reference_pos;
}

void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
Expand Down

0 comments on commit de6b7ca

Please sign in to comment.