From 9f7c9285c28fc4c898234868db7f0b24a31a8bb5 Mon Sep 17 00:00:00 2001 From: Biotronic Date: Fri, 29 May 2020 21:24:18 +0200 Subject: [PATCH] Fix issue 20099 - Memoize should handle lambdas --- std/functional.d | 202 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 142 insertions(+), 60 deletions(-) diff --git a/std/functional.d b/std/functional.d index cec61fef575..fbf779b0578 100644 --- a/std/functional.d +++ b/std/functional.d @@ -1245,6 +1245,15 @@ alias pipe(fun...) = compose!(Reverse!(fun)); assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5); } +private template getOverloads(alias fun) +{ + import std.meta : AliasSeq; + static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun)))) + alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun)); + else + alias getOverloads = AliasSeq!fun; +} + /** * $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as * to avoid repeated computation. The memoization structure is a hash table keyed by a @@ -1280,87 +1289,129 @@ Note: */ template memoize(alias fun) { - import std.traits : ReturnType; - // https://issues.dlang.org/show_bug.cgi?id=13580 - // alias Args = Parameters!fun; + import std.traits : Parameters; + import std.meta : anySatisfy; + + // Specific overloads: + alias overloads = getOverloads!fun; + static foreach (fn; overloads) + static if (is(Parameters!fn)) + alias memoize = impl!(Parameters!fn); + + enum isTemplate(alias a) = __traits(isTemplate, a); + static if (anySatisfy!(isTemplate, overloads)) + { + // Generic implementation + alias memoize = impl; + } - ReturnType!fun memoize(Parameters!fun args) + auto impl(Args...)(Args args) if (is(typeof(fun(args)))) { - alias Args = Parameters!fun; - import std.typecons : Tuple; + import std.typecons : Tuple, tuple; import std.traits : Unqual; - static Unqual!(ReturnType!fun)[Tuple!Args] memo; - auto t = Tuple!Args(args); - if (auto p = t in memo) - return *p; - auto r = fun(args); - memo[t] = r; - return r; + static if (args.length > 0) + { + static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo; + + auto t = Tuple!Args(args); + if (auto p = t in memo) + return *p; + auto r = fun(args); + memo[t] = r; + return r; + } + else + { + static typeof(fun(args)) result = fun(args); + return result; + } } } /// ditto template memoize(alias fun, uint maxSize) { - import std.traits : ReturnType; - // https://issues.dlang.org/show_bug.cgi?id=13580 - // alias Args = Parameters!fun; - ReturnType!fun memoize(Parameters!fun args) + import std.traits : Parameters; + import std.meta : anySatisfy; + + // Specific overloads: + alias overloads = getOverloads!fun; + static foreach (fn; overloads) + static if (is(Parameters!fn)) + alias memoize = impl!(Parameters!fn); + + enum isTemplate(alias a) = __traits(isTemplate, a); + static if (anySatisfy!(isTemplate, overloads)) { - import std.meta : staticMap; - import std.traits : hasIndirections, Unqual; - import std.typecons : tuple; - static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; } - static Value[] memo; - static size_t[] initialized; + // Generic implementation + alias memoize = impl; + } - if (!memo.length) + auto impl(Args...)(Args args) if (is(typeof(fun(args)))) + { + static if (args.length > 0) { - import core.memory : GC; + import std.meta : staticMap; + import std.traits : hasIndirections, Unqual; + import std.typecons : tuple; + alias returnType = typeof(fun(args)); + static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; } + static Value[] memo; + static size_t[] initialized; + + if (!memo.length) + { + import core.memory : GC; - // Ensure no allocation overflows - static assert(maxSize < size_t.max / Value.sizeof); - static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1)); + // Ensure no allocation overflows + static assert(maxSize < size_t.max / Value.sizeof); + static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1)); - enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN); - memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize]; - enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof); - initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords]; - } + enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN); + memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize]; + enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof); + initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords]; + } - import core.bitop : bt, bts; - import core.lifetime : emplace; + import core.bitop : bt, bts; + import core.lifetime : emplace; - size_t hash; - foreach (ref arg; args) - hash = hashOf(arg, hash); - // cuckoo hashing - immutable idx1 = hash % maxSize; - if (!bt(initialized.ptr, idx1)) - { - emplace(&memo[idx1], args, fun(args)); - // only set to initialized after setting args and value - // https://issues.dlang.org/show_bug.cgi?id=14025 - bts(initialized.ptr, idx1); + size_t hash; + foreach (ref arg; args) + hash = hashOf(arg, hash); + // cuckoo hashing + immutable idx1 = hash % maxSize; + if (!bt(initialized.ptr, idx1)) + { + emplace(&memo[idx1], args, fun(args)); + // only set to initialized after setting args and value + // https://issues.dlang.org/show_bug.cgi?id=14025 + bts(initialized.ptr, idx1); + return memo[idx1].res; + } + else if (memo[idx1].args == args) + return memo[idx1].res; + // FNV prime + immutable idx2 = (hash * 16_777_619) % maxSize; + if (!bt(initialized.ptr, idx2)) + { + emplace(&memo[idx2], memo[idx1]); + bts(initialized.ptr, idx2); + } + else if (memo[idx2].args == args) + return memo[idx2].res; + else if (idx1 != idx2) + memo[idx2] = memo[idx1]; + + memo[idx1] = Value(args, fun(args)); return memo[idx1].res; } - else if (memo[idx1].args == args) - return memo[idx1].res; - // FNV prime - immutable idx2 = (hash * 16_777_619) % maxSize; - if (!bt(initialized.ptr, idx2)) + else { - emplace(&memo[idx2], memo[idx1]); - bts(initialized.ptr, idx2); + static typeof(fun(args)) result = fun(args); + return result; } - else if (memo[idx2].args == args) - return memo[idx2].res; - else if (idx1 != idx2) - memo[idx2] = memo[idx1]; - - memo[idx1] = Value(args, fun(args)); - return memo[idx1].res; } } @@ -1420,6 +1471,37 @@ unittest assert(fact(10) == 3628800); } +// Issue 20099 +@system unittest // not @safe due to memoize +{ + int i = 3; + alias a = memoize!((n) => i + n); + alias b = memoize!((n) => i + n, 3); + + assert(a(3) == 6); + assert(b(3) == 6); +} + +@system unittest // not @safe due to memoize +{ + static Object objNum(int a) { return new Object(); } + assert(memoize!objNum(0) is memoize!objNum(0U)); + assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U)); +} + +@system unittest // not @safe due to memoize +{ + struct S + { + static int fun() { return 0; } + static int fun(int i) { return 1; } + } + assert(memoize!(S.fun)() == 0); + assert(memoize!(S.fun)(3) == 1); + assert(memoize!(S.fun, 3)() == 0); + assert(memoize!(S.fun, 3)(3) == 1); +} + @system unittest // not @safe due to memoize { import core.math : sqrt;