Skip to content

Commit

Permalink
nested or 2 dimensional array support #386
Browse files Browse the repository at this point in the history
  • Loading branch information
tminglei committed Feb 8, 2018
1 parent 7e6bde8 commit 658792b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import slick.jdbc.{JdbcTypesComponent, PostgresProfile}

trait PgArrayJdbcTypes extends JdbcTypesComponent { driver: PostgresProfile =>

class SimpleArrayJdbcType[T] private[slickpg] (sqlBaseType: String,
class SimpleArrayJdbcType[T] private[slickpg](sqlBaseType: String,
tmap: Any => T,
tcomap: T => Any,
zero: Seq[T] = null.asInstanceOf[Seq[T]])(
Expand All @@ -21,7 +21,7 @@ trait PgArrayJdbcTypes extends JdbcTypesComponent { driver: PostgresProfile =>

override def sqlType: Int = java.sql.Types.ARRAY

override def sqlTypeName(size: Option[FieldSymbol]): String = s"$sqlBaseType ARRAY"
override def sqlTypeName(size: Option[FieldSymbol]): String = s"$sqlBaseType []"

override def getValue(r: ResultSet, idx: Int): Seq[T] = {
val value = r.getArray(idx)
Expand Down Expand Up @@ -72,7 +72,7 @@ trait PgArrayJdbcTypes extends JdbcTypesComponent { driver: PostgresProfile =>

override def sqlType: Int = java.sql.Types.ARRAY

override def sqlTypeName(size: Option[FieldSymbol]): String = s"$sqlBaseType ARRAY"
override def sqlTypeName(size: Option[FieldSymbol]): String = s"$sqlBaseType []"

override def getValue(r: ResultSet, idx: Int): Seq[T] = {
val value = r.getString(idx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,25 @@ object PgTokenHelper {

def createString(root: Token): String = {
val MARK_REQUIRED_CHAR_LIST = List('\\', '"', ',', '(', ')', '{', '}')
val rootIsArray = root match {

///
def isArray(token: Token): Boolean = token match {
case GroupToken(mList) =>
mList match {
case Open("{", _) :: tail => true
case Open("{", _) :: _ => true
case _ => false
}
case _ => false
}

///
def isMarkRequired(token: Token): Boolean = token match {
case g: GroupToken => true
def isMarkRequired(token: Token, parentIsArray: Boolean): Boolean = token match {
case _: GroupToken => !parentIsArray || !isArray(token)
case Chunk(v) => v.isEmpty || v.trim.length < v.length || "NULL".equalsIgnoreCase(v) || v.find(MARK_REQUIRED_CHAR_LIST.contains).isDefined
case _ => false
}

def appendMark(buf: mutable.StringBuilder, level: Int) =
val rootIsArray = isArray(root)
def appendMark(buf: mutable.StringBuilder, level: Int) =
if (level >= 0) {
val markLen = math.pow(2, level).toInt
level match {
Expand All @@ -120,16 +122,16 @@ object PgTokenHelper {
}
}

def mergeString(buf: mutable.StringBuilder, token: Token, level: Int): Unit = {
val markRequired = isMarkRequired(token)
def mergeString(buf: mutable.StringBuilder, token: Token, level: Int, parentIsArray: Boolean): Unit = {
val markRequired = isMarkRequired(token, parentIsArray)
if (markRequired) appendMark(buf, level)
token match {
case GroupToken(mList) => {
buf append mList(0).value
var isFirst = true
for(i <- 1 to (mList.length -2)) {
if (isFirst) isFirst = false else buf append ","
mergeString(buf, mList(i), level +1)
mergeString(buf, mList(i), level +1, isArray(token))
}
buf append mList.last.value
}
Expand All @@ -141,7 +143,7 @@ object PgTokenHelper {

///
val buf = StringBuilder.newBuilder
mergeString(buf, root, -1)
mergeString(buf, root, -1, isArray(root))
buf.toString
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ object SimpleArrayUtils {
})
}

def mkString[T](ToString: T => String)(value: Seq[T]): String =
def mkString[T](toString: T => String)(value: Seq[Any]): String = {
def toGroupToken(vList: Seq[Any]): Token = GroupToken(Open("{") +: vList.map {
case null => Null
case v if v.isInstanceOf[Seq[_]] => toGroupToken(v.asInstanceOf[Seq[_]])
case v => Chunk(toString(v.asInstanceOf[T]))
} :+ Close("}"))

createString (value match {
case null => Null
case vList => {
val members = Open("{") +: vList.map {
case v => if (v == null) Null else Chunk(ToString(v))
} :+ Close("}")
GroupToken(members)
}
case vList => toGroupToken(vList)
})
}

def mkArray[T : ClassTag](mkString: (Seq[T] => String))(sqlBaseType: String, vList: Seq[T]): java.sql.Array =
new SimpleArray(sqlBaseType, vList, mkString)
Expand All @@ -37,7 +39,7 @@ object SimpleArrayUtils {
/** !!! NOTE: only used to transfer array data into driver/preparedStatement. !!! */
private class SimpleArray[T : ClassTag](sqlBaseTypeName: String, vList: Seq[T], mkString: (Seq[T] => String)) extends java.sql.Array {

override def getBaseTypeName = sqlBaseTypeName
override def getBaseTypeName = sqlBaseTypeName.replace("[]", "").trim

override def getBaseType(): Int = ???

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class PgTokenHelperSuite extends FunSuite {

val expected = """{"(201,\"(101,\"\"(test1'\"\",\"\"2001-01-03 13:21:00\"\",\"\"[\\\\\"\"2010-01-01 14:30:00\\\\\"\",\\\\\"\"2010-01-03 15:30:00\\\\\"\")\"\")\",t)"}"""

assert(expected === pgStr)
assert(pgStr === expected)

///
val input1 =
Expand Down Expand Up @@ -280,6 +280,25 @@ class PgTokenHelperSuite extends FunSuite {
val pgStr1 = createString(input1)

val expected1 = """(201,"(101,""(test1'"",""2001-01-03 13:21:00"",""[\\""2010-01-01 14:30:00\\"",\\""2010-01-03 15:30:00\\"")"")",t)"""
assert(expected1 === pgStr1)
assert(pgStr1 === expected1)

///
val input2 =
GroupToken(Open("{") +: List(
GroupToken(Open("{") +: List(
Chunk("11"),
Chunk("12"),
Chunk("13")
) :+ Close("}")),
GroupToken(Open("{") +: List(
Chunk("21"),
Chunk("22"),
Chunk("23")
) :+ Close("}"))
) :+ Close("}"))
val pgStr2 = createString(input2)

val expected2 = """{{11,12,13},{21,22,23}}"""
assert(pgStr2 === expected2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class PgArraySupportSuite extends FunSuite {
///
implicit val advancedStringListTypeMapper = new AdvancedArrayJdbcType[String]("text",
fromString(identity)(_).orNull, mkString(identity))
///
implicit val longlongWitness = ElemWitness.AnyWitness.asInstanceOf[ElemWitness[List[Long]]]
implicit val simpleLongLongListTypeMapper = new SimpleArrayJdbcType[List[Long]]("int8[]")
.to(_.asInstanceOf[Seq[Array[Any]]].toList.map(_.toList.asInstanceOf[List[Long]]))
}
}
object MyPostgresProfile1 extends MyPostgresProfile1
Expand All @@ -44,6 +48,7 @@ class PgArraySupportSuite extends FunSuite {
id: Long,
intArr: List[Int],
longArr: Buffer[Long],
longlongArr: List[List[Long]],
shortArr: List[Short],
strList: List[String],
strArr: Option[Vector[String]],
Expand All @@ -56,14 +61,15 @@ class PgArraySupportSuite extends FunSuite {
def id = column[Long]("id", O.AutoInc, O.PrimaryKey)
def intArr = column[List[Int]]("intArray", O.Default(Nil))
def longArr = column[Buffer[Long]]("longArray")
def longlongArr = column[List[List[Long]]]("longlongArray")
def shortArr = column[List[Short]]("shortArray")
def strList = column[List[String]]("stringList")
def strArr = column[Option[Vector[String]]]("stringArray")
def uuidArr = column[List[UUID]]("uuidArray")
def institutions = column[List[Institution]]("institutions")
def mktFinancialProducts = column[Option[List[MarketFinancialProduct]]]("mktFinancialProducts")

def * = (id, intArr, longArr, shortArr, strList, strArr, uuidArr, institutions, mktFinancialProducts) <> (ArrayBean.tupled, ArrayBean.unapply)
def * = (id, intArr, longArr, longlongArr, shortArr, strList, strArr, uuidArr, institutions, mktFinancialProducts) <> (ArrayBean.tupled, ArrayBean.unapply)
}
val ArrayTests = TableQuery[ArrayTestTable]

Expand All @@ -73,11 +79,11 @@ class PgArraySupportSuite extends FunSuite {
val uuid2 = UUID.randomUUID()
val uuid3 = UUID.randomUUID()

val testRec1 = ArrayBean(33L, List(101, 102, 103), Buffer(1L, 3L, 5L, 7L), List(1,7), List("robert}; drop table students--", "NULL"),
val testRec1 = ArrayBean(33L, List(101, 102, 103), Buffer(1L, 3L, 5L, 7L), List(List(11L, 12L, 13L)), List(1,7), List("robert}; drop table students--", "NULL"),
Some(Vector("str1", "str3", "", " ")), List(uuid1, uuid2), List(Institution(113)), None)
val testRec2 = ArrayBean(37L, List(101, 103), Buffer(11L, 31L, 5L), Nil, List(""),
val testRec2 = ArrayBean(37L, List(101, 103), Buffer(11L, 31L, 5L), List(List(21L, 22L, 23L)), Nil, List(""),
Some(Vector("str11", "str3")), List(uuid1, uuid2, uuid3), List(Institution(579)), Some(List(MarketFinancialProduct("product1"))))
val testRec3 = ArrayBean(41L, List(103, 101), Buffer(11L, 5L, 31L), List(35,77), Nil,
val testRec3 = ArrayBean(41L, List(103, 101), Buffer(11L, 5L, 31L), List(List(31L, 32L, 33L)), List(35,77), Nil,
Some(Vector("(s)", "str5", "str3")), List(uuid1, uuid3), Nil, Some(List(MarketFinancialProduct("product3"), MarketFinancialProduct("product x"))))

test("Array Lifted support") {
Expand Down

0 comments on commit 658792b

Please sign in to comment.