diff --git a/DESCRIPTION b/DESCRIPTION index fa6fc893..9fd71f69 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: Strategus Type: Package Title: Coordinate and Execute OHDSI HADES Modules -Version: 1.0.0 -Date: 2024-10-08 +Version: 1.1.0 +Date: 2024-11-12 Authors@R: c( person("Anthony", "Sena", email = "sena@ohdsi.org", role = c("aut", "cre")), person("Martijn", "Schuemie", email = "schuemie@ohdsi.org", role = c("aut")), diff --git a/R/StrategusModule.R b/R/StrategusModule.R index 8e4f7d49..38f3d6d0 100644 --- a/R/StrategusModule.R +++ b/R/StrategusModule.R @@ -159,14 +159,30 @@ StrategusModule <- R6::R6Class( } private$jobContext$settings <- moduleSpecification$settings + # Make sure that the covariate settings for the analysis are updated + # to reflect the location of the cohort tables if we are executing + # on a CDM. + if (inherits(executionSettings, "CdmExecutionSettings")) { + private$jobContext$settings <- .replaceCovariateSettings( + moduleSettings = private$jobContext$settings, + executionSettings = executionSettings + ) + } + # Assemble the job context from the analysis specification # for the given module. private$jobContext$sharedResources <- analysisSpecifications$sharedResources private$jobContext$moduleExecutionSettings <- executionSettings private$jobContext$moduleExecutionSettings$resultsSubFolder <- file.path(private$jobContext$moduleExecutionSettings$resultsFolder, self$moduleName) + if (!dir.exists(private$jobContext$moduleExecutionSettings$resultsSubFolder)) { + dir.create(private$jobContext$moduleExecutionSettings$resultsSubFolder, showWarnings = F, recursive = T) + } if (is(private$jobContext$moduleExecutionSettings, "ExecutionSettings")) { private$jobContext$moduleExecutionSettings$workSubFolder <- file.path(private$jobContext$moduleExecutionSettings$workFolder, self$moduleName) + if (!dir.exists(private$jobContext$moduleExecutionSettings$workSubFolder)) { + dir.create(private$jobContext$moduleExecutionSettings$workSubFolder, showWarnings = F, recursive = T) + } } }, .getModuleSpecification = function(analysisSpecifications, moduleName) { @@ -294,3 +310,64 @@ StrategusModule <- R6::R6Class( } ) ) + +# Utility function to set the cohort table & schema on +# createCohortBasedCovariateSettings with information from +# the execution settings (Issue #181) +.replaceCovariateSettingsCohortTableNames <- function(covariateSettings, executionSettings) { + errorMessages <- checkmate::makeAssertCollection() + checkmate::assertList(covariateSettings, min.len = 1, add = errorMessages) + checkmate::assertClass(executionSettings, "CdmExecutionSettings", add = errorMessages) + checkmate::reportAssertions(collection = errorMessages) + + .replaceProperties <- function(s) { + if (inherits(s, "covariateSettings") && "fun" %in% names(attributes(s))) { + if (attr(s, "fun") == "getDbCohortBasedCovariatesData") { + # Set the covariateCohortDatabaseSchema & covariateCohortTable values + s$covariateCohortDatabaseSchema = executionSettings$workDatabaseSchema + s$covariateCohortTable = executionSettings$cohortTableNames$cohortTable + } + } + return(s) + } + if (is.null(names(covariateSettings))) { + # List of lists + modifiedCovariateSettings <- lapply(covariateSettings, .replaceProperties) + } else { + # Plain list + modifiedCovariateSettings <- .replaceProperties(covariateSettings) + } + return(modifiedCovariateSettings) +} + +.replaceCovariateSettings <- function(moduleSettings, executionSettings) { + errorMessages <- checkmate::makeAssertCollection() + checkmate::assertList(moduleSettings, min.len = 1, add = errorMessages) + checkmate::assertClass(executionSettings, "CdmExecutionSettings", add = errorMessages) + checkmate::reportAssertions(collection = errorMessages) + + # A helper function to perform the replacement + replaceHelper <- function(x) { + if (is.list(x) && inherits(x, "covariateSettings")) { + # If the element is a list and of type covariate settings + # replace the cohort table names + return(.replaceCovariateSettingsCohortTableNames(x, executionSettings)) + } else if (is.list(x)) { + # If the element is a list, recurse on each element + # Keep the original attributes by saving them before modification + attrs <- attributes(x) + newList <- lapply(x, replaceHelper) + # Restore attributes to the new list + attributes(newList) <- attrs + return(newList) + } else { + # If the element is not a list or "covariateSettings", return it as is + return(x) + } + } + + # Call the helper function on the input list + return(replaceHelper(moduleSettings)) +} + + diff --git a/tests/testthat/test-Settings.R b/tests/testthat/test-Settings.R index 08970cb3..551c1be0 100644 --- a/tests/testthat/test-Settings.R +++ b/tests/testthat/test-Settings.R @@ -408,3 +408,80 @@ test_that("Create results data model settings", { expect_equal(class(settings), c("ResultsDataModelSettings")) }) + +test_that("Test internal function for modifying covariate settings", { + # Create module settings that contain a combination of + # 1) covariate settings that do not contain cohort table settings + # 2) covariate settings that contain cohort table settings + # 3) a list of covariate setting that has 1 & 2 above + # 4) Something other than a covariate setting object + esModuleSettingsCreator <- EvidenceSynthesisModule$new() + evidenceSynthesisSourceCmGrid <- esModuleSettingsCreator$createEvidenceSynthesisSource( + sourceMethod = "CohortMethod", + likelihoodApproximation = "adaptive grid" + ) + + cov1 <- FeatureExtraction::createDefaultCovariateSettings() + cov2 <- FeatureExtraction::createCohortBasedCovariateSettings( + analysisId = 999, + covariateCohorts = data.frame( + cohortId = 1, + cohortName = "test" + ) + ) + covariateSettings <- list(cov1, cov2) + moduleSettings <- list( + analysis = list( + something = covariateSettings, + somethingElse = list( + nested1 = cov1, + nested2 = cov2, + nested3 = covariateSettings + ), + esSettings = evidenceSynthesisSourceCmGrid + ) + ) + workDatabaseSchema <- "foo" + cohortTableNames <- CohortGenerator::getCohortTableNames(cohortTable = "unit_test") + executionSettings <- createCdmExecutionSettings( + workDatabaseSchema = workDatabaseSchema, + cdmDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + workFolder = "temp", + resultsFolder = "temp" + ) + + testReplacedModuleSettings <- .replaceCovariateSettings(moduleSettings, executionSettings) + # For visual inspection + #ParallelLogger::saveSettingsToJson(moduleSettings, "before_unit_test.json") + #ParallelLogger::saveSettingsToJson(testReplacedModuleSettings, "after_unit_test.json") + expect_equal(testReplacedModuleSettings$analysis$something[[1]]$covariateCohortDatabaseSchema, NULL) + expect_equal(testReplacedModuleSettings$analysis$something[[1]]$covariateCohortTable, NULL) + expect_equal(testReplacedModuleSettings$analysis$something[[2]]$covariateCohortDatabaseSchema, workDatabaseSchema) + expect_equal(testReplacedModuleSettings$analysis$something[[2]]$covariateCohortTable, cohortTableNames$cohortTable) + + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested1$covariateCohortDatabaseSchema, NULL) + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested1$covariateCohortTable, NULL) + + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested2$covariateCohortDatabaseSchema, workDatabaseSchema) + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested2$covariateCohortTable, cohortTableNames$cohortTable) + + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[1]]$covariateCohortDatabaseSchema, NULL) + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[1]]$covariateCohortTable, NULL) + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[2]]$covariateCohortDatabaseSchema, workDatabaseSchema) + expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[2]]$covariateCohortTable, cohortTableNames$cohortTable) + expect_equal(class(testReplacedModuleSettings$analysis$esSettings), class(moduleSettings$analysis$esSettings)) + + # Additional tests for the table name replacement function + test1 <- .replaceCovariateSettingsCohortTableNames(covariateSettings, executionSettings) + expect_equal(test1[[2]]$covariateCohortDatabaseSchema, workDatabaseSchema) + expect_equal(test1[[2]]$covariateCohortTable, cohortTableNames$cohortTable) + + test2 <- .replaceCovariateSettingsCohortTableNames(cov1, executionSettings) + expect_equal(test2$covariateCohortDatabaseSchema, NULL) + expect_equal(test2$covariateCohortTable, NULL) + + test3 <- .replaceCovariateSettingsCohortTableNames(cov2, executionSettings) + expect_equal(test3$covariateCohortDatabaseSchema, workDatabaseSchema) + expect_equal(test3$covariateCohortTable, cohortTableNames$cohortTable) +})