Skip to content

Commit 9bce1a9

Browse files
davleopowoess
authored andcommitted
[GR-69804] Fold UnsafeByteArraySupport.get*Unaligned during creation.
PullRequest: graal/22190
2 parents 6a1c8ec + 9438f34 commit 9bce1a9

File tree

4 files changed

+379
-51
lines changed

4 files changed

+379
-51
lines changed

compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/truffle/test/ByteArraySupportPartialEvaluationTest.java

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -69,6 +69,21 @@ public int execute(VirtualFrame frame) {
6969
}
7070
}
7171

72+
static class GetShortUnalignedNonStableNode extends AbstractTestNode {
73+
@CompilationFinal(dimensions = 0) byte[] bytes;
74+
final int offset;
75+
76+
GetShortUnalignedNonStableNode(String hex, int offset) {
77+
this.bytes = hexToBytes(hex);
78+
this.offset = offset;
79+
}
80+
81+
@Override
82+
public int execute(VirtualFrame frame) {
83+
return BYTES.getShortUnaligned(bytes, offset);
84+
}
85+
}
86+
7287
static class GetIntNode extends AbstractTestNode {
7388
@CompilationFinal(dimensions = 1) byte[] bytes;
7489
final int offset;
@@ -99,6 +114,21 @@ public int execute(VirtualFrame frame) {
99114
}
100115
}
101116

117+
static class GetIntUnalignedNonStableNode extends AbstractTestNode {
118+
@CompilationFinal(dimensions = 0) byte[] bytes;
119+
final int offset;
120+
121+
GetIntUnalignedNonStableNode(String hex, int offset) {
122+
this.bytes = hexToBytes(hex);
123+
this.offset = offset;
124+
}
125+
126+
@Override
127+
public int execute(VirtualFrame frame) {
128+
return BYTES.getIntUnaligned(bytes, offset);
129+
}
130+
}
131+
102132
static class GetLongNode extends LongNode {
103133
@CompilationFinal(dimensions = 1) byte[] bytes;
104134
final int offset;
@@ -129,6 +159,21 @@ public long execute(VirtualFrame frame) {
129159
}
130160
}
131161

162+
static class GetLongUnalignedNonStableNode extends LongNode {
163+
@CompilationFinal(dimensions = 0) byte[] bytes;
164+
final int offset;
165+
166+
GetLongUnalignedNonStableNode(String hex, int offset) {
167+
this.bytes = hexToBytes(hex);
168+
this.offset = offset;
169+
}
170+
171+
@Override
172+
public long execute(VirtualFrame frame) {
173+
return BYTES.getLongUnaligned(bytes, offset);
174+
}
175+
}
176+
132177
private static byte[] hexToBytes(String s) {
133178
int len = s.length();
134179
byte[] data = new byte[len / 2];
@@ -222,4 +267,44 @@ public void testGetLongUnaligned() {
222267
assertPartialEvalEquals(constLongRootNode(0x1122334455667788L), new LongRootNode(new GetLongUnalignedNode("00000000000000008877665544332211", 8)));
223268
assertPartialEvalEquals(constLongRootNode(0x1122334455667788L), new LongRootNode(new GetLongUnalignedNode("008877665544332211", 1)));
224269
}
270+
271+
@Test
272+
public void testGetUnalignedFromNonStableArray() {
273+
assertPartialEvalEquals(readShortRootNode(), new RootTestNode("getShortUnaligned", new GetShortUnalignedNonStableNode("0089abcdef", 1)));
274+
assertPartialEvalEquals(readIntRootNode(), new RootTestNode("getIntUnaligned", new GetIntUnalignedNonStableNode("000089abcdef", 2)));
275+
assertPartialEvalEquals(readLongRootNode(), new LongRootNode(new GetLongUnalignedNonStableNode("008877665544332211", 1)));
276+
}
277+
278+
private static RootTestNode readShortRootNode() {
279+
return new RootTestNode("readShort", new AbstractTestNode() {
280+
private final byte[] bytes = new byte[8];
281+
282+
@Override
283+
public int execute(VirtualFrame frame) {
284+
return BYTES.getShort(bytes, 0);
285+
}
286+
});
287+
}
288+
289+
private static RootTestNode readIntRootNode() {
290+
return new RootTestNode("readInt", new AbstractTestNode() {
291+
private final byte[] bytes = new byte[8];
292+
293+
@Override
294+
public int execute(VirtualFrame frame) {
295+
return BYTES.getInt(bytes, 0);
296+
}
297+
});
298+
}
299+
300+
private static LongRootNode readLongRootNode() {
301+
return new LongRootNode(new LongNode() {
302+
private final byte[] bytes = new byte[8];
303+
304+
@Override
305+
public long execute(VirtualFrame frame) {
306+
return BYTES.getLong(bytes, 0);
307+
}
308+
});
309+
}
225310
}

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/nodes/ObjectLocationIdentity.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2013, 2018, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2013, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -35,6 +35,13 @@
3535

