Skip to content
This repository has been archived by the owner on Mar 5, 2024. It is now read-only.

Commit

Permalink
Implement random elements selection from a stream
Browse files Browse the repository at this point in the history
via the optimal reservoir sampling algorithm.
  • Loading branch information
xandkar committed Jun 13, 2022
1 parent 87906a3 commit 55d3486
Showing 1 changed file with 128 additions and 1 deletion.
129 changes: 128 additions & 1 deletion src/data/data_stream.erl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
lazy_map/2,
lazy_filter/2,
pmap_to_bag/2,
pmap_to_bag/3
pmap_to_bag/3,
random_elements/2
]).

-define(T, ?MODULE).

-type reservoir(A) :: #{pos_integer() => A}.

-type filter(A, B)
:: {map, fun((A) -> B)}
| {test, fun((A) -> boolean())}
Expand Down Expand Up @@ -190,8 +193,87 @@ pmap_to_bag(T, F, J) when is_function(F), is_integer(J), J > 0 ->
error({data_stream_scheduler_crashed_before_sending_results, Reason})
end.

-spec random_elements(t(A), non_neg_integer()) -> [A].
random_elements(_, 0) -> [];
random_elements(T, K) when K > 0 ->
{_N, Reservoir} = reservoir_sample(T, #{}, K),
[X || {_, X} <- maps:to_list(Reservoir)].

%% Internal ===================================================================

%% @doc
%% The optimal reservoir sampling algorithm. Known as "Algorithm L" in:
%% https://dl.acm.org/doi/pdf/10.1145/198429.198435
%% https://en.wikipedia.org/wiki/Reservoir_sampling#An_optimal_algorithm
%% @end
-spec reservoir_sample(t(A), reservoir(A), pos_integer()) ->
{pos_integer(), reservoir(A)}.
reservoir_sample(T0, R0, K) ->
case reservoir_sample_init(T0, R0, 1, K) of
{none, R1, I} ->
{I, R1};
{{some, T1}, R1, I} ->
W = random_weight_init(K),
J = random_index_next(I, W),
reservoir_sample_update(T1, R1, W, I, J, K)
end.

-spec reservoir_sample_init(t(A), reservoir(A), pos_integer(), pos_integer()) ->
{none | {some, A}, reservoir(A), pos_integer()}.
reservoir_sample_init(T0, R, I, K) ->
case I > K of
true ->
{{some, T0}, R, I - 1};
false ->
case next(T0) of
{some, {X, T1}} ->
reservoir_sample_init(T1, R#{I => X}, I + 1, K);
none ->
{none, R, I - 1}
end
end.

-spec random_weight_init(pos_integer()) -> float().
random_weight_init(K) ->
math:exp(math:log(rand:uniform()) / K).

-spec random_weight_next(float(), pos_integer()) -> float().
random_weight_next(W, K) ->
W * random_weight_init(K).

-spec random_index_next(pos_integer(), float()) -> pos_integer().
random_index_next(I, W) ->
I + floor(math:log(rand:uniform()) / math:log(1 - W)) + 1.

-spec reservoir_sample_update(
t(A),
reservoir(A),
float(),
pos_integer(),
pos_integer(),
pos_integer()
) ->
{pos_integer(), reservoir(A)}.
reservoir_sample_update(T0, R0, W0, I0, J0, K) ->
case next(T0) of
none ->
{I0, R0};
{some, {X, T1}} ->
I1 = I0 + 1,
case I0 =:= J0 of
true ->
R1 = R0#{rand:uniform(K) => X},
W1 = random_weight_next(W0, K),
J1 = random_index_next(J0, W0),
reservoir_sample_update(T1, R1, W1, I1, J1, K);
false ->
% Here is where the big win takes place over the simple
% Algorithm R. We skip computing random numbers for an
% element that will not be picked.
reservoir_sample_update(T1, R0, W0, I1, J0, K)
end
end.

-spec sched(#sched{}) -> [any()].
sched(#sched{id=_, producers=[], consumers=[], consumers_free=[], work=[], results=Ys}) ->
Ys;
Expand Down Expand Up @@ -396,4 +478,49 @@ fold_test_() ->
]
].

random_elements_test_() ->
TestCases =
[
?_assertMatch([a], random_elements(from_list([a]), 1)),
?_assertEqual(0, length(random_elements(from_list([]), 1))),
?_assertEqual(0, length(random_elements(from_list([]), 10))),
?_assertEqual(0, length(random_elements(from_list([]), 100))),
?_assertEqual(1, length(random_elements(from_list(lists:seq(1, 100)), 1))),
?_assertEqual(2, length(random_elements(from_list(lists:seq(1, 100)), 2))),
?_assertEqual(3, length(random_elements(from_list(lists:seq(1, 100)), 3))),
?_assertEqual(5, length(random_elements(from_list(lists:seq(1, 100)), 5)))
|
[
(fun () ->
Trials = 10,
K = floor(N * KF),
L = lists:seq(1, N),
S = from_list(L),
Rands =
[
random_elements(S, K)
||
_ <- lists:duplicate(Trials, {})
],
Head = lists:sublist(L, K),
Unique = lists:usort(Rands) -- [Head],
Name =
lists:flatten(io_lib:format(
"At least 1/~p of trials makes a new sequence. "
"N:~p K:~p KF:~p length(Unique):~p",
[Trials, N, K, KF, length(Unique)]
)),
{Name, ?_assertMatch([_|_], Unique)}
end)()
||
N <- lists:seq(10, 100),
KF <- [
0.25,
0.50,
0.75
]
]
],
{inparallel, TestCases}.

-endif.

0 comments on commit 55d3486

Please sign in to comment.