-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random op #3060
Merged
Merged
Random op #3060
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
5ad9474
add random op
dzhwinter c110f56
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter 0d554f1
"add template fill function"
dzhwinter 6f80b5f
"move to template function"
dzhwinter d263cce
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter 32c15a2
"random op test"
dzhwinter 30a47fe
"link pybind11"
dzhwinter 2b3e362
"template specialization link include"
dzhwinter 984225e
"fix operator"
dzhwinter 11f9f5f
"fix const dependency hell"
dzhwinter 9a16327
"remove const qualify"
dzhwinter 69b1b26
"cpu only macro"
dzhwinter a22567e
"fix almost equal error"
dzhwinter 5721334
"update the compute kernel"
dzhwinter 36d7e1f
"fix const hell"
dzhwinter 0253f2c
"fix bind python error"
dzhwinter 4d8ece8
"update"
dzhwinter 4755668
"remove unused code"
dzhwinter 4973926
"fix register error"
dzhwinter 933e55e
fix conflict
dzhwinter 2447c34
merge origin/develop
dzhwinter 0f8c9db
device context pointer
dzhwinter 58561d8
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter fcd6f64
"redefine random op"
dzhwinter e2c08d2
"keep style same with uniform operators"
dzhwinter 52d2ebd
"test gaussian random in python side"
dzhwinter 8804b24
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter 555af4d
"format code"
dzhwinter 23ac845
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter 6535a7b
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter d98e299
"keep same with uniform random op"
dzhwinter 6fc6647
Merge remote-tracking branch 'origin/develop' into random_op
dzhwinter 7082550
"remove context random seeding "
dzhwinter df4fe67
"remove attribute"
dzhwinter 6bac3e1
"remove unused test net modified"
dzhwinter bbd7378
"ci job failed weired. restart ci job."
dzhwinter f702e79
"relauch ci"
dzhwinter File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/random_op.h" | ||
#include "glog/logging.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class GaussianRandomOpKernel<platform::CPUPlace, T> | ||
: public framework::OpKernel { | ||
public: | ||
void Compute(const framework::KernelContext& context) const override { | ||
auto mean = context.op_.GetAttr<T>("mean"); | ||
auto std = context.op_.GetAttr<T>("std"); | ||
// auto seed = context.op_.GetAttr<T>("seed"); | ||
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); | ||
T* r = output->mutable_data<T>(context.GetPlace()); | ||
auto ctx = | ||
static_cast<const platform::CPUDeviceContext*>(context.device_context_); | ||
// generator need to modify context | ||
auto g = const_cast<platform::CPUDeviceContext*>(ctx)->RandGenerator(); | ||
std::normal_distribution<T> distribution(mean, std); | ||
for (int i = 0; i < framework::product(output->dims()); ++i) { | ||
r[i] = distribution(g); | ||
} | ||
} | ||
}; | ||
|
||
class GaussianRandomOp : public framework::OperatorWithKernel { | ||
protected: | ||
void InferShape( | ||
const std::vector<const framework::Tensor*>& inputs, | ||
const std::vector<framework::Tensor*>& outputs) const override { | ||
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); | ||
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); | ||
PADDLE_ENFORCE(outputs[0] != nullptr, | ||
"Outputs of RandomOp must all be set."); | ||
outputs[0]->Resize( | ||
framework::make_ddim(this->GetAttr<std::vector<int>>("shape"))); | ||
} | ||
}; | ||
|
||
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
GaussianRandomOpMaker(framework::OpProto* proto, | ||
framework::OpAttrChecker* op_checker) | ||
: framework::OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized"); | ||
// AddAttr<float>("seed", "random seed generator.").SetDefault(1337); | ||
AddAttr<float>("mean", "mean value of random.").SetDefault(.0); | ||
AddAttr<float>("std", "minimum value of random value") | ||
.SetDefault(1.0) | ||
.LargerThan(.0); | ||
AddOutput("Out", "output matrix of random op"); | ||
AddComment(R"DOC( | ||
GaussianRandom Operator fill a matrix in normal distribution. | ||
The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std) | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OP(gaussian_random, | ||
paddle::operators::GaussianRandomOp, | ||
paddle::operators::GaussianRandomOpMaker); | ||
|
||
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace, | ||
float> | ||
GaussianRandomOpKernel_CPU_float; | ||
REGISTER_OP_CPU_KERNEL(gaussian_random, GaussianRandomOpKernel_CPU_float); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "paddle/operators/random_op.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template<typename T> | ||
class GaussianRandomOpKernel<platform::GPUPlace, T> : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::KernelContext& context) const override { | ||
auto mean = context.op_.GetAttr<T>("mean"); | ||
auto std = context.op_.GetAttr<T>("std"); | ||
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); | ||
T* r = output->mutable_data<T>(context.GetPlace()); | ||
auto ctx = static_cast<const platform::GPUDeviceContext*> | ||
(context.device_context_); | ||
// generator need to modify context | ||
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator(); | ||
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std); | ||
|
||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
|
||
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, float> | ||
RandomOpKernel_GPU_float; | ||
REGISTER_OP_GPU_KERNEL(random, RandomOpKernel_GPU_float); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#pragma once | ||
#include <random> | ||
#include "glog/logging.h" | ||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T> | ||
class GaussianRandomOpKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::KernelContext& context) const override {} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python | ||
add_op fc_op sgd_op cross_entropy_op) | ||
add_op fc_op sgd_op cross_entropy_op random_op) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import unittest | ||
import paddle.v2.framework.create_op_creation_methods as creation | ||
import paddle.v2.framework.core as core | ||
from op_test_util import OpTestMeta | ||
import numpy | ||
|
||
|
||
class TestRandomOp(unittest.TestCase): | ||
def test_random(self): | ||
scope = core.Scope(None) | ||
# Out = scope.create_var("Out") | ||
op = creation.op_creations.random( | ||
shape=[1000, 1000], mean=5.0, std=1.0, seed=1701, Out="Out") | ||
for out in op.outputs(): | ||
if scope.get_var(out) is None: | ||
scope.create_var(out).get_tensor() | ||
|
||
tensor = scope.get_var("Out").get_tensor() | ||
op.infer_shape(scope) | ||
self.assertEqual([1000, 1000], tensor.shape()) | ||
ctx = core.DeviceContext.cpu_context() | ||
op.run(scope, ctx) | ||
tensor_array = numpy.array(tensor) | ||
self.assertAlmostEqual(numpy.mean(tensor_array), 5.0, places=3) | ||
self.assertAlmostEqual(numpy.std(tensor_array), 1.0, places=3) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So now we are not able to manually specify random seed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted the seeding function in context.
it is the place that we need to modify the Execution context. currently, move the set seed to random op. We need a global configuration to unify the random seed in the future. Maybe we can fix it in next PR.