Skip to content

Commit

Permalink
primitives: Remove templates from RandomSource APIs (RobotLocomotion#…
Browse files Browse the repository at this point in the history
…11670)

Templates make compilation times long and error messages bad.
  • Loading branch information
jwnimmer-tri authored Jun 24, 2019
1 parent c6f0525 commit 12a8432
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 199 deletions.
2 changes: 2 additions & 0 deletions bindings/pydrake/systems/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ drake_pybind_library(
":framework_py",
":module_py",
],
py_srcs = ["_primitives_extra.py"],
)

drake_pybind_library(
Expand Down Expand Up @@ -456,6 +457,7 @@ drake_py_unittest(
":primitives_py",
":test_util_py",
"//bindings/pydrake:trajectories_py",
"//bindings/pydrake/common/test_utilities:deprecation_py",
],
)

Expand Down
54 changes: 54 additions & 0 deletions bindings/pydrake/systems/_primitives_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# See `ExecuteExtraPythonCode` in `pydrake_pybind.h` for usage details and
# rationale.

import warnings


_DEPRECATION_MESSAGE = (
"Use primitives.RandomSource(RandomDistribution.kFoo, ...) instead of " +
"primitives.FooRandomSource. This class will be removed on 2019-10-01.")


def UniformRandomSource(num_outputs, sampling_interval_sec):
"""Deprecated constructor that desugars to
primitives.RandomSource(RandomDistribution.kUniform, **args, **kwargs).
This constructor will be removed on 2019-10-01."""
from pydrake.common import RandomDistribution
from pydrake.common.deprecation import DrakeDeprecationWarning
from pydrake.systems.primitives import RandomSource
warnings.warn(_DEPRECATION_MESSAGE, category=DrakeDeprecationWarning,
stacklevel=2)
return RandomSource(
distribution=RandomDistribution.kUniform,
num_outputs=num_outputs,
sampling_interval_sec=sampling_interval_sec)


def GaussianRandomSource(num_outputs, sampling_interval_sec):
"""Deprecated constructor that desugars to
primitives.RandomSource(RandomDistribution.kGaussian, **args, **kwargs).
This constructor will be removed on 2019-10-01."""
from pydrake.common import RandomDistribution
from pydrake.common.deprecation import DrakeDeprecationWarning
from pydrake.systems.primitives import RandomSource
warnings.warn(_DEPRECATION_MESSAGE, category=DrakeDeprecationWarning,
stacklevel=2)
return RandomSource(
distribution=RandomDistribution.kGaussian,
num_outputs=num_outputs,
sampling_interval_sec=sampling_interval_sec)


def ExponentialRandomSource(num_outputs, sampling_interval_sec):
"""Deprecated constructor that desugars to
primitives.RandomSource(RandomDistribution.kExponential, **args, **kwargs).
This constructor will be removed on 2019-10-01."""
from pydrake.common import RandomDistribution
from pydrake.common.deprecation import DrakeDeprecationWarning
from pydrake.systems.primitives import RandomSource
warnings.warn(_DEPRECATION_MESSAGE, category=DrakeDeprecationWarning,
stacklevel=2)
return RandomSource(
distribution=RandomDistribution.kExponential,
num_outputs=num_outputs,
sampling_interval_sec=sampling_interval_sec)
23 changes: 7 additions & 16 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,22 +228,11 @@ PYBIND11_MODULE(primitives, m) {
&BarycentricMeshSystem<double>::get_output_values,
doc.BarycentricMeshSystem.get_output_values.doc);

// Docs for typedef not being parsed.
py::class_<UniformRandomSource, LeafSystem<double>>(m, "UniformRandomSource")
.def(py::init<int, double>(), py::arg("num_outputs"),
py::arg("sampling_interval_sec"));

// Docs for typedef not being parsed.
py::class_<GaussianRandomSource, LeafSystem<double>>(
m, "GaussianRandomSource")
.def(py::init<int, double>(), py::arg("num_outputs"),
py::arg("sampling_interval_sec"));

// Docs for typedef not being parsed.
py::class_<ExponentialRandomSource, LeafSystem<double>>(
m, "ExponentialRandomSource")
.def(py::init<int, double>(), py::arg("num_outputs"),
py::arg("sampling_interval_sec"));
py::class_<RandomSource, LeafSystem<double>>(
m, "RandomSource", doc.RandomSource.doc)
.def(py::init<RandomDistribution, int, double>(), py::arg("distribution"),
py::arg("num_outputs"), py::arg("sampling_interval_sec"),
doc.RandomSource.ctor.doc);

py::class_<TrajectorySource<double>, LeafSystem<double>>(
m, "TrajectorySource", doc.TrajectorySource.doc)
Expand Down Expand Up @@ -289,6 +278,8 @@ PYBIND11_MODULE(primitives, m) {
py_reference, doc.LogOutput.doc);

