Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schema inference parameterized types #32757

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.annotations.SchemaIgnore;
import org.apache.beam.sdk.schemas.utils.AutoValueUtils;
Expand Down Expand Up @@ -61,8 +63,9 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());
List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(methods.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < methods.size(); ++i) {
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i));
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -143,7 +146,8 @@ public SchemaUserTypeCreator schemaTypeCreator(

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return JavaBeanUtils.schemaFromJavaBeanClass(
typeDescriptor, AbstractGetterTypeSupplier.INSTANCE);
typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
Expand Down Expand Up @@ -125,18 +126,20 @@ public static FieldValueTypeInformation forOneOf(
.build();
}

public static FieldValueTypeInformation forField(Field field, int index) {
TypeDescriptor<?> type = TypeDescriptor.of(field.getGenericType());
public static FieldValueTypeInformation forField(
Field field, int index, Map<Type, Type> boundTypes) {
Comment on lines +131 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a breaking change to a public method (not marked with @Internal). Are we okay with that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with some other changes to method signatures in this class: forGetter() and forSetter()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I think so. This is a logically an internal class, which was only marked public because we needed to access it from other packages (e.g. the protobuf package). The fact that it was not marked @internal was likely a mistake, and I think it's fine to change it in this case.

TypeDescriptor<?> type =
TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes));
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(getNameOverride(field.getName(), field))
.setNumber(getNumberOverride(index, field))
.setNullable(hasNullableAnnotation(field))
.setType(type)
.setRawType(type.getRawType())
.setField(field)
.setElementType(getIterableComponentType(field))
.setMapKeyType(getMapKeyType(field))
.setMapValueType(getMapValueType(field))
.setElementType(getIterableComponentType(field, boundTypes))
.setMapKeyType(getMapKeyType(field, boundTypes))
.setMapValueType(getMapValueType(field, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.setDescription(getFieldDescription(field))
.build();
Expand Down Expand Up @@ -184,7 +187,8 @@ public static <T extends AnnotatedElement & Member> String getNameOverride(
return fieldDescription.value();
}

public static FieldValueTypeInformation forGetter(Method method, int index) {
public static FieldValueTypeInformation forGetter(
Method method, int index, Map<Type, Type> boundTypes) {
String name;
if (method.getName().startsWith("get")) {
name = ReflectUtils.stripPrefix(method.getName(), "get");
Expand All @@ -194,7 +198,8 @@ public static FieldValueTypeInformation forGetter(Method method, int index) {
throw new RuntimeException("Getter has wrong prefix " + method.getName());
}

TypeDescriptor<?> type = TypeDescriptor.of(method.getGenericReturnType());
TypeDescriptor<?> type =
TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes));
boolean nullable = hasNullableReturnType(method);
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(getNameOverride(name, method))
Expand All @@ -203,9 +208,9 @@ public static FieldValueTypeInformation forGetter(Method method, int index) {
.setType(type)
.setRawType(type.getRawType())
.setMethod(method)
.setElementType(getIterableComponentType(type))
.setMapKeyType(getMapKeyType(type))
.setMapValueType(getMapValueType(type))
.setElementType(getIterableComponentType(type, boundTypes))
.setMapKeyType(getMapKeyType(type, boundTypes))
.setMapValueType(getMapValueType(type, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.setDescription(getFieldDescription(method))
.build();
Expand Down Expand Up @@ -252,29 +257,33 @@ private static boolean isNullableAnnotation(Annotation annotation) {
return annotation.annotationType().getSimpleName().equals("Nullable");
}

public static FieldValueTypeInformation forSetter(Method method) {
return forSetter(method, "set");
public static FieldValueTypeInformation forSetter(
Method method, Map<Type, Type> boundParameters) {
return forSetter(method, "set", boundParameters);
}

public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) {
public static FieldValueTypeInformation forSetter(
Method method, String setterPrefix, Map<Type, Type> boundTypes) {
String name;
if (method.getName().startsWith(setterPrefix)) {
name = ReflectUtils.stripPrefix(method.getName(), setterPrefix);
} else {
throw new RuntimeException("Setter has wrong prefix " + method.getName());
}

TypeDescriptor<?> type = TypeDescriptor.of(method.getGenericParameterTypes()[0]);
TypeDescriptor<?> type =
TypeDescriptor.of(
ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes));
boolean nullable = hasSingleNullableParameter(method);
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(name)
.setNullable(nullable)
.setType(type)
.setRawType(type.getRawType())
.setMethod(method)
.setElementType(getIterableComponentType(type))
.setMapKeyType(getMapKeyType(type))
.setMapValueType(getMapValueType(type))
.setElementType(getIterableComponentType(type, boundTypes))
.setMapKeyType(getMapKeyType(type, boundTypes))
.setMapValueType(getMapValueType(type, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}
Expand All @@ -283,13 +292,15 @@ public FieldValueTypeInformation withName(String name) {
return toBuilder().setName(name).build();
}

private static FieldValueTypeInformation getIterableComponentType(Field field) {
return getIterableComponentType(TypeDescriptor.of(field.getGenericType()));
private static FieldValueTypeInformation getIterableComponentType(
Field field, Map<Type, Type> boundTypes) {
return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes);
}

static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor<?> valueType) {
static @Nullable FieldValueTypeInformation getIterableComponentType(
TypeDescriptor<?> valueType, Map<Type, Type> boundTypes) {
// TODO: Figure out nullable elements.
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType);
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes);
if (componentType == null) {
return null;
}
Expand All @@ -299,41 +310,43 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) {
.setNullable(false)
.setType(componentType)
.setRawType(componentType.getRawType())
.setElementType(getIterableComponentType(componentType))
.setMapKeyType(getMapKeyType(componentType))
.setMapValueType(getMapValueType(componentType))
.setElementType(getIterableComponentType(componentType, boundTypes))
.setMapKeyType(getMapKeyType(componentType, boundTypes))
.setMapValueType(getMapValueType(componentType, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}

// If the Field is a map type, returns the key type, otherwise returns a null reference.

private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) {
return getMapKeyType(TypeDescriptor.of(field.getGenericType()));
private static @Nullable FieldValueTypeInformation getMapKeyType(
Field field, Map<Type, Type> boundTypes) {
return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes);
}

private static @Nullable FieldValueTypeInformation getMapKeyType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 0);
TypeDescriptor<?> typeDescriptor, Map<Type, Type> boundTypes) {
return getMapType(typeDescriptor, 0, boundTypes);
}

// If the Field is a map type, returns the value type, otherwise returns a null reference.

private static @Nullable FieldValueTypeInformation getMapValueType(Field field) {
return getMapType(TypeDescriptor.of(field.getGenericType()), 1);
private static @Nullable FieldValueTypeInformation getMapValueType(
Field field, Map<Type, Type> boundTypes) {
return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes);
}

private static @Nullable FieldValueTypeInformation getMapValueType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 1);
TypeDescriptor<?> typeDescriptor, Map<Type, Type> boundTypes) {
return getMapType(typeDescriptor, 1, boundTypes);
}

// If the Field is a map type, returns the key or value type (0 is key type, 1 is value).
// Otherwise returns a null reference.
@SuppressWarnings("unchecked")
private static @Nullable FieldValueTypeInformation getMapType(
TypeDescriptor<?> valueType, int index) {
TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index);
TypeDescriptor<?> valueType, int index, Map<Type, Type> boundTypes) {
TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes);
if (mapType == null) {
return null;
}
Expand All @@ -342,9 +355,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) {
.setNullable(false)
.setType(mapType)
.setRawType(mapType.getRawType())
.setElementType(getIterableComponentType(mapType))
.setMapKeyType(getMapKeyType(mapType))
.setMapValueType(getMapValueType(mapType))
.setElementType(getIterableComponentType(mapType, boundTypes))
.setMapKeyType(getMapKeyType(mapType, boundTypes))
.setMapValueType(getMapValueType(mapType, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldName;
Expand Down Expand Up @@ -67,8 +69,9 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());
List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(methods.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < methods.size(); ++i) {
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i));
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -111,10 +114,11 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier {

@Override
public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream()
.filter(ReflectUtils::isSetter)
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.map(FieldValueTypeInformation::forSetter)
.map(m -> FieldValueTypeInformation.forSetter(m, boundTypes))
.map(
t -> {
if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) {
Expand Down Expand Up @@ -156,8 +160,10 @@ public boolean equals(@Nullable Object obj) {

@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
Schema schema =
JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE);
JavaBeanUtils.schemaFromJavaBeanClass(
typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes);

// If there are no creator methods, then validate that we have setters for every field.
// Otherwise, we will have no way of creating instances of the class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -62,9 +64,11 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
ReflectUtils.getFields(typeDescriptor.getRawType()).stream()
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());

List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(fields.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < fields.size(); ++i) {
types.add(FieldValueTypeInformation.forField(fields.get(i), i));
types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -111,7 +115,9 @@ private static void validateFieldNumbers(List<FieldValueTypeInformation> types)

@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE);
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return POJOUtils.schemaFromPojoClass(
typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ public interface SchemaProvider extends Serializable {
* Given a type, return a function that converts that type to a {@link Row} object If no schema
* exists, returns null.
*/
@Nullable
<T> SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDescriptor);
<T> @Nullable SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDescriptor);

/**
* Given a type, returns a function that converts from a {@link Row} object to that type. If no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid
providers.put(typeDescriptor, schemaProvider);
}

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
private <T> @Nullable SchemaProvider schemaProviderFor(TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return schemaProvider.schemaFor(type);
return schemaProvider;
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
Expand All @@ -92,38 +91,24 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid
} while (true);
}

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null;
}

@Override
public <T> @Nullable SerializableFunction<T, Row> toRowFunction(
TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return (SerializableFunction<T, Row>) schemaProvider.toRowFunction(type);
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
return null;
}
type = TypeDescriptor.of(superClass);
} while (true);
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null;
}

@Override
public <T> @Nullable SerializableFunction<Row, T> fromRowFunction(
TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return (SerializableFunction<Row, T>) schemaProvider.fromRowFunction(type);
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
return null;
}
type = TypeDescriptor.of(superClass);
} while (true);
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ public FunctionAndType(Type outputType, Function<Row, Object> function) {

public FunctionAndType(TypeDescriptor<?> outputType, Function<Row, Object> function) {
this(
StaticSchemaInference.fieldFromType(outputType, new EmptyFieldValueTypeSupplier()),
StaticSchemaInference.fieldFromType(
outputType, new EmptyFieldValueTypeSupplier(), Collections.emptyMap()),
function);
}

Expand Down
Loading
Loading