Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 20099 - Memoize should handle lambdas #7507

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 142 additions & 60 deletions std/functional.d
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,15 @@
assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5);
}

private template getOverloads(alias fun)
{
import std.meta : AliasSeq;
static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun))))
alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun));
Comment on lines +1251 to +1252
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to pass true as an additional argument to __traits(getOverloads) here in order to include template overloads. As-is, the check for anySatisfy!(isTemplate, overloads) below will always evaluate to false.

else
alias getOverloads = AliasSeq!fun;
}

/**
* $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as
* to avoid repeated computation. The memoization structure is a hash table keyed by a
Expand Down Expand Up @@ -1280,87 +1289,129 @@
*/
template memoize(alias fun)
{
import std.traits : ReturnType;
// https://issues.dlang.org/show_bug.cgi?id=13580
// alias Args = Parameters!fun;
import std.traits : Parameters;
import std.meta : anySatisfy;

// Specific overloads:
alias overloads = getOverloads!fun;
static foreach (fn; overloads)
static if (is(Parameters!fn))
alias memoize = impl!(Parameters!fn);

enum isTemplate(alias a) = __traits(isTemplate, a);
static if (anySatisfy!(isTemplate, overloads))
{
// Generic implementation
alias memoize = impl;
}

ReturnType!fun memoize(Parameters!fun args)
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
{
alias Args = Parameters!fun;
import std.typecons : Tuple;
import std.typecons : Tuple, tuple;
import std.traits : Unqual;

static Unqual!(ReturnType!fun)[Tuple!Args] memo;
auto t = Tuple!Args(args);
if (auto p = t in memo)
return *p;
auto r = fun(args);
memo[t] = r;
return r;
static if (args.length > 0)
{
static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo;

auto t = Tuple!Args(args);
if (auto p = t in memo)
return *p;
auto r = fun(args);
memo[t] = r;
return r;
}
else
{
static typeof(fun(args)) result = fun(args);
return result;
}
}
}

/// ditto
template memoize(alias fun, uint maxSize)
{
import std.traits : ReturnType;
// https://issues.dlang.org/show_bug.cgi?id=13580
// alias Args = Parameters!fun;
ReturnType!fun memoize(Parameters!fun args)
import std.traits : Parameters;
import std.meta : anySatisfy;

// Specific overloads:
alias overloads = getOverloads!fun;
static foreach (fn; overloads)
static if (is(Parameters!fn))
alias memoize = impl!(Parameters!fn);

enum isTemplate(alias a) = __traits(isTemplate, a);
static if (anySatisfy!(isTemplate, overloads))
{
import std.meta : staticMap;
import std.traits : hasIndirections, Unqual;
import std.typecons : tuple;
static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; }
static Value[] memo;
static size_t[] initialized;
// Generic implementation
alias memoize = impl;
}

if (!memo.length)
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
{
static if (args.length > 0)
{
import core.memory : GC;
import std.meta : staticMap;
import std.traits : hasIndirections, Unqual;
import std.typecons : tuple;
alias returnType = typeof(fun(args));
static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; }
static Value[] memo;
static size_t[] initialized;

if (!memo.length)
{
import core.memory : GC;

// Ensure no allocation overflows
static assert(maxSize < size_t.max / Value.sizeof);
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
// Ensure no allocation overflows
static assert(maxSize < size_t.max / Value.sizeof);
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));

enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
}
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
}

import core.bitop : bt, bts;
import core.lifetime : emplace;
import core.bitop : bt, bts;
import core.lifetime : emplace;

size_t hash;
foreach (ref arg; args)
hash = hashOf(arg, hash);
// cuckoo hashing
immutable idx1 = hash % maxSize;
if (!bt(initialized.ptr, idx1))
{
emplace(&memo[idx1], args, fun(args));
// only set to initialized after setting args and value
// https://issues.dlang.org/show_bug.cgi?id=14025
bts(initialized.ptr, idx1);
size_t hash;
foreach (ref arg; args)
hash = hashOf(arg, hash);
// cuckoo hashing
immutable idx1 = hash % maxSize;
if (!bt(initialized.ptr, idx1))
{
emplace(&memo[idx1], args, fun(args));
// only set to initialized after setting args and value
// https://issues.dlang.org/show_bug.cgi?id=14025
bts(initialized.ptr, idx1);
return memo[idx1].res;
}
else if (memo[idx1].args == args)
return memo[idx1].res;
// FNV prime
immutable idx2 = (hash * 16_777_619) % maxSize;
if (!bt(initialized.ptr, idx2))

Check warning on line 1397 in std/functional.d

View check run for this annotation

Codecov / codecov/patch

std/functional.d#L1396-L1397

Added lines #L1396 - L1397 were not covered by tests
{
emplace(&memo[idx2], memo[idx1]);
bts(initialized.ptr, idx2);

Check warning on line 1400 in std/functional.d

View check run for this annotation

Codecov / codecov/patch

std/functional.d#L1399-L1400

Added lines #L1399 - L1400 were not covered by tests
}
else if (memo[idx2].args == args)
return memo[idx2].res;
else if (idx1 != idx2)
memo[idx2] = memo[idx1];

Check warning on line 1405 in std/functional.d

View check run for this annotation

Codecov / codecov/patch

std/functional.d#L1402-L1405

Added lines #L1402 - L1405 were not covered by tests

memo[idx1] = Value(args, fun(args));

Check warning on line 1407 in std/functional.d

View check run for this annotation

Codecov / codecov/patch

std/functional.d#L1407

Added line #L1407 was not covered by tests
return memo[idx1].res;
}
else if (memo[idx1].args == args)
return memo[idx1].res;
// FNV prime
immutable idx2 = (hash * 16_777_619) % maxSize;
if (!bt(initialized.ptr, idx2))
else
{
emplace(&memo[idx2], memo[idx1]);
bts(initialized.ptr, idx2);
static typeof(fun(args)) result = fun(args);
return result;
}
else if (memo[idx2].args == args)
return memo[idx2].res;
else if (idx1 != idx2)
memo[idx2] = memo[idx1];

memo[idx1] = Value(args, fun(args));
return memo[idx1].res;
}
}

Expand Down Expand Up @@ -1420,6 +1471,37 @@
assert(fact(10) == 3628800);
}

// Issue 20099
@system unittest // not @safe due to memoize
{
int i = 3;
alias a = memoize!((n) => i + n);
Biotronic marked this conversation as resolved.
Show resolved Hide resolved
alias b = memoize!((n) => i + n, 3);

assert(a(3) == 6);
assert(b(3) == 6);
}

@system unittest // not @safe due to memoize
{
static Object objNum(int a) { return new Object(); }
Biotronic marked this conversation as resolved.
Show resolved Hide resolved
assert(memoize!objNum(0) is memoize!objNum(0U));
assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U));
}

@system unittest // not @safe due to memoize
{
struct S
{
static int fun() { return 0; }
static int fun(int i) { return 1; }
}
assert(memoize!(S.fun)() == 0);
assert(memoize!(S.fun)(3) == 1);
assert(memoize!(S.fun, 3)() == 0);
assert(memoize!(S.fun, 3)(3) == 1);
}

@system unittest // not @safe due to memoize
{
import core.math : sqrt;
Expand Down