Skip to content

Commit

Permalink
Abstract authentication storage layer for unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kostmo committed May 20, 2024
1 parent 1ceb672 commit 54f9019
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 96 deletions.
22 changes: 14 additions & 8 deletions app/tournament/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Main where
import Control.Monad.Trans.Reader (runReaderT)
import Data.Maybe (fromMaybe)
import Data.Yaml (decodeFileThrow)
import Database.SQLite.Simple (withConnection)
import Network.Wai.Handler.Warp (Port)
import Options.Applicative
import Swarm.Game.State (Sha1 (..))
Expand Down Expand Up @@ -76,17 +77,22 @@ main = do
PersistenceLayer
{ scenarioStorage =
ScenarioPersistence
{ lookupCache = withConnInfo lookupScenarioSolution
, storeCache = withConnInfo insertScenario
, getContent = withConnInfo lookupScenarioContent
{ lookupCache = withConn lookupScenarioSolution
, storeCache = withConn insertScenario
, getContent = withConn lookupScenarioContent
}
, solutionStorage =
ScenarioPersistence
{ lookupCache = withConnInfo lookupSolutionSubmission
, storeCache = withConnInfo insertSolutionSubmission
, getContent = withConnInfo lookupSolutionContent
{ lookupCache = withConn lookupSolutionSubmission
, storeCache = withConn insertSolutionSubmission
, getContent = withConn lookupSolutionContent
}
, authenticationStorage =
AuthenticationStorage
{ usernameFromCookie = withConn getUsernameFromCookie
, cookieFromUsername = withConn insertCookie
}
}
where
withConnInfo f x =
runReaderT (f x) databaseFilename
withConn f x =
withConnection databaseFilename $ runReaderT $ f x
61 changes: 35 additions & 26 deletions src/swarm-tournament/Swarm/Web/Tournament.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import Data.Text qualified as T
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.Encoding (decodeUtf8, decodeUtf8', encodeUtf8)
import Data.Yaml (decodeEither', defaultEncodeOptions, encodeWith)
import Database.SQLite.Simple (withConnection)
import GHC.Generics (Generic)
import Network.HTTP.Client qualified as HC
import Network.HTTP.Client.TLS (tlsManagerSettings)
Expand Down Expand Up @@ -108,8 +109,8 @@ mkApp appData =
:<|> listScenarios
:<|> listSolutions
:<|> echoUsername
:<|> doGithubCallback (gitHubCredentials appData)
:<|> doLocalDevelopmentLogin (developmentMode appData)
:<|> doGithubCallback (authenticationStorage $ persistence appData) (gitHubCredentials appData)
:<|> doLocalDevelopmentLogin (authenticationStorage $ persistence appData) (developmentMode appData)
:<|> doLogout
where
echoUsername = return
Expand Down Expand Up @@ -144,8 +145,12 @@ data LoginProblem = LoginProblem
--- | The auth handler wraps a function from Request -> Handler UserAlias.
--- We look for a token in the request headers that we expect to be in the cookie.
--- The token is then passed to our `lookupAccount` function.
authHandler :: GitHubCredentials -> DeploymentEnvironment -> AuthHandler Request UserAlias
authHandler creds deployMode = mkAuthHandler handler
authHandler ::
AuthenticationStorage IO ->
GitHubCredentials ->
DeploymentEnvironment ->
AuthHandler Request UserAlias
authHandler authStorage creds deployMode = mkAuthHandler handler
where
url = case deployMode of
LocalDevelopment _ -> "api/private/login/local"
Expand All @@ -159,12 +164,10 @@ authHandler creds deployMode = mkAuthHandler handler
. lookup myAppCookieName
$ parseCookies cookie

userLookup cookieText = runReaderT (getUsernameFromCookie cookieText) databaseFilename

-- \| A method that, when given a cookie/password, will return a UserAlias.
-- A method that, when given a cookie/password, will return a 'UserAlias'.
lookupAccount :: TL.Text -> Handler UserAlias
lookupAccount cookieText = do
maybeUser <- liftIO $ userLookup cookieText
maybeUser <- liftIO $ usernameFromCookie authStorage cookieText
case maybeUser of
Nothing -> throwError (err403 {errBody = encode $ LoginProblem "Invalid cookie password" url})
Just usr -> return usr
Expand Down Expand Up @@ -254,38 +257,43 @@ downloadRedactedScenario (AppData _ _ persistenceLayer _) scenarioSha1 = do

listScenarios :: Handler [TournamentGame]
listScenarios =
Handler $ liftIO $ runReaderT listGames databaseFilename
Handler . liftIO . withConnection databaseFilename $ runReaderT listGames

listSolutions :: Sha1 -> Handler GameWithSolutions
listSolutions sha1 =
Handler $ liftIO $ runReaderT (listSubmissions sha1) databaseFilename
Handler . liftIO . withConnection databaseFilename . runReaderT $ listSubmissions sha1

doGithubCallback ::
AuthenticationStorage IO ->
GitHubCredentials ->
Maybe TokenExchangeCode ->
LoginHandler
doGithubCallback creds maybeCode = do
doGithubCallback authStorage creds maybeCode = do
c <- maybe (fail "Missing 'code' parameter") return maybeCode

manager <- liftIO $ HC.newManager tlsManagerSettings
receivedTokens <- exchangeCode manager creds c

let aToken = token $ accessToken receivedTokens
userInfo <- fetchAuthenticatedUser manager aToken

doLoginResponse refererUrl (UserAlias $ login userInfo) receivedTokens
let user = UserAlias $ login userInfo
x <- doLoginResponse authStorage refererUrl user
liftIO . withConnection databaseFilename . runReaderT $ do
insertGitHubTokens user receivedTokens
return x
where
refererUrl = "/list-games.html"

doLocalDevelopmentLogin :: DeploymentEnvironment -> Maybe TL.Text -> LoginHandler
doLocalDevelopmentLogin envType maybeRefererUrl =
doLocalDevelopmentLogin ::
AuthenticationStorage IO ->
DeploymentEnvironment ->
Maybe TL.Text ->
LoginHandler
doLocalDevelopmentLogin authStorage envType maybeRefererUrl =
case envType of
ProdDeployment -> error "Login bypass not available in production"
LocalDevelopment user ->
doLoginResponse refererUrl user $
ReceivedTokens
(Expirable (AccessToken "abcd") 100)
(Expirable (RefreshToken "efgh") 10000)
doLoginResponse authStorage refererUrl user
where
refererUrl = fromMaybe "foo" maybeRefererUrl

Expand All @@ -304,16 +312,13 @@ doLogout maybeRefererUrl =
addHeader ((makeCookieHeader "") {setCookieMaxAge = Just 0}) NoContent

doLoginResponse ::
AuthenticationStorage IO ->
TL.Text ->
UserAlias ->
ReceivedTokens ->
LoginHandler
doLoginResponse refererUrl userAlias receivedTokens = do
doLoginResponse authStorage refererUrl userAlias = do
cookieString <-
liftIO $
flip runReaderT databaseFilename $
insertGitHubAuth userAlias receivedTokens

liftIO $ cookieFromUsername authStorage userAlias
return $
addHeader refererUrl $
addHeader (makeCookieHeader $ LBS.toStrict $ encodeUtf8 cookieString) NoContent
Expand All @@ -335,7 +340,11 @@ app appData = Servant.serveWithContext (Proxy :: Proxy ToplevelAPI) context serv
$ defaultParseRequestBodyOptions
}

thisAuthHandler = authHandler (gitHubCredentials appData) (developmentMode appData)
thisAuthHandler =
authHandler
(authenticationStorage $ persistence appData)
(gitHubCredentials appData)
(developmentMode appData)
context = thisAuthHandler :. multipartOpts :. EmptyContext

server :: Server ToplevelAPI
Expand Down
108 changes: 64 additions & 44 deletions src/swarm-tournament/Swarm/Web/Tournament/Database/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ newtype UserId = UserId Int
instance ToField UserId where
toField (UserId x) = toField x

data AuthenticationStorage m = AuthenticationStorage
{ usernameFromCookie :: TL.Text -> m (Maybe UserAlias)
, cookieFromUsername :: UserAlias -> m TL.Text
}

data PersistenceLayer m = PersistenceLayer
{ scenarioStorage :: ScenarioPersistence m ScenarioUploadResponsePayload
, solutionStorage :: ScenarioPersistence m SolutionUploadResponsePayload
, authenticationStorage :: AuthenticationStorage m
}

data ScenarioPersistence m a = ScenarioPersistence
Expand Down Expand Up @@ -109,57 +115,70 @@ instance FromRow SolutionFileCharacterization where

-- * Authentication

-- | If the username already exists, overwrite it.
insertGitHubAuth ::
-- | If the username already exists, overwrite the row.
insertCookie ::
UserAlias ->
ReceivedTokens ->
ReaderT ConnectInfo IO TL.Text
insertGitHubAuth gitHubUsername gitHubTokens = do
connInfo <- ask
currentTime <- liftIO getCurrentTime
let expirationOf = mkExpirationTime currentTime
liftIO $ withConnection connInfo $ \conn -> do
ReaderT Connection IO TL.Text
insertCookie gitHubUsername = do
conn <- ask
liftIO $ do
[Only cookieString] <-
query
conn
"REPLACE INTO users (alias, github_access_token, github_access_token_expires_at, github_refresh_token, github_refresh_token_expires_at) VALUES (?, ?, ?, ?, ?) RETURNING cookie;"
( gitHubUsername
, token $ accessToken gitHubTokens
, expirationOf accessToken
, token $ refreshToken gitHubTokens
, expirationOf refreshToken
)
"REPLACE INTO users (alias) VALUES (?) RETURNING cookie;"
(Only gitHubUsername)
return cookieString

-- | If the username already exists, overwrite the row.
insertGitHubTokens ::
UserAlias ->
ReceivedTokens ->
ReaderT Connection IO ()
insertGitHubTokens gitHubUsername gitHubTokens = do
conn <- ask
currentTime <- liftIO getCurrentTime
let expirationOf = mkExpirationTime currentTime
liftIO $ do
execute
conn
"REPLACE INTO github_tokens (alias, github_access_token, github_access_token_expires_at, github_refresh_token, github_refresh_token_expires_at) VALUES (?, ?, ?, ?, ?);"
( gitHubUsername
, token $ accessToken gitHubTokens
, expirationOf accessToken
, token $ refreshToken gitHubTokens
, expirationOf refreshToken
)
return ()
where
mkExpirationTime currTime accessor =
addUTCTime (fromIntegral $ expirationSeconds $ accessor gitHubTokens) currTime

getUsernameFromCookie ::
TL.Text ->
ReaderT ConnectInfo IO (Maybe UserAlias)
ReaderT Connection IO (Maybe UserAlias)
getUsernameFromCookie cookieText = do
connInfo <- ask
liftIO . fmap (fmap (UserAlias . fromOnly) . listToMaybe) . withConnection connInfo $ \conn ->
conn <- ask
liftIO . fmap (fmap (UserAlias . fromOnly) . listToMaybe) $
query conn "SELECT alias FROM users WHERE cookie = ?;" (Only cookieText)

-- * Retrieval

lookupScenarioContent :: Sha1 -> ReaderT ConnectInfo IO (Maybe LBS.ByteString)
lookupScenarioContent :: Sha1 -> ReaderT Connection IO (Maybe LBS.ByteString)
lookupScenarioContent sha1 = do
connInfo <- ask
liftIO . fmap (fmap fromOnly . listToMaybe) . withConnection connInfo $ \conn ->
conn <- ask
liftIO . fmap (fmap fromOnly . listToMaybe) $
query conn "SELECT content FROM scenarios WHERE content_sha1 = ?;" (Only sha1)

lookupSolutionContent :: Sha1 -> ReaderT ConnectInfo IO (Maybe LBS.ByteString)
lookupSolutionContent :: Sha1 -> ReaderT Connection IO (Maybe LBS.ByteString)
lookupSolutionContent sha1 = do
connInfo <- ask
liftIO . fmap (fmap fromOnly . listToMaybe) . withConnection connInfo $ \conn ->
conn <- ask
liftIO . fmap (fmap fromOnly . listToMaybe) $
query conn "SELECT content FROM solution_submission WHERE content_sha1 = ?;" (Only sha1)

lookupSolutionSubmission :: Sha1 -> ReaderT ConnectInfo IO (Maybe AssociatedSolutionCharacterization)
lookupSolutionSubmission :: Sha1 -> ReaderT Connection IO (Maybe AssociatedSolutionCharacterization)
lookupSolutionSubmission contentSha1 = do
connInfo <- ask
liftIO $ withConnection connInfo $ \conn -> runMaybeT $ do
conn <- ask
liftIO $ runMaybeT $ do
evaluationId :: Int <-
MaybeT $
fmap fromOnly . listToMaybe
Expand All @@ -170,23 +189,24 @@ lookupSolutionSubmission contentSha1 = do
<$> query conn "SELECT scenario, wall_time_seconds, ticks, seed, char_count, ast_size FROM evaluated_solution WHERE id = ?;" (Only evaluationId)

-- | There should only be one builtin solution for the scenario.
lookupScenarioSolution :: Sha1 -> ReaderT ConnectInfo IO (Maybe AssociatedSolutionCharacterization)
lookupScenarioSolution :: Sha1 -> ReaderT Connection IO (Maybe AssociatedSolutionCharacterization)
lookupScenarioSolution scenarioSha1 = do
connInfo <- ask
solnChar <- liftIO . fmap listToMaybe . withConnection connInfo $ \conn ->
query conn "SELECT wall_time_seconds, ticks, seed, char_count, ast_size FROM evaluated_solution WHERE builtin AND scenario = ? LIMIT 1;" (Only scenarioSha1)
conn <- ask
solnChar <-
liftIO . fmap listToMaybe $
query conn "SELECT wall_time_seconds, ticks, seed, char_count, ast_size FROM evaluated_solution WHERE builtin AND scenario = ? LIMIT 1;" (Only scenarioSha1)
return $ AssociatedSolutionCharacterization scenarioSha1 <$> solnChar

listGames :: ReaderT ConnectInfo IO [TournamentGame]
listGames :: ReaderT Connection IO [TournamentGame]
listGames = do
connInfo <- ask
liftIO $ withConnection connInfo $ \conn ->
conn <- ask
liftIO $
query_ conn "SELECT original_filename, scenario_uploader, scenario, submission_count, swarm_git_sha1, title FROM agg_scenario_submissions;"

listSubmissions :: Sha1 -> ReaderT ConnectInfo IO GameWithSolutions
listSubmissions :: Sha1 -> ReaderT Connection IO GameWithSolutions
listSubmissions scenarioSha1 = do
connInfo <- ask
liftIO $ withConnection connInfo $ \conn -> do
conn <- ask
liftIO $ do
[game] <- query conn "SELECT original_filename, scenario_uploader, scenario, submission_count, swarm_git_sha1, title FROM agg_scenario_submissions WHERE scenario = ?;" (Only scenarioSha1)
solns <- query conn "SELECT uploaded_at, solution_submitter, solution_sha1, wall_time_seconds, ticks, seed, char_count, ast_size FROM all_solution_submissions WHERE scenario = ?;" (Only scenarioSha1)
return $ GameWithSolutions game solns
Expand All @@ -195,10 +215,10 @@ listSubmissions scenarioSha1 = do

insertScenario ::
CharacterizationResponse ScenarioUploadResponsePayload ->
ReaderT ConnectInfo IO Sha1
ReaderT Connection IO Sha1
insertScenario s = do
connInfo <- ask
h <- liftIO $ withConnection connInfo $ \conn -> do
conn <- ask
h <- liftIO $ do
[Only resultList] <-
query
conn
Expand All @@ -219,10 +239,10 @@ insertScenario s = do

insertSolutionSubmission ::
CharacterizationResponse SolutionUploadResponsePayload ->
ReaderT ConnectInfo IO Sha1
ReaderT Connection IO Sha1
insertSolutionSubmission (CharacterizationResponse solutionUpload s (SolutionUploadResponsePayload scenarioSha)) = do
connInfo <- ask
liftIO $ withConnection connInfo $ \conn -> do
conn <- ask
liftIO $ do
solutionEvalId <- insertSolution conn False scenarioSha $ characterization s
[Only echoedSha1] <-
query
Expand Down
1 change: 1 addition & 0 deletions swarm.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,7 @@ executable swarm-host-tournament
build-depends:
base,
optparse-applicative >=0.16 && <0.19,
sqlite-simple,
transformers,
warp,
yaml,
Expand Down
9 changes: 8 additions & 1 deletion test/tournament-host/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,21 @@ main = do
{ getContent = return . fmap content . (`NEM.lookup` scenariosMap)
}
, solutionStorage = noPersistence
, authenticationStorage =
AuthenticationStorage
{ usernameFromCookie = const $ return $ Just fakeUser
, cookieFromUsername = const $ return "fake-cookie-value"
}
}

fakeUser = UserAlias "test-user"

mkAppData scenariosMap =
Tournament.AppData
{ Tournament.swarmGameGitVersion = Sha1 "abcdef"
, Tournament.gitHubCredentials = Tournament.GitHubCredentials "" ""
, Tournament.persistence = mkPersistenceLayer scenariosMap
, Tournament.developmentMode = Tournament.LocalDevelopment $ UserAlias "test-user"
, Tournament.developmentMode = Tournament.LocalDevelopment fakeUser
}

type LocalFileLookup = NEMap Sha1 FilePathAndContent
Expand Down
Loading

0 comments on commit 54f9019

Please sign in to comment.