diff --git a/src/main/java/net/starlark/java/eval/Dict.java b/src/main/java/net/starlark/java/eval/Dict.java index 9be147b8a0aff3..b1a0c160adfbda 100644 --- a/src/main/java/net/starlark/java/eval/Dict.java +++ b/src/main/java/net/starlark/java/eval/Dict.java @@ -15,6 +15,7 @@ package net.starlark.java.eval; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import java.util.ArrayList; @@ -24,6 +25,7 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; +import java.util.function.BiConsumer; import javax.annotation.Nullable; import net.starlark.java.annot.Param; import net.starlark.java.annot.StarlarkBuiltin; @@ -109,29 +111,43 @@ public class Dict StarlarkIndexable, StarlarkIterable { - // TODO(adonovan): for dicts that are born frozen, use ImmutableMap, which is also - // insertion-ordered and has smaller Entries (singly linked, no hash). - private final LinkedHashMap contents; + private final Map contents; private int iteratorCount; // number of active iterators (unused once frozen) /** Final except for {@link #unsafeShallowFreeze}; must not be modified any other way. */ private Mutability mutability; - private Dict(@Nullable Mutability mutability, LinkedHashMap contents) { - this.mutability = mutability == null ? Mutability.IMMUTABLE : mutability; + private Dict(Mutability mutability, LinkedHashMap contents) { + Preconditions.checkNotNull(mutability); + Preconditions.checkState(mutability != Mutability.IMMUTABLE); + this.mutability = mutability; + // TODO(bazel-team): Memory optimization opportunity: Make it so that a call to + // `mutability.freeze()` causes `contents` here to become an ImmutableMap. this.contents = contents; } - private Dict(@Nullable Mutability mutability) { - this(mutability, new LinkedHashMap<>()); + private Dict(ImmutableMap contents) { + // An immutable dict might as well store its contents as an ImmutableMap, since ImmutableMap + // both is more memory-efficient than LinkedHashMap and also it has the requisite deterministic + // iteration order. + this.mutability = Mutability.IMMUTABLE; + this.contents = contents; } /** * Takes ownership of the supplied LinkedHashMap and returns a new Dict that wraps it. The caller * must not subsequently modify the map, but the Dict may do so. */ - static Dict wrap(@Nullable Mutability mutability, LinkedHashMap contents) { - return new Dict<>(mutability, contents); + static Dict wrap(@Nullable Mutability mu, LinkedHashMap contents) { + if (mu == null) { + mu = Mutability.IMMUTABLE; + } + if (mu == Mutability.IMMUTABLE && contents.isEmpty()) { + return empty(); + } + // #wrap is used in situations where the resulting Dict isn't necessarily retained [forever]. + // So, don't make an ImmutableMap copy of `contents`, as #copyOf would do. + return new Dict<>(mu, contents); } @Override @@ -394,7 +410,7 @@ public StarlarkList keys(StarlarkThread thread) throws EvalException { return StarlarkList.wrap(thread.mutability(), array); } - private static final Dict EMPTY = of(Mutability.IMMUTABLE); + private static final Dict EMPTY = new Dict<>(ImmutableMap.of()); /** Returns an immutable empty dict. */ // Safe because the empty singleton is immutable. @@ -405,24 +421,49 @@ public static Dict empty() { /** Returns a new empty dict with the specified mutability. */ public static Dict of(@Nullable Mutability mu) { - return new Dict<>(mu); + if (mu == null) { + mu = Mutability.IMMUTABLE; + } + if (mu == Mutability.IMMUTABLE) { + return empty(); + } else { + return new Dict<>(mu, new LinkedHashMap<>()); + } } /** Returns a new dict with the specified mutability containing the entries of {@code m}. */ public static Dict copyOf(@Nullable Mutability mu, Map m) { - if (mu == null && m instanceof Dict && ((Dict) m).isImmutable()) { + if (mu == null) { + mu = Mutability.IMMUTABLE; + } + + if (mu == Mutability.IMMUTABLE && m instanceof Dict && ((Dict) m).isImmutable()) { @SuppressWarnings("unchecked") Dict dict = (Dict) m; // safe return dict; } - Dict dict = new Dict<>(mu); - for (Map.Entry e : m.entrySet()) { - dict.contents.put( - Starlark.checkValid(e.getKey()), // - Starlark.checkValid(e.getValue())); + if (mu == Mutability.IMMUTABLE) { + if (m.isEmpty()) { + return empty(); + } + ImmutableMap.Builder immutableMapBuilder = + ImmutableMap.builderWithExpectedSize(m.size()); + for (Map.Entry e : m.entrySet()) { + immutableMapBuilder.put( + Starlark.checkValid(e.getKey()), // + Starlark.checkValid(e.getValue())); + } + return new Dict<>(immutableMapBuilder.buildOrThrow()); + } else { + LinkedHashMap linkedHashMap = new LinkedHashMap<>(); + for (Map.Entry e : m.entrySet()) { + linkedHashMap.put( + Starlark.checkValid(e.getKey()), // + Starlark.checkValid(e.getValue())); + } + return new Dict<>(mu, linkedHashMap); } - return dict; } /** Returns an immutable dict containing the entries of {@code m}. */ @@ -462,7 +503,7 @@ public Dict buildImmutable() { /** Returns a new {@link ImmutableKeyTrackingDict} containing the entries added so far. */ public ImmutableKeyTrackingDict buildImmutableWithKeyTracking() { - return new ImmutableKeyTrackingDict<>(buildMap()); + return new ImmutableKeyTrackingDict<>(buildImmutableMap()); } /** @@ -470,19 +511,42 @@ public ImmutableKeyTrackingDict buildImmutableWithKeyTracking() { * mutability; null means immutable. */ public Dict build(@Nullable Mutability mu) { - return wrap(mu, buildMap()); + if (mu == null) { + mu = Mutability.IMMUTABLE; + } + + if (mu == Mutability.IMMUTABLE) { + if (items.isEmpty()) { + return empty(); + } + return new Dict<>(buildImmutableMap()); + } else { + return new Dict<>(mu, buildLinkedHashMap()); + } } - private LinkedHashMap buildMap() { - int n = items.size() / 2; - LinkedHashMap map = Maps.newLinkedHashMapWithExpectedSize(n); + private void populateMap(int n, BiConsumer mapEntryConsumer) { for (int i = 0; i < n; i++) { @SuppressWarnings("unchecked") K k = (K) items.get(2 * i); // safe @SuppressWarnings("unchecked") V v = (V) items.get(2 * i + 1); // safe - map.put(k, v); + mapEntryConsumer.accept(k, v); } + } + + private ImmutableMap buildImmutableMap() { + int n = items.size() / 2; + ImmutableMap.Builder immutableMapBuilder = ImmutableMap.builderWithExpectedSize(n); + populateMap(n, immutableMapBuilder::put); + // Respect the desired semantics of Builder#put. + return immutableMapBuilder.buildKeepingLast(); + } + + private LinkedHashMap buildLinkedHashMap() { + int n = items.size() / 2; + LinkedHashMap map = Maps.newLinkedHashMapWithExpectedSize(n); + populateMap(n, map::put); return map; } } @@ -695,8 +759,8 @@ public V remove(Object key) { public static final class ImmutableKeyTrackingDict extends Dict { private final ImmutableSet.Builder accessedKeys = ImmutableSet.builder(); - private ImmutableKeyTrackingDict(LinkedHashMap contents) { - super(Mutability.IMMUTABLE, contents); + private ImmutableKeyTrackingDict(ImmutableMap contents) { + super(contents); } public ImmutableSet getAccessedKeys() { diff --git a/src/test/java/net/starlark/java/eval/StarlarkMutableTest.java b/src/test/java/net/starlark/java/eval/StarlarkMutableTest.java index e7cc32c79fa1e2..66d02a851f2069 100644 --- a/src/test/java/net/starlark/java/eval/StarlarkMutableTest.java +++ b/src/test/java/net/starlark/java/eval/StarlarkMutableTest.java @@ -107,7 +107,7 @@ public void testDictBuilder() throws Exception { Dict.builder() .put("one", "1") .put("two", "2.0") - .put("two", "2") // overrwrites previous entry + .put("two", "2") // overwrites previous entry .put("three", "3") .buildImmutable(); assertThat(dict1.toString()).isEqualTo("{\"one\": \"1\", \"two\": \"2\", \"three\": \"3\"}"); @@ -125,19 +125,21 @@ public void testDictBuilder() throws Exception { // builder reuse and mutability Dict.Builder builder = Dict.builder().putAll(dict1); + builder.put("three", "33"); // overwrites previous entry Mutability mu = Mutability.create("test"); Dict dict3 = builder.build(mu); - dict3.putEntry("four", "4"); + dict3.putEntry("four", "4"); // new entry + dict3.putEntry("two", "22"); // overwrites previous entry assertThat(dict3.toString()) - .isEqualTo("{\"one\": \"1\", \"two\": \"2\", \"three\": \"3\", \"four\": \"4\"}"); + .isEqualTo("{\"one\": \"1\", \"two\": \"22\", \"three\": \"33\", \"four\": \"4\"}"); mu.close(); assertThrows(EvalException.class, dict1::clearEntries); // frozen builder.put("five", "5"); // keep building Dict dict4 = builder.buildImmutable(); assertThat(dict4.toString()) - .isEqualTo("{\"one\": \"1\", \"two\": \"2\", \"three\": \"3\", \"five\": \"5\"}"); + .isEqualTo("{\"one\": \"1\", \"two\": \"2\", \"three\": \"33\", \"five\": \"5\"}"); assertThat(dict3.toString()) .isEqualTo( - "{\"one\": \"1\", \"two\": \"2\", \"three\": \"3\", \"four\": \"4\"}"); // unchanged + "{\"one\": \"1\", \"two\": \"22\", \"three\": \"33\", \"four\": \"4\"}"); // unchanged } }