Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M3a] Add Sampling Primitive SampleCategorical. #8817

Merged
merged 34 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3af2d12
Add sampling function SampleInt.
zxybazh Aug 20, 2021
121e658
Add new line.
zxybazh Aug 20, 2021
1ca4aed
Fix comment.
zxybazh Aug 20, 2021
fcf1d67
Update include/tvm/support/random_engine.h
zxybazh Aug 20, 2021
ce6e806
Minor fix.
zxybazh Aug 20, 2021
d546c38
Fix schedules & add SampleCategorical func.
zxybazh Aug 20, 2021
b7542f1
Fix SampleCategorical brief.
zxybazh Aug 20, 2021
cb0d96e
Fix issues and make FFI work.
zxybazh Aug 21, 2021
d735664
Fix ffi name.
zxybazh Aug 21, 2021
35aa00c
Test sample categorical.
zxybazh Aug 21, 2021
91454b1
Add Integer type change.
zxybazh Aug 21, 2021
3b30c93
bugfix for xiyou
junrushao Aug 21, 2021
ef4fda1
Make tests work.
zxybazh Aug 21, 2021
6e740b4
Fixed sampling test.
zxybazh Aug 21, 2021
73f57e8
Add seed value guard in python side.
zxybazh Aug 22, 2021
e937881
Fix ForkSeed func.
zxybazh Aug 22, 2021
0da7894
Remove extra files, add size check and update test.
zxybazh Aug 22, 2021
e75ed2b
Update header and imports.
zxybazh Aug 22, 2021
9f40d16
Move AsVector func.
zxybazh Aug 22, 2021
83cff2f
Move Seed and ForkSeed definition & ffi position.
zxybazh Aug 22, 2021
3d2d5d2
Modify Sample Categorical func to work with discrete distribution.
zxybazh Aug 23, 2021
8efe5a3
Remote unused argument ScheduleState self in function signature.
zxybazh Aug 23, 2021
93878a2
Renable pylint.
zxybazh Aug 23, 2021
b67f14a
Fix test name and debug mask.
zxybazh Aug 23, 2021
72e2456
Fix format clang.
zxybazh Aug 23, 2021
ce8e6bb
Fix non-class template compilation problem.
zxybazh Aug 23, 2021
b2fffa0
Add copy & serialization test.
zxybazh Aug 23, 2021
414f440
Add docs on python-side function.
zxybazh Aug 24, 2021
d7a545e
Minor fix.
zxybazh Aug 24, 2021
58de4a9
Minor fix.
zxybazh Aug 24, 2021
cb25711
Modify ExprRV constructor from int64_t.
zxybazh Aug 24, 2021
f9c5458
Fix docs.
zxybazh Aug 24, 2021
b541d49
Modify docs.
zxybazh Aug 24, 2021
5a6b2d3
Modify tests.
zxybazh Aug 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

/*!
* \file random_engine.h
* \brief Random number generator, for Sampler and Sampling functions.
* \brief Random number generator. It provides a generic interface consistent with
* `std::uniform_random_bit_generator`
*/

