Skip to content

Commit

Permalink
fix: Generate override methods when parent's parameters have a differ…
Browse files Browse the repository at this point in the history
…ent signature
  • Loading branch information
Christopher-Chianelli authored and triceo committed Sep 26, 2024
1 parent 0801dfd commit 1482036
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package ai.timefold.jpyinterpreter;

import static ai.timefold.jpyinterpreter.PythonBytecodeToJavaBytecodeTranslator.ARGUMENT_SPEC_INSTANCE_FIELD_NAME;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -46,9 +45,11 @@
import ai.timefold.jpyinterpreter.types.PythonString;
import ai.timefold.jpyinterpreter.types.collections.PythonLikeDict;
import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple;
import ai.timefold.jpyinterpreter.types.errors.NotImplementedError;
import ai.timefold.jpyinterpreter.types.wrappers.JavaObjectWrapper;
import ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference;
import ai.timefold.jpyinterpreter.util.JavaPythonClassWriter;
import ai.timefold.jpyinterpreter.util.OverrideMethod;
import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec;

import org.objectweb.asm.ClassWriter;
Expand All @@ -60,6 +61,8 @@
import org.objectweb.asm.signature.SignatureVisitor;
import org.objectweb.asm.signature.SignatureWriter;

import static ai.timefold.jpyinterpreter.PythonBytecodeToJavaBytecodeTranslator.ARGUMENT_SPEC_INSTANCE_FIELD_NAME;

