Skip to content

Commit

Permalink
refactor creating connection record (#1021)
Browse files Browse the repository at this point in the history
  • Loading branch information
epoberezkin authored Mar 2, 2024
1 parent 294d7ec commit ce78646
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
29 changes: 14 additions & 15 deletions src/Simplex/Messaging/Agent/Store/SQLite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ module Simplex.Messaging.Agent.Store.SQLite
createNewConn,
updateNewConnRcv,
updateNewConnSnd,
createRcvConn, -- no longer used
createSndConn,
getConn,
getDeletedConn,
Expand Down Expand Up @@ -543,11 +542,8 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
ConnData {connId} -> Right . (connId,) <$> create connId

createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} cMode = do
fst <$$> createConn_ gVar cData create
where
create connId =
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True)
createNewConn db gVar cData cMode = do
fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode)

updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue)
updateNewConnRcv db connId rq =
Expand All @@ -569,22 +565,25 @@ updateNewConnSnd db connId sq =
updateConn :: IO (Either StoreError SndQueue)
updateConn = Right <$> addConnSndQueue_ db connId sq

createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue))
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True)
insertRcvQueue_ db connId q serverKeyHash_

createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue))
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@SndQueue {server} =
createSndConn db gVar cData q@SndQueue {server} =
-- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_
ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $
createConn_ gVar cData $ \connId -> do
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, True)
createConnRecord db connId cData SCMInvitation
insertSndQueue_ db connId q serverKeyHash_

createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO ()
createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs} cMode =
DB.execute
db
[sql|
INSERT INTO connections
(user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)
|]
(userId, connId, cMode, connAgentVersion, enableNtfs, True)

checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
fromMaybe False
Expand Down
10 changes: 9 additions & 1 deletion tests/AgentTests/SQLiteTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import Control.Concurrent.Async (concurrently_)
import Control.Concurrent.STM
import Control.Exception (SomeException)
import Control.Monad (replicateM_)
import Control.Monad.Trans.Except
import Crypto.Random (ChaChaDRG)
import Data.ByteArray (ScrubbedBytes)
import Data.ByteString.Char8 (ByteString)
import Data.List (isInfixOf)
Expand Down Expand Up @@ -91,7 +93,7 @@ storeTests = do
testForeignKeysEnabled
describe "db methods" $ do
describe "Queue and Connection management" $ do
describe "createRcvConn" $ do
describe "create Rcv connection" $ do
testCreateRcvConn
testCreateRcvConnRandomId
testCreateRcvConnDuplicate
Expand Down Expand Up @@ -227,6 +229,12 @@ sndQueue1 =
smpClientVersion = 1
}

createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue))
createRcvConn db g cData rq cMode = runExceptT $ do
connId <- ExceptT $ createNewConn db g cData cMode
rq' <- ExceptT $ updateNewConnRcv db connId rq
pure (connId, rq')

testCreateRcvConn :: SpecWith SQLiteStore
testCreateRcvConn =
it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do
Expand Down

0 comments on commit ce78646

Please sign in to comment.