Skip to content

Commit

Permalink
Change the model parsing to support generic types.
Browse files Browse the repository at this point in the history
Purpose:
1. Support generic types in method return type.
2. Support generic types in super class.
3. Support generic array types.
When parsing the generic types, the type parameters can be successfully
replaced by the actual type rather than just Objects.

The changes:
1. Use ClassWrapper to parse the super class and method return types
2. Use ClassWrapper to parse the type string in SwaggerContext.loadClass
3. Change ComplexTypeMatcher to match built-in container types only
4. In ModelConverters, include all type parameters when addRecursive
5. Change ModelPropertyParser to use generic types when parsing methods
and fields

Minor changes:
1. In JavaDateTimeOverride test case, revert the change to ModelConverts
to avoid impacting other test cases. Otherwise if other test cases with
date fields run after this case, they will fail.
2. In swagger UI, changed SwaggerOperation.prototype.isListType to check
arrays rather than any type contains "[".
 SwaggerOperation.prototype.isListType = function(type) {
 -  if (type && type.indexOf('[') >= 0) {
 +  if (type && type.indexOf('array[') >= 0) {
3. In ReaderUtil fixed groupByResourcePath to merge models
  • Loading branch information
Shu Zhang committed Mar 30, 2014
1 parent a4b6cd0 commit 207307d
Show file tree
Hide file tree
Showing 17 changed files with 135 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package com.wordnik.swagger.converter

import com.wordnik.swagger.core.SwaggerSpec
import com.wordnik.swagger.annotations.ApiModel
import com.wordnik.swagger.core.util.ClassWrapper

trait BaseConverter {
def toDescriptionOpt(cls: Class[_]): Option[String] = {
def toDescriptionOpt(cls: ClassWrapper): Option[String] = {
var description: Option[String] = None
for(anno <- cls.getAnnotations) {
anno match {
Expand All @@ -17,14 +18,14 @@ trait BaseConverter {
description
}

def toName(cls: Class[_]): String = {
def toName(cls: ClassWrapper): String = {
import javax.xml.bind.annotation._

val xmlRootElement = cls.getAnnotation(classOf[XmlRootElement])
val xmlEnum = cls.getAnnotation(classOf[XmlEnum])

if (xmlEnum != null && xmlEnum.value != null)
toName(xmlEnum.value())
toName(ClassWrapper(xmlEnum.value()))
else if (xmlRootElement != null) {
if ("##default".equals(xmlRootElement.name())) {
cls.getSimpleName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package com.wordnik.swagger.converter
import com.wordnik.swagger.model._

import org.slf4j.LoggerFactory
import com.wordnik.swagger.core.util.ClassWrapper

class JodaDateTimeConverter extends ModelConverter with BaseConverter {
private val LOGGER = LoggerFactory.getLogger(this.getClass)

def read(cls: Class[_], typeMap: Map[String, String]): Option[Model] = None
def read(cls: ClassWrapper, typeMap: Map[String, String]): Option[Model] = None

// map DateTime to Date, which is serialized as such:
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import com.wordnik.swagger.model._
import org.slf4j.LoggerFactory

import scala.collection.mutable.{ ListBuffer, LinkedHashMap, HashSet, HashMap }
import com.wordnik.swagger.core.util._

object ModelConverters {
private val LOGGER = LoggerFactory.getLogger(ModelConverters.getClass)
val ComplexTypeMatcher = "([a-zA-Z]*)\\[([a-zA-Z\\.\\-]*)\\].*".r

val converters = new ListBuffer[ModelConverter]() ++ List(
new JodaDateTimeConverter,
Expand All @@ -27,7 +27,11 @@ object ModelConverters {
else converters += c
}

def read(cls: Class[_], t: Map[String, String] = Map.empty): Option[Model] = {
def removeConverter(c: ModelConverter) = {
converters -= c
}

def read(cls: ClassWrapper, t: Map[String, String] = Map.empty): Option[Model] = {
val types = {
if(t.isEmpty)typeMap
else t
Expand All @@ -46,7 +50,7 @@ object ModelConverters {
model
}

def readAll(cls: Class[_]): List[Model] = {
def readAll(cls: ClassWrapper): List[Model] = {
val output = new HashMap[String, Model]
var model = read(cls, typeMap)
val propertyNames = new HashSet[String]
Expand All @@ -70,24 +74,24 @@ object ModelConverters {
model.map(m => {
output += cls.getName -> m
val checkedNames = new HashSet[String]
addRecursive(m, checkedNames, output)
addRecursive(cls, m, checkedNames, output)
})
output.values.toList
}

def addRecursive(model: Model, checkedNames: HashSet[String], output: HashMap[String, Model]): Unit = {
def addRecursive(modelCls: ClassWrapper, model: Model, checkedNames: HashSet[String], output: HashMap[String, Model]): Unit = {
if(!checkedNames.contains(model.name)) {
val propertyNames = new HashSet[String]
val typeParams = modelCls.getRawClass.getTypeParameters.map(t => modelCls.getTypeArgument(t.getName))
for (typeParam <- typeParams) {
propertyNames += typeParam.getName
}
for((name, property) <- model.properties) {
val propertyName = property.items match {
case Some(item) => item.qualifiedType.getOrElse(item.`type`)
case None => property.qualifiedType
}
val name = propertyName match {
case ComplexTypeMatcher(containerType, basePart) => basePart
case e: String => e
}
propertyNames += name
propertyNames += propertyName
}
for(typeRef <- propertyNames) {
if(ignoredPackages.contains(getPackage(typeRef))) None
Expand All @@ -102,7 +106,7 @@ object ModelConverters {
ModelConverters.read(cls, typeMap) match {
case Some(model) => {
output += typeRef -> model
addRecursive(model, checkedNames, output)
addRecursive(cls, model, checkedNames, output)
}
case None =>
}
Expand All @@ -115,7 +119,7 @@ object ModelConverters {
}
}

def toName(cls: Class[_]): String = {
def toName(cls: ClassWrapper): String = {
var name: String = null
val itr = converters.iterator
while(name == null && itr.hasNext) {
Expand Down Expand Up @@ -145,9 +149,9 @@ object ModelConverters {
}

trait ModelConverter {
def read(cls: Class[_], typeMap: Map[String, String]): Option[Model]
def toName(cls: Class[_]): String
def toDescriptionOpt(cls: Class[_]): Option[String]
def read(cls: ClassWrapper, typeMap: Map[String, String]): Option[Model]
def toName(cls: ClassWrapper): String
def toDescriptionOpt(cls: ClassWrapper): Option[String]

def ignoredPackages: Set[String] = Set("java.lang")
def ignoredClasses: Set[String] = Set("java.util.Date")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.wordnik.swagger.converter

import com.wordnik.swagger.model._
import com.wordnik.swagger.core.{ SwaggerSpec, SwaggerTypes }
import com.wordnik.swagger.core.util.TypeUtil
import com.wordnik.swagger.core.util.{ClassWrapper, TypeUtil}
import com.wordnik.swagger.annotations.ApiModelProperty

import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty}
Expand All @@ -18,7 +18,7 @@ import javax.xml.bind.annotation._
import scala.collection.mutable.{ LinkedHashMap, ListBuffer, HashSet, HashMap }
import com.wordnik.swagger.reader.{PropertyMetaInfo, ModelReaders}

class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (implicit properties: LinkedHashMap[String, ModelProperty]) {
class ModelPropertyParser(cls: ClassWrapper, t: Map[String, String] = Map.empty) (implicit properties: LinkedHashMap[String, ModelProperty]) {
private val LOGGER = LoggerFactory.getLogger(classOf[ModelPropertyParser])

val typeMap = {
Expand All @@ -33,16 +33,16 @@ class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (im

def parse = Option(cls).map(parseRecursive(_))

def parseRecursive(hostClass: Class[_]): Unit = {
def parseRecursive(hostClass: ClassWrapper): Unit = {
if(!hostClass.isEnum) {
LOGGER.debug("processing class " + hostClass)
for (method <- hostClass.getDeclaredMethods) {
if (Modifier.isPublic(method.getModifiers()) && !Modifier.isStatic(method.getModifiers()))
parseMethod(method)
parseMethod(hostClass, method)
}
for (field <- hostClass.getDeclaredFields) {
if (Modifier.isPublic(field.getModifiers()) && !Modifier.isStatic(field.getModifiers()))
parseField(field)
parseField(hostClass, field)
}
Option(hostClass.getSuperclass).map(parseRecursive(_))
}
Expand All @@ -51,20 +51,20 @@ class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (im
}
}

def parseField(field: Field) = {
LOGGER.debug("processing field " + field)
def parseField(hostClass: ClassWrapper, field: Field) = {
LOGGER.debug("processing field " + hostClass + "." + field)

val propertyMetaInfo = new PropertyMetaInfo(field.getDeclaringClass, field.getName, field.getAnnotations, field.getGenericType, field.getType)
val newMetaInfo = ModelReaders.reader.parseField(field, propertyMetaInfo)
val propertyMetaInfo = new PropertyMetaInfo(hostClass.getFieldType(field), field.getName, field.getAnnotations)
val newMetaInfo = ModelReaders.reader.parseField(hostClass, field, propertyMetaInfo)
parsePropertyAnnotations(newMetaInfo)
}

def parseMethod(method: Method) = {
def parseMethod(hostClass: ClassWrapper, method: Method) = {
if (method.getParameterTypes == null || method.getParameterTypes.length == 0) {
LOGGER.debug("processing method " + method)

val propertyMetaInfo = new PropertyMetaInfo(method.getReturnType, method.getName, method.getAnnotations, method.getGenericReturnType, method.getReturnType)
val newMetaInfo = ModelReaders.reader.parseMethod(method, propertyMetaInfo)
val propertyMetaInfo = new PropertyMetaInfo(hostClass.getMethodReturnType(method), method.getName, method.getAnnotations)
val newMetaInfo = ModelReaders.reader.parseMethod(hostClass, method, propertyMetaInfo)
parsePropertyAnnotations(newMetaInfo)
}
}
Expand All @@ -85,11 +85,11 @@ class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (im

def parsePropertyAnnotations(metaInfo : PropertyMetaInfo): Any = {
if (metaInfo != null) {
parsePropertyAnnotations(metaInfo.returnClass, metaInfo.propertyName, metaInfo.propertyAnnotations, metaInfo.genericReturnType, metaInfo.returnType)
parsePropertyAnnotations(metaInfo.returnClass, metaInfo.propertyName, metaInfo.propertyAnnotations)
}
}

def parsePropertyAnnotations(returnClass: Class[_], propertyName: String, propertyAnnotations: Array[Annotation], genericReturnType: Type, returnType: Type): Any = {
def parsePropertyAnnotations(returnClass: ClassWrapper, propertyName: String, propertyAnnotations: Array[Annotation]): Any = {
val e = extractGetterProperty(propertyName)
var originalName = e._1
var isGetter = e._2
Expand Down Expand Up @@ -153,9 +153,9 @@ class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (im
isTransient = true

if (!(isTransient && !isXmlElement && !isJsonProperty) && name != null && (isFieldExists || isGetter || isDocumented)) {
var paramType = getDataType(genericReturnType, returnType, false)
var paramType = getDataType(returnClass, false)
LOGGER.debug("inspecting " + paramType)
var simpleName = getDataType(genericReturnType, returnType, true)
var simpleName = getDataType(returnClass, true)

if (!"void".equals(paramType) && null != paramType && !processedFields.contains(name)) {
if(!excludedFieldTypes.contains(paramType)) {
Expand Down Expand Up @@ -305,12 +305,12 @@ class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (im
else s.trim
}

def getDeclaredField(inputClass: Class[_], fieldName: String): Field = {
def getDeclaredField(inputClass: ClassWrapper, fieldName: String): Field = {
try {
inputClass.getDeclaredField(fieldName)
} catch {
case t: NoSuchFieldException => {
if (inputClass.getSuperclass != null && inputClass.getSuperclass.getName != "Object") {
if (inputClass.getSuperclass != null && inputClass.getSuperclass.getName != "java.lang.Object") {
getDeclaredField(inputClass.getSuperclass, fieldName)
} else {
throw t
Expand Down Expand Up @@ -356,50 +356,33 @@ class ModelPropertyParser(cls: Class[_], t: Map[String, String] = Map.empty) (im
}
}

def getDataType(genericReturnType: Type, returnType: Type, isSimple: Boolean = false): String = {
if (TypeUtil.isParameterizedList(genericReturnType)) {
val parameterizedType = genericReturnType.asInstanceOf[java.lang.reflect.ParameterizedType]
val valueType = parameterizedType.getActualTypeArguments.head
"List[" + getDataType(valueType, valueType, isSimple) + "]"
} else if (TypeUtil.isParameterizedSet(genericReturnType)) {
val parameterizedType = genericReturnType.asInstanceOf[java.lang.reflect.ParameterizedType]
val valueType = parameterizedType.getActualTypeArguments.head
"Set[" + getDataType(valueType, valueType, isSimple) + "]"
} else if (TypeUtil.isParameterizedMap(genericReturnType)) {
val parameterizedType = genericReturnType.asInstanceOf[java.lang.reflect.ParameterizedType]
val typeArgs = parameterizedType.getActualTypeArguments
val keyType = typeArgs(0)
val valueType = typeArgs(1)

val keyName: String = getDataType(keyType, keyType, isSimple)
val valueName: String = getDataType(valueType, valueType, isSimple)
"Map[" + keyName + "," + valueName + "]"
} else if (!returnType.getClass.isAssignableFrom(classOf[ParameterizedTypeImpl]) && returnType.isInstanceOf[Class[_]] && returnType.asInstanceOf[Class[_]].isArray) {
var arrayClass = returnType.asInstanceOf[Class[_]].getComponentType
"Array[" + readName(arrayClass, isSimple) + "]"
def getDataType(returnType: ClassWrapper, isSimple: Boolean = false): String = {
if (TypeUtil.isParameterizedList(returnType.getRawType)) {
val typeParameters = returnType.getRawClass.getTypeParameters
val types = typeParameters.map(t => getDataType(returnType.getTypeArgument(t.getName), isSimple))
"List" + types.mkString("[", ",", "]")
} else if (TypeUtil.isParameterizedSet(returnType.getRawType)) {
val typeParameters = returnType.getRawClass.getTypeParameters
val types = typeParameters.map(t => getDataType(returnType.getTypeArgument(t.getName), isSimple))
"Set" + types.mkString("[", ",", "]")
} else if (TypeUtil.isParameterizedMap(returnType.getRawType)) {
val typeParameters = returnType.getRawClass.getTypeParameters
val types = typeParameters.map(t => getDataType(returnType.getTypeArgument(t.getName), isSimple))
"Map" + types.mkString("[", ",", "]")
} else if (returnType.isArray) {
var arrayClass = returnType.getArrayComponent
"Array[" + getDataType(arrayClass, isSimple) + "]"
} else if (returnType.getRawClass == classOf[Option[_]]) {
val valueType = returnType.getTypeArgument(returnType.getRawClass.getTypeParameters.head.getName)
getDataType(valueType, isSimple)
} else if (classOf[Class[_]].isAssignableFrom(returnType.getRawClass)) {
// ignore Class
null
} else {
if (genericReturnType.getClass.isAssignableFrom(classOf[TypeVariableImpl[_]])) {
genericReturnType.asInstanceOf[TypeVariableImpl[_]].getName
}
else if (!genericReturnType.getClass.isAssignableFrom(classOf[ParameterizedTypeImpl])) {
if(genericReturnType.isInstanceOf[Class[_]])
readName(genericReturnType.asInstanceOf[Class[_]], isSimple)
else{
LOGGER.debug("can't get type info for " + genericReturnType.toString)
genericReturnType.toString
}
} else {
val parameterizedType = genericReturnType.asInstanceOf[java.lang.reflect.ParameterizedType]
if (parameterizedType.getRawType == classOf[Option[_]]) {
val valueType = parameterizedType.getActualTypeArguments.head
getDataType(valueType, valueType, isSimple)
}
else {
genericReturnType.toString match {
case "java.lang.Class<?>" => null
case e: String => e
}
}
val typeParameters = returnType.getRawClass.getTypeParameters
val types = typeParameters.map(t => getDataType(returnType.getTypeArgument(t.getName), isSimple))
readName(returnType.getRawClass, isSimple) + {
if (types.length > 0) types.mkString("[", ",", "]") else ""
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.json4s.jackson.JsonMethods._
import org.json4s.jackson.Serialization.{read, write}

import scala.collection.mutable.HashMap
import com.wordnik.swagger.core.util.ClassWrapper

class OverrideConverter
extends ModelConverter
Expand Down Expand Up @@ -41,7 +42,7 @@ class OverrideConverter
}
}

def read(cls: Class[_], typeMap: Map[String, String]): Option[Model] = {
def read(cls: ClassWrapper, typeMap: Map[String, String]): Option[Model] = {
overrides.getOrElse(cls.getName, None)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import org.slf4j.LoggerFactory
import java.lang.reflect.Modifier
import scala.collection.mutable
import scala.collection.mutable.LinkedHashMap
import com.wordnik.swagger.core.util.ClassWrapper

class SwaggerSchemaConverter
extends ModelConverter
with BaseConverter {

def read(cls: Class[_], typeMap: Map[String, String]): Option[Model] = {
def read(cls: ClassWrapper, typeMap: Map[String, String]): Option[Model] = {
Option(cls).flatMap({
cls => {
implicit val properties = new LinkedHashMap[String, ModelProperty]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.wordnik.swagger.core

import collection.mutable.ListBuffer
import org.slf4j.{ LoggerFactory, Logger }
import com.wordnik.swagger.core.util.ClassWrapper

object SwaggerContext {
private val LOGGER = LoggerFactory.getLogger("com.wordnik.swagger.core.SwaggerContext")
Expand All @@ -13,12 +14,18 @@ object SwaggerContext {

def registerClassLoader(cl: ClassLoader) = this.classLoaders += cl

def loadClass(name: String) = {
var cls: Class[_] = null
def loadClass(name: String): ClassWrapper = {
var cls: ClassWrapper = null
val itr = classLoaders.reverse.iterator
while (cls == null && itr.hasNext) {
try {
cls = Class.forName(name.trim, true, itr.next)
val classLoader = itr.next
cls = ClassWrapper.forName(name.trim, name => name match {
case "List" => classOf[List[_]]
case "Array" => classOf[Array[_]]
case "Set" => classOf[Set[_]]
case name => Class.forName(name, true, classLoader)
})
} catch {
case e: ClassNotFoundException => {
LOGGER.debug("Class %s not found in classLoader".format(name))
Expand Down
Loading

0 comments on commit 207307d

Please sign in to comment.