3636
/**
3737
* A {@link LocationIdentity} wrapping an object.
38+
*
39+
* Used by Truffle unsafe accesses to associate unique location identities with DynamicObject and
40+
* FrameWithoutBoxing accesses, backed by non-null object references (compared by identity). The
41+
* compiler can assume that accesses with different location identities (except for "any") do not
42+
* interfere with each other (or when they do are constrained by memory barriers), even when they
43+
* may access the same relative memory address (array or field offset) of objects of the same array
44+
* or instance class.
3845
*/
3946
public final class ObjectLocationIdentity extends LocationIdentity implements JavaConstantFormattable {
4047

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/substitutions/TruffleGraphBuilderPlugins.java

Lines changed: 132 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import static jdk.graal.compiler.replacements.PEGraphDecoder.Options.MaximumLoopExplosionCount;
2929

3030
import java.lang.reflect.Type;
31+
import java.nio.ByteOrder;
3132
import java.util.ArrayList;
3233
import java.util.Collections;
3334
import java.util.LinkedHashMap;
@@ -40,14 +41,17 @@
4041

4142
import com.oracle.truffle.compiler.TruffleCompilationTask;
4243

44+
import jdk.graal.compiler.core.common.NumUtil;
4345
import jdk.graal.compiler.core.common.calc.CanonicalCondition;
4446
import jdk.graal.compiler.core.common.memory.MemoryOrderMode;
47+
import jdk.graal.compiler.core.common.type.IntegerStamp;
4548
import jdk.graal.compiler.core.common.type.ObjectStamp;
4649
import jdk.graal.compiler.core.common.type.Stamp;
4750
import jdk.graal.compiler.core.common.type.StampFactory;
4851
import jdk.graal.compiler.core.common.type.StampPair;
4952
import jdk.graal.compiler.core.common.type.TypeReference;
5053
import jdk.graal.compiler.debug.DebugContext;
54+
import jdk.graal.compiler.debug.GraalError;
5155
import jdk.graal.compiler.graph.Node;
5256
import jdk.graal.compiler.lir.gen.ArithmeticLIRGeneratorTool.RoundingMode;
5357
import jdk.graal.compiler.nodes.CallTargetNode;
@@ -61,6 +65,7 @@
6165
import jdk.graal.compiler.nodes.InvokeNode;
6266
import jdk.graal.compiler.nodes.LogicConstantNode;
6367
import jdk.graal.compiler.nodes.LogicNode;
68+
import jdk.graal.compiler.nodes.NamedLocationIdentity;
6469
import jdk.graal.compiler.nodes.NodeView;
6570
import jdk.graal.compiler.nodes.PiArrayNode;
6671
import jdk.graal.compiler.nodes.PiNode;
@@ -133,6 +138,7 @@
133138
import jdk.vm.ci.meta.DeoptimizationReason;
134139
import jdk.vm.ci.meta.JavaConstant;
135140
import jdk.vm.ci.meta.JavaKind;
141+
import jdk.vm.ci.meta.MemoryAccessProvider;
136142
import jdk.vm.ci.meta.MetaAccessProvider;
137143
import jdk.vm.ci.meta.ResolvedJavaField;
138144
import jdk.vm.ci.meta.ResolvedJavaMethod;
@@ -172,6 +178,7 @@ public static void registerInvocationPlugins(InvocationPlugins plugins, KnownTru
172178
registerDynamicObjectPlugins(plugins, types, canDelayIntrinsification, providers.getConstantReflection());
173179
registerBufferPlugins(plugins, types, canDelayIntrinsification);
174180
registerMemorySegmentPlugins(plugins, types, canDelayIntrinsification);
181+
registerByteArraySupportPlugins(plugins, canDelayIntrinsification);
175182
}
176183

177184
private static void registerTruffleSafepointPlugins(InvocationPlugins plugins, KnownTruffleTypes types, boolean canDelayIntrinsification) {
@@ -1332,18 +1339,17 @@ static class CustomizedUnsafeStorePlugin extends RequiredInvocationPlugin {
13321339

13331340
@Override
13341341
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode object, ValueNode offset, ValueNode value, ValueNode location) {
1335-
ValueNode locationArgument = location;
1336-
if (locationArgument.isConstant()) {
1342+
if (location.isConstant()) {
13371343
LocationIdentity locationIdentity;
13381344
boolean forceLocation;
1339-
if (locationArgument.isNullConstant()) {
1345+
if (location.isNullConstant()) {
13401346
locationIdentity = LocationIdentity.any();
13411347
forceLocation = false;
1342-
} else if (locationArgument.asJavaConstant().equals(anyConstant)) {
1348+
} else if (location.asJavaConstant().equals(anyConstant)) {
13431349
locationIdentity = LocationIdentity.any();
13441350
forceLocation = true;
13451351
} else {
1346-
locationIdentity = ObjectLocationIdentity.create(locationArgument.asJavaConstant());
1352+
locationIdentity = ObjectLocationIdentity.create(location.asJavaConstant());
13471353
forceLocation = true;
13481354
}
13491355
b.add(new RawStoreNode(object, offset, value, kind, locationIdentity, true, null, forceLocation));
@@ -1377,7 +1383,7 @@ static void logPerformanceWarningLocationNotConstant(ValueNode location, Resolve
13771383
debug.dump(DebugContext.VERBOSE_LEVEL, graph, "perf warn: Location argument is not a partial evaluation constant: %s", location);
13781384
}
13791385
} catch (Throwable t) {
1380-
debug.handle(t);
1386+
throw debug.handle(t);
13811387
}
13821388
}
13831389
}
@@ -1411,16 +1417,15 @@ static void logPerformanceWarningUnsafeCastArgNotConst(ResolvedJavaMethod target
14111417
debug.dump(DebugContext.VERBOSE_LEVEL, graph, "perf warn: unsafeCast arguments could not reduce to a constant: %s, %s, %s", type, nonNull, isExactType);
14121418
}
14131419
} catch (Throwable t) {
1414-
debug.handle(t);
1420+
throw debug.handle(t);
14151421
}
14161422
}
14171423
}
14181424

