-
Notifications
You must be signed in to change notification settings - Fork 793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix l2_normalize & add nn.functional.normalize #6940
Merged
Merged
Changes from 13 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
3e28df7
fix l2_normalize
mosout dc5d95b
Merge remote-tracking branch 'upstream/master' into fix_l2_norm
mosout 8cc1e06
add normalize
mosout 6f38334
Merge remote-tracking branch 'upstream/master' into fix_l2_norm
mosout 693f970
add test for normalize
mosout 7ebdb2d
refine
mosout fad58d3
clean l2_normalize and refine normalize
mosout fe7536c
Merge remote-tracking branch 'upstream/master' into fix_l2_norm
mosout 7de74a9
simplify normalize test
mosout d1844f6
Fix l2norm block_size
liujuncheng 2d78b19
Merge branch 'master' into fix_l2_block_size
oneflow-ci-bot 7d3b10d
Merge remote-tracking branch 'upstream/fix_l2_block_size' into fix_l2…
mosout d875f8c
Merge remote-tracking branch 'upstream/master' into fix_l2_norm
mosout 4f42c87
refine
mosout 63ad632
Merge remote-tracking branch 'upstream/master' into fix_l2_norm
mosout 5ca43f9
Merge branch 'master' into fix_l2_norm
mosout File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,10 +16,9 @@ | |||||
|
||||||
import unittest | ||||||
from collections import OrderedDict | ||||||
|
||||||
import numpy as np | ||||||
from test_util import GenArgList | ||||||
|
||||||
from oneflow.test_utils.automated_test_util import * | ||||||
import numpy as np | ||||||
import oneflow as flow | ||||||
import oneflow.unittest | ||||||
|
||||||
|
@@ -32,12 +31,18 @@ def _count(shape, begin_axis, end_axis): | |||||
|
||||||
|
||||||
def _l2_norm_numpy(x, dim, epsilon=1e-12): | ||||||
axes = [k for k in range(len(list(x.shape)))] | ||||||
axes[0], axes[dim] = axes[dim], axes[0] | ||||||
axes_tuple = tuple(axes) | ||||||
|
||||||
x = np.transpose(x, axes_tuple) | ||||||
|
||||||
square_x_sum_shape = list(x.shape) | ||||||
square_x_sum_shape[dim] = 1 | ||||||
square_x_sum_shape[0] = 1 | ||||||
|
||||||
c = x.shape[dim] | ||||||
c = x.shape[0] | ||||||
n = int(x.size / c) | ||||||
d = _count(x.shape, dim + 1, len(x.shape)) | ||||||
d = _count(x.shape, 1, len(x.shape)) | ||||||
|
||||||
square_x_sum = np.zeros(square_x_sum_shape) | ||||||
|
||||||
|
@@ -58,13 +63,21 @@ def _l2_norm_numpy(x, dim, epsilon=1e-12): | |||||
|
||||||
square_x_sum = square_x_sum_flatten.reshape(square_x_sum.shape) | ||||||
out = out.reshape(x.shape) | ||||||
return out, square_x_sum | ||||||
return np.transpose(out, axes_tuple), np.transpose(square_x_sum, axes_tuple) | ||||||
|
||||||
|
||||||
def _l2_norm_backward_np(dy, y, square_x_sum, dim, epsilon=1e-12): | ||||||
c = dy.shape[dim] | ||||||
axes = [k for k in range(len(list(y.shape)))] | ||||||
axes[0], axes[dim] = axes[dim], axes[0] | ||||||
axes_tuple = tuple(axes) | ||||||
|
||||||
dy = np.transpose(dy, axes_tuple) | ||||||
y = np.transpose(y, axes_tuple) | ||||||
square_x_sum = np.transpose(square_x_sum, axes_tuple) | ||||||
|
||||||
c = dy.shape[0] | ||||||
n = int(dy.size / c) | ||||||
d = _count(dy.shape, dim + 1, len(y.shape)) | ||||||
d = _count(dy.shape, 1, len(y.shape)) | ||||||
|
||||||
dx = np.zeros(dy.shape).reshape(-1) | ||||||
dy_flatten = dy.reshape(-1) | ||||||
|
@@ -89,7 +102,7 @@ def _l2_norm_backward_np(dy, y, square_x_sum, dim, epsilon=1e-12): | |||||
index = offset + j * d | ||||||
dx[index] = (1 / norm) * dy_flatten[index] | ||||||
|
||||||
return dx.reshape(y.shape) | ||||||
return np.transpose(dx.reshape(y.shape), axes_tuple) | ||||||
|
||||||
|
||||||
def _test_l2_normalize(test_case, device, dim, shape): | ||||||
|
@@ -124,5 +137,24 @@ def test_l2_normalize(test_case): | |||||
arg[0](test_case, *arg[1:]) | ||||||
|
||||||
|
||||||
@flow.unittest.skip_unless_1n1d() | ||||||
class TestFunctionalNormalize(flow.unittest.TestCase): | ||||||
@autotest(check_graph=False) | ||||||
def test_functional_normalize(test_case): | ||||||
device = random_device() | ||||||
ndim = random(low=2) | ||||||
|
||||||
shape = list(random_tensor(ndim).value().shape) | ||||||
dim = random(low=0, high=ndim).to(int).value() | ||||||
shape[dim] = random(low=2, high=8).to(int).value() | ||||||
shape = tuple(shape) | ||||||
|
||||||
x = random_pytorch_tensor(len(shape), *shape).to(device) | ||||||
m = torch.nn.functional.normalize | ||||||
y = m(x, oneof(2, 3, 4), dim, 1e-12) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||||||
|
||||||
return y | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那这个需要函数导出到oneflow目录下么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以导出一下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我查了一下pytorch的接口 l2_normalize和normalize都没有导出到torch下都是在torch.nn.functional下的 我们还要导出么