Skip to content

feat: add JPA @EmbeddedId support #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@
import javax.persistence.metamodel.SingularAttribute;
import javax.persistence.metamodel.Type;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.introproventures.graphql.jpa.query.annotation.GraphQLDescription;
import com.introproventures.graphql.jpa.query.annotation.GraphQLIgnore;
import com.introproventures.graphql.jpa.query.schema.GraphQLSchemaBuilder;
import com.introproventures.graphql.jpa.query.schema.JavaScalars;
import com.introproventures.graphql.jpa.query.schema.NamingStrategy;
import com.introproventures.graphql.jpa.query.schema.impl.PredicateFilter.Criteria;

import graphql.Assert;
import graphql.Scalars;
import graphql.schema.Coercing;
Expand All @@ -65,8 +69,6 @@
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeReference;
import graphql.schema.PropertyDataFetcher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* JPA specific schema builder implementation of {code #GraphQLSchemaBuilder} interface
Expand Down Expand Up @@ -95,7 +97,8 @@ public class GraphQLJpaSchemaBuilder implements GraphQLSchemaBuilder {

private Map<Class<?>, GraphQLType> classCache = new HashMap<>();
private Map<EntityType<?>, GraphQLObjectType> entityCache = new HashMap<>();
private Map<EmbeddableType<?>, GraphQLObjectType> embeddableCache = new HashMap<>();
private Map<EmbeddableType<?>, GraphQLObjectType> embeddableOutputCache = new HashMap<>();
private Map<EmbeddableType<?>, GraphQLInputObjectType> embeddableInputCache = new HashMap<>();

private static final Logger log = LoggerFactory.getLogger(GraphQLJpaSchemaBuilder.class);

Expand Down Expand Up @@ -292,13 +295,13 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.EQ.name())
.description("Equals criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.NE.name())
.description("Not Equals criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
);

Expand All @@ -307,25 +310,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
builder.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.LE.name())
.description("Less then or Equals criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.GE.name())
.description("Greater or Equals criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.GT.name())
.description("Greater Then criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.LT.name())
.description("Less Then criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
);
}
Expand All @@ -334,25 +337,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
builder.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.LIKE.name())
.description("Like criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.CASE.name())
.description("Case sensitive match criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.STARTS.name())
.description("Starts with criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.ENDS.name())
.description("Ends with criteria")
.type((GraphQLInputType) getAttributeType(attribute))
.type(getAttributeInputType(attribute))
.build()
);
}
Expand All @@ -373,25 +376,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.IN.name())
.description("In criteria")
.type(new GraphQLList(getAttributeType(attribute)))
.type(new GraphQLList(getAttributeInputType(attribute)))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.NIN.name())
.description("Not In criteria")
.type(new GraphQLList(getAttributeType(attribute)))
.type(new GraphQLList(getAttributeInputType(attribute)))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.BETWEEN.name())
.description("Between criteria")
.type(new GraphQLList(getAttributeType(attribute)))
.type(new GraphQLList(getAttributeInputType(attribute)))
.build()
)
.field(GraphQLInputObjectField.newInputObjectField()
.name(Criteria.NOT_BETWEEN.name())
.description("Not Between criteria")
.type(new GraphQLList(getAttributeType(attribute)))
.type(new GraphQLList(getAttributeInputType(attribute)))
.build()
);

Expand All @@ -404,39 +407,52 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
}

private GraphQLArgument getArgument(Attribute<?,?> attribute) {
GraphQLType type = getAttributeType(attribute);
GraphQLInputType type = getAttributeInputType(attribute);
String description = getSchemaDescription(attribute.getJavaMember());

if (type instanceof GraphQLInputType) {
return GraphQLArgument.newArgument()
.name(attribute.getName())
.type((GraphQLInputType) type)
.description(description)
.build();
}

throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Input Argument");
return GraphQLArgument.newArgument()
.name(attribute.getName())
.type((GraphQLInputType) type)
.description(description)
.build();
}