14191425
static BailoutException failPEConstant(GraphBuilderContext b, ValueNode value) {
14201426
StringBuilder sb = new StringBuilder();
14211427
sb.append(value);
1422-
if (value instanceof ValuePhiNode) {
1423-
ValuePhiNode valuePhi = (ValuePhiNode) value;
1428+
if (value instanceof ValuePhiNode valuePhi) {
14241429
sb.append(" (");
14251430
for (Node n : valuePhi.inputs()) {
14261431
sb.append(n);
@@ -1443,8 +1448,7 @@ private PEConstantPlugin(boolean canDelayIntrinsification, Type... argumentTypes
14431448
@Override
14441449
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value) {
14451450
ValueNode curValue = value;
1446-
if (curValue instanceof BoxNode) {
1447-
BoxNode boxNode = (BoxNode) curValue;
1451+
if (curValue instanceof BoxNode boxNode) {
14481452
curValue = boxNode.getValue();
14491453
}
14501454
if (curValue.isConstant()) {
@@ -1457,4 +1461,121 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
14571461
}
14581462

14591463
}
1464+
1465+
private static void registerByteArraySupportPlugins(InvocationPlugins plugins, boolean canDelayIntrinsification) {
1466+
Registration r = new Registration(plugins, "com.oracle.truffle.api.memory.UnsafeByteArraySupport");
1467+
r.register(new UnsafeGetUnalignedPlugin("unsafeGetShortUnaligned", JavaKind.Short, canDelayIntrinsification));
1468+
r.register(new UnsafeGetUnalignedPlugin("unsafeGetIntUnaligned", JavaKind.Int, canDelayIntrinsification));
1469+
r.register(new UnsafeGetUnalignedPlugin("unsafeGetLongUnaligned", JavaKind.Long, canDelayIntrinsification));
1470+
}
1471+
1472+
private static class UnsafeGetUnalignedPlugin extends OptionalInvocationPlugin {
1473+
private final boolean canDelayIntrinsification;
1474+
private final JavaKind resultKind;
1475+
1476+
UnsafeGetUnalignedPlugin(String name, JavaKind resultKind, boolean canDelayIntrinsification) {
1477+
super(name, byte[].class, long.class);
1478+
this.canDelayIntrinsification = canDelayIntrinsification;
1479+
this.resultKind = resultKind;
1480+
assert resultKind == JavaKind.Short || resultKind == JavaKind.Int || resultKind == JavaKind.Long : resultKind;
1481+
GraalError.guarantee(ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN, "only supported on little-endian architecture");
1482+
}
1483+
1484+
@Override
1485+
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver,
1486+
ValueNode bufferNode, ValueNode byteOffsetNode) {
1487+
if (bufferNode instanceof ConstantNode bufferConstNode && byteOffsetNode.isConstant()) {
1488+
if (bufferConstNode.getStableDimension() == 1) { // implies non-null
1489+
JavaConstant bufferConst = bufferConstNode.asJavaConstant();
1490+
long byteOffset = byteOffsetNode.asJavaConstant().asLong();
1491+
JavaConstant value = readUnaligned(b, resultKind, bufferConst, byteOffset);
1492+
if (value != null && (bufferConstNode.isDefaultStable() || !value.isDefaultForKind())) {
1493+
b.addPush(resultKind, ConstantNode.forPrimitive(value, b.getGraph()));
1494+
return true;
1495+
}
1496+
}
1497+
} else if (canDelayIntrinsification) {
1498+
return false;
1499+
}
1500+
b.addPush(resultKind, new RawLoadNode(bufferNode, byteOffsetNode, resultKind, NamedLocationIdentity.getArrayLocation(JavaKind.Byte), MemoryOrderMode.PLAIN));
1501+
return true;
1502+
}
1503+
1504+
/**
1505+
* Reads a short, int, or long value from a potentially unaligned offset in a byte[] array.
1506+
* Performs a single aligned read if the address is aligned, otherwise combines the results
1507+
* of multiple reads of the next narrower naturally aligned width or individual bytes.
1508+
*
1509+
* @param resultKind value kind, either short, int, or long
1510+
* @param base byte[] array constant, with stable dimensions = 1
1511+
* @param byteOffset byte[] index, not including array base offset
1512+
* @return result value constant or {@code null} if out of bounds
1513+
*/
1514+
@SuppressWarnings("fallthrough")
1515+
private static JavaConstant readUnaligned(GraphBuilderContext b, JavaKind resultKind, JavaConstant base, long byteOffset) {
1516+
ConstantReflectionProvider constantReflection = b.getConstantReflection();
1517+
MemoryAccessProvider memoryAccessProvider = constantReflection.getMemoryAccessProvider();
1518+
long displacement = b.getMetaAccess().getArrayBaseOffset(JavaKind.Byte) + byteOffset;
1519+
int resultBytes = resultKind.getByteCount();
1520+
if (displacement % resultBytes == 0) {
1521+
// Already aligned, so we can read the value directly.
1522+
IntegerStamp accessStamp = StampFactory.forInteger(resultKind.getBitCount());
1523+
return (JavaConstant) accessStamp.readConstant(memoryAccessProvider, base, displacement);
1524+
}
1525+
1526+
// Figure out if we can read the value in wider-than-byte aligned parts.
1527+
JavaKind alignedKind = null;
1528+
switch (resultKind) {
1529+
case Long:
1530+
if (displacement % Integer.BYTES == 0) {
1531+
alignedKind = JavaKind.Int;
1532+
break;
1533+
}
1534+
// fallthrough
1535+
case Int:
1536+
if (displacement % Short.BYTES == 0) {
1537+
alignedKind = JavaKind.Short;
1538+
break;
1539+
}
1540+
break;
1541+
}
1542+
if (alignedKind != null) {
1543+
long value = 0;
1544+
long mask = NumUtil.getNbitNumberLong(alignedKind.getBitCount());
1545+
IntegerStamp accessStamp = StampFactory.forInteger(alignedKind.getBitCount());
1546+
for (int byteCount = 0; byteCount < resultBytes; byteCount += alignedKind.getByteCount()) {
1547+
var part = (JavaConstant) accessStamp.readConstant(memoryAccessProvider, base, displacement + byteCount);
1548+
if (part == null) {
1549+
/*
1550+
* Should not normally happen if base+displacement is aligned and in bounds;
1551+
* but in the unexpected case that the read fails, handle it gracefully.
1552+
*/
1553+
return null;
1554+
}
1555+
value |= ((part.asLong() & mask) << (byteCount * Byte.SIZE));
1556+
}
1557+
return JavaConstant.forPrimitive(resultKind, value);
1558+
}
1559+
1560+
// Displacement is odd, so we have to read the value byte-by-byte.
1561+
assert displacement % 2 != 0 : displacement;
1562+
long value = 0;
1563+
int byteOffsetAsInt = NumUtil.safeToInt(byteOffset);
1564+
for (int byteCount = 0; byteCount < resultBytes; byteCount += 2) {
1565+
JavaConstant b0 = constantReflection.readArrayElement(base, byteOffsetAsInt + byteCount);
1566+
JavaConstant b1 = constantReflection.readArrayElement(base, byteOffsetAsInt + byteCount + 1);
1567+
if (b0 == null || b1 == null) {
1568+
/*
1569+
* Byte offset is out of bounds. This is not necessarily an error since it
1570+
* depends on control flow / bounds checks if this read is actually reachable,
1571+
* so we must not fail compilation. We can either deoptimize here or fall back
1572+
* to a normal unsafe read.
1573+
*/
1574+
return null;
1575+
}
1576+
value |= (b0.asInt() & 0xffL | ((b1.asInt() & 0xffL) << Byte.SIZE)) << (byteCount * Byte.SIZE);
1577+
}
1578+
return JavaConstant.forPrimitive(resultKind, value);
1579+
}
1580+
}
14601581
}

0 commit comments

Comments
 (0)