#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
Expand All @@ -41,10 +42,11 @@ namespace support {
* included for simplification. For full member functions of std::minstd_rand, please check out the
* following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine
*/

class LinearCongruentialEngine {
public:
/*!
* \brief The result type is defined as int64_t here for meta_schedule sampler usage.
* \brief The result type is defined as uint64_t here to avoid overflow.
* \note The type name is not in Google style because it is used in STL's distribution inferface.
*/
using result_type = uint64_t;
Expand All @@ -63,13 +65,13 @@ class LinearCongruentialEngine {
* \brief The minimum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
result_type min() { return 0; }
static constexpr result_type min() { return 0; }

/*!
* \brief The maximum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
result_type max() { return modulus - 1; }
static constexpr result_type max() { return modulus - 1; }

/*!
* \brief Operator to move the random state to the next and return the new random state. According
Expand Down
27 changes: 20 additions & 7 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

Expand Down Expand Up @@ -118,9 +119,9 @@ class ScheduleNode : public runtime::Object {
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(int64_t seed = -1) {
LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed";
}
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
/*! \brief Fork the random state */
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;

public:
/******** Lookup/Remove random variables ********/
Expand Down Expand Up @@ -184,6 +185,16 @@ class ScheduleNode : public runtime::Object {

public:
/******** Schedule: Sampling ********/
/*!
* \brief Sample an integer given the probability distribution
* \param candidates The candidates
* \param probs The probability distribution of the candidates
* \param decision The sampling decision
* \return The random variable sampled from candidates
*/
virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
Optional<Integer> decision = NullOpt) = 0;

/******** Schedule: Get blocks & loops ********/
/*!
* \brief Retrieve a block in a specific function with its name
Expand Down Expand Up @@ -356,6 +367,7 @@ class Schedule : public runtime::ObjectRef {
/*!
* \brief Construct a concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param seed The seed value for schedule's random state
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
Expand All @@ -365,11 +377,12 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param seed The seed value for schedule's random state
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
Expand All @@ -379,8 +392,8 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
57 changes: 57 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ def _parse_error_render_level(error_render_level: str) -> int:
return _ERROR_RENDER_LEVEL.get(error_render_level)


def _parse_seed(seed: Optional[int]) -> int:
if seed is None:
return -1
if not isinstance(seed, int):
raise TypeError(f"Expected `seed` to be int or None, but gets: {seed}")
if seed < 1 or seed > 2147483647:
raise ValueError(f"seed must be in the range [1, 2147483647], but gets: {seed}")
return seed


@_register_object("tir.Schedule")
class Schedule(Object):
"""The user-facing schedule class
Expand All @@ -98,6 +108,7 @@ def __init__(
self,
mod: Union[PrimFunc, IRModule],
*,
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> None:
Expand All @@ -107,6 +118,10 @@ def __init__(
----------
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to be scheduled
seed: Optional[int]
The seed value for schedule's random state
Note that None and -1 means use device random, otherwise only integer between 1 and
2147483647 is allowed.
debug_mask : Union[str, int]
Do extra correctness checking after the class creation and each time
after calling the Replace method.
Expand All @@ -130,6 +145,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)
Expand All @@ -138,12 +154,14 @@ def __init__(
def _create_non_traced(
mod: Union[PrimFunc, IRModule],
*,
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)
Expand Down Expand Up @@ -190,6 +208,16 @@ def seed(self, seed: int) -> None:
"""
return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member

def fork_seed(self) -> int:
"""Returns a forked random state as seed for new schedules

Returns
-------
seed : int
The forked random state, not the same as the current random state
"""
return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member

def show(self, rand_var: RAND_VAR_TYPE) -> str:
"""Returns a string representation of the value that the random variable evaluates to

Expand Down Expand Up @@ -268,6 +296,35 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:

########## Schedule: Sampling ##########

def sample_categorical(
self,
candidates: List[int],
probs: List[float],
decision: Optional[int] = None,
) -> ExprRV:
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
"""Sample an integer given the probability distribution

Parameters
----------
candidates : List[int]
The candidates to be sampled from
probs : List[float]
The probability of each candidate
decision : Optional[int]
The sampling decision, if any

Returns
-------
result : ExprRV
The random variable sampled from candidates
"""
return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member
self,
candidates,
probs,
decision,
)

########## Schedule: Get blocks & loops ##########
def get_block(
self,
Expand Down
68 changes: 68 additions & 0 deletions src/support/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
#ifndef TVM_SUPPORT_ARRAY_H_
#define TVM_SUPPORT_ARRAY_H_
#include <tvm/ir/expr.h>
#include <tvm/runtime/container/array.h>

#include <vector>
Expand Down Expand Up @@ -67,6 +68,73 @@ inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>
return true;
}

/*!
* \brief Convert a tvm::runtime::Array to std::vector
* \tparam TSrc The type of elements in the source Array
* \tparam TDst The type of elements in the result vector
* \return The result vector
*/
template <class TSrc, class TDst>
std::vector<TDst> AsVector(const Array<TSrc>& vec);

/********** Implementation details of AsVector<TSrc, TDst> **********/
namespace details {

template <class TSrc, class TDst>
struct AsVectorImpl {};

template <class TSrc>
struct AsVectorImpl<TSrc, TSrc> {
inline std::vector<TSrc> operator()(const Array<TSrc>& vec) const {
return std::vector<TSrc>(vec.begin(), vec.end());
}
};

template <class TSrcObjectRef>
struct AsVectorImpl<TSrcObjectRef, int> {
inline std::vector<int> operator()(const Array<TSrcObjectRef>& vec) const {
std::vector<int> results;
for (const TSrcObjectRef& x : vec) {
const auto* n = x.template as<IntImmNode>();
ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
results.push_back(n->value);
}
return results;
}
};

template <class TSrcObjectRef>
struct AsVectorImpl<TSrcObjectRef, int64_t> {
inline std::vector<int64_t> operator()(const Array<TSrcObjectRef>& vec) const {
std::vector<int64_t> results;
for (const TSrcObjectRef& x : vec) {
const auto* n = x.template as<IntImmNode>();
ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
results.push_back(n->value);
}
return results;
}
};

template <class TSrcObjectRef>
struct AsVectorImpl<TSrcObjectRef, double> {
inline std::vector<double> operator()(const Array<TSrcObjectRef>& array) const {
std::vector<double> results;
for (const TSrcObjectRef& x : array) {
const auto* n = x.template as<FloatImmNode>();
ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey();
results.push_back(n->value);
}
return results;
}
};
} // namespace details

template <class TSrc, class TDst>
inline std::vector<TDst> AsVector(const Array<TSrc>& vec) {
return details::AsVectorImpl<TSrc, TDst>()(vec);
}

} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_ARRAY_H_
30 changes: 28 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
*/
#include "./concrete_schedule.h"

#include <random>

namespace tvm {
namespace tir {

Schedule Schedule::Concrete(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level) {
Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
return Schedule(std::move(n));
}

Expand Down Expand Up @@ -208,6 +211,29 @@ Schedule ConcreteScheduleNode::Copy() const {
}

/******** Schedule: Schedule: Sampling ********/

void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) {
if (seed == -1) {
seed = std::random_device()();
}
support::LinearCongruentialEngine(&rand_state_).Seed(seed);
}
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() {
// In order for reproducibility, we computer the new seed using RNG's random state and a different
// set of parameters. Note that both 32767 and 1999999973 are prime numbers.
return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973;
}
Comment on lines +222 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like ForkSeed is analogous to what is called "splitting" in the random number generator literature. I'm not quite an expert on this, but I did do a bit of research into PRNGS for the Threefry implementation we have. Everything I read says that there are no proofs to the validity of splitting LCGs (is the method you use here from a paper?). The paper "Splittable Pseudorandom Number Generators using Cryptographic Hashing" provides some good explanations.

In practice, I expect we will see some issues. If this function somehow perfectly bisects the space of random numbers generated by this PRNG, then we could expect to start seeing repeats of previous random numbers after 31 splits. Given that this splitting does not perfectly bisect the space, I'd assume that we start seeing repeats much sooner. Repeating portions of the search space may mean that we may no be able to visit the entire search space during tuning or that we may bias results towards a certain section of the space.

I'd suggest we adopt a splittable PRNG here as that appears the be what we need. Maybe we can find an existing implementation online as implementing your own PRNG can have subtle issues.

Copy link
Member

@junrushao junrushao Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LCGs are pretty easy to be cracked in terms of security, but our search isn't something where there is an adversarial against you haha.

To be clear, we don't split the RNG too many times. It is only used in terms of multi-threaded search where we split the RNG for each thread, where in practice we didn't see repetition or any problem caused by repetition when running tens of real-world workloads.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't most modern machines have at least 31 hyper threads, i.e we will split at least 31 times on those machines?

Copy link
Member

@junrushao junrushao Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually agree with Tristan's theory in general. Thank you for bringing this up! Indeed seeding of parallel PRNG would require some really careful thought to avoid quick repetition. LCG may not be the best candidate to ensure such a property.

Fortunately, in our particular use case it is not a practical problem. Here is a quick example, supposing we have 128 threads and 10k trials: https://gist.github.com/junrushao1994/ea986add81b01b89fd99a5a7d41d087a. The result is that there is no repetition at all. This is a harsher condition than our practical usage.

To further address the issue, architecturally we have designed the PRNG interface to be generic and compliant to STL, and easily switchable to any splittable PRNG in the future if there are new interesting usecases. Therefore, I assume it won't constitute an architecture issue :-)

Thanks again for the discussion!

Copy link
Member

@tqchen tqchen Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming late to the discussion. I read the thread yesterday evening and wanted
to spend more time thinking so did not reply immediately.

Please let me try to summarize and share some thoughts

As we generate random numbers, PRNG state circles through the space of states
and eventually will come back to its initial state to form a circle.
When we split the state of PRNG to start parallel generators, if the
splitted states are "too close" in terms of their traversal distance,
then we get repeatitions, or streams that are correlated to each other.

The following pts about the PRNGs

  • A0: When we split a PRNG in an adhoc way, there is a "possibility"
    that repeats can happen in different streams. The nature of such
    possibility depends on the splitting method, the seed and PRNG itself.
  • A1: There is currently no proof of splitting LCG in general.

My read is that @tkonolige is right about A0 and A1 and seems we also agree.
The particular number 31 however might not directly match the particular
scencario(as junru's emperical experiment suggested otherwise). The repeat
depends on the splitting method, the seed and PRNG itself and I personally
can not tell what will happen in this case for sure.

Because of the A0 and A1, it would be helpful need to consider the implication of
using a possibly correlated PRNG streams. In our case, we use PRNG
to generate explorations in the search space and perform the following task:

  • T0: explore the space and find possible maxima in the space

To make things simple, let us assume that there are two streams in the 32 threads
that are exactly identical to each other. In this case, we waste the computation
from one of the thread, because it will result in exactly the same set of samples.
Because our goal is to explore the space and find a maximum point. This would mean
that the computation from that particular thread is wasted, and we get less statistical
efficiency from the sampling process.

At a high level, we have two parameters we can tweak, the number of sampling steps n,
and number of threads K. What we are saying is that the statistical efficiency of running
uniform sampling over the search space scales linearly with n, but perhaps less optimally
wrt to K. For other more complicated tasks, such as estimating density of certain regions.
We know that the sampling and averaging over time(n) always can give us the right estimation.
Correlation across streams, would make averaging over streams(K) less efficient because the
random numbers are not independent, but we will still get the correct mean as we increase n.

So in summary:

  • A2: The impact of possibly correlated streams means that we could get less than
    full-linear efficiency in terms of number of thread K. The sampling will still
    effective samples wrt to n(the number of samples we take in each thread).

Note that the real sampling scenario is much more complicated. As junru's experiments
showed that repeatition did not happen on the particular case for quite long period of time.
Additionally, the sampling steps are dependent on each other(they form a Markov Chain),
so it is hard to tell the end effect correlation between random sequence (abcde) and (bcdef),
even if they are one step apart. While in most of the cases they can be many steps apart
emperically. To summarize

  • A3: In this particular usecase, emperically we get non-repetitively sequences among the streams.
    This does not preclude that correlation won't happen(as it is not a proof), it does suggest
    that correlation may not be large in most cases.

The end effect of A2 has a quite close analogy in parallel computing: as we start to use
K thread, we may not exactly get Kx speedups. Except that in this case it is not due to
hardware reasons, it is due to the possible correlation. As in parallel computing,
we can run the program longer, in our case increase n to compensate the possible loss
of efficiency. In short, A2 and A3 together might suggest that parallel stream correlation may not be the problem
that we need to perfectly solve, as long as it does not become very bad(e.g. all streams are the same).

Yesterday I did not think of A2 in particular, which might change our perspective. So I would
like to share this summary here. Would be great to get your takes as well.

Copy link
Member

@junrushao junrushao Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen I happened to implement a BSGS discrete logarithm this morning. This is a simple but effective algorithm (but not effective enough for crypto) we use in high school competitive programming: https://gist.github.com/junrushao1994/d32f265f5b4815d4b346d6022e95f394.

I use this script to find out what the minimal number of trials is required for a first repeat to happen given num_threads threads, and here is the outcome when I set num_threads=1000.

k = 0, forked_seed = 1, repeat = 0
k = 1, forked_seed = 32767, repeat = 477805293
k = 2, forked_seed = 1073676289, repeat = 955610586
k = 3, forked_seed = 1151436593, repeat = 792763173
k = 4, forked_seed = 1123352159, repeat = 1465074278
k = 5, forked_seed = 880690861, repeat = 1324276493
k = 6, forked_seed = 1597831943, repeat = 1242547212
k = 7, forked_seed = 159983087, repeat = 1577951775
k = 8, forked_seed = 165882496, repeat = 12097717
k = 9, forked_seed = 1471819791, repeat = 670441922
k = 10, forked_seed = 1119742748, repeat = 130415288
k = 11, forked_seed = 611119031, repeat = 268696235
k = 12, forked_seed = 537559101, repeat = 53491731
k = 13, forked_seed = 199300256, repeat = 1422308282
k = 14, forked_seed = 471576507, repeat = 535435921
k = 15, forked_seed = 147613471, repeat = 482454183
k = 16, forked_seed = 850669543, repeat = 1171123831
k = 17, forked_seed = 1889291753, repeat = 1518153595
k = 18, forked_seed = 423706282, repeat = 325709943
k = 19, forked_seed = 1583929701, repeat = 1523336904
k = 20, forked_seed = 625213317, repeat = 2097645560
...
k = 990, forked_seed = 465632523, repeat = 1395274874
k = 991, forked_seed = 1381087097, repeat = 1524683345
k = 992, forked_seed = 81518328, repeat = 874365972
k = 993, forked_seed = 1111089621, repeat = 464689348
k = 994, forked_seed = 1074102788, repeat = 1779776079
k = 995, forked_seed = 1126529515, repeat = 113162479
k = 996, forked_seed = 993116317, repeat = 711897275
k = 997, forked_seed = 1442798429, repeat = 285912163
k = 998, forked_seed = 176761269, repeat = 918045815
k = 999, forked_seed = 1936579488, repeat = 43150205
min repeat: 1407035

In a word, in practice the conflict with the 0-th thread won't happen after 1407035 trials in the first 999 threads which split this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen @junrushao1994 You both lay out a lot of interesting points here, but I'm not sure I have the expertise to evaluate them. The PRNGS themselves might appear simple, but analysis of their randomness is complicated and non-intuitive. Looking at the paper I linked above, you can get subtle bugs if the PRNG is used incorrectly. I've tested the LCG implemented in TVM with some PRNG test suites (you can try it yourself here: https://github.com/tkonolige/prng-tests), and it fails all of them. This result is unsurprising because LCGs aren't particularly good random number generators, but it just adds a little to my concern.

Given that we want to avoid any potential issues, why don't we just do things the right way and use a splittable PRNG? This page (https://www.pcg-random.org/posts/some-prng-implementations.html) lists some implementations of PRNGs including SplitMix which is splittable. (pcg-random appears to be a reputable source, it is run by the create of the PCG family of PRNGS). It seems like there is basically no overhead to just dropping this SplitMix implementation into the codebase. And then we won't have to worry about any bugs due to bad randomness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to block this PR on this, so I'm going to approve. But I would like us to fix this in the future.


ExprRV ConcreteScheduleNode::SampleCategorical(const Array<Integer>& candidates,
const Array<FloatImm>& probs,
Optional<Integer> decision) {
TVM_TIR_SCHEDULE_BEGIN();
return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision));
TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_);
throw;
}

/******** Schedule: Get blocks & loops ********/

BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {
Expand Down
Loading