Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "primitives: Remove templates from RandomSource APIs" #11718

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions bindings/pydrake/systems/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ drake_pybind_library(
":framework_py",
":module_py",
],
py_srcs = ["_primitives_extra.py"],
)

drake_pybind_library(
Expand Down Expand Up @@ -457,7 +456,6 @@ drake_py_unittest(
":primitives_py",
":test_util_py",
"//bindings/pydrake:trajectories_py",
"//bindings/pydrake/common/test_utilities:deprecation_py",
],
)

Expand Down
54 changes: 0 additions & 54 deletions bindings/pydrake/systems/_primitives_extra.py

This file was deleted.

23 changes: 16 additions & 7 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,22 @@ PYBIND11_MODULE(primitives, m) {
&BarycentricMeshSystem<double>::get_output_values,
doc.BarycentricMeshSystem.get_output_values.doc);

py::class_<RandomSource, LeafSystem<double>>(
m, "RandomSource", doc.RandomSource.doc)
.def(py::init<RandomDistribution, int, double>(), py::arg("distribution"),
py::arg("num_outputs"), py::arg("sampling_interval_sec"),
doc.RandomSource.ctor.doc);
// Docs for typedef not being parsed.
py::class_<UniformRandomSource, LeafSystem<double>>(m, "UniformRandomSource")
.def(py::init<int, double>(), py::arg("num_outputs"),
py::arg("sampling_interval_sec"));

// Docs for typedef not being parsed.
py::class_<GaussianRandomSource, LeafSystem<double>>(
m, "GaussianRandomSource")
.def(py::init<int, double>(), py::arg("num_outputs"),
py::arg("sampling_interval_sec"));

// Docs for typedef not being parsed.
py::class_<ExponentialRandomSource, LeafSystem<double>>(
m, "ExponentialRandomSource")
.def(py::init<int, double>(), py::arg("num_outputs"),
py::arg("sampling_interval_sec"));

py::class_<TrajectorySource<double>, LeafSystem<double>>(
m, "TrajectorySource", doc.TrajectorySource.doc)
Expand Down Expand Up @@ -278,8 +289,6 @@ PYBIND11_MODULE(primitives, m) {
py_reference, doc.LogOutput.doc);

// TODO(eric.cousineau): Add more systems as needed.

ExecuteExtraPythonCode(m);
}

} // namespace pydrake
Expand Down
27 changes: 12 additions & 15 deletions bindings/pydrake/systems/test/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np

