-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dropout functor error #9808
Conversation
…fix_dropout_functor_error
@@ -2846,31 +2846,29 @@ Maybe<Tensor> DropoutImpl(const std::shared_ptr<one::Tensor>& input, const float | |||
if (p == 1) { | |||
std::shared_ptr<Tensor> other = | |||
JUST(Constant(*input->shape(), Scalar(0.0), input->dtype(), JUST(input->device()))); | |||
return InplaceMul(input, other); | |||
return Mul(input, other); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
若使用inplace版mul,则后向计算检查会报错。
} | ||
std::shared_ptr<Tensor> noise = JUST(MakeFeatureNoise(input)); | ||
noise = | ||
JUST(BernoulliProb(noise, 1.0 - p, noise->dtype(), JUST(one::DefaultAutoGenerator()), false)); | ||
noise = JUST(InplaceScalarDiv(noise, Scalar(1.0 - p))); | ||
noise = JUST(InplaceMul(input, noise)); | ||
return noise; | ||
return JUST(Mul(input, noise)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false) | ||
<< "dropout probability has to be between 0 and 1, but got " << p; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里之前应该是写错了,结果应该是 false
"Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " | ||
"spatial dimension, a channel dimension, and an optional batch dimension " | ||
"(i.e. 2D or 3D inputs)."; | ||
bool is_batched = (input_dim == 3); | ||
std::shared_ptr<one::Tensor> result; | ||
std::shared_ptr<one::Tensor> result = input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要先给result
赋值为input
,否则is_batched
为true
时,DropoutImpl
输入为空。
LOG(WARNING) | ||
<< "dropout2d: Received a " << input_dim | ||
<< "-D input to dropout2d, which is deprecated " | ||
"and will result in an error in a future release. To retain the behavior " | ||
"and silence this warning, please use dropout instead. Note that dropout2d " | ||
"exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " | ||
"a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里对齐pytorch,弹warning,不报错
"a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)."; | ||
} | ||
if (input_dim == 3) { | ||
LOG(WARNING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
"exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " | ||
"a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)."; | ||
if (input_dim != 4 && input_dim != 5) { | ||
LOG(WARNING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
device = random_device() | ||
x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device) | ||
m = torch.nn.Dropout(p=0, inplace=random_bool()) | ||
m = torch.nn.Dropout(p=0, inplace=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
设置inplace为false,因为不能对 requires_grad = True 的叶子节点,做 inplace 修改
@autotest(n=5, check_graph=False) | ||
def test_dropout_eval(test_case): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在graph模式下(build函数中),m.eval()时非法的,故这几个用例关闭graph测试
Static analysis with clang failed. PR label automerge has been removed |
Static analysis with clang failed. PR label automerge has been removed |
Static analysis with clang failed. PR label automerge has been removed |
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9808/ |
修复dropout1d/2d/3d中的错误,打开相关测试, 修复 https://github.com/Oneflow-Inc/OneTeam/issues/1893 中的问题。