diff --git a/changelog.md b/changelog.md index 455673110..a1a0bc40b 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,9 @@ +Unreleased (3.1.3) +======== + +- @JoseD92 + - [#155](https://github.com/bitemyapp/esqueleto/pull/149): Added `insertSelectWithConflict` postgres function. + 3.1.2 ======== diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index d6d165b4c..5ba734d60 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -887,6 +887,54 @@ instance ( ToSomeValues a toSomeValues c ++ toSomeValues d ++ toSomeValues e ++ toSomeValues f ++ toSomeValues g ++ toSomeValues h +type family KnowResult a where + KnowResult (i -> o) = KnowResult o + KnowResult a = a + +-- | A class for constructors or function which result type is known. +-- +-- @since 3.1.3 +class FinalResult a where + finalR :: a -> KnowResult a + +instance FinalResult (Unique val) where + finalR = id + +instance (FinalResult b) => FinalResult (a -> b) where + finalR f = finalR (f undefined) + +-- | Convert a constructor for a 'Unique' key on a record to the 'UniqueDef' that defines it. You +-- can supply just the constructor itself, or a value of the type - the library is capable of figuring +-- it out from there. +-- +-- @since 3.1.3 +toUniqueDef :: forall a val. (KnowResult a ~ (Unique val), PersistEntity val,FinalResult a) => + a -> UniqueDef +toUniqueDef uniqueConstructor = uniqueDef + where + proxy :: Proxy val + proxy = Proxy + unique :: Unique val + unique = finalR uniqueConstructor + -- there must be a better way to get the constrain name from a unique, make this not a list search + filterF = (==) (persistUniqueToFieldNames unique) . uniqueFields + uniqueDef = head . filter filterF . entityUniques . entityDef $ proxy + +-- | Render updates to be use in a SET clause for a given sql backend. +-- +-- @since 3.1.3 +renderUpdates :: (BackendCompatible SqlBackend backend) => + backend + -> [SqlExpr (Update val)] + -> (TLB.Builder, [PersistValue]) +renderUpdates conn = uncommas' . concatMap renderUpdate + where + mk :: SqlExpr (Value ()) -> [(TLB.Builder, [PersistValue])] + mk (ERaw _ f) = [f info] + mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME + renderUpdate :: SqlExpr (Update val) -> [(TLB.Builder, [PersistValue])] + renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused + info = (projectBackend conn, initialIdentState) -- | Data type that represents an @INNER JOIN@ (see 'LeftOuterJoin' for an example). data InnerJoin a b = a `InnerJoin` b diff --git a/src/Database/Esqueleto/PostgreSQL.hs b/src/Database/Esqueleto/PostgreSQL.hs index cbbf7887e..8f70fb2d4 100644 --- a/src/Database/Esqueleto/PostgreSQL.hs +++ b/src/Database/Esqueleto/PostgreSQL.hs @@ -1,6 +1,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings - , GADTs, CPP + , GADTs, CPP, Rank2Types + , ScopedTypeVariables #-} -- | This module contain PostgreSQL-specific functions. -- @@ -20,6 +21,8 @@ module Database.Esqueleto.PostgreSQL , random_ , upsert , upsertBy + , insertSelectWithConflict + , insertSelectWithConflictCount -- * Internal , unsafeSqlAggregateFunction ) where @@ -33,15 +36,19 @@ import Database.Esqueleto.Internal.Language hiding (random_) import Database.Esqueleto.Internal.PersistentImport hiding (upsert, upsertBy) import Database.Esqueleto.Internal.Sql import Database.Esqueleto.Internal.Internal (EsqueletoError(..), CompositeKeyError(..), - UnexpectedCaseError(..), SetClause) + UnexpectedCaseError(..), SetClause, Ident(..), + uncommas, FinalResult(..), toUniqueDef, + KnowResult, renderUpdates) import Database.Persist.Class (OnlyOneUniqueKey) import Data.List.NonEmpty ( NonEmpty( (:|) ) ) +import Data.Int (Int64) +import Data.Proxy (Proxy(..)) import Control.Arrow ((***), first) import Control.Exception (Exception, throw, throwIO) +import Control.Monad (void) import Control.Monad.IO.Class (MonadIO (..)) import qualified Control.Monad.Trans.Reader as R - -- | (@random()@) Split out into database specific modules -- because MySQL uses `rand()`. -- @@ -199,18 +206,95 @@ upsertBy uniqueKey record updates = do where addVals l = map toPersistValue (toPersistFields record) ++ l ++ persistUniqueToValues uniqueKey entDef = entityDef (Just record) - uDef = head $ filter ((==) (persistUniqueToFieldNames uniqueKey) . uniqueFields) $ entityUniques entDef + uDef = toUniqueDef uniqueKey updatesText conn = first builderToText $ renderUpdates conn updates handler conn f = fmap head $ uncurry rawSql $ (***) (f entDef (uDef :| [])) addVals $ updatesText conn - renderUpdates :: SqlBackend - -> [SqlExpr (Update val)] - -> (TLB.Builder, [PersistValue]) - renderUpdates conn = uncommas' . concatMap renderUpdate - where - mk :: SqlExpr (Value ()) -> [(TLB.Builder, [PersistValue])] - mk (ERaw _ f) = [f info] - mk (ECompositeKey _) = throw (CompositeKeyErr MakeSetError) -- FIXME - renderUpdate :: SqlExpr (Update val) -> [(TLB.Builder, [PersistValue])] - renderUpdate (ESet f) = mk (f undefined) -- second parameter of f is always unused - info = (projectBackend conn, initialIdentState) \ No newline at end of file + +-- | Inserts into a table the results of a query similar to 'insertSelect' but allows +-- to update values that violate a constraint during insertions. +-- +-- Example of usage: +-- +-- @ +-- share [ mkPersist sqlSettings +-- , mkDeleteCascade sqlSettings +-- , mkMigrate "migrate" +-- ] [persistLowerCase| +-- Bar +-- num Int +-- deriving Eq Show +-- Foo +-- num Int +-- UniqueFoo num +-- deriving Eq Show +-- |] +-- +-- insertSelectWithConflict +-- UniqueFoo -- (UniqueFoo undefined) or (UniqueFoo anyNumber) would also work +-- (from $ \b -> +-- return $ Foo <# (b ^. BarNum) +-- ) +-- (\current excluded -> +-- [FooNum =. (current ^. FooNum) +. (excluded ^. FooNum)] +-- ) +-- @ +-- +-- Inserts to table Foo all Bar.num values and in case of conflict SomeFooUnique, +-- the conflicting value is updated to the current plus the excluded. +-- +-- @since 3.1.3 +insertSelectWithConflict :: forall a m val. ( + FinalResult a, + KnowResult a ~ (Unique val), + MonadIO m, + PersistEntity val) => + a + -- ^ Unique constructor or a unique, this is used just to get the name of the postgres constraint, the value(s) is(are) never used, so if you have a unique "MyUnique 0", "MyUnique undefined" would work as well. + -> SqlQuery (SqlExpr (Insertion val)) + -- ^ Insert query. + -> (SqlExpr (Entity val) -> SqlExpr (Entity val) -> [SqlExpr (Update val)]) + -- ^ A list of updates to be applied in case of the constraint being violated. The expression takes the current and excluded value to produce the updates. + -> SqlWriteT m () +insertSelectWithConflict unique query = void . insertSelectWithConflictCount unique query + +-- | Same as 'insertSelectWithConflict' but returns the number of rows affected. +-- +-- @since 3.1.3 +insertSelectWithConflictCount :: forall a val m. ( + FinalResult a, + KnowResult a ~ (Unique val), + MonadIO m, + PersistEntity val) => + a + -> SqlQuery (SqlExpr (Insertion val)) + -> (SqlExpr (Entity val) -> SqlExpr (Entity val) -> [SqlExpr (Update val)]) + -> SqlWriteT m Int64 +insertSelectWithConflictCount unique query conflictQuery = do + conn <- R.ask + uncurry rawExecuteCount $ + combine + (toRawSql INSERT_INTO (conn, initialIdentState) (fmap EInsertFinal query)) + (conflict conn) + where + proxy :: Proxy val + proxy = Proxy + updates = conflictQuery entCurrent entExcluded + combine (tlb1,vals1) (tlb2,vals2) = (builderToText (tlb1 `mappend` tlb2), vals1 ++ vals2) + entExcluded = EEntity $ I "excluded" + tableName = unDBName . entityDB . entityDef + entCurrent = EEntity $ I (tableName proxy) + uniqueDef = toUniqueDef unique + constraint = TLB.fromText . unDBName . uniqueDBName $ uniqueDef + renderedUpdates :: (BackendCompatible SqlBackend backend) => backend -> (TLB.Builder, [PersistValue]) + renderedUpdates conn = renderUpdates conn updates + conflict conn = (foldr1 mappend ([ + TLB.fromText "ON CONFLICT ON CONSTRAINT \"", + constraint, + TLB.fromText "\" DO " + ] ++ if null updates then [TLB.fromText "NOTHING"] else [ + TLB.fromText "UPDATE SET ", + updatesTLB + ]),values) + where + (updatesTLB,values) = renderedUpdates conn diff --git a/test/Common/Test.hs b/test/Common/Test.hs index 9c6d36734..4f55d1b0f 100644 --- a/test/Common/Test.hs +++ b/test/Common/Test.hs @@ -52,6 +52,7 @@ module Common.Test , Circle (..) , Numbers (..) , OneUnique(..) + , Unique(..) ) where import Control.Monad (forM_, replicateM, replicateM_, void) diff --git a/test/PostgreSQL/Test.hs b/test/PostgreSQL/Test.hs index 407cb4119..4ef69c4e4 100644 --- a/test/PostgreSQL/Test.hs +++ b/test/PostgreSQL/Test.hs @@ -978,6 +978,50 @@ testUpsert = u3e <- EP.upsert u3 [OneUniqueName =. val "fifth"] liftIO $ entityVal u3e `shouldBe` u1{oneUniqueName="fifth"} +testInsertSelectWithConflict :: Spec +testInsertSelectWithConflict = + describe "insertSelectWithConflict test" $ do + it "Should do Nothing when no updates set" $ run $ do + _ <- insert p1 + _ <- insert p2 + _ <- insert p3 + n1 <- EP.insertSelectWithConflictCount UniqueValue ( + from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum) + ) + (\current excluded -> []) + uniques1 <- select $ from $ \u -> return u + n2 <- EP.insertSelectWithConflictCount UniqueValue ( + from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum) + ) + (\current excluded -> []) + uniques2 <- select $ from $ \u -> return u + liftIO $ n1 `shouldBe` 3 + liftIO $ n2 `shouldBe` 0 + let test = map (OneUnique "test" . personFavNum) [p1,p2,p3] + liftIO $ map entityVal uniques1 `shouldBe` test + liftIO $ map entityVal uniques2 `shouldBe` test + it "Should update a value if given an update on conflict" $ run $ do + _ <- insert p1 + _ <- insert p2 + _ <- insert p3 + -- Note, have to sum 4 so that the update does not conflicts again with another row. + n1 <- EP.insertSelectWithConflictCount UniqueValue ( + from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum) + ) + (\current excluded -> [OneUniqueValue =. val 4 +. (current ^. OneUniqueValue) +. (excluded ^. OneUniqueValue)]) + uniques1 <- select $ from $ \u -> return u + n2 <- EP.insertSelectWithConflictCount UniqueValue ( + from $ \p -> return $ OneUnique <# val "test" <&> (p ^. PersonFavNum) + ) + (\current excluded -> [OneUniqueValue =. val 4 +. (current ^. OneUniqueValue) +. (excluded ^. OneUniqueValue)]) + uniques2 <- select $ from $ \u -> return u + liftIO $ n1 `shouldBe` 3 + liftIO $ n2 `shouldBe` 3 + let test = map (OneUnique "test" . personFavNum) [p1,p2,p3] + test2 = map (OneUnique "test" . (+4) . (*2) . personFavNum) [p1,p2,p3] + liftIO $ map entityVal uniques1 `shouldBe` test + liftIO $ map entityVal uniques2 `shouldBe` test2 + type JSONValue = Maybe (JSONB A.Value) createSaneSQL :: (PersistField a) => SqlExpr (Value a) -> T.Text -> [PersistValue] -> IO () @@ -1051,6 +1095,7 @@ main = do testPostgresqlTextFunctions testInsertUniqueViolation testUpsert + testInsertSelectWithConflict describe "PostgreSQL JSON tests" $ do -- NOTE: We only clean the table once, so we -- can use its contents across all JSON tests