diff --git a/src/main/java/graphql/annotations/ExtensionDataFetcherWrapper.java b/src/main/java/graphql/annotations/ExtensionDataFetcherWrapper.java new file mode 100644 index 00000000..a593755e --- /dev/null +++ b/src/main/java/graphql/annotations/ExtensionDataFetcherWrapper.java @@ -0,0 +1,48 @@ +/** + * Copyright 2016 Yurii Rashkovskii + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ +package graphql.annotations; + +import graphql.schema.DataFetcher; +import graphql.schema.DataFetchingEnvironment; +import graphql.schema.DataFetchingEnvironmentImpl; + +import java.util.Map; + +import static graphql.annotations.ReflectionKit.newInstance; + +public class ExtensionDataFetcherWrapper implements DataFetcher{ + + private final Class declaringClass; + private final DataFetcher dataFetcher; + + public ExtensionDataFetcherWrapper(Class declaringClass, DataFetcher dataFetcher) { + this.declaringClass = declaringClass; + this.dataFetcher = dataFetcher; + } + + @SuppressWarnings("unchecked") + @Override + public T get(DataFetchingEnvironment environment) { + Object source = environment.getSource(); + if (source != null && (!declaringClass.isInstance(source)) && !(source instanceof Map)) { + environment = new DataFetchingEnvironmentImpl(newInstance(declaringClass, source), environment.getArguments(), + environment.getContext(), environment.getFields(), environment.getFieldType(), environment.getParentType(), + environment.getGraphQLSchema(), environment.getFragmentsByName(), environment.getExecutionId(), environment.getSelectionSet()); + } + + return dataFetcher.get(environment); + } + +} diff --git a/src/main/java/graphql/annotations/GraphQLAnnotations.java b/src/main/java/graphql/annotations/GraphQLAnnotations.java index 94be6b19..a0874b73 100644 --- a/src/main/java/graphql/annotations/GraphQLAnnotations.java +++ b/src/main/java/graphql/annotations/GraphQLAnnotations.java @@ -16,46 +16,14 @@ import graphql.TypeResolutionEnvironment; import graphql.relay.Relay; -import graphql.schema.DataFetcher; -import graphql.schema.DataFetchingEnvironment; -import graphql.schema.DataFetchingEnvironmentImpl; -import graphql.schema.FieldDataFetcher; -import graphql.schema.GraphQLArgument; -import graphql.schema.GraphQLFieldDefinition; -import graphql.schema.GraphQLInputObjectField; -import graphql.schema.GraphQLInputObjectType; -import graphql.schema.GraphQLInputType; -import graphql.schema.GraphQLInterfaceType; -import graphql.schema.GraphQLList; +import graphql.schema.*; import graphql.schema.GraphQLNonNull; -import graphql.schema.GraphQLObjectType; -import graphql.schema.GraphQLOutputType; -import graphql.schema.GraphQLTypeReference; -import graphql.schema.GraphQLUnionType; -import graphql.schema.PropertyDataFetcher; -import graphql.schema.TypeResolver; import org.osgi.service.component.annotations.Component; import org.osgi.service.component.annotations.Reference; import javax.validation.constraints.NotNull; -import java.lang.reflect.AccessibleObject; -import java.lang.reflect.AnnotatedElement; -import java.lang.reflect.AnnotatedType; -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.lang.reflect.Parameter; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Stack; -import java.util.TreeMap; +import java.lang.reflect.*; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -83,6 +51,7 @@ public class GraphQLAnnotations implements GraphQLAnnotationsProcessor { private static final Relay RELAY_TYPES = new Relay(); private Map typeRegistry = new HashMap<>(); + private Map, Set>> extensionsTypeRegistry = new HashMap<>(); private final Stack processing = new Stack<>(); public GraphQLAnnotations() { @@ -180,13 +149,18 @@ public GraphQLInterfaceType.Builder getIfaceBuilder(Class iface) throws Graph if (description != null) { builder.description(description.value()); } + List definedFields = new ArrayList<>(); for (Method method : getOrderedMethods(iface)) { boolean valid = !Modifier.isStatic(method.getModifiers()) && method.getAnnotation(GraphQLField.class) != null; if (valid) { - builder.field(getField(method)); + GraphQLFieldDefinition gqlField = getField(method); + definedFields.add(gqlField.getName()); + builder.field(gqlField); } } + builder.fields(getExtensionFields(iface, definedFields)); + GraphQLTypeResolver typeResolver = iface.getAnnotation(GraphQLTypeResolver.class); builder.typeResolver(newInstance(typeResolver.value())); return builder; @@ -321,13 +295,15 @@ public GraphQLObjectType.Builder getObjectBuilder(Class object) throws GraphQ if (description != null) { builder.description(description.value()); } - + List fieldsDefined = new ArrayList<>(); for (Method method : getOrderedMethods(object)) { if (method.isBridge() || method.isSynthetic()) { continue; } if (breadthFirstSearch(method)) { - builder.field(getField(method)); + GraphQLFieldDefinition gqlField = getField(method); + fieldsDefined.add(gqlField.getName()); + builder.field(gqlField); } } @@ -336,18 +312,58 @@ public GraphQLObjectType.Builder getObjectBuilder(Class object) throws GraphQ continue; } if (parentalSearch(field)) { - builder.field(getField(field)); + GraphQLFieldDefinition gqlField = getField(field); + fieldsDefined.add(gqlField.getName()); + builder.field(gqlField); } } for (Class iface : object.getInterfaces()) { if (iface.getAnnotation(GraphQLTypeResolver.class) != null) { builder.withInterface((GraphQLInterfaceType) getInterface(iface)); + builder.fields(getExtensionFields(iface, fieldsDefined)); } } + + builder.fields(getExtensionFields(object, fieldsDefined)); + return builder; } + private List getExtensionFields(Class object, List fieldsDefined) { + List fields = new ArrayList<>(); + if (extensionsTypeRegistry.containsKey(object)) { + for (Class aClass : extensionsTypeRegistry.get(object)) { + for (Method method : getOrderedMethods(aClass)) { + if (method.isBridge() || method.isSynthetic()) { + continue; + } + if (breadthFirstSearch(method)) { + addExtensionField(getField(method), fields, fieldsDefined); + } + } + for (Field field : getAllFields(aClass).values()) { + if (Modifier.isStatic(field.getModifiers())) { + continue; + } + if (parentalSearch(field)) { + addExtensionField(getField(field), fields, fieldsDefined); + } + } + } + } + return fields; + } + + private void addExtensionField(GraphQLFieldDefinition gqlField, List fields, List fieldsDefined) { + if (!fieldsDefined.contains(gqlField.getName())) { + fieldsDefined.add(gqlField.getName()); + fields.add(gqlField); + } else { + throw new GraphQLAnnotationsException("Duplicate field found in extension : " + gqlField.getName(), null); + } + } + public static GraphQLObjectType.Builder objectBuilder(Class object) throws GraphQLAnnotationsException { return getInstance().getObjectBuilder(object); } @@ -434,16 +450,16 @@ protected GraphQLFieldDefinition getField(Field field) throws GraphQLAnnotations if (outputType == GraphQLBoolean || (outputType instanceof GraphQLNonNull && ((GraphQLNonNull) outputType).getWrappedType() == GraphQLBoolean)) { if (checkIfPrefixGetterExists(field.getDeclaringClass(), "is", field.getName()) || checkIfPrefixGetterExists(field.getDeclaringClass(), "get", field.getName())) { - actualDataFetcher = new PropertyDataFetcher(field.getName()); + actualDataFetcher = new ExtensionDataFetcherWrapper(field.getDeclaringClass(), new PropertyDataFetcher(field.getName())); } } else if (checkIfPrefixGetterExists(field.getDeclaringClass(), "get", field.getName())) { - actualDataFetcher = new PropertyDataFetcher(field.getName()); + actualDataFetcher = new ExtensionDataFetcherWrapper(field.getDeclaringClass(), new PropertyDataFetcher(field.getName())); } else if (hasFluentGetter) { actualDataFetcher = new MethodDataFetcher(fluentMethod, typeFunction); } if (actualDataFetcher == null) { - actualDataFetcher = new FieldDataFetcher(field.getName()); + actualDataFetcher = new ExtensionDataFetcherWrapper(field.getDeclaringClass(), new FieldDataFetcher(field.getName())); } } @@ -459,7 +475,7 @@ protected GraphQLFieldDefinition getField(Field field) throws GraphQLAnnotations private DataFetcher constructDataFetcher(String fieldName, GraphQLDataFetcher annotatedDataFetcher) { final String[] args; - if ( annotatedDataFetcher.firstArgIsTargetName() ) { + if (annotatedDataFetcher.firstArgIsTargetName()) { args = Stream.concat(Stream.of(fieldName), stream(annotatedDataFetcher.args())).toArray(String[]::new); } else { args = annotatedDataFetcher.args(); @@ -673,6 +689,31 @@ public void setDefaultTypeFunction(TypeFunction function) { ((DefaultTypeFunction) defaultTypeFunction).setAnnotationsProcessor(this); } + public void registerTypeExtension(Class objectClass) { + GraphQLTypeExtension typeExtension = objectClass.getAnnotation(GraphQLTypeExtension.class); + if (typeExtension == null) { + throw new GraphQLAnnotationsException("Class is not annotated with GraphQLTypeExtension", null); + } else { + Class aClass = typeExtension.value(); + if (!extensionsTypeRegistry.containsKey(aClass)) { + extensionsTypeRegistry.put(aClass, new HashSet<>()); + } + extensionsTypeRegistry.get(aClass).add(objectClass); + } + } + + public void unregisterTypeExtension(Class objectClass) { + GraphQLTypeExtension typeExtension = objectClass.getAnnotation(GraphQLTypeExtension.class); + if (typeExtension == null) { + throw new GraphQLAnnotationsException("Class is not annotated with GraphQLTypeExtension", null); + } else { + Class aClass = typeExtension.value(); + if (extensionsTypeRegistry.containsKey(aClass)) { + extensionsTypeRegistry.get(aClass).remove(objectClass); + } + } + } + public void registerType(TypeFunction typeFunction) { ((DefaultTypeFunction) defaultTypeFunction).register(typeFunction); } diff --git a/src/main/java/graphql/annotations/GraphQLTypeExtension.java b/src/main/java/graphql/annotations/GraphQLTypeExtension.java new file mode 100644 index 00000000..b9d39df9 --- /dev/null +++ b/src/main/java/graphql/annotations/GraphQLTypeExtension.java @@ -0,0 +1,26 @@ +/** + * Copyright 2016 Yurii Rashkovskii + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ +package graphql.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface GraphQLTypeExtension { + Class value(); +} diff --git a/src/main/java/graphql/annotations/MethodDataFetcher.java b/src/main/java/graphql/annotations/MethodDataFetcher.java index c96f341c..b91946e6 100644 --- a/src/main/java/graphql/annotations/MethodDataFetcher.java +++ b/src/main/java/graphql/annotations/MethodDataFetcher.java @@ -52,13 +52,15 @@ public Object get(DataFetchingEnvironment environment) { if (Modifier.isStatic(method.getModifiers())) { obj = null; - } else if (method.getAnnotation(GraphQLInvokeDetached.class) == null) { + } else if (method.getAnnotation(GraphQLInvokeDetached.class) != null) { + obj = newInstance(method.getDeclaringClass()); + } else if (!method.getDeclaringClass().isInstance(environment.getSource())) { + obj = newInstance(method.getDeclaringClass(), environment.getSource()); + } else { obj = environment.getSource(); if (obj == null) { return null; } - } else { - obj = newInstance(method.getDeclaringClass()); } return method.invoke(obj, invocationArgs(environment)); } catch (IllegalAccessException | InvocationTargetException e) { diff --git a/src/main/java/graphql/annotations/ReflectionKit.java b/src/main/java/graphql/annotations/ReflectionKit.java index 8899ffb4..a86a5b0d 100644 --- a/src/main/java/graphql/annotations/ReflectionKit.java +++ b/src/main/java/graphql/annotations/ReflectionKit.java @@ -47,4 +47,16 @@ static Constructor constructor(Class type, Class... parameterTypes) } } + static T newInstance(Class clazz, Object parameter) { + if (parameter != null) { + for (Constructor constructor : (Constructor[]) clazz.getConstructors()) { + if (constructor.getParameterCount() == 1 && constructor.getParameters()[0].getType().isAssignableFrom(parameter.getClass())) { + return constructNewInstance(constructor, parameter); + } + } + } + return null; + } + + } diff --git a/src/test/java/graphql/annotations/GraphQLExtensionsTest.java b/src/test/java/graphql/annotations/GraphQLExtensionsTest.java new file mode 100644 index 00000000..6dc85fbc --- /dev/null +++ b/src/test/java/graphql/annotations/GraphQLExtensionsTest.java @@ -0,0 +1,138 @@ +/** + * Copyright 2016 Yurii Rashkovskii + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ +package graphql.annotations; + +import graphql.ExecutionResult; +import graphql.GraphQL; +import graphql.schema.*; +import org.testng.Assert; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; + +import static graphql.Scalars.GraphQLString; +import static graphql.schema.GraphQLSchema.newSchema; +import static org.testng.Assert.*; + +public class GraphQLExtensionsTest { + + @GraphQLDescription("TestObject object") + @GraphQLName("TestObject") + public static class TestObject { + @GraphQLField + public + String field() { + return "test"; + } + + } + + @GraphQLTypeExtension(GraphQLExtensionsTest.TestObject.class) + public static class TestObjectExtension { + private TestObject obj; + + public TestObjectExtension(TestObject obj) { + this.obj = obj; + this.field4 = obj.field() + " test4"; + } + + @GraphQLField + public String field2() { + return obj.field() + " test2"; + } + + @GraphQLDataFetcher(TestDataFetcher.class) + @GraphQLField + private String field3; + + @GraphQLField + public String field4; + + @GraphQLField + public String field5; + + public String getField5() { + return obj.field() + " test5"; + } + } + + @GraphQLTypeExtension(GraphQLExtensionsTest.TestObject.class) + public static class TestObjectExtensionInvalid { + private TestObject obj; + + public TestObjectExtensionInvalid(TestObject obj) { + this.obj = obj; + } + + @GraphQLField + public String getField() { + return "invalid"; + } + } + + public static class TestDataFetcher implements DataFetcher { + @Override + public Object get(DataFetchingEnvironment environment) { + return ((TestObject)environment.getSource()).field() + " test3"; + } + } + + @Test + public void fields() { + GraphQLAnnotations.getInstance().registerTypeExtension(TestObjectExtension.class); + GraphQLObjectType object = GraphQLAnnotations.object(GraphQLExtensionsTest.TestObject.class); + GraphQLAnnotations.getInstance().unregisterTypeExtension(TestObjectExtension.class); + + List fields = object.getFieldDefinitions(); + assertEquals(fields.size(), 5); + + fields.sort((o1, o2) -> o1.getName().compareTo(o2.getName())); + + assertEquals(fields.get(0).getName(), "field"); + assertEquals(fields.get(1).getName(), "field2"); + assertEquals(fields.get(1).getType(), GraphQLString); + assertEquals(fields.get(2).getName(), "field3"); + assertEquals(fields.get(2).getType(), GraphQLString); + } + + @Test + public void values() { + GraphQLAnnotations.getInstance().registerTypeExtension(TestObjectExtension.class); + GraphQLObjectType object = GraphQLAnnotations.object(GraphQLExtensionsTest.TestObject.class); + GraphQLAnnotations.getInstance().unregisterTypeExtension(TestObjectExtension.class); + + GraphQLSchema schema = newSchema().query(object).build(); + GraphQLSchema schemaInherited = newSchema().query(object).build(); + + ExecutionResult result = GraphQL.newGraphQL(schema).build().execute("{field field2 field3 field4 field5}", new GraphQLExtensionsTest.TestObject()); + Map data = (Map) result.getData(); + assertEquals(data.get("field"), "test"); + assertEquals(data.get("field2"), "test test2"); + assertEquals(data.get("field3"), "test test3"); + assertEquals(data.get("field4"), "test test4"); + assertEquals(data.get("field5"), "test test5"); + } + + @Test + public void testDuplicateField() { + GraphQLAnnotations.getInstance().registerTypeExtension(TestObjectExtensionInvalid.class); + GraphQLAnnotationsException e = expectThrows(GraphQLAnnotationsException.class, () -> GraphQLAnnotations.object(TestObject.class)); + assertTrue(e.getMessage().startsWith("Duplicate field")); + GraphQLAnnotations.getInstance().unregisterTypeExtension(TestObjectExtensionInvalid.class); + } +} diff --git a/src/test/java/graphql/annotations/GraphQLObjectTest.java b/src/test/java/graphql/annotations/GraphQLObjectTest.java index bc910b21..f21f5a63 100644 --- a/src/test/java/graphql/annotations/GraphQLObjectTest.java +++ b/src/test/java/graphql/annotations/GraphQLObjectTest.java @@ -19,7 +19,7 @@ import graphql.Scalars; import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; -import graphql.schema.FieldDataFetcher; + import graphql.schema.GraphQLArgument; import graphql.schema.GraphQLFieldDefinition; import graphql.schema.GraphQLInputObjectType; @@ -28,13 +28,10 @@ import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; import graphql.schema.GraphQLType; -import graphql.schema.PropertyDataFetcher; import org.testng.annotations.Test; import javax.validation.constraints.NotNull; import java.lang.reflect.AnnotatedType; -import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -197,8 +194,8 @@ public void fields() { assertEquals(fields.get(5).getName(), "privateTest"); assertEquals(fields.get(6).getName(), "publicTest"); - assertEquals(fields.get(5).getDataFetcher().getClass(), PropertyDataFetcher.class); - assertEquals(fields.get(6).getDataFetcher().getClass(), FieldDataFetcher.class); + assertEquals(fields.get(5).getDataFetcher().getClass(), ExtensionDataFetcherWrapper.class); + assertEquals(fields.get(6).getDataFetcher().getClass(), ExtensionDataFetcherWrapper.class); assertEquals(fields.get(7).getName(), "z_nonOptionalString"); assertTrue(fields.get(7).getType() instanceof graphql.schema.GraphQLNonNull);