Skip to content

Commit

Permalink
Merge pull request #7406 from fabriziofortino/develop
Browse files Browse the repository at this point in the history
OSQLStaticReflectiveFunction to automatically bind static methods. 

Add `java.lang.Math` static methods as functions with `math_` prefix
  • Loading branch information
luigidellaquila authored May 11, 2017
2 parents 29c4be1 + 79af48e commit b8b556f
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.orientechnologies.orient.core.sql.functions;

import com.orientechnologies.common.exception.OException;
import com.orientechnologies.common.log.OLogManager;
import com.orientechnologies.orient.core.exception.OCommandExecutionException;
import com.orientechnologies.orient.core.sql.functions.coll.*;
import com.orientechnologies.orient.core.sql.functions.geo.OSQLFunctionDistance;
Expand All @@ -27,9 +28,10 @@
import com.orientechnologies.orient.core.sql.functions.text.OSQLFunctionConcat;
import com.orientechnologies.orient.core.sql.functions.text.OSQLFunctionFormat;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.stream.Collectors;

/**
* Default set of SQL function.
Expand All @@ -38,7 +40,7 @@
*/
public final class ODefaultSQLFunctionFactory implements OSQLFunctionFactory {

private static final Map<String, Object> FUNCTIONS = new HashMap<String, Object>();
private static final Map<String, Object> FUNCTIONS = new HashMap<>();
static {
// MISC FUNCTIONS
register(OSQLFunctionAverage.NAME, OSQLFunctionAverage.class);
Expand Down Expand Up @@ -94,13 +96,40 @@ public final class ODefaultSQLFunctionFactory implements OSQLFunctionFactory {
register(OSQLFunctionShortestPath.NAME, OSQLFunctionShortestPath.class);
register(OSQLFunctionDijkstra.NAME, OSQLFunctionDijkstra.class);
register(OSQLFunctionAstar.NAME, OSQLFunctionAstar.class);

// auto-register all Math.<method>() automatically with math_ prefix
registerStaticReflectiveFunctions("math_", Math.class);
}

public static void register(final String iName, final Object iImplementation) {
FUNCTIONS.put(iName.toLowerCase(), iImplementation);
}

private static void registerStaticReflectiveFunctions(final String prefix, final Class<?> clazz) {
final Map<String, List<Method>> methodsMap = Arrays.stream(clazz.getMethods())
.filter(m -> Modifier.isStatic(m.getModifiers()))
.collect(Collectors.groupingBy(Method::getName));

for (Map.Entry<String, List<Method>> entry : methodsMap.entrySet()) {
final String name = prefix + entry.getKey();
if (FUNCTIONS.containsKey(name)) {
OLogManager.instance().warn(null, "Unable to register reflective function with name " + name);
} else {
List<Method> methodsList = methodsMap.get(entry.getKey());
Method[] methods = new Method[methodsList.size()];
int i = 0;
int minParams = 0;
int maxParams = 0;
for (Method m : methodsList) {
methods[i++] = m;
minParams = minParams < m.getParameterTypes().length ? minParams : m.getParameterTypes().length;
maxParams = maxParams > m.getParameterTypes().length ? maxParams : m.getParameterTypes().length;
}
register(name, new OSQLStaticReflectiveFunction(name, minParams, maxParams, methods));
}
}

}

@Override
public Set<String> getFunctionNames() {
return FUNCTIONS.keySet();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package com.orientechnologies.orient.core.sql.functions.misc;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.orientechnologies.orient.core.command.OCommandContext;
import com.orientechnologies.orient.core.db.record.OIdentifiable;
import com.orientechnologies.orient.core.exception.OQueryParsingException;
import com.orientechnologies.orient.core.sql.functions.OSQLFunction;
import com.orientechnologies.orient.core.sql.functions.OSQLFunctionAbstract;

/**
* This {@link OSQLFunction} is able to invoke a static method using reflection. If contains more than one {@link Method} it tries
* to pick the one that better fits the input parameters.
*
* @author Fabrizio Fortino
*/
public class OSQLStaticReflectiveFunction extends OSQLFunctionAbstract {

private static final Map<Class<?>, Class<?>> PRIMITIVE_TO_WRAPPER = new HashMap<>();
static {
PRIMITIVE_TO_WRAPPER.put(Boolean.TYPE, Boolean.class);
PRIMITIVE_TO_WRAPPER.put(Byte.TYPE, Byte.class);
PRIMITIVE_TO_WRAPPER.put(Character.TYPE, Character.class);
PRIMITIVE_TO_WRAPPER.put(Short.TYPE, Short.class);
PRIMITIVE_TO_WRAPPER.put(Integer.TYPE, Integer.class);
PRIMITIVE_TO_WRAPPER.put(Long.TYPE, Long.class);
PRIMITIVE_TO_WRAPPER.put(Double.TYPE, Double.class);
PRIMITIVE_TO_WRAPPER.put(Float.TYPE, Float.class);
PRIMITIVE_TO_WRAPPER.put(Void.TYPE, Void.TYPE);
}

private static final Map<Class<?>, Class<?>> WRAPPER_TO_PRIMITIVE = new HashMap<>();
static {
for (Class<?> primitive : PRIMITIVE_TO_WRAPPER.keySet()) {
Class<?> wrapper = PRIMITIVE_TO_WRAPPER.get(primitive);
if (!primitive.equals(wrapper)) {
WRAPPER_TO_PRIMITIVE.put(wrapper, primitive);
}
}
}

private static final Map<Class<?>, Integer> PRIMITIVE_WEIGHT = new HashMap<>();
static {
PRIMITIVE_WEIGHT.put(boolean.class, 1);
PRIMITIVE_WEIGHT.put(char.class, 2);
PRIMITIVE_WEIGHT.put(byte.class, 3);
PRIMITIVE_WEIGHT.put(short.class, 4);
PRIMITIVE_WEIGHT.put(int.class, 5);
PRIMITIVE_WEIGHT.put(long.class, 6);
PRIMITIVE_WEIGHT.put(float.class, 7);
PRIMITIVE_WEIGHT.put(double.class, 8);
PRIMITIVE_WEIGHT.put(void.class, 9);
}

private Method[] methods;

public OSQLStaticReflectiveFunction(String name, int minParams, int maxParams, Method... methods) {
super(name, minParams, maxParams);
this.methods = methods;
// we need to sort the methods by parameters type to return the closest overloaded method
Arrays.sort(methods, (m1, m2) -> {
Class<?>[] m1Params = m1.getParameterTypes();
Class<?>[] m2Params = m2.getParameterTypes();

int c = m1Params.length - m2Params.length;
if (c == 0) {
for (int i = 0; i < m1Params.length; i++) {
if (m1Params[i].isPrimitive() && m2Params[i].isPrimitive() && !m1Params[i].equals(m2Params[i])) {
c += PRIMITIVE_WEIGHT.get(m1Params[i]) - PRIMITIVE_WEIGHT.get(m2Params[i]);
}
}
}

return c;
});
}

@Override
public Object execute(Object iThis, OIdentifiable iCurrentRecord, Object iCurrentResult, Object[] iParams,
OCommandContext iContext) {

final Supplier<String> paramsPrettyPrint = () ->
Arrays.stream(iParams).map(p -> p + " [ " + p.getClass().getName() + " ]").collect(Collectors.joining(", ", "(", ")"));

Method method = pickMethod(iParams);

if (method == null) {
throw new OQueryParsingException("Unable to find a function for " + name + paramsPrettyPrint.get());
}

try {
return method.invoke(null, iParams);
} catch (ReflectiveOperationException | IllegalArgumentException e) {
e.printStackTrace();
throw new OQueryParsingException("Error executing function " + name + paramsPrettyPrint.get());
}

}

@Override
public String getSyntax() {
return this.getName();
}

private Method pickMethod(Object[] iParams) {
Method method = null;

boolean match = false;
for (Method m : methods) {
Class<?>[] parameterTypes = m.getParameterTypes();
if (iParams.length == parameterTypes.length) {
for (int i = 0; i < parameterTypes.length; i++) {
if (isAssignable(iParams[i].getClass(), parameterTypes[i])) {
match = true;
break;
}
}

if (iParams.length == 0 || match) {
method = m;
break;
}
}
}

return method;
}

private static boolean isAssignable(final Class<?> iFromClass, final Class<?> iToClass) {
// handle autoboxing
final BiFunction<Class<?>, Class<?>, Class<?>> autoboxer = (from, to) -> {
if (from.isPrimitive() && !to.isPrimitive()) {
return PRIMITIVE_TO_WRAPPER.get(from);
} else if (to.isPrimitive() && !from.isPrimitive()) {
return WRAPPER_TO_PRIMITIVE.get(from);
} else return from;
};

final Class<?> fromClass = autoboxer.apply(iFromClass, iToClass);

if (fromClass == null) {
return false;
} else if (fromClass.equals(iToClass)) {
return true;
} else if (fromClass.isPrimitive()) {
if (!iToClass.isPrimitive()) {
return false;
} else if (Integer.TYPE.equals(fromClass)) {
return Long.TYPE.equals(iToClass) || Float.TYPE.equals(iToClass) || Double.TYPE.equals(iToClass);
} else if (Long.TYPE.equals(fromClass)) {
return Float.TYPE.equals(iToClass) || Double.TYPE.equals(iToClass);
} else if (Boolean.TYPE.equals(fromClass)) {
return false;
} else if (Double.TYPE.equals(fromClass)) {
return false;
} else if (Float.TYPE.equals(fromClass)) {
return Double.TYPE.equals(iToClass);
} else if (Character.TYPE.equals(fromClass)) {
return Integer.TYPE.equals(iToClass) || Long.TYPE.equals(iToClass) || Float.TYPE.equals(iToClass)
|| Double.TYPE.equals(iToClass);
} else if (Short.TYPE.equals(fromClass)) {
return Integer.TYPE.equals(iToClass) || Long.TYPE.equals(iToClass) || Float.TYPE.equals(iToClass)
|| Double.TYPE.equals(iToClass);
} else if (Byte.TYPE.equals(fromClass)) {
return Short.TYPE.equals(iToClass) || Integer.TYPE.equals(iToClass) || Long.TYPE.equals(iToClass)
|| Float.TYPE.equals(iToClass) || Double.TYPE.equals(iToClass);
}
// this should never happen
return false;
}
return iToClass.isAssignableFrom(fromClass);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.orientechnologies.orient.core.sql.functions.sql;

import java.util.List;

import com.orientechnologies.orient.core.db.document.ODatabaseDocumentTx;
import com.orientechnologies.orient.core.exception.OQueryParsingException;
import com.orientechnologies.orient.core.record.impl.ODocument;
import com.orientechnologies.orient.core.sql.query.OSQLSynchQuery;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import static org.junit.Assert.*;

public class OSQLMathFunctionsTest {

private static ODatabaseDocumentTx db;

@BeforeClass
public static void beforeClass() {
db = new ODatabaseDocumentTx("memory:" + OSQLMathFunctionsTest.class.getSimpleName());
db.create();
}

@AfterClass
public static void afterClass() {
db.close();
}

@Test
public void testRandom() {
List<ODocument> result = db.query(new OSQLSynchQuery<ODocument>("select math_random() as random"));
assertTrue((Double) result.get(0).field("random") > 0);
}

@Test
public void testLog10() {
List<ODocument> result = db.query(new OSQLSynchQuery<ODocument>("select math_log10(10000) as log10"));
assertEquals((Double) result.get(0).field("log10"), 4.0, 0.0001);
}

@Test
public void testAbsInt() {
List<ODocument> result = db.query(new OSQLSynchQuery<ODocument>("select math_abs(-5) as abs"));
assertTrue((Integer)result.get(0).field("abs") == 5);
}

@Test
public void testAbsDouble() {
List<ODocument> result = db.query(new OSQLSynchQuery<ODocument>("select math_abs(-5.0d) as abs"));
assertTrue((Double)result.get(0).field("abs") == 5.0);
}

@Test
public void testAbsFloat() {
List<ODocument> result = db.query(new OSQLSynchQuery<ODocument>("select math_abs(-5.0f) as abs"));
assertTrue((Float)result.get(0).field("abs") == 5.0);
}

@Test(expected = OQueryParsingException.class)
public void testNonExistingFunction() {
db.query(new OSQLSynchQuery<ODocument>("select math_min('boom', 'boom') as boom"));
}
}

0 comments on commit b8b556f

Please sign in to comment.