Skip to content

Commit

Permalink
Merge pull request #134 from r-causal/style_pkg
Browse files Browse the repository at this point in the history
`styler::style_pkg()`
  • Loading branch information
malcolmbarrett authored Jan 29, 2024
2 parents 969ff03 + 85f8387 commit 7b5007a
Show file tree
Hide file tree
Showing 28 changed files with 1,536 additions and 641 deletions.
106 changes: 80 additions & 26 deletions R/StatsandGeoms.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
StatNodes <- ggplot2::ggproto("StatNodes", ggplot2::Stat,
StatNodes <- ggplot2::ggproto(
"StatNodes",
ggplot2::Stat,
compute_layer = function(data, scales, params) {
if (all(c("xend", "yend") %in% names(data))) {
unique(dplyr::select(data, -xend, -yend))
Expand All @@ -8,7 +10,9 @@ StatNodes <- ggplot2::ggproto("StatNodes", ggplot2::Stat,
}
)

StatNodesRepel <- ggplot2::ggproto("StatNodesRepel", ggplot2::Stat,
StatNodesRepel <- ggplot2::ggproto(
"StatNodesRepel",
ggplot2::Stat,
compute_layer = function(data, scales, params) {
if (all(c("xend", "yend") %in% names(data))) {
data <- unique(dplyr::select(data, -xend, -yend))
Expand All @@ -28,27 +32,41 @@ StatNodesRepel <- ggplot2::ggproto("StatNodesRepel", ggplot2::Stat,
}
)

GeomDagPoint <- ggplot2::ggproto("GeomDagPoint", ggplot2::GeomPoint,
GeomDagPoint <- ggplot2::ggproto(
"GeomDagPoint",
ggplot2::GeomPoint,
default_aes = ggplot2::aes(
shape = 19, colour = "black", size = 16, fill = NA,
alpha = NA, stroke = 0.5
shape = 19,
colour = "black",
size = 16,
fill = NA,
alpha = NA,
stroke = 0.5
)
)

GeomDagNode <- ggplot2::ggproto("GeomDagNode", ggplot2::Geom,
GeomDagNode <- ggplot2::ggproto(
"GeomDagNode",
ggplot2::Geom,
required_aes = c("x", "y"),
non_missing_aes = c("size", "shape", "colour", "internal_colour"),
default_aes = ggplot2::aes(
shape = 19, colour = "black", size = 16, fill = NA,
alpha = NA, stroke = 0.5, internal_colour = "white"
shape = 19,
colour = "black",
size = 16,
fill = NA,
alpha = NA,
stroke = 0.5,
internal_colour = "white"
),
draw_panel = function(data, panel_params, coord, na.rm = FALSE) {
coords <- coord$transform(data, panel_params)
grid::gList(
ggname(
"geom_dag_node",
grid::pointsGrob(
coords$x, coords$y,
coords$x,
coords$y,
pch = coords$shape,
gp = grid::gpar(
col = alpha(coords$colour, coords$alpha),
Expand All @@ -61,7 +79,8 @@ GeomDagNode <- ggplot2::ggproto("GeomDagNode", ggplot2::Geom,
ggname(
"geom_dag_node",
grid::pointsGrob(
coords$x, coords$y,
coords$x,
coords$y,
pch = coords$shape,
gp = grid::gpar(
col = alpha(coords$internal_colour, coords$alpha),
Expand All @@ -74,7 +93,8 @@ GeomDagNode <- ggplot2::ggproto("GeomDagNode", ggplot2::Geom,
ggname(
"geom_dag_node",
grid::pointsGrob(
coords$x, coords$y,
coords$x,
coords$y,
pch = coords$shape,
gp = grid::gpar(
col = alpha(coords$colour, coords$alpha),
Expand All @@ -89,12 +109,25 @@ GeomDagNode <- ggplot2::ggproto("GeomDagNode", ggplot2::Geom,
draw_key = ggplot2::draw_key_point
)

GeomDagText <- ggplot2::ggproto("GeomDagText", ggplot2::GeomText, default_aes = ggplot2::aes(
colour = "white", size = 4, angle = 0, hjust = 0.5,
vjust = 0.5, alpha = NA, family = "", fontface = "bold", lineheight = 1.2
))
GeomDagText <- ggplot2::ggproto(
"GeomDagText",
ggplot2::GeomText,
default_aes = ggplot2::aes(
colour = "white",
size = 4,
angle = 0,
hjust = 0.5,
vjust = 0.5,
alpha = NA,
family = "",
fontface = "bold",
lineheight = 1.2
)
)

StatEdgeLink <- ggplot2::ggproto("StatEdgeLink", ggraph::StatEdgeLink,
StatEdgeLink <- ggplot2::ggproto(
"StatEdgeLink",
ggraph::StatEdgeLink,
setup_data = function(data, params) {
data <- data[!is.na(data$xend), ]

Expand All @@ -107,7 +140,9 @@ StatEdgeLink <- ggplot2::ggproto("StatEdgeLink", ggraph::StatEdgeLink,
}
)

StatEdgeArc <- ggplot2::ggproto("StatEdgeArc", ggraph::StatEdgeArc,
StatEdgeArc <- ggplot2::ggproto(
"StatEdgeArc",
ggraph::StatEdgeArc,
setup_data = function(data, params) {
data <- data[!is.na(data$xend), ]
data[is.na(data$circular), "circular"] <- FALSE
Expand All @@ -122,7 +157,9 @@ StatEdgeArc <- ggplot2::ggproto("StatEdgeArc", ggraph::StatEdgeArc,
default_aes = ggplot2::aes(filter = TRUE)
)

StatEdgeDiagonal <- ggplot2::ggproto("StatEdgeDiagonal", ggraph::StatEdgeDiagonal,
StatEdgeDiagonal <- ggplot2::ggproto(
"StatEdgeDiagonal",
ggraph::StatEdgeDiagonal,
setup_data = function(data, params) {
data <- data[!is.na(data$xend), ]
data[is.na(data$circular), "circular"] <- FALSE
Expand All @@ -137,7 +174,9 @@ StatEdgeDiagonal <- ggplot2::ggproto("StatEdgeDiagonal", ggraph::StatEdgeDiagona
default_aes = ggplot2::aes(filter = TRUE)
)

StatEdgeFan <- ggplot2::ggproto("StatEdgeFan", ggraph::StatEdgeFan,
StatEdgeFan <- ggplot2::ggproto(
"StatEdgeFan",
ggraph::StatEdgeFan,
setup_data = function(data, params) {
data <- data[!is.na(data$xend), ]

Expand All @@ -156,7 +195,9 @@ StatEdgeFan <- ggplot2::ggproto("StatEdgeFan", ggraph::StatEdgeFan,
)


GeomDAGEdgePath <- ggplot2::ggproto("GeomDAGEdgePath", ggraph::GeomEdgePath,
GeomDAGEdgePath <- ggplot2::ggproto(
"GeomDAGEdgePath",
ggraph::GeomEdgePath,
setup_data = function(data, params) {
ggraph::GeomEdgePath$setup_data(data, params)
},
Expand All @@ -175,19 +216,32 @@ GeomDAGEdgePath <- ggplot2::ggproto("GeomDAGEdgePath", ggraph::GeomEdgePath,
non_missing_aes = c("direction", "direction_type"),
default_aes = ggplot2::aes(
linewidth = NA,
edge_colour = "black", edge_width = 0.6, edge_linetype = "solid",
edge_alpha = NA, start_cap = ggraph::circle(8, "mm"), end_cap = ggraph::circle(8, "mm"), label = NA,
label_pos = 0.5, label_size = 3.88, angle = 0, hjust = 0.5,
vjust = 0.5, family = "", fontface = 1,
lineheight = 1.2, direction = "->", direction_type = "->"
edge_colour = "black",
edge_width = 0.6,
edge_linetype = "solid",
edge_alpha = NA,
start_cap = ggraph::circle(8, "mm"),
end_cap = ggraph::circle(8, "mm"),
label = NA,
label_pos = 0.5,
label_size = 3.88,
angle = 0,
hjust = 0.5,
vjust = 0.5,
family = "",
fontface = 1,
lineheight = 1.2,
direction = "->",
direction_type = "->"
)
)


silence_scales <- function(plot) {
old_scales <- plot$scales
plot$scales <- ggproto(
"ScalesListQuiet", old_scales,
"ScalesListQuiet",
old_scales,
add = silent_add
)
plot
Expand Down
79 changes: 57 additions & 22 deletions R/adjustment_sets.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
#' @export
#'
#' @examples
#' dag <- dagify(y ~ x + z2 + w2 + w1,
#' dag <- dagify(
#' y ~ x + z2 + w2 + w1,
#' x ~ z1 + w1,
#' z1 ~ w1 + v,
#' z2 ~ w2 + v,
Expand All @@ -32,7 +33,8 @@
#'
#' ggdag_adjustment_set(dag)
#'
#' ggdag_adjustment_set(dagitty::randomDAG(10, .5),
#' ggdag_adjustment_set(
#' dagitty::randomDAG(10, .5),
#' exposure = "x3",
#' outcome = "x5"
#' )
Expand Down Expand Up @@ -70,24 +72,41 @@ extract_sets <- function(sets) {

#' @rdname adjustment_sets
#' @export
ggdag_adjustment_set <- function(.tdy_dag, exposure = NULL, outcome = NULL, ..., shadow = FALSE,
size = 1, node_size = 16, text_size = 3.88,
label_size = text_size,
text_col = "white", label_col = "black",
edge_width = 0.6, edge_cap = 8, arrow_length = 5,
use_edges = TRUE, use_nodes = TRUE, use_stylized = FALSE,
use_text = TRUE, use_labels = FALSE, label = NULL,
text = NULL, node = deprecated(), stylized = deprecated(),
expand_x = expansion(c(0.25, 0.25)),
expand_y = expansion(c(0.2, 0.2))) {
ggdag_adjustment_set <- function(
.tdy_dag,
exposure = NULL,
outcome = NULL,
...,
shadow = FALSE,
size = 1,
node_size = 16,
text_size = 3.88,
label_size = text_size,
text_col = "white",
label_col = "black",
edge_width = 0.6,
edge_cap = 8,
arrow_length = 5,
use_edges = TRUE,
use_nodes = TRUE,
use_stylized = FALSE,
use_text = TRUE,
use_labels = FALSE,
label = NULL,
text = NULL,
node = deprecated(),
stylized = deprecated(),
expand_x = expansion(c(0.25, 0.25)),
expand_y = expansion(c(0.2, 0.2))
) {
.tdy_dag <- if_not_tidy_daggity(.tdy_dag) %>%
dag_adjustment_sets(exposure = exposure, outcome = outcome, ...)

p <- ggplot2::ggplot(
.tdy_dag,
aes_dag(shape = adjusted, color = adjusted)
) +
ggplot2::facet_wrap(~ set) +
ggplot2::facet_wrap(~set) +
scale_adjusted() +
expand_plot(expand_x = expand_x, expand_y = expand_y)

Expand Down Expand Up @@ -206,15 +225,31 @@ adjust_for <- control_for

#' @rdname control_for
#' @export
ggdag_adjust <- function(.tdy_dag, var = NULL, ...,
size = 1, edge_type = c("link_arc", "link", "arc", "diagonal"),
node_size = 16, text_size = 3.88, label_size = text_size,
text_col = "white", label_col = "black",
edge_width = 0.6, edge_cap = 10, arrow_length = 5,
use_edges = TRUE,
use_nodes = TRUE, use_stylized = FALSE, use_text = TRUE,
use_labels = FALSE, text = NULL, label = NULL,
node = deprecated(), stylized = deprecated(), collider_lines = TRUE) {
ggdag_adjust <- function(
.tdy_dag,
var = NULL,
...,
size = 1,
edge_type = c("link_arc", "link", "arc", "diagonal"),
node_size = 16,
text_size = 3.88,
label_size = text_size,
text_col = "white",
label_col = "black",
edge_width = 0.6,
edge_cap = 10,
arrow_length = 5,
use_edges = TRUE,
use_nodes = TRUE,
use_stylized = FALSE,
use_text = TRUE,
use_labels = FALSE,
text = NULL,
label = NULL,
node = deprecated(),
stylized = deprecated(),
collider_lines = TRUE
) {
.tdy_dag <- if_not_tidy_daggity(.tdy_dag, ...)
if (!is.null(var)) {
.tdy_dag <- .tdy_dag %>% control_for(var)
Expand Down
40 changes: 30 additions & 10 deletions R/canonical.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,38 @@ node_canonical <- function(.dag, ...) {

#' @rdname canonicalize
#' @export
ggdag_canonical <- function(.tdy_dag, ..., edge_type = "link_arc", node_size = 16, text_size = 3.88,
label_size = text_size,
text_col = "white", label_col = text_col, use_edges = TRUE,
use_nodes = TRUE, use_stylized = FALSE, use_text = TRUE,
use_labels = NULL, label = NULL, text = NULL, node = deprecated(),
stylized = deprecated()) {
ggdag_canonical <- function(
.tdy_dag,
...,
edge_type = "link_arc",
node_size = 16,
text_size = 3.88,
label_size = text_size,
text_col = "white",
label_col = text_col,
use_edges = TRUE,
use_nodes = TRUE,
use_stylized = FALSE,
use_text = TRUE,
use_labels = NULL,
label = NULL,
text = NULL,
node = deprecated(),
stylized = deprecated()
) {
if_not_tidy_daggity(.tdy_dag, ...) %>%
node_canonical() %>%
ggdag(
node_size = node_size, text_size = text_size, label_size,
edge_type = edge_type, text_col = text_col, label_col = label_col,
use_edges = use_edges, use_nodes = use_nodes, use_stylized = use_stylized,
use_text = use_text, use_labels = use_labels
node_size = node_size,
text_size = text_size,
label_size,
edge_type = edge_type,
text_col = text_col,
label_col = label_col,
use_edges = use_edges,
use_nodes = use_nodes,
use_stylized = use_stylized,
use_text = use_text,
use_labels = use_labels
)
}
Loading

0 comments on commit 7b5007a

Please sign in to comment.