Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add primary key to create table query in postgres dialect #52

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions src/main/scala/slick/migration/api/Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers {
case None => quoteIdentifier(t.tableName)
}

protected def quotedColumnNames(ns: Seq[FieldSymbol]) = ns.map(fs => quoteIdentifier(fs.name))
protected def quotedColumnNames(ns: Seq[FieldSymbol]): Seq[String] = ns.map(fs => quoteIdentifier(fs.name))

def columnType(ci: ColumnInfo): String = ci.sqlType

def autoInc(ci: ColumnInfo) = if(ci.autoInc) " AUTOINCREMENT" else ""

def primaryKey(ci: ColumnInfo, newTable: Boolean) =
def primaryKey(ci: ColumnInfo, newTable: Boolean): String =
(if (newTable && ci.isPk) " PRIMARY KEY" else "") + autoInc(ci)

def notNull(ci: ColumnInfo) = if (ci.notNull) " NOT NULL" else ""
Expand All @@ -50,19 +50,19 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers {
s"$name $typ$default${ notNull(ci) }${ primaryKey(ci, newTable) }"
}

def columnList(columns: Seq[FieldSymbol]) =
def columnList(columns: Seq[FieldSymbol]): String =
quotedColumnNames(columns).mkString("(", ", ", ")")

def createTable(table: TableInfo, columns: Seq[ColumnInfo]): List[String] = List(
def createTable(table: TableInfo, columns: Seq[ColumnInfo], primaryKeys: Seq[PrimaryKeyInfo] = Nil): List[String] = List(
s"""create table ${quoteTableName(table)} (
| ${columns map { columnSql(_, newTable = true) } mkString ", "}
|)""".stripMargin
)
) ++ primaryKeys.map(info => createPrimaryKey(table, info.name, info.columns))

def dropTable(table: TableInfo): String =
s"drop table ${quoteTableName(table)}"

def renameTable(table: TableInfo, to: String) =
def renameTable(table: TableInfo, to: String): String =
s"""alter table ${quoteTableName(table)}
| rename to ${quoteIdentifier(to)}""".stripMargin

Expand All @@ -77,15 +77,15 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers {
def dropConstraint(table: TableInfo, name: String) =
s"alter table ${quoteTableName(table)} drop constraint ${quoteIdentifier(name)}"

def dropForeignKey(sourceTable: TableInfo, name: String) =
def dropForeignKey(sourceTable: TableInfo, name: String): String =
dropConstraint(sourceTable, name)

def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]) =
def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]): String =
s"""alter table ${quoteTableName(table)}
| add constraint ${quoteIdentifier(name)} primary key
| ${columnList(columns)}""".stripMargin

def dropPrimaryKey(table: TableInfo, name: String) =
def dropPrimaryKey(table: TableInfo, name: String): String =
dropConstraint(table, name)

def createIndex(index: IndexInfo) =
Expand All @@ -100,11 +100,11 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers {
s"alter index ${quoteIdentifier(old.name)} rename to ${quoteIdentifier(newName)}"
)

def addColumn(table: TableInfo, column: ColumnInfo) =
def addColumn(table: TableInfo, column: ColumnInfo): String =
s"""alter table ${quoteTableName(table)}
| add column ${columnSql(column, newTable = false)}""".stripMargin

def addColumnWithInitialValue(table: TableInfo, column: ColumnInfo, rawSqlExpr: String) =
def addColumnWithInitialValue(table: TableInfo, column: ColumnInfo, rawSqlExpr: String): List[String] =
List(addColumn(table, column.copy(default = Some(rawSqlExpr)))) ++
(if (column.default.contains(rawSqlExpr)) Nil else List(alterColumnDefault(table, column)))

Expand All @@ -113,7 +113,7 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers {
| drop column ${quoteIdentifier(column)}""".stripMargin
)

def renameColumn(table: TableInfo, from: String, to: String) =
def renameColumn(table: TableInfo, from: String, to: String): String =
s"""alter table ${quoteTableName(table)}
| alter column ${quoteIdentifier(from)}
| rename to ${quoteIdentifier(to)}""".stripMargin
Expand All @@ -126,28 +126,32 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers {
| set data type ${column.sqlType}""".stripMargin
)