from pydrake.autodiffutils import AutoDiffXd
from pydrake.common import RandomDistribution
from pydrake.common.test_utilities.deprecation import catch_drake_warnings
from pydrake.symbolic import Expression, Variable
from pydrake.systems.analysis import Simulator
from pydrake.systems.framework import (
Expand Down Expand Up @@ -38,7 +36,6 @@
Multiplexer, Multiplexer_,
ObservabilityMatrix,
PassThrough, PassThrough_,
RandomSource,
Saturation, Saturation_,
SignalLogger, SignalLogger_,
Sine, Sine_,
Expand Down Expand Up @@ -365,24 +362,24 @@ def test_multiplexer(self):
value = output.get_vector_data(0)
self.assertTrue(isinstance(value, MyVector2))

def test_random_source(self):
source = RandomSource(distribution=RandomDistribution.kUniform,
num_outputs=2, sampling_interval_sec=0.01)
self.assertEqual(source.get_output_port(0).size(), 2)
def test_random_sources(self):
uniform_source = UniformRandomSource(num_outputs=2,
sampling_interval_sec=0.01)
self.assertEqual(uniform_source.get_output_port(0).size(), 2)

gaussian_source = GaussianRandomSource(num_outputs=3,
sampling_interval_sec=0.01)
self.assertEqual(gaussian_source.get_output_port(0).size(), 3)

exponential_source = ExponentialRandomSource(num_outputs=4,
sampling_interval_sec=0.1)
self.assertEqual(exponential_source.get_output_port(0).size(), 4)

builder = DiagramBuilder()
# Note: There are no random inputs to add to the empty diagram, but it
# confirms the API works.
AddRandomInputs(sampling_interval_sec=0.01, builder=builder)

def test_random_sources_deprecated(self):
with catch_drake_warnings(expected_count=1):
UniformRandomSource(num_outputs=2, sampling_interval_sec=0.01)
with catch_drake_warnings(expected_count=1):
GaussianRandomSource(num_outputs=3, sampling_interval_sec=0.01)
with catch_drake_warnings(expected_count=1):
ExponentialRandomSource(num_outputs=4, sampling_interval_sec=0.1)

def test_ctor_api(self):
"""Tests construction of systems for systems whose executions semantics
are not tested above.
Expand Down
4 changes: 2 additions & 2 deletions systems/analysis/test/monte_carlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ GTEST_TEST(RandomSimulationTest, WithRandomInputs) {
DiagramBuilder<double> builder;
const int kNumOutputs = 1;
const double sampling_interval = 0.1;
auto random_source = builder.AddSystem<RandomSource>(
RandomDistribution::kUniform, kNumOutputs, sampling_interval);
auto random_source =
builder.AddSystem<UniformRandomSource>(kNumOutputs, sampling_interval);
auto pass_through = builder.template AddSystem<PassThrough>(kNumOutputs);
builder.Connect(random_source->get_output_port(0),
pass_through->get_input_port());
Expand Down
147 changes: 33 additions & 114 deletions systems/primitives/random_source.cc
Original file line number Diff line number Diff line change
@@ -1,122 +1,21 @@
#include "drake/systems/primitives/random_source.h"

#include <atomic>
#include <random>

#include "drake/common/never_destroyed.h"

namespace drake {
namespace systems {
namespace {

using Seed = RandomSource::Seed;

// Stores exactly one of the three supported distribution objects. Note that
// the distribution objects hold computational state; they are not just pure
// mathematical functions.
using DistributionVariant = variant<
std::uniform_real_distribution<double>,
std::normal_distribution<double>,
std::exponential_distribution<double>>;

// Creates a distribution object from the distribution enumeration.
DistributionVariant MakeDistributionVariant(RandomDistribution which) {
switch (which) {
case RandomDistribution::kUniform:
return std::uniform_real_distribution<double>();
case RandomDistribution::kGaussian:
return std::normal_distribution<double>();
case RandomDistribution::kExponential:
return std::exponential_distribution<double>();
}
DRAKE_UNREACHABLE();
}

// Generates real-valued (i.e., `double`) samples from some distribution. This
// serves as the abstract state of a RandomSource, which encompasses all of the
// source's state *except* for the currently-sampled output values which are
// stored as discrete state.
class SampleGenerator {
public:
DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(SampleGenerator)

SampleGenerator() = default;
SampleGenerator(Seed seed, RandomDistribution which)
: generator_(seed), distribution_(MakeDistributionVariant(which)) {}

double GenerateNext() {
switch (distribution_.index()) {
case 0: return get<0>(distribution_)(generator_);
case 1: return get<1>(distribution_)(generator_);
case 2: return get<2>(distribution_)(generator_);
}
DRAKE_UNREACHABLE();
}

private:
RandomGenerator generator_;
DistributionVariant distribution_;
};

// Returns a monotonically increasing integer on each call.
Seed get_next_seed() {
static never_destroyed<std::atomic<Seed>> seed(
RandomGenerator::default_seed);
namespace internal {
template<typename Generator>
typename Generator::result_type generate_unique_seed() {
static never_destroyed<typename Generator::result_type> seed(
Generator::default_seed);
return seed.access()++;
}

} // namespace
template RandomGenerator::result_type generate_unique_seed<RandomGenerator>();

RandomSource::RandomSource(
RandomDistribution distribution, int num_outputs,
double sampling_interval_sec)
: distribution_(distribution), seed_(get_next_seed()) {
this->DeclareDiscreteState(num_outputs);
this->DeclareAbstractState(Value<SampleGenerator>().Clone());
this->DeclarePeriodicUnrestrictedUpdateEvent(
sampling_interval_sec, 0., &RandomSource::UpdateSamples);
this->DeclareVectorOutputPort(
"output", BasicVector<double>(num_outputs),
[](const Context<double>& context, BasicVector<double>* output) {
const auto& values = context.get_discrete_state(0);
output->SetFrom(values);
});
}

RandomSource::~RandomSource() {}

void RandomSource::SetDefaultState(
const Context<double>& context, State<double>* state) const {
SetSeed(seed_, context, state);
}

void RandomSource::SetRandomState(
const Context<double>& context, State<double>* state,
RandomGenerator* seed_generator) const {
const Seed fresh_seed = (*seed_generator)();
SetSeed(fresh_seed, context, state);
}

// Writes the given seed into abstract state (replacing the existing
// SampleGenerator) and then does `UpdateSamples`.
void RandomSource::SetSeed(
Seed seed, const Context<double>& context, State<double>* state) const {
state->template get_mutable_abstract_state<SampleGenerator>(0) =
SampleGenerator(seed, distribution_);
UpdateSamples(context, state);
}

// Samples random values into the discrete state, using the SampleGenerator
// from the abstract state. (Note that the generator's abstract state is also
// mutated as a side effect of this method.)
void RandomSource::UpdateSamples(
const Context<double>&, State<double>* state) const {
auto& source = state->template get_mutable_abstract_state<SampleGenerator>(0);
auto& samples = state->get_mutable_discrete_state(0);
for (int i = 0; i < samples.size(); ++i) {
samples[i] = source.GenerateNext();
}
}
} // namespace internal

int AddRandomInputs(double sampling_interval_sec,
DiagramBuilder<double>* builder) {
Expand All @@ -125,28 +24,48 @@ int AddRandomInputs(double sampling_interval_sec,
// there is (currently) no builder->GetSystems() method.
for (const auto* system : builder->GetMutableSystems()) {
for (int i = 0; i < system->num_input_ports(); i++) {
const systems::InputPort<double>& port = system->get_input_port(i);
const systems::InputPort<double>& port =
system->get_input_port(i);
// Check for the random label.
if (!port.is_random()) {
continue;
}

using InputPortLocator = Diagram<double>::InputPortLocator;
typedef typename Diagram<double>::InputPortLocator InputPortLocator;
// Check if the input is already wired up.
InputPortLocator id{&port.get_system(), port.get_index()};
if (builder->connection_map_.count(id) > 0 ||
builder->diagram_input_set_.count(id) > 0) {
continue;
}

const auto* const source = builder->AddSystem<RandomSource>(
port.get_random_type().value(), port.size(), sampling_interval_sec);
builder->Connect(source->get_output_port(0), port);
++count;
count++;
switch (port.get_random_type().value()) {
case RandomDistribution::kUniform: {
const auto* uniform = builder->AddSystem<UniformRandomSource>(
port.size(), sampling_interval_sec);
builder->Connect(uniform->get_output_port(0), port);
continue;
}
case RandomDistribution::kGaussian: {
const auto* gaussian = builder->AddSystem<GaussianRandomSource>(
port.size(), sampling_interval_sec);
builder->Connect(gaussian->get_output_port(0), port);
continue;
}
case RandomDistribution::kExponential: {
const auto* exponential = builder->AddSystem<ExponentialRandomSource>(
port.size(), sampling_interval_sec);
builder->Connect(exponential->get_output_port(0), port);
continue;
}
}
DRAKE_UNREACHABLE();
}
}
return count;
}

} // namespace systems
} // namespace drake

Loading