public class PythonClassTranslator {
static Map<FunctionSignature, InterfaceDeclaration> functionSignatureToInterfaceName = new HashMap<>();

Expand Down Expand Up @@ -527,7 +530,8 @@ private static Class<?> createBytecodeForMethodAndSetOnClass(String className, P
PythonCompiledFunction function = methodEntry.getValue();
pythonLikeType.clearMethod(methodEntry.getKey());
String javaMethodDescriptor = Arrays.stream(generatedClass.getDeclaredMethods())
.filter(method -> method.getName().equals(getJavaMethodName(methodEntry.getKey())))
.filter(method -> method.getName().equals(getJavaMethodName(methodEntry.getKey()))
&& !method.isAnnotationPresent(OverrideMethod.class))
.map(Type::getMethodDescriptor)
.findFirst().orElseThrow();
ArgumentSpec<?> argumentSpec = function.getArgumentSpecMapper()
Expand Down Expand Up @@ -982,6 +986,25 @@ private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWri
interfaceDeclaration.methodDescriptor, function,
interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor);

Set<String> overrides = new HashSet<>();
for (var parent : pythonLikeType.getParentList()) {
try {
var parentType = parent.getJavaClass();
for (var method : parentType.getMethods()) {
var parentMethodDescriptor = Type.getMethodDescriptor(method);
if (method.getName().equals(javaMethodName)
&& !Modifier.isStatic(method.getModifiers())
&& !parentMethodDescriptor.equals(javaMethodDescriptor)
&& !overrides.contains(parentMethodDescriptor)) {
overrides.add(parentMethodDescriptor);
createOverrideMethod(classWriter, internalClassName, method, javaMethodDescriptor, javaParameterTypes);
}
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}

pythonLikeType.addMethod(methodName,
new PythonFunctionSignature(new MethodDescriptor(internalClassName, MethodDescriptor.MethodType.VIRTUAL,
javaMethodName, javaMethodDescriptor),
Expand Down Expand Up @@ -1102,9 +1125,41 @@ private static void createInstanceOrStaticMethodBody(String internalClassName, S
methodVisitor.visitEnd();
}

private static void createOverrideMethod(ClassWriter classWriter, String internalClassName, Method overridenMethod,
String overrideMethodDescriptor, Type[] overrideParameterTypes) {
var methodVisitor = classWriter.visitMethod(Modifier.PUBLIC, overridenMethod.getName(),
Type.getMethodDescriptor(overridenMethod), null, null);
methodVisitor.visitAnnotation(Type.getDescriptor(OverrideMethod.class), true).visitEnd();

methodVisitor.visitCode();

if (overridenMethod.getParameterCount() != overrideParameterTypes.length) {
methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(NotImplementedError.class));
methodVisitor.visitInsn(Opcodes.DUP);
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(NotImplementedError.class),
"<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false);
methodVisitor.visitInsn(Opcodes.ATHROW);
} else {
methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
for (int i = 0; i < overrideParameterTypes.length; i++) {
methodVisitor.visitVarInsn(Opcodes.ALOAD, i + 1);
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, overrideParameterTypes[i].getInternalName());
}
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, internalClassName, overridenMethod.getName(),
overrideMethodDescriptor, false);
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(overridenMethod.getReturnType()));
methodVisitor.visitInsn(Opcodes.ARETURN);
}

methodVisitor.visitMaxs(-1, -1);
methodVisitor.visitEnd();
}

public static Type getVirtualFunctionReturnType(PythonCompiledFunction function) {
return Type.getType('L' + function.getReturnType().map(PythonLikeType::getJavaTypeInternalName)
.orElseGet(() -> getPythonReturnTypeOfFunction(function, true).getJavaTypeInternalName()) + ';');
// Do not determine return type from method body if type annotation absent,
// since overrides might return a different type
var returnType = function.getReturnType().orElse(BuiltinTypes.BASE_TYPE);
return Type.getType(returnType.getJavaTypeDescriptor());
}

public static String getFunctionSignature(PythonCompiledFunction function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ private void implementMethod(ClassWriter classWriter, PythonCompiledClass compil
for (int i = 1; i < argumentCount; i++) {
javaParameterTypes[i - 1] = methodType.getArgumentTypes()[i];
}
String javaMethodDescriptor = Type.getMethodDescriptor(methodType.getReturnType(), javaParameterTypes);
var methodReturnType =
PythonClassTranslator.getVirtualFunctionReturnType(compiledClass.instanceFunctionNameToPythonBytecode
.get(interfaceMethod.getName()));
String javaMethodDescriptor = Type.getMethodDescriptor(methodReturnType, javaParameterTypes);

interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, internalClassName,
PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package ai.timefold.jpyinterpreter.test;

import ai.timefold.jpyinterpreter.types.PythonString;
import ai.timefold.jpyinterpreter.types.numeric.PythonInteger;

/**
* Not a real interface; in main sources instead of test sources
* so a Python test can use it.
*/
public interface TestdataExtendedInterface {
PythonString stringMethod(PythonString name);

PythonInteger intMethod(PythonInteger value);

static String getString(TestdataExtendedInterface instance, String name) {
return instance.stringMethod(PythonString.valueOf(name)).value;
}

static int getInt(TestdataExtendedInterface instance, int value) {
return instance.intMethod(PythonInteger.valueOf(value)).value.intValue();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ai.timefold.jpyinterpreter.util;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

/**
* Marks a generated method as an override implementation.
* Needed since {@link Override} is not retained at runtime.
*/
@Retention(RetentionPolicy.RUNTIME)
public @interface OverrideMethod {
}
13 changes: 9 additions & 4 deletions python/jpyinterpreter/src/main/python/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import dis
import inspect
import sys
from jpype import JInt, JBoolean, JProxy, JClass, JArray
from typing import Protocol

from jpype import JInt, JBoolean, JProxy, JClass, JArray

MINIMUM_SUPPORTED_PYTHON_VERSION = (3, 10)
MAXIMUM_SUPPORTED_PYTHON_VERSION = (3, 12)

Expand Down Expand Up @@ -754,7 +755,11 @@ def translate_python_class_to_java_class(python_class):
python_compiled_class.staticAttributeNameToClassInstance = static_attributes_to_class_instance_map
python_compiled_class.staticAttributeDescriptorNames = static_attribute_descriptor_names

out = PythonClassTranslator.translatePythonClass(python_compiled_class, prepared_class_info)
PythonClassTranslator.setSelfStaticInstances(python_compiled_class, out.getJavaClass(), out,
CPythonBackedPythonInterpreter.pythonObjectIdToConvertedObjectMap)
try:
out = PythonClassTranslator.translatePythonClass(python_compiled_class, prepared_class_info)
PythonClassTranslator.setSelfStaticInstances(python_compiled_class, out.getJavaClass(), out,
CPythonBackedPythonInterpreter.pythonObjectIdToConvertedObjectMap)
except Exception as e:
e.printStackTrace()
raise e
return out
35 changes: 34 additions & 1 deletion python/jpyinterpreter/tests/test_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from typing import Type

import pytest

from .conftest import verifier_for


Expand Down Expand Up @@ -1098,6 +1099,38 @@ def applyAsInt(self, argument: int):
assert java_object.applyAsInt(1) == 2


def test_extend_interface_wrapper():
from ai.timefold.jpyinterpreter.test import TestdataExtendedInterface
from jpyinterpreter import translate_python_class_to_java_class, add_java_interface

@add_java_interface(TestdataExtendedInterface)
class A:
def stringMethod(self, name):
return self.string_method(name)

def intMethod(self, value):
return self.int_method(value)

def string_method(self, name):
raise NotImplementedError

def int_method(self, value):
raise NotImplementedError

class B(A):
def string_method(self, name: str) -> str:
return f'Hello {name}!'

def int_method(self, value: int) -> int:
return value + 1

translated_class = translate_python_class_to_java_class(B).getJavaClass()
assert TestdataExtendedInterface.class_.isAssignableFrom(translated_class)
java_object = translated_class.getConstructor().newInstance()
assert TestdataExtendedInterface.getString(java_object, 'World') == 'Hello World!'
assert TestdataExtendedInterface.getInt(java_object, 1) == 2


def test_python_java_type_mapping():
from java.lang import String
from jpyinterpreter import (translate_python_class_to_java_class,
Expand Down
21 changes: 11 additions & 10 deletions python/python-core/src/main/python/domain/_variable_listener.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from ..score import ScoreDirector
from _jpyinterpreter import add_java_interface
from typing import TYPE_CHECKING, TypeVar

from ..score import ScoreDirector

if TYPE_CHECKING:
from ai.timefold.solver.core.api.domain.variable import VariableListener
pass

Solution_ = TypeVar('Solution_')
Entity_ = TypeVar('Entity_')
Expand Down Expand Up @@ -57,48 +58,48 @@ def requires_unique_entity_events(self) -> bool:
if not TYPE_CHECKING: # We do not want these methods to appear in the API
def afterEntityAdded(self, java_score_director, entity) -> None:
score_director = ScoreDirector(java_score_director)
type(self).after_entity_added(self, score_director, entity)
self.after_entity_added(score_director, entity)

VariableListener.afterEntityAdded = afterEntityAdded

def afterEntityRemoved(self, java_score_director, entity) -> None:
score_director = ScoreDirector(java_score_director)
type(self).after_entity_removed(self, score_director, entity)
self.after_entity_removed(score_director, entity)

VariableListener.afterEntityRemoved = afterEntityRemoved

def beforeEntityAdded(self, java_score_director, entity) -> None:
score_director = ScoreDirector(java_score_director)
type(self).before_entity_added(self, score_director, entity)
self.before_entity_added(score_director, entity)

VariableListener.beforeEntityAdded = beforeEntityAdded

def beforeEntityRemoved(self, java_score_director, entity) -> None:
score_director = ScoreDirector(java_score_director)
type(self).before_entity_removed(self, score_director, entity)
self.before_entity_removed(score_director, entity)

VariableListener.beforeEntityRemoved = beforeEntityRemoved

def resetWorkingSolution(self, java_score_director) -> None:
score_director = ScoreDirector(java_score_director)
type(self).reset_working_solution(self, score_director)
self.reset_working_solution(score_director)

VariableListener.resetWorkingSolution = resetWorkingSolution

def afterVariableChanged(self, java_score_director, entity) -> None:
score_director = ScoreDirector(java_score_director)
type(self).after_variable_changed(self, score_director, entity)
self.after_variable_changed(score_director, entity)

VariableListener.afterVariableChanged = afterVariableChanged

def beforeVariableChanged(self, java_score_director, entity) -> None:
score_director = ScoreDirector(java_score_director)
type(self).before_variable_changed(self, score_director, entity)
self.before_variable_changed(score_director, entity)

VariableListener.beforeVariableChanged = beforeVariableChanged

def requiresUniqueEntityEvents(self) -> bool:
return type(self).requires_unique_entity_events(self)
return self.requires_unique_entity_events()

VariableListener.requiresUniqueEntityEvents = requiresUniqueEntityEvents

Expand Down
Loading

0 comments on commit 1482036

Please sign in to comment.