diff --git a/CHANGELOG.md b/CHANGELOG.md index a2c1a7b..dd32bc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# hpqtypes-extras-1.15.0.0 (2022-??-??) +* Add support for triggers and trigger functions. + # hpqtypes-extras-1.14.2.0 (2022-??-??) * Add support for GHC 9.2. * Drop support for GHC < 8.8. diff --git a/hpqtypes-extras.cabal b/hpqtypes-extras.cabal index 3cbe203..c7a1615 100644 --- a/hpqtypes-extras.cabal +++ b/hpqtypes-extras.cabal @@ -1,6 +1,6 @@ cabal-version: 2.2 name: hpqtypes-extras -version: 1.14.2.0 +version: 1.15.0.0 synopsis: Extra utilities for hpqtypes library description: The following extras for hpqtypes library: . @@ -68,6 +68,7 @@ library , Database.PostgreSQL.PQTypes.Model.Migration , Database.PostgreSQL.PQTypes.Model.PrimaryKey , Database.PostgreSQL.PQTypes.Model.Table + , Database.PostgreSQL.PQTypes.Model.Trigger , Database.PostgreSQL.PQTypes.SQL.Builder , Database.PostgreSQL.PQTypes.Versions @@ -111,6 +112,7 @@ test-suite hpqtypes-extras-tests ghc-options: -Wall build-depends: base + , containers , exceptions , hpqtypes , hpqtypes-extras diff --git a/src/Database/PostgreSQL/PQTypes/Checks.hs b/src/Database/PostgreSQL/PQTypes/Checks.hs index c426117..f0c2933 100644 --- a/src/Database/PostgreSQL/PQTypes/Checks.hs +++ b/src/Database/PostgreSQL/PQTypes/Checks.hs @@ -419,12 +419,15 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) indexes <- fetchMany fetchTableIndex runQuery_ $ sqlGetForeignKeys table fkeys <- fetchMany fetchForeignKey + triggers <- getDBTriggers return $ mconcat [ checkColumns 1 tblColumns desc , checkPrimaryKey tblPrimaryKey pk , checkChecks tblChecks checks , checkIndexes tblIndexes indexes , checkForeignKeys tblForeignKeys fkeys + , checkTriggers tblTriggers $ + filter (\Trigger{..} -> triggerTable == tblName) triggers ] where fetchTableColumn @@ -541,6 +544,9 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) , checkNames (fkName tblName) fkeys ] + checkTriggers :: [Trigger] -> [Trigger] -> ValidationResult + checkTriggers = checkEquality "TRIGGERs" + -- | Checks whether database is consistent, performing migrations if -- necessary. Requires all table names to be in lower case. -- @@ -607,6 +613,11 @@ checkDBConsistency options domains tablesWithVersions migrations = do expectedMigrationVersions = reverse $ take (length presentMigrationVersions) $ reverse [0 .. tblVersion table - 1] + -- -- TODO: File a separate issue about this with a reproducer! + -- = if null presentMigrationVersions + -- then [] + -- else [0 .. tblVersion table - 1] + checkMigrationsListValidity table presentMigrationVersions expectedMigrationVersions @@ -814,6 +825,14 @@ checkDBConsistency options domains tablesWithVersions migrations = do runSQL_ "COMMIT" runQuery_ (sqlDropIndexConcurrently tname idx) `finally` begin updateTableVersion + + CreateTriggerMigration trigger@Trigger{..} -> do + logInfo_ $ " Creating function" <+> (unRawSQL $ tfName triggerFunction) + runQuery_ $ sqlCreateTriggerFunction triggerFunction + logInfo_ $ " Creating trigger" <+> (unRawSQL $ triggerMakeName triggerName triggerTable) + runQuery_ $ sqlCreateTrigger trigger + updateTableVersion + where logMigration = do logInfo_ $ arrListTable mgrTableName diff --git a/src/Database/PostgreSQL/PQTypes/Migrate.hs b/src/Database/PostgreSQL/PQTypes/Migrate.hs index f2cf818..f67aa6c 100644 --- a/src/Database/PostgreSQL/PQTypes/Migrate.hs +++ b/src/Database/PostgreSQL/PQTypes/Migrate.hs @@ -1,7 +1,8 @@ module Database.PostgreSQL.PQTypes.Migrate ( createDomain, createTable, - createTableConstraints + createTableConstraints, + createTableTriggers ) where import Control.Monad @@ -28,6 +29,8 @@ createTable withConstraints table@Table{..} = do forM_ tblIndexes $ runQuery_ . sqlCreateIndexMaybeDowntime tblName -- Add all the other constraints if applicable. when withConstraints $ createTableConstraints table + -- Create triggers. + createTableTriggers table -- Register the table along with its version. runQuery_ . sqlInsert "table_versions" $ do sqlSet "name" (tblNameText table) @@ -42,3 +45,8 @@ createTableConstraints Table{..} = when (not $ null addConstraints) $ do , map sqlAddValidCheckMaybeDowntime tblChecks , map (sqlAddValidFKMaybeDowntime tblName) tblForeignKeys ] + +createTableTriggers :: MonadDB m => Table -> m () +createTableTriggers Table{..} = forM_ tblTriggers $ \t -> do + runQuery_ . sqlCreateTriggerFunction $ triggerFunction t + runQuery_ $ sqlCreateTrigger t diff --git a/src/Database/PostgreSQL/PQTypes/Model.hs b/src/Database/PostgreSQL/PQTypes/Model.hs index f1e0aa3..978ea3c 100644 --- a/src/Database/PostgreSQL/PQTypes/Model.hs +++ b/src/Database/PostgreSQL/PQTypes/Model.hs @@ -9,6 +9,7 @@ module Database.PostgreSQL.PQTypes.Model ( , module Database.PostgreSQL.PQTypes.Model.Migration , module Database.PostgreSQL.PQTypes.Model.PrimaryKey , module Database.PostgreSQL.PQTypes.Model.Table + , module Database.PostgreSQL.PQTypes.Model.Trigger ) where import Database.PostgreSQL.PQTypes.Model.Check @@ -21,3 +22,4 @@ import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.Migration import Database.PostgreSQL.PQTypes.Model.PrimaryKey import Database.PostgreSQL.PQTypes.Model.Table +import Database.PostgreSQL.PQTypes.Model.Trigger diff --git a/src/Database/PostgreSQL/PQTypes/Model/Migration.hs b/src/Database/PostgreSQL/PQTypes/Model/Migration.hs index 6ea8cf8..d6815ae 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Migration.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Migration.hs @@ -33,6 +33,7 @@ import Data.Int import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.Table +import Database.PostgreSQL.PQTypes.Model.Trigger import Database.PostgreSQL.PQTypes.SQL.Raw -- | Migration action to run, either an arbitrary 'MonadDB' action, or @@ -57,6 +58,9 @@ data MigrationAction m = (RawSQL ()) -- ^ Table name TableIndex -- ^ Index + -- | Migration for creating a trigger. + | CreateTriggerMigration Trigger + -- | Migration object. data Migration m = Migration { @@ -78,6 +82,7 @@ isStandardMigration Migration{..} = DropTableMigration{} -> False CreateIndexConcurrentlyMigration{} -> False DropIndexConcurrentlyMigration{} -> False + CreateTriggerMigration{} -> False isDropTableMigration :: Migration m -> Bool isDropTableMigration Migration{..} = @@ -86,3 +91,4 @@ isDropTableMigration Migration{..} = DropTableMigration{} -> True CreateIndexConcurrentlyMigration{} -> False DropIndexConcurrentlyMigration{} -> False + CreateTriggerMigration{} -> False diff --git a/src/Database/PostgreSQL/PQTypes/Model/Table.hs b/src/Database/PostgreSQL/PQTypes/Model/Table.hs index 266e6f2..55ee4d8 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Table.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Table.hs @@ -25,6 +25,7 @@ import Database.PostgreSQL.PQTypes.Model.ColumnType import Database.PostgreSQL.PQTypes.Model.ForeignKey import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.PrimaryKey +import Database.PostgreSQL.PQTypes.Model.Trigger data TableColumn = TableColumn { colName :: RawSQL () @@ -69,6 +70,7 @@ data Table = , tblChecks :: [Check] , tblForeignKeys :: [ForeignKey] , tblIndexes :: [TableIndex] +, tblTriggers :: [Trigger] , tblInitialSetup :: Maybe TableInitialSetup } @@ -86,6 +88,7 @@ tblTable = Table { , tblChecks = [] , tblForeignKeys = [] , tblIndexes = [] +, tblTriggers = [] , tblInitialSetup = Nothing } diff --git a/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs b/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs new file mode 100644 index 0000000..1282017 --- /dev/null +++ b/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs @@ -0,0 +1,284 @@ +-- | +-- Module: Database.PostgreSQL.PQTypes.Model.Trigger +-- +-- Trigger name must be unique among triggers of same table. Only @CONTRAINT@ triggers are +-- supported. They can only be run @AFTER@ an event. The associated functions are always +-- created with no arguments and always @RETURN NULL@. +-- +-- For details, see . + +module Database.PostgreSQL.PQTypes.Model.Trigger ( + -- * Trigger functions + TriggerFunction(..) + , sqlCreateTriggerFunction + -- * Triggers + , TriggerEvent(..) + , Trigger(..) + , triggerMakeName + , triggerBaseName + , sqlCreateTrigger + , getDBTriggers + -- TODO testing; remove when PR is ready + -- , testDB + -- , testGetDBTriggers + -- , testTrigger1 + -- , testTrigger2 + ) where + +import Data.Bits (testBit) +import Data.Int +import Data.Monoid.Utils +import Data.Set (Set) +import Data.String +import Database.PostgreSQL.PQTypes +import Database.PostgreSQL.PQTypes.SQL.Builder +import qualified Data.Set as Set +import qualified Data.Text as Text + +-- | Function associated to a trigger. +-- +-- @since 1.15.0.0 +data TriggerFunction = TriggerFunction { + tfName :: RawSQL () + -- ^ The function's name. + , tfSource :: RawSQL () + -- ^ The functions's body source code. +} deriving (Show) + +instance Eq TriggerFunction where + -- Since the functions have no arguments, it's impossible to create two functions with + -- the same name. Therefore comparing functions only by their names is enough in this + -- case. The assumption is, of course, that the database schema is only changed using + -- this framework. + f1 == f2 = tfName f1 == tfName f2 + +-- | Build an SQL statement for creating a trigger function. +-- +-- Since we only support @CONSTRAINT@ triggers, the function will always @RETURN TRIGGER@ +-- and will have no parameters. +-- +-- @since 1.15.0.0 +sqlCreateTriggerFunction :: TriggerFunction -> RawSQL () +sqlCreateTriggerFunction TriggerFunction{..} = + "CREATE FUNCTION" + <+> tfName + <> "()" + <+> "RETURNS TRIGGER" + <+> "AS $$" + <+> tfSource + <+> "$$" + <+> "LANGUAGE PLPGSQL" + <+> "VOLATILE" + <+> "RETURNS NULL ON NULL INPUT;" + +-- | Trigger event name. +-- +-- @since 1.15.0.0 +data TriggerEvent + = TriggerInsert + -- ^ The @INSERT@ event. + | TriggerUpdate + -- ^ The @UPDATE@ event. + | TriggerDelete + -- ^ The @DELETE@ event. + deriving (Eq, Ord, Show) + +-- | Trigger. +-- +-- @since 1.15.0.0 +data Trigger = Trigger { + triggerTable :: RawSQL () + -- ^ The table that the trigger is associated with. + , triggerName :: RawSQL () + -- ^ The internal name without any prefixes. Trigger name must be unique among + -- triggers of same table. See 'triggerMakeName'. + , triggerEvents :: Set TriggerEvent + -- ^ The set of events. Corresponds to the @{ __event__ [ OR ... ] }@ in the trigger + -- definition. The order in which they are defined doesn't matter and there can + -- only be one of each. + , triggerDeferrable :: Bool + -- ^ Is the trigger @DEFERRABLE@ or @NOT DEFERRABLE@ ? + , triggerInitiallyDeferred :: Bool + -- ^ Is the trigger @INITIALLY DEFERRED@ or @INITIALLY IMMEDIATE@ ? + , triggerWhen :: Maybe (RawSQL ()) + -- ^ The condition that specifies whether the trigger should fire. Corresponds to the + -- @WHEN ( __condition__ )@ in the trigger definition. + , triggerFunction :: TriggerFunction + -- ^ The function to execute when the trigger fires. +} deriving Show + +instance Eq Trigger where + -- There is no comparison for the WHEN clause. It's not possible to have two triggers + -- that only differ in triggerWhen. + t1 == t2 = triggerTable t1 == triggerTable t2 + && triggerName t1 == triggerName t2 + && triggerEvents t1 == triggerEvents t2 + && triggerDeferrable t1 == triggerDeferrable t2 + && triggerInitiallyDeferred t1 == triggerInitiallyDeferred t2 + && triggerFunction t1 == triggerFunction t2 + +-- | Make a trigger name that can be used in SQL. +-- +-- Given a base @name@ and @tableName@, return a new name that will be used as the +-- actually name of the trigger in an SQL query. The returned name is in the format +-- @trg\__\\__\@. +-- +-- @since 1.15.0 +triggerMakeName :: RawSQL () -> RawSQL () -> RawSQL () +triggerMakeName name tableName = "trg__" <> tableName <> "__" <> name + +-- | Return the trigger's base name. +-- +-- Given the trigger's actual @name@ and @tableName@, return the base name of the +-- trigger. This is basically the reverse of what 'triggerMakeName' does. +-- +-- @since 1.15.0 +triggerBaseName :: RawSQL () -> RawSQL () -> RawSQL () +triggerBaseName name tableName = + rawSQL (snd . Text.breakOnEnd (unRawSQL tableName <> "__") $ unRawSQL name) () + +triggerEventName :: TriggerEvent -> RawSQL () +triggerEventName = \case + TriggerInsert -> "INSERT" + TriggerUpdate -> "UPDATE" + TriggerDelete -> "DELETE" + +-- | Build an SQL statement that creates a trigger. +-- +-- Only supports @CONSTRAINT@ triggers which can only run @AFTER@. +-- +-- @since 1.15.0 +sqlCreateTrigger :: Trigger -> RawSQL () +sqlCreateTrigger Trigger{..} = + "CREATE CONSTRAINT TRIGGER" <+> trgName + <+> "AFTER" <+> trgEvents + <+> "ON" <+> triggerTable + <+> trgTiming + <+> "FOR EACH ROW" + <+> trgWhen + <+> "EXECUTE FUNCTION" <+> trgFunction + <+> "();" + where + trgName + | triggerName == "" = error "Trigger must have a name." + | otherwise = triggerMakeName triggerName triggerTable + trgEvents + | triggerEvents == Set.empty = error "Trigger must have at least one event." + | otherwise = mintercalate " OR " . map triggerEventName $ Set.toList triggerEvents + trgTiming = let deferrable = (if triggerDeferrable then "" else "NOT") <+> "DEFERRABLE" + deferred = if triggerInitiallyDeferred + then "INITIALLY DEFERRED" + else "INITIALLY IMMEDIATE" + in deferrable <+> deferred + trgWhen = maybe "" ("WHEN" <+>) triggerWhen + trgFunction = tfName triggerFunction + +-- | Get all noninternal triggers from the database. +-- +-- Run a query that returns all database triggers marked as @tgisinternal = false@. +-- +-- Note that, in the background, to get the trigger's @WHEN@ clause and the source code of +-- the attached function, the entire query that had created the trigger is received using +-- @pg_get_triggerdef(t.oid)::text@ and then parsed. The result of that call will be +-- decompiled and normalized, which means that it's likely not what the user had +-- originally actually typed. Therefore, 'triggerWhen' and 'tfSource' of 'triggerFunction' +-- should not be relied upon when comparing the original 'Trigger' that had been used for +-- creating in the database and the one received by this function. +-- +-- @since 1.15.0 +getDBTriggers :: forall m. MonadDB m => m [Trigger] +getDBTriggers = do + runQuery_ . sqlSelect "pg_trigger t" $ do + sqlResult "t.tgname::text" -- name + sqlResult "t.tgtype" -- smallint == int2 => (2 bytes) + sqlResult "t.tgdeferrable" -- boolean + sqlResult "t.tginitdeferred"-- boolean + -- This gets the entire query that created this trigger. Note that it's decompiled and + -- normalized, which means that it's likely not what the user actually typed. For + -- example, if the original query had excessive whitespace in it, it won't be in this + -- result. + -- TODO: Do we want to remove one layer of () similarly to how Checks.hs does that? + --sqlResult "regexp_replace(pg_get_triggerdef(t.oid)::text, 'WHEN \\((.*)\\)', 'WHEN \\1')" + sqlResult "pg_get_triggerdef(t.oid)::text" + sqlResult "p.proname::text" -- name + sqlResult "p.prosrc" -- text + sqlResult "c.relname::text" + sqlJoinOn "pg_proc p" "t.tgfoid = p.oid" + sqlJoinOn "pg_class c" "c.oid = t.tgrelid" + sqlWhereEq "t.tgisinternal" False + fetchMany getTrigger + where + getTrigger :: (String, Int16, Bool, Bool, String, String, String, String) -> Trigger + getTrigger (tgname, tgtype, tgdeferrable, tginitdeferrable, triggerdef, proname, prosrc, tblName) = + Trigger { triggerTable = tableName + , triggerName = triggerBaseName (fromString tgname) tableName + , triggerEvents = getEvents tgtype + , triggerDeferrable = tgdeferrable + , triggerInitiallyDeferred = tginitdeferrable + , triggerWhen = tgrWhen + , triggerFunction = TriggerFunction (fromString proname) (fromString prosrc) + } + where + tableName = fromString tblName + -- Get the WHEN part of the query. Anything between WHEN and EXECUTE is what we + -- want. The Postgres' grammar guarantees that WHEN and EXECUTE are always next to + -- each other and in that order. + tgrWhen :: Maybe (RawSQL ()) + tgrWhen = + let (_, match) = Text.breakOn "WHEN" $ Text.pack triggerdef + in if Text.null match + then Nothing + else Just $ (rawSQL . fst $ Text.breakOn " EXECUTE" match) () + + getEvents :: Int16 -> Set TriggerEvent + getEvents tgtype = + foldl (\set (mask, event) -> + if testBit tgtype mask + then Set.insert event set + else set + ) + Set.empty + -- Taken from PostgreSQL sources: src/include/catalog/pg_trigger.h: + [ (2, TriggerInsert) -- #define TRIGGER_TYPE_INSERT (1 << 2) + , (3, TriggerDelete) -- #define TRIGGER_TYPE_DELETE (1 << 3) + , (4, TriggerUpdate) -- #define TRIGGER_TYPE_UPDATE (1 << 4) + ] + + +testDB :: DBT IO a -> IO a +testDB action = do + let cs = defaultConnectionSettings { csConnInfo = "host=localhost user=jsynacek dbname=kontrakcja" } + connSource = unConnectionSource $ simpleSource cs + runDBT connSource defaultTransactionSettings action + +testGetDBTriggers :: IO () +testGetDBTriggers = do + trgs <- testDB getDBTriggers + mapM_ print trgs + + +testTrigger1 :: Trigger +testTrigger1 = Trigger + { triggerTable = "users" + , triggerName = "users_trigger" + , triggerEvents = Set.fromList + [ TriggerInsert + , TriggerUpdate + , TriggerInsert + , TriggerUpdate + , TriggerDelete + ] + , triggerDeferrable = False + , triggerInitiallyDeferred = False + , triggerWhen = Nothing + , triggerFunction = TriggerFunction "testfun1" $ + "begin" + <+> " perform true;" + <+> " return null;" + <+> "end;" + } + +testTrigger2 :: Trigger +testTrigger2 = testTrigger1 { triggerName = "users_trigger_2" + , triggerEvents = Set.fromList [TriggerDelete] + } diff --git a/test/Main.hs b/test/Main.hs index 98f1383..7408f6f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,3 +1,5 @@ +{-# OPTIONS_GHC -Wno-name-shadowing #-} + module Main where import Control.Monad.Catch @@ -5,8 +7,10 @@ import Control.Monad.IO.Class import Data.Either import Data.Typeable import Data.UUID.Types +import qualified Data.Set as Set import qualified Data.Text as T +import Data.Monoid.Utils import Database.PostgreSQL.PQTypes import Database.PostgreSQL.PQTypes.Checks import Database.PostgreSQL.PQTypes.Model.ColumnType @@ -16,6 +20,7 @@ import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.Migration import Database.PostgreSQL.PQTypes.Model.PrimaryKey import Database.PostgreSQL.PQTypes.Model.Table +import Database.PostgreSQL.PQTypes.Model.Trigger import Database.PostgreSQL.PQTypes.SQL.Builder import Log import Log.Backend.StandardOutput @@ -69,6 +74,7 @@ tableBankSchema1 = , colNullable = False } ] , tblPrimaryKey = pkOnColumn "id" + , tblTriggers = [] } tableBankSchema2 :: Table @@ -368,7 +374,8 @@ schema2Migrations :: (MonadDB m) => [Migration m] schema2Migrations = schema1Migrations ++ [ dropTableMigration tableWitnessedRobberySchema1 , dropTableMigration tableWitnessSchema1 - , createTableMigration tableUnderArrestSchema2 ] + , createTableMigration tableUnderArrestSchema2 + ] schema3Tables :: [Table] schema3Tables = [ tableBankSchema3 @@ -826,6 +833,152 @@ migrationTest1Body step = do migrateDBToSchema5 step testDBSchema5 step +bankTrigger1 :: Trigger +bankTrigger1 = + Trigger { triggerTable = "bank" + , triggerName = "trigger_1" + , triggerEvents = Set.fromList [TriggerDelete] + , triggerDeferrable = False + , triggerInitiallyDeferred = False + , triggerWhen = Nothing + , triggerFunction = TriggerFunction "function_1" $ + "begin" + <+> " perform true;" + <+> " return null;" + <+> "end;" + } + +bankTrigger2 :: Trigger +bankTrigger2 = + bankTrigger1 + { triggerFunction = TriggerFunction "function_2" $ + "begin" + <+> " return null;" + <+> "end;" + } + +bankTrigger2Proper :: Trigger +bankTrigger2Proper = + bankTrigger2 { triggerName = "trigger_2" } + +testTriggers :: HasCallStack => (String -> TestM ()) -> TestM () +testTriggers step = do + step "Running trigger tests..." + + step "create the initial database" + migrate [tableBankSchema1] [createTableMigration tableBankSchema1] + + let msg = "checkDatabase fails if there are triggers in the database but not in the schema" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 ] + step msg + assertException msg $ migrate ts ms + + let msg = "checkDatabase fails if there are triggers in the schema but not in the database" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [] + triggerStep msg $ do + assertException msg $ migrate ts ms + + let msg = "test succeeds when creating a single trigger" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [bankTrigger1] + + + let msg = "checkDatabase fails if triggers differ in function name" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [ createTriggerMigration 1 bankTrigger2 ] + triggerStep msg $ do + assertException msg $ migrate ts ms + + -- Attempt to create the same triggers twice. Should fail with a DBException saying + -- that function already exists. + let msg = "database exception is raised if trigger is created twice" + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [bankTrigger1] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger1 + ] + triggerStep msg $ do + -- assertException is not good enough. We only want to catch DBException here. + try (migrate ts ms) >>= either (\DBException{} -> pure ()) + (const . liftIO $ assertFailure "Failure expected") + + + let msg = "database exception is raised if triggers only differ in function name" + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [bankTrigger1, bankTrigger2] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger2 + ] + triggerStep msg $ do + -- assertException is not good enough. We only want to catch DBException here. + try (migrate ts ms) >>= either (\DBException{} -> pure ()) + (const . liftIO $ assertFailure "Failure expected") + + let msg = "test successfully creates two triggers" + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [bankTrigger1, bankTrigger2Proper] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger2Proper + ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [bankTrigger1, bankTrigger2Proper] + + where + triggerStep msg rest = do + recreateTriggerDB + step msg + rest + + migrate tables migrations = do + migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] tables migrations + checkDatabase defaultExtrasOptions [] [] tables + + -- Verify that the given triggers are present in the database. + verify :: (MonadIO m, MonadDB m, HasCallStack) => [Trigger] -> m () + verify triggers = do + dbTriggers <- getDBTriggers + let ok = and $ map (`elem` dbTriggers) triggers + liftIO $ assertBool "Triggers not present in the database." ok + + createTriggerMigration :: MonadDB m => Int -> Trigger -> Migration m + createTriggerMigration from trg = Migration + { mgrTableName = tblName tableBankSchema1 + , mgrFrom = fromIntegral from + , mgrAction = CreateTriggerMigration trg + } + + recreateTriggerDB = do + runSQL_ "DROP TRIGGER IF EXISTS trg__bank__trigger_1 ON bank;" + runSQL_ "DROP TRIGGER IF EXISTS trg__bank__trigger_2 ON bank;" + runSQL_ "DROP FUNCTION IF EXISTS function_1;" + runSQL_ "DROP FUNCTION IF EXISTS function_2;" + runSQL_ "DROP TABLE IF EXISTS bank;" + runSQL_ "DELETE FROM table_versions WHERE name = 'bank'"; + migrate [tableBankSchema1] [createTableMigration tableBankSchema1] migrationTest1 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest1 connSource = @@ -834,8 +987,6 @@ migrationTest1 connSource = migrationTest1Body step - -- freshTestDB step - -- | Test for behaviour of 'checkDatabase' and 'checkDatabaseAllowUnknownObjects' migrationTest2 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest2 connSource = @@ -957,6 +1108,13 @@ migrationTest4 connSource = freshTestDB step +-- | Test triggers. +triggerTests :: ConnectionSourceM (LogT IO) -> TestTree +triggerTests connSource = + testCaseSteps' "Trigger tests" connSource $ \step -> do + freshTestDB step + testTriggers step + eitherExc :: MonadCatch m => (SomeException -> m ()) -> (a -> m ()) -> m a -> m () eitherExc left right c = try c >>= either left right @@ -994,6 +1152,7 @@ main = do , migrationTest2 connSource , migrationTest3 connSource , migrationTest4 connSource + , triggerTests connSource ] where ings =