def alterColumnDefault(table: TableInfo, column: ColumnInfo) =
def alterColumnDefault(table: TableInfo, column: ColumnInfo): String =
s"""alter table ${quoteTableName(table)}
| alter column ${quoteIdentifier(column.name)}
| set default ${column.default getOrElse "null"}""".stripMargin

def alterColumnNullability(table: TableInfo, column: ColumnInfo) =
def alterColumnNullability(table: TableInfo, column: ColumnInfo): String =
s"""alter table ${quoteTableName(table)}
| alter column ${quoteIdentifier(column.name)}
| ${if (column.notNull) "set" else "drop"} not null""".stripMargin

private def partition[A, B](xs: List[A])(toB: PartialFunction[A, B]): (List[B], List[A]) =
xs.foldLeft((List.empty[B], List.empty[A])) {
case ((bs, as), a) =>
toB.andThen(b => (b :: bs, as)).applyOrElse(a, (_: A) => (bs, a :: as))
private def partition(xs: List[TableMigration.Action]): (List[AddColumn], List[AddPrimaryKey], List[TableMigration.Action]) =
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you have to change partition?

xs.foldLeft((List.empty[AddColumn], List.empty[AddPrimaryKey], List.empty[TableMigration.Action])) {
case ((cs, bs, as), a) =>
a match {
case ac: AddColumn => (ac :: cs, bs, as)
case ap: AddPrimaryKey => (cs, ap :: bs, as)
case _ => (cs, bs, a :: as)
}
}

def migrateTable(table: TableInfo, actions: List[TableMigration.Action]): List[String] = {
def loop(actions: List[TableMigration.Action]): List[String] = actions match {
case Nil => Nil
case CreateTable :: rest =>
val (cols, other) = partition(rest) { case a: AddColumn => a }
createTable(table, cols.map(_.info)) ::: loop(other)
val (cols, pKeys, other) = partition(rest)
createTable(table, cols.map(_.info), pKeys.map(_.info)) ::: loop(other)
case AlterColumnType(info) :: rest => alterColumnType(table, info) ::: loop(rest)
case DropTable :: rest => dropTable(table) :: loop(rest)
case RenameTableTo(to) :: rest => renameTable(table, to) :: loop(rest)
Expand Down Expand Up @@ -189,7 +193,7 @@ class DerbyDialect extends Dialect[DerbyProfile] {
override def autoInc(ci: ColumnInfo) =
if(ci.autoInc) " GENERATED BY DEFAULT AS IDENTITY" else ""

override def alterColumnType(table: TableInfo, column: ColumnInfo) = {
override def alterColumnType(table: TableInfo, column: ColumnInfo): List[String] = {
val tmpColumnName = "temp_column"+(math.random*1000000).toInt
val tmpColumn = column.copy(name = tmpColumnName)

Expand All @@ -202,15 +206,15 @@ class DerbyDialect extends Dialect[DerbyProfile] {
override def renameColumn(table: TableInfo, from: String, to: String) =
s"rename column ${quoteTableName(table)}.${quoteIdentifier(from)} to ${quoteIdentifier(to)}"

override def alterColumnNullability(table: TableInfo, column: ColumnInfo) =
override def alterColumnNullability(table: TableInfo, column: ColumnInfo): String =
s"""alter table ${quoteTableName(table)}
| alter column ${quoteIdentifier(column.name)}
| ${if (column.notNull) "not" else ""} null""".stripMargin

override def renameTable(table: TableInfo, to: String) =
s"rename table ${quoteTableName(table)} to ${quoteIdentifier(to)}"

override def renameIndex(old: IndexInfo, newName: String) = List(
override def renameIndex(old: IndexInfo, newName: String): List[String] = List(
s"rename index ${quoteIdentifier(old.name)} to ${quoteIdentifier(newName)}"
)
}
Expand All @@ -234,7 +238,7 @@ class SQLiteDialect extends Dialect[SQLiteProfile] with SimulatedRenameIndex[SQL
class HsqldbDialect extends Dialect[HsqldbProfile] {
override def autoInc(ci: ColumnInfo) =
if(ci.autoInc) " GENERATED BY DEFAULT AS IDENTITY" else ""
override def primaryKey(ci: ColumnInfo, newTable: Boolean) =
override def primaryKey(ci: ColumnInfo, newTable: Boolean): String =
autoInc(ci) + (if (newTable && ci.isPk) " PRIMARY KEY" else "")
override def notNull(ci: ColumnInfo) =
if (ci.notNull && !ci.isPk) " NOT NULL" else ""
Expand All @@ -258,23 +262,23 @@ class MySQLDialect extends Dialect[MySQLProfile] with SimulatedRenameIndex[MySQL
override def renameColumn(table: TableInfo, from: String, to: String) =
s"ALTER TABLE ${quoteTableName(table)} RENAME COLUMN ${quoteIdentifier(from)} TO ${quoteIdentifier(to)}"

override def renameColumn(table: TableInfo, from: ColumnInfo, to: String) = {
override def renameColumn(table: TableInfo, from: ColumnInfo, to: String): String = {
val newCol = from.copy(name = to)
s"""alter table ${quoteTableName(table)}
| change ${quoteIdentifier(from.name)}
| ${columnSql(newCol, newTable = false)}""".stripMargin
}

override def alterColumnNullability(table: TableInfo, column: ColumnInfo) =
override def alterColumnNullability(table: TableInfo, column: ColumnInfo): String =
renameColumn(table, column, column.name)

override def alterColumnType(table: TableInfo, column: ColumnInfo) =
override def alterColumnType(table: TableInfo, column: ColumnInfo): List[String] =
List(renameColumn(table, column, column.name))

override def dropForeignKey(table: TableInfo, name: String) =
s"alter table ${quoteTableName(table)} drop foreign key ${quoteIdentifier(name)}"

override def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]) =
override def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]): String =
s"""alter table ${quoteTableName(table)}
| add constraint primary key
| ${columnList(columns)}""".stripMargin
Expand All @@ -291,17 +295,25 @@ class PostgresDialect extends Dialect[PostgresProfile] {
case (true, "BIGINT") => "BIGSERIAL"
case (true, _) => throw new RuntimeException("Unsupported autoincrement type")
}

override def createTable(table: TableInfo, columns: Seq[ColumnInfo], primaryKeys: Seq[PrimaryKeyInfo] = Nil): List[String] = List(
s"""create table ${quoteTableName(table)} (
| ${columns map { columnSql(_, newTable = true) } mkString("", ", ", if (columns.nonEmpty && primaryKeys.nonEmpty) "," else "")}
| ${if (primaryKeys.nonEmpty) primaryKeys.map{ ci => quoteIdentifier(ci.name) }.mkString("primary key (", ", ", ")") else ""}
|)""".stripMargin
)

override def autoInc(ci: ColumnInfo) = ""
override def renameColumn(table: TableInfo, from: String, to: String) =
override def renameColumn(table: TableInfo, from: String, to: String): String =
s"""alter table ${quoteTableName(table)}
| rename column ${quoteIdentifier(from)}
| to ${quoteIdentifier(to)}""".stripMargin
}

class OracleDialect extends Dialect[OracleProfile] {

override def createTable(table: TableInfo, columns: Seq[ColumnInfo]): List[String] = {
super.createTable(table, columns) ++
override def createTable(table: TableInfo, columns: Seq[ColumnInfo], primaryKeys: Seq[PrimaryKeyInfo] = Nil): List[String] = {
super.createTable(table, columns, primaryKeys) ++
columns.filter(_.autoInc).flatMap(addAutoInc(table, _, 1L))
}

Expand Down