// TODO(eric.cousineau): Add more systems as needed.

ExecuteExtraPythonCode(m);
}

} // namespace pydrake
Expand Down
27 changes: 15 additions & 12 deletions bindings/pydrake/systems/test/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np

from pydrake.autodiffutils import AutoDiffXd
from pydrake.common import RandomDistribution
from pydrake.common.test_utilities.deprecation import catch_drake_warnings
from pydrake.symbolic import Expression, Variable
from pydrake.systems.analysis import Simulator
from pydrake.systems.framework import (
Expand Down Expand Up @@ -36,6 +38,7 @@
Multiplexer, Multiplexer_,
ObservabilityMatrix,
PassThrough, PassThrough_,
RandomSource,
Saturation, Saturation_,
SignalLogger, SignalLogger_,
Sine, Sine_,
Expand Down Expand Up @@ -362,24 +365,24 @@ def test_multiplexer(self):
value = output.get_vector_data(0)
self.assertTrue(isinstance(value, MyVector2))

def test_random_sources(self):
uniform_source = UniformRandomSource(num_outputs=2,
sampling_interval_sec=0.01)
self.assertEqual(uniform_source.get_output_port(0).size(), 2)

gaussian_source = GaussianRandomSource(num_outputs=3,
sampling_interval_sec=0.01)
self.assertEqual(gaussian_source.get_output_port(0).size(), 3)

exponential_source = ExponentialRandomSource(num_outputs=4,
sampling_interval_sec=0.1)
self.assertEqual(exponential_source.get_output_port(0).size(), 4)
def test_random_source(self):
source = RandomSource(distribution=RandomDistribution.kUniform,
num_outputs=2, sampling_interval_sec=0.01)
self.assertEqual(source.get_output_port(0).size(), 2)

builder = DiagramBuilder()
# Note: There are no random inputs to add to the empty diagram, but it
# confirms the API works.
AddRandomInputs(sampling_interval_sec=0.01, builder=builder)

def test_random_sources_deprecated(self):
with catch_drake_warnings(expected_count=1):
UniformRandomSource(num_outputs=2, sampling_interval_sec=0.01)
with catch_drake_warnings(expected_count=1):
GaussianRandomSource(num_outputs=3, sampling_interval_sec=0.01)
with catch_drake_warnings(expected_count=1):
ExponentialRandomSource(num_outputs=4, sampling_interval_sec=0.1)

def test_ctor_api(self):
"""Tests construction of systems for systems whose executions semantics
are not tested above.
Expand Down
4 changes: 2 additions & 2 deletions systems/analysis/test/monte_carlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ GTEST_TEST(RandomSimulationTest, WithRandomInputs) {
DiagramBuilder<double> builder;
const int kNumOutputs = 1;
const double sampling_interval = 0.1;
auto random_source =
builder.AddSystem<UniformRandomSource>(kNumOutputs, sampling_interval);
auto random_source = builder.AddSystem<RandomSource>(
RandomDistribution::kUniform, kNumOutputs, sampling_interval);
auto pass_through = builder.template AddSystem<PassThrough>(kNumOutputs);
builder.Connect(random_source->get_output_port(0),
pass_through->get_input_port());
Expand Down
147 changes: 114 additions & 33 deletions systems/primitives/random_source.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,122 @@
#include "drake/systems/primitives/random_source.h"

#include <atomic>
#include <random>

#include "drake/common/never_destroyed.h"

namespace drake {
namespace systems {
namespace {

using Seed = RandomSource::Seed;

// Stores exactly one of the three supported distribution objects. Note that
// the distribution objects hold computational state; they are not just pure
// mathematical functions.
using DistributionVariant = variant<
std::uniform_real_distribution<double>,
std::normal_distribution<double>,
std::exponential_distribution<double>>;

// Creates a distribution object from the distribution enumeration.
DistributionVariant MakeDistributionVariant(RandomDistribution which) {
switch (which) {
case RandomDistribution::kUniform:
return std::uniform_real_distribution<double>();
case RandomDistribution::kGaussian:
return std::normal_distribution<double>();
case RandomDistribution::kExponential:
return std::exponential_distribution<double>();
}
DRAKE_UNREACHABLE();
}

// Generates real-valued (i.e., `double`) samples from some distribution. This
// serves as the abstract state of a RandomSource, which encompasses all of the
// source's state *except* for the currently-sampled output values which are
// stored as discrete state.
class SampleGenerator {
public:
DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(SampleGenerator)

namespace internal {
template<typename Generator>
typename Generator::result_type generate_unique_seed() {
static never_destroyed<typename Generator::result_type> seed(
Generator::default_seed);
SampleGenerator() = default;
SampleGenerator(Seed seed, RandomDistribution which)
: generator_(seed), distribution_(MakeDistributionVariant(which)) {}

double GenerateNext() {
switch (distribution_.index()) {
case 0: return get<0>(distribution_)(generator_);
case 1: return get<1>(distribution_)(generator_);
case 2: return get<2>(distribution_)(generator_);
}
DRAKE_UNREACHABLE();
}

private:
RandomGenerator generator_;
DistributionVariant distribution_;
};

// Returns a monotonically increasing integer on each call.
Seed get_next_seed() {
static never_destroyed<std::atomic<Seed>> seed(
RandomGenerator::default_seed);
return seed.access()++;
}

template RandomGenerator::result_type generate_unique_seed<RandomGenerator>();
} // namespace

} // namespace internal
RandomSource::RandomSource(
RandomDistribution distribution, int num_outputs,
double sampling_interval_sec)
: distribution_(distribution), seed_(get_next_seed()) {
this->DeclareDiscreteState(num_outputs);
this->DeclareAbstractState(Value<SampleGenerator>().Clone());
this->DeclarePeriodicUnrestrictedUpdateEvent(
sampling_interval_sec, 0., &RandomSource::UpdateSamples);
this->DeclareVectorOutputPort(
"output", BasicVector<double>(num_outputs),
[](const Context<double>& context, BasicVector<double>* output) {
const auto& values = context.get_discrete_state(0);
output->SetFrom(values);
});
}

RandomSource::~RandomSource() {}

void RandomSource::SetDefaultState(
const Context<double>& context, State<double>* state) const {
SetSeed(seed_, context, state);
}

void RandomSource::SetRandomState(
const Context<double>& context, State<double>* state,
RandomGenerator* seed_generator) const {
const Seed fresh_seed = (*seed_generator)();
SetSeed(fresh_seed, context, state);
}

// Writes the given seed into abstract state (replacing the existing
// SampleGenerator) and then does `UpdateSamples`.
void RandomSource::SetSeed(
Seed seed, const Context<double>& context, State<double>* state) const {
state->template get_mutable_abstract_state<SampleGenerator>(0) =
SampleGenerator(seed, distribution_);
UpdateSamples(context, state);
}

// Samples random values into the discrete state, using the SampleGenerator
// from the abstract state. (Note that the generator's abstract state is also
// mutated as a side effect of this method.)
void RandomSource::UpdateSamples(
const Context<double>&, State<double>* state) const {
auto& source = state->template get_mutable_abstract_state<SampleGenerator>(0);
auto& samples = state->get_mutable_discrete_state(0);
for (int i = 0; i < samples.size(); ++i) {
samples[i] = source.GenerateNext();
}
}

int AddRandomInputs(double sampling_interval_sec,
DiagramBuilder<double>* builder) {
Expand All @@ -24,48 +125,28 @@ int AddRandomInputs(double sampling_interval_sec,
// there is (currently) no builder->GetSystems() method.
for (const auto* system : builder->GetMutableSystems()) {
for (int i = 0; i < system->num_input_ports(); i++) {
const systems::InputPort<double>& port =
system->get_input_port(i);
const systems::InputPort<double>& port = system->get_input_port(i);
// Check for the random label.
if (!port.is_random()) {
continue;
}

typedef typename Diagram<double>::InputPortLocator InputPortLocator;
using InputPortLocator = Diagram<double>::InputPortLocator;
// Check if the input is already wired up.
InputPortLocator id{&port.get_system(), port.get_index()};
if (builder->connection_map_.count(id) > 0 ||
builder->diagram_input_set_.count(id) > 0) {
continue;
}

count++;
switch (port.get_random_type().value()) {
case RandomDistribution::kUniform: {
const auto* uniform = builder->AddSystem<UniformRandomSource>(
port.size(), sampling_interval_sec);
builder->Connect(uniform->get_output_port(0), port);
continue;
}
case RandomDistribution::kGaussian: {
const auto* gaussian = builder->AddSystem<GaussianRandomSource>(
port.size(), sampling_interval_sec);
builder->Connect(gaussian->get_output_port(0), port);
continue;
}
case RandomDistribution::kExponential: {
const auto* exponential = builder->AddSystem<ExponentialRandomSource>(
port.size(), sampling_interval_sec);
builder->Connect(exponential->get_output_port(0), port);
continue;
}
}
DRAKE_UNREACHABLE();
const auto* const source = builder->AddSystem<RandomSource>(
port.get_random_type().value(), port.size(), sampling_interval_sec);
builder->Connect(source->get_output_port(0), port);
++count;
}
}
return count;
}

} // namespace systems
} // namespace drake

Loading

0 comments on commit 12a8432

Please sign in to comment.