private GraphQLObjectType getEmbeddableType(EmbeddableType<?> embeddableType) {
if (embeddableCache.containsKey(embeddableType))
return embeddableCache.get(embeddableType);

String embeddableTypeName = namingStrategy.singularize(embeddableType.getJavaType().getSimpleName())+"EmbeddableType";

GraphQLObjectType objectType = GraphQLObjectType.newObject()
.name(embeddableTypeName)
.description(getSchemaDescription( embeddableType.getJavaType()))
.fields(embeddableType.getAttributes().stream()
.filter(this::isNotIgnored)
.map(this::getObjectField)
.collect(Collectors.toList())
)
.build();

embeddableCache.putIfAbsent(embeddableType, objectType);
private GraphQLType getEmbeddableType(EmbeddableType<?> embeddableType, boolean input) {
if (input && embeddableInputCache.containsKey(embeddableType))
return embeddableInputCache.get(embeddableType);

if (!input && embeddableOutputCache.containsKey(embeddableType))
return embeddableOutputCache.get(embeddableType);
String embeddableTypeName = namingStrategy.singularize(embeddableType.getJavaType().getSimpleName())+ (input ? "Input" : "") +"EmbeddableType";
GraphQLType graphQLType=null;
if (input) {
graphQLType = GraphQLInputObjectType.newInputObject()
.name(embeddableTypeName)
.description(getSchemaDescription(embeddableType.getJavaType()))
.fields(embeddableType.getAttributes().stream()
.filter(this::isNotIgnored)
.map(this::getInputObjectField)
.collect(Collectors.toList())
)
.build();
} else {
graphQLType = GraphQLObjectType.newObject()
.name(embeddableTypeName)
.description(getSchemaDescription(embeddableType.getJavaType()))
.fields(embeddableType.getAttributes().stream()
.filter(this::isNotIgnored)
.map(this::getObjectField)
.collect(Collectors.toList())
)
.build();
}
if (input) {
embeddableInputCache.putIfAbsent(embeddableType, (GraphQLInputObjectType) graphQLType);
} else{
embeddableOutputCache.putIfAbsent(embeddableType, (GraphQLObjectType) graphQLType);
}

return objectType;
return graphQLType;
}


Expand All @@ -462,67 +478,92 @@ private GraphQLObjectType getObjectType(EntityType<?> entityType) {

@SuppressWarnings( { "rawtypes", "unchecked" } )
private GraphQLFieldDefinition getObjectField(Attribute attribute) {
GraphQLType type = getAttributeType(attribute);

if (type instanceof GraphQLOutputType) {
List<GraphQLArgument> arguments = new ArrayList<>();
DataFetcher dataFetcher = PropertyDataFetcher.fetching(attribute.getName());

// Only add the orderBy argument for basic attribute types
if (attribute instanceof SingularAttribute
&& attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC) {
arguments.add(GraphQLArgument.newArgument()
.name(ORDER_BY_PARAM_NAME)
.description("Specifies field sort direction in the query results.")
.type(orderByDirectionEnum)
.build()
);
}

// Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
if (attribute instanceof SingularAttribute
&& attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC) {
ManagedType foreignType = (ManagedType) ((SingularAttribute) attribute).getType();

// TODO fix page count query
arguments.add(getWhereArgument(foreignType));

} // Get Sub-Objects fields queries via DataFetcher
else if (attribute instanceof PluralAttribute
&& (attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ONE_TO_MANY
|| attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.MANY_TO_MANY)) {
EntityType declaringType = (EntityType) ((PluralAttribute) attribute).getDeclaringType();
EntityType elementType = (EntityType) ((PluralAttribute) attribute).getElementType();

arguments.add(getWhereArgument(elementType));
dataFetcher = new GraphQLJpaOneToManyDataFetcher(entityManager, declaringType, (PluralAttribute) attribute);
}
GraphQLOutputType type = getAttributeOutputType(attribute);

List<GraphQLArgument> arguments = new ArrayList<>();
DataFetcher dataFetcher = PropertyDataFetcher.fetching(attribute.getName());

// Only add the orderBy argument for basic attribute types
if (attribute instanceof SingularAttribute
&& attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC) {
arguments.add(GraphQLArgument.newArgument()
.name(ORDER_BY_PARAM_NAME)
.description("Specifies field sort direction in the query results.")
.type(orderByDirectionEnum)
.build()
);
}

return GraphQLFieldDefinition.newFieldDefinition()
.name(attribute.getName())
.description(getSchemaDescription(attribute.getJavaMember()))
.type((GraphQLOutputType) type)
.dataFetcher(dataFetcher)
.argument(arguments)
.build();
// Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
if (attribute instanceof SingularAttribute
&& attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC) {
ManagedType foreignType = (ManagedType) ((SingularAttribute) attribute).getType();

// TODO fix page count query
arguments.add(getWhereArgument(foreignType));

} // Get Sub-Objects fields queries via DataFetcher
else if (attribute instanceof PluralAttribute
&& (attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ONE_TO_MANY
|| attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.MANY_TO_MANY)) {
EntityType declaringType = (EntityType) ((PluralAttribute) attribute).getDeclaringType();
EntityType elementType = (EntityType) ((PluralAttribute) attribute).getElementType();

arguments.add(getWhereArgument(elementType));
dataFetcher = new GraphQLJpaOneToManyDataFetcher(entityManager, declaringType, (PluralAttribute) attribute);
}

throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Output Argument");
return GraphQLFieldDefinition.newFieldDefinition()
.name(attribute.getName())
.description(getSchemaDescription(attribute.getJavaMember()))
.type(type)
.dataFetcher(dataFetcher)
.argument(arguments)
.build();
}

@SuppressWarnings( { "rawtypes", "unchecked" } )
private GraphQLInputObjectField getInputObjectField(Attribute attribute) {
GraphQLInputType type = getAttributeInputType(attribute);

return GraphQLInputObjectField.newInputObjectField()
.name(attribute.getName())
.description(getSchemaDescription(attribute.getJavaMember()))
.type(type)
.build();
}

private Stream<Attribute<?,?>> findBasicAttributes(Collection<Attribute<?,?>> attributes) {
return attributes.stream().filter(it -> it.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC);
}

@SuppressWarnings( "rawtypes" )
private GraphQLType getAttributeType(Attribute<?,?> attribute) {
private GraphQLInputType getAttributeInputType(Attribute<?,?> attribute) {
try{
return (GraphQLInputType) getAttributeType(attribute, true);
} catch (ClassCastException e){
throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Input Argument");
}
}

@SuppressWarnings( "rawtypes" )
private GraphQLOutputType getAttributeOutputType(Attribute<?,?> attribute) {
try {
return (GraphQLOutputType) getAttributeType(attribute, false);
} catch (ClassCastException e){
throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Output Argument");
}
}

@SuppressWarnings( "rawtypes" )
private GraphQLType getAttributeType(Attribute<?,?> attribute, boolean input) {

if (isBasic(attribute)) {
return getGraphQLTypeFromJavaType(attribute.getJavaType());
}
else if (isEmbeddable(attribute)) {
EmbeddableType embeddableType = (EmbeddableType) ((SingularAttribute) attribute).getType();
return getEmbeddableType(embeddableType);
return getEmbeddableType(embeddableType, input);
}
else if (isToMany(attribute)) {
EntityType foreignType = (EntityType) ((PluralAttribute) attribute).getElementType();
Expand Down Expand Up @@ -572,7 +613,8 @@ protected final boolean isToOne(Attribute<?,?> attribute) {

protected final boolean isValidInput(Attribute<?,?> attribute) {
return attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC ||
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ELEMENT_COLLECTION;
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ELEMENT_COLLECTION ||
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.EMBEDDED;
}

private String getSchemaDescription(Member member) {
Expand Down
Loading