Skip to content

Commit

Permalink
Start on custom whitelists for Painless (#23563)
Browse files Browse the repository at this point in the history
We'd like to be able to support context-sensitive whitelists in
Painless but we can't now because the whitelist is a static thing.
This begins to de-static the whitelist, in particular removing
the static keyword from most of the methods on `Definition` and
plumbing the static instance into the appropriate spots as though
it weren't static. Once we de-static all the methods we should be
able to fairly simply build context-sensitive whitelists.

The only "fun" bit of this is that I added another layer in the
chain of methods that bootstraps `def` calls. Instead of running
`invokedynamic` directly on `DefBootstrap` we now `invokedynamic`
`$bootstrapDef` on the script itself loads the `Definition` that
the script was compiled against and then calls `DefBootstrap`.

I chose to put `Definition` into `Locals` so I didn't have to
change the signature of all the `analyze` methods. I could have
do it another way, but that seems ok for now.
  • Loading branch information
nik9000 authored Apr 18, 2017
1 parent 8f54034 commit 0b15fde
Show file tree
Hide file tree
Showing 34 changed files with 318 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,18 @@ static <T> T compile(Loader loader, Class<T> iface, String name, String source,
" characters. The passed in script is " + source.length() + " characters. Consider using a" +
" plugin if a script longer than this length is a requirement.");
}
ScriptInterface scriptInterface = new ScriptInterface(iface);
Definition definition = Definition.BUILTINS;
ScriptInterface scriptInterface = new ScriptInterface(definition, iface);

SSource root = Walker.buildPainlessTree(scriptInterface, name, source, settings, null);
SSource root = Walker.buildPainlessTree(scriptInterface, name, source, settings, definition,
null);

root.analyze();
root.analyze(definition);
root.write();

try {
Class<? extends PainlessScript> clazz = loader.define(CLASS_NAME, root.getBytes());
clazz.getField("$DEFINITION").set(null, definition);
java.lang.reflect.Constructor<? extends PainlessScript> constructor =
clazz.getConstructor(String.class, String.class, BitSet.class);

Expand All @@ -131,11 +134,13 @@ static byte[] compile(Class<?> iface, String name, String source, CompilerSettin
" characters. The passed in script is " + source.length() + " characters. Consider using a" +
" plugin if a script longer than this length is a requirement.");
}
ScriptInterface scriptInterface = new ScriptInterface(iface);
Definition definition = Definition.BUILTINS;
ScriptInterface scriptInterface = new ScriptInterface(definition, iface);

SSource root = Walker.buildPainlessTree(scriptInterface, name, source, settings, debugStream);
SSource root = Walker.buildPainlessTree(scriptInterface, name, source, settings, definition,
debugStream);

root.analyze();
root.analyze(definition);
root.write();

return root.getBytes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,18 @@ static MethodHandle arrayLengthGetter(Class<?> arrayType) {
* until it finds a matching whitelisted method. If one is not found, it throws an exception.
* Otherwise it returns the matching method.
* <p>
* @params definition the whitelist
* @param receiverClass Class of the object to invoke the method on.
* @param name Name of the method.
* @param arity arity of method
* @return matching method to invoke. never returns null.
* @throws IllegalArgumentException if no matching whitelisted method was found.
*/
static Method lookupMethodInternal(Class<?> receiverClass, String name, int arity) {
static Method lookupMethodInternal(Definition definition, Class<?> receiverClass, String name, int arity) {
Definition.MethodKey key = new Definition.MethodKey(name, arity);
// check whitelist for matching method
for (Class<?> clazz = receiverClass; clazz != null; clazz = clazz.getSuperclass()) {
RuntimeClass struct = Definition.getRuntimeClass(clazz);
RuntimeClass struct = definition.getRuntimeClass(clazz);

if (struct != null) {
Method method = struct.methods.get(key);
Expand All @@ -195,7 +196,7 @@ static Method lookupMethodInternal(Class<?> receiverClass, String name, int arit
}

for (Class<?> iface : clazz.getInterfaces()) {
struct = Definition.getRuntimeClass(iface);
struct = definition.getRuntimeClass(iface);

if (struct != null) {
Method method = struct.methods.get(key);
Expand All @@ -220,6 +221,7 @@ static Method lookupMethodInternal(Class<?> receiverClass, String name, int arit
* until it finds a matching whitelisted method. If one is not found, it throws an exception.
* Otherwise it returns a handle to the matching method.
* <p>
* @param definition the whitelist
* @param lookup caller's lookup
* @param callSiteType callsite's type
* @param receiverClass Class of the object to invoke the method on.
Expand All @@ -229,13 +231,13 @@ static Method lookupMethodInternal(Class<?> receiverClass, String name, int arit
* @throws IllegalArgumentException if no matching whitelisted method was found.
* @throws Throwable if a method reference cannot be converted to an functional interface
*/
static MethodHandle lookupMethod(Lookup lookup, MethodType callSiteType,
static MethodHandle lookupMethod(Definition definition, Lookup lookup, MethodType callSiteType,
Class<?> receiverClass, String name, Object args[]) throws Throwable {
String recipeString = (String) args[0];
int numArguments = callSiteType.parameterCount();
// simple case: no lambdas
if (recipeString.isEmpty()) {
return lookupMethodInternal(receiverClass, name, numArguments - 1).handle;
return lookupMethodInternal(definition, receiverClass, name, numArguments - 1).handle;
}

// convert recipe string to a bitset for convenience (the code below should be refactored...)
Expand All @@ -258,7 +260,7 @@ static MethodHandle lookupMethod(Lookup lookup, MethodType callSiteType,

// lookup the method with the proper arity, then we know everything (e.g. interface types of parameters).
// based on these we can finally link any remaining lambdas that were deferred.
Method method = lookupMethodInternal(receiverClass, name, arity);
Method method = lookupMethodInternal(definition, receiverClass, name, arity);
MethodHandle handle = method.handle;

int replaced = 0;
Expand All @@ -282,7 +284,8 @@ static MethodHandle lookupMethod(Lookup lookup, MethodType callSiteType,
if (signature.charAt(0) == 'S') {
// the implementation is strongly typed, now that we know the interface type,
// we have everything.
filter = lookupReferenceInternal(lookup,
filter = lookupReferenceInternal(definition,
lookup,
interfaceType,
type,
call,
Expand All @@ -292,7 +295,8 @@ static MethodHandle lookupMethod(Lookup lookup, MethodType callSiteType,
// this is dynamically based on the receiver type (and cached separately, underneath
// this cache). It won't blow up since we never nest here (just references)
MethodType nestedType = MethodType.methodType(interfaceType.clazz, captures);
CallSite nested = DefBootstrap.bootstrap(lookup,
CallSite nested = DefBootstrap.bootstrap(definition,
lookup,
call,
nestedType,
0,
Expand All @@ -319,21 +323,23 @@ static MethodHandle lookupMethod(Lookup lookup, MethodType callSiteType,
* This is just like LambdaMetaFactory, only with a dynamic type. The interface type is known,
* so we simply need to lookup the matching implementation method based on receiver type.
*/
static MethodHandle lookupReference(Lookup lookup, String interfaceClass,
Class<?> receiverClass, String name) throws Throwable {
Definition.Type interfaceType = Definition.getType(interfaceClass);
static MethodHandle lookupReference(Definition definition, Lookup lookup, String interfaceClass,
Class<?> receiverClass, String name) throws Throwable {
Definition.Type interfaceType = definition.getType(interfaceClass);
Method interfaceMethod = interfaceType.struct.getFunctionalMethod();
if (interfaceMethod == null) {
throw new IllegalArgumentException("Class [" + interfaceClass + "] is not a functional interface");
}
int arity = interfaceMethod.arguments.size();
Method implMethod = lookupMethodInternal(receiverClass, name, arity);
return lookupReferenceInternal(lookup, interfaceType, implMethod.owner.name, implMethod.name, receiverClass);
Method implMethod = lookupMethodInternal(definition, receiverClass, name, arity);
return lookupReferenceInternal(definition, lookup, interfaceType, implMethod.owner.name,
implMethod.name, receiverClass);
}

/** Returns a method handle to an implementation of clazz, given method reference signature. */
private static MethodHandle lookupReferenceInternal(Lookup lookup, Definition.Type clazz, String type,
String call, Class<?>... captures) throws Throwable {
private static MethodHandle lookupReferenceInternal(Definition definition, Lookup lookup,
Definition.Type clazz, String type, String call, Class<?>... captures)
throws Throwable {
final FunctionRef ref;
if ("this".equals(type)) {
// user written method
Expand Down Expand Up @@ -361,7 +367,7 @@ private static MethodHandle lookupReferenceInternal(Lookup lookup, Definition.Ty
ref = new FunctionRef(clazz, interfaceMethod, handle, captures.length);
} else {
// whitelist lookup
ref = new FunctionRef(clazz, type, call, captures.length);
ref = new FunctionRef(definition, clazz, type, call, captures.length);
}
final CallSite callSite;
if (ref.needsBridges()) {
Expand Down Expand Up @@ -411,15 +417,16 @@ public static String getUserFunctionHandleFieldName(String name, int arity) {
* until it finds a matching whitelisted getter. If one is not found, it throws an exception.
* Otherwise it returns a handle to the matching getter.
* <p>
* @param definition the whitelist
* @param receiverClass Class of the object to retrieve the field from.
* @param name Name of the field.
* @return pointer to matching field. never returns null.
* @throws IllegalArgumentException if no matching whitelisted field was found.
*/
static MethodHandle lookupGetter(Class<?> receiverClass, String name) {
static MethodHandle lookupGetter(Definition definition, Class<?> receiverClass, String name) {
// first try whitelist
for (Class<?> clazz = receiverClass; clazz != null; clazz = clazz.getSuperclass()) {
RuntimeClass struct = Definition.getRuntimeClass(clazz);
RuntimeClass struct = definition.getRuntimeClass(clazz);

if (struct != null) {
MethodHandle handle = struct.getters.get(name);
Expand All @@ -429,7 +436,7 @@ static MethodHandle lookupGetter(Class<?> receiverClass, String name) {
}

for (final Class<?> iface : clazz.getInterfaces()) {
struct = Definition.getRuntimeClass(iface);
struct = definition.getRuntimeClass(iface);

if (struct != null) {
MethodHandle handle = struct.getters.get(name);
Expand Down Expand Up @@ -481,15 +488,16 @@ static MethodHandle lookupGetter(Class<?> receiverClass, String name) {
* until it finds a matching whitelisted setter. If one is not found, it throws an exception.
* Otherwise it returns a handle to the matching setter.
* <p>
* @param definition the whitelist
* @param receiverClass Class of the object to retrieve the field from.
* @param name Name of the field.
* @return pointer to matching field. never returns null.
* @throws IllegalArgumentException if no matching whitelisted field was found.
*/
static MethodHandle lookupSetter(Class<?> receiverClass, String name) {
static MethodHandle lookupSetter(Definition definition, Class<?> receiverClass, String name) {
// first try whitelist
for (Class<?> clazz = receiverClass; clazz != null; clazz = clazz.getSuperclass()) {
RuntimeClass struct = Definition.getRuntimeClass(clazz);
RuntimeClass struct = definition.getRuntimeClass(clazz);

if (struct != null) {
MethodHandle handle = struct.setters.get(name);
Expand All @@ -499,7 +507,7 @@ static MethodHandle lookupSetter(Class<?> receiverClass, String name) {
}

for (final Class<?> iface : clazz.getInterfaces()) {
struct = Definition.getRuntimeClass(iface);
struct = definition.getRuntimeClass(iface);

if (struct != null) {
MethodHandle handle = struct.setters.get(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,19 @@ static final class PIC extends MutableCallSite {
/** maximum number of types before we go megamorphic */
static final int MAX_DEPTH = 5;

private final Definition definition;
private final Lookup lookup;
private final String name;
private final int flavor;
private final Object[] args;
int depth; // pkg-protected for testing

PIC(Lookup lookup, String name, MethodType type, int initialDepth, int flavor, Object[] args) {
PIC(Definition definition, Lookup lookup, String name, MethodType type, int initialDepth, int flavor, Object[] args) {
super(type);
if (type.parameterType(0) != Object.class) {
throw new BootstrapMethodError("The receiver type (1st arg) of invokedynamic descriptor must be Object.");
}
this.definition = definition;
this.lookup = lookup;
this.name = name;
this.flavor = flavor;
Expand Down Expand Up @@ -142,19 +144,19 @@ static boolean checkClass(Class<?> clazz, Object receiver) {
private MethodHandle lookup(int flavor, String name, Class<?> receiver) throws Throwable {
switch(flavor) {
case METHOD_CALL:
return Def.lookupMethod(lookup, type(), receiver, name, args);
return Def.lookupMethod(definition, lookup, type(), receiver, name, args);
case LOAD:
return Def.lookupGetter(receiver, name);
return Def.lookupGetter(definition, receiver, name);
case STORE:
return Def.lookupSetter(receiver, name);
return Def.lookupSetter(definition, receiver, name);
case ARRAY_LOAD:
return Def.lookupArrayLoad(receiver);
case ARRAY_STORE:
return Def.lookupArrayStore(receiver);
case ITERATOR:
return Def.lookupIterator(receiver);
case REFERENCE:
return Def.lookupReference(lookup, (String) args[0], receiver, name);
return Def.lookupReference(definition, lookup, (String) args[0], receiver, name);
case INDEX_NORMALIZE:
return Def.lookupIndexNormalize(receiver);
default: throw new AssertionError();
Expand Down Expand Up @@ -237,7 +239,7 @@ Object fallback(final Object[] callArgs) throws Throwable {
*/
static final class MIC extends MutableCallSite {
private boolean initialized;

private final String name;
private final int flavor;
private final int flags;
Expand Down Expand Up @@ -419,16 +421,18 @@ static boolean checkBoth(Class<?> left, Class<?> right, Object leftObject, Objec
/**
* invokeDynamic bootstrap method
* <p>
* In addition to ordinary parameters, we also take some static parameters:
* In addition to ordinary parameters, we also take some parameters defined at the call site:
* <ul>
* <li>{@code initialDepth}: initial call site depth. this is used to exercise megamorphic fallback.
* <li>{@code flavor}: type of dynamic call it is (and which part of whitelist to look at).
* <li>{@code args}: flavor-specific args.
* </ul>
* And we take the {@link Definition} used to compile the script for whitelist checking.
* <p>
* see https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-6.html#jvms-6.5.invokedynamic
*/
public static CallSite bootstrap(Lookup lookup, String name, MethodType type, int initialDepth, int flavor, Object... args) {
public static CallSite bootstrap(Definition definition, Lookup lookup, String name, MethodType type, int initialDepth, int flavor,
Object... args) {
// validate arguments
switch(flavor) {
// "function-call" like things get a polymorphic cache
Expand All @@ -447,7 +451,7 @@ public static CallSite bootstrap(Lookup lookup, String name, MethodType type, in
if (args.length != numLambdas + 1) {
throw new BootstrapMethodError("Illegal number of parameters: expected " + numLambdas + " references");
}
return new PIC(lookup, name, type, initialDepth, flavor, args);
return new PIC(definition, lookup, name, type, initialDepth, flavor, args);
case LOAD:
case STORE:
case ARRAY_LOAD:
Expand All @@ -457,15 +461,15 @@ public static CallSite bootstrap(Lookup lookup, String name, MethodType type, in
if (args.length > 0) {
throw new BootstrapMethodError("Illegal static bootstrap parameters for flavor: " + flavor);
}
return new PIC(lookup, name, type, initialDepth, flavor, args);
return new PIC(definition, lookup, name, type, initialDepth, flavor, args);
case REFERENCE:
if (args.length != 1) {
throw new BootstrapMethodError("Invalid number of parameters for reference call");
}
if (args[0] instanceof String == false) {
throw new BootstrapMethodError("Illegal parameter for reference call: " + args[0]);
}
return new PIC(lookup, name, type, initialDepth, flavor, args);
return new PIC(definition, lookup, name, type, initialDepth, flavor, args);

// operators get monomorphic cache, with a generic impl for a fallback
case UNARY_OPERATOR:
Expand Down
Loading

0 comments on commit 0b15fde

Please sign in to comment.