Skip to content

Commit

Permalink
Merge pull request #152 from r-causal/fix_guides
Browse files Browse the repository at this point in the history
Fix guides
  • Loading branch information
malcolmbarrett authored Mar 7, 2024
2 parents b53b23f + aa84adb commit 795fe76
Show file tree
Hide file tree
Showing 17 changed files with 234 additions and 285 deletions.
31 changes: 17 additions & 14 deletions R/drelationship.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ node_dconnected <- function(.tdy_dag, from = NULL, to = NULL, controlling_for =
if (!is.null(controlling_for)) {
.tdy_dag <- control_for(.tdy_dag, controlling_for)
} else {
.tdy_dag <- .tdy_dag %>%
dplyr::mutate(collider_line = FALSE, adjusted = "unadjusted")
controlling_for <- c()
}

Expand Down Expand Up @@ -169,16 +167,8 @@ node_drelationship <- function(.tdy_dag, from = NULL, to = NULL, controlling_for

if (!is.null(controlling_for)) {
.tdy_dag <- control_for(.tdy_dag, controlling_for)
} else {
.tdy_dag <- dplyr::mutate(
.tdy_dag,
collider_line = FALSE,
adjusted = "unadjusted"
)
controlling_for <- c()
}


.dseparated <- dagitty::dseparated(pull_dag(.tdy_dag), from, to, controlling_for)
.from <- from
.to <- to
Expand Down Expand Up @@ -232,11 +222,24 @@ ggdag_drelationship <- function(
stylized = deprecated(),
collider_lines = TRUE
) {
p <- if_not_tidy_daggity(.tdy_dag) %>%
node_drelationship(from = from, to = to, controlling_for = controlling_for, ...) %>%
ggplot2::ggplot(aes_dag(shape = adjusted, col = d_relationship))
df <- if_not_tidy_daggity(.tdy_dag) %>%
node_drelationship(
from = from,
to = to,
controlling_for = controlling_for,
...
)

has_adjusted <- "adjusted" %in% names(pull_dag_data(df))
if (has_adjusted) {
mapping <- aes_dag(shape = adjusted, color = d_relationship)
} else {
mapping <- aes_dag(color = d_relationship)
}

p <- ggplot2::ggplot(df, mapping)

if (collider_lines) p <- p + geom_dag_collider_edges()
if (has_adjusted && collider_lines) p <- p + geom_dag_collider_edges()

p <- p + geom_dag(
size = size,
Expand Down
37 changes: 14 additions & 23 deletions R/instrumental.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ node_instrumental <- function(.dag, exposure = NULL, outcome = NULL, ...) {
if (purrr::is_empty(i_vars)) {
.dag <- dplyr::mutate(
.dag,
adjusted = factor(
"unadjusted",
levels = c("unadjusted", "adjusted"),
exclude = NA
),
instrumental = NA
)
return(.dag)
Expand All @@ -58,14 +53,6 @@ node_instrumental <- function(.dag, exposure = NULL, outcome = NULL, ...) {
)
if (!is.null(.z)) {
.dag <- .dag %>% control_for(.z, activate_colliders = FALSE)
} else {
.dag <- .dag %>% dplyr::mutate(
adjusted = factor(
"unadjusted",
levels = c("unadjusted", "adjusted"),
exclude = NA
)
)
}
.dag <- .dag %>% dplyr::mutate(
instrumental = ifelse(name == .i, "instrumental", NA)
Expand Down Expand Up @@ -106,17 +93,21 @@ ggdag_instrumental <- function(
) {
.tdy_dag <- if_not_tidy_daggity(.tdy_dag) %>%
node_instrumental(exposure = exposure, outcome = outcome, ...)
has_instrumental <- !all(is.na((pull_dag_data(.tdy_dag)$instrumental)))
has_adjusted <- "adjusted" %in% names(pull_dag_data(.tdy_dag))
mapping <- aes_dag()
if (has_adjusted) {
mapping$shape <- substitute(adjusted)
}

if (all(is.na((pull_dag_data(.tdy_dag)$instrumental)))) {
mapping <- aes_dag(shape = adjusted)
} else {
mapping <- aes_dag(shape = adjusted, color = instrumental)
if (has_instrumental) {
mapping$colour <- substitute(instrumental)
}

p <- .tdy_dag %>%
ggplot2::ggplot(mapping) +
scale_adjusted() +
breaks("instrumental")
ggplot2::ggplot(mapping)
if (has_adjusted) p <- p + scale_adjusted()
if (has_instrumental) p <- p + breaks("instrumental")

p <- p +
geom_dag(
Expand All @@ -141,10 +132,10 @@ ggdag_instrumental <- function(
stylized = stylized
)

if (all(is.na(pull_dag_data(.tdy_dag)$instrumental))) {
p <- p + ggplot2::facet_wrap(~"{No instrumental variables present}")
} else {
if (has_instrumental) {
p <- p + ggplot2::facet_wrap(~instrumental_name)
} else {
p <- p + ggplot2::facet_wrap(~"{No instrumental variables present}")
}
p
}
8 changes: 4 additions & 4 deletions R/relations.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ ggdag_children <- function(
node_children(.var) %>%
ggplot2::ggplot(aes_dag(color = children)) +
scale_adjusted() +
breaks(c("parent", "child"))
breaks(c("parent", "child"), drop = FALSE)

p <- p + geom_dag(
size = size,
Expand Down Expand Up @@ -288,7 +288,7 @@ ggdag_parents <- function(
node_parents(.var) %>%
ggplot2::ggplot(aes_dag(color = parent)) +
scale_adjusted() +
breaks(c("parent", "child"))
breaks(c("parent", "child"), drop = FALSE)

p <- p + geom_dag(
size = size,
Expand Down Expand Up @@ -344,7 +344,7 @@ ggdag_ancestors <- function(
node_ancestors(.var) %>%
ggplot2::ggplot(aes_dag(color = ancestor)) +
scale_adjusted() +
breaks(c("ancestor", "descendant"))
breaks(c("ancestor", "descendant"), drop = FALSE)

p <- p + geom_dag(
size = size,
Expand Down Expand Up @@ -401,7 +401,7 @@ ggdag_descendants <- function(
node_descendants(.var) %>%
ggplot2::ggplot(aes_dag(color = descendant)) +
scale_adjusted() +
breaks(c("ancestor", "descendant"))
breaks(c("ancestor", "descendant"), drop = FALSE)

p <- p + geom_dag(
size = size,
Expand Down
22 changes: 16 additions & 6 deletions R/themes.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,27 @@ theme_dag_gray_grid <- theme_dag_grey_grid
scale_adjusted <- function(include_alpha = FALSE) {
list(
ggplot2::scale_linetype_manual(name = NULL, values = "dashed"),
ggplot2::scale_shape_manual(drop = FALSE, values = c("adjusted" = 15, "unadjusted" = 19), limits = c("adjusted", "unadjusted")),
ggplot2::scale_shape_manual(
values = c("adjusted" = 15, "unadjusted" = 19),
limits = c("adjusted", "unadjusted")
),
ggplot2::scale_color_discrete(limits = c("adjusted", "unadjusted")),
if (include_alpha) ggplot2::scale_alpha_manual(drop = FALSE, values = c("adjusted" = .30, "unadjusted" = 1), limits = c("adjusted", "unadjusted")),
if (include_alpha) ggraph::scale_edge_alpha_manual(name = NULL, drop = FALSE, values = c("adjusted" = .30, "unadjusted" = 1), limits = c("adjusted", "unadjusted"))
if (include_alpha) ggplot2::scale_alpha_manual(
values = c("adjusted" = .30, "unadjusted" = 1),
limits = c("adjusted", "unadjusted")
),
if (include_alpha) ggraph::scale_edge_alpha_manual(
name = NULL,
values = c("adjusted" = .30, "unadjusted" = 1),
limits = c("adjusted", "unadjusted")
)
)
}

breaks <- function(breaks = ggplot2::waiver(), name = ggplot2::waiver()) {
breaks <- function(breaks = ggplot2::waiver(), name = ggplot2::waiver(), drop = TRUE) {
list(
ggplot2::scale_color_discrete(name = name, drop = FALSE, breaks = breaks),
ggplot2::scale_fill_discrete(name = name, drop = FALSE, breaks = breaks)
ggplot2::scale_color_discrete(name = name, breaks = breaks, drop = drop),
ggplot2::scale_fill_discrete(name = name, breaks = breaks, drop = drop)
)
}

Expand Down
Loading

0 comments on commit 795fe76

Please sign in to comment.