diff --git a/docs/_quarto.yml b/docs/_quarto.yml
index d69673cc4..08ecac6b3 100644
--- a/docs/_quarto.yml
+++ b/docs/_quarto.yml
@@ -46,6 +46,7 @@ website:
- get-started/column-selection.qmd
- get-started/row-selection.qmd
- get-started/nanoplots.qmd
+ - get-started/targeted-styles.qmd
format:
html:
diff --git a/docs/get-started/targeted-styles.qmd b/docs/get-started/targeted-styles.qmd
new file mode 100644
index 000000000..a3394ec75
--- /dev/null
+++ b/docs/get-started/targeted-styles.qmd
@@ -0,0 +1,131 @@
+---
+title: Targeted styles
+jupyter: python3
+---
+
+In [Styling the Table Body](./basic-styling), we discussed styling table data with `.tab_style()`.
+In this article we'll cover how the same method can be used to style many other parts of the table, like the header, specific spanner labels, the footer, and more.
+
+:::{.callout-warning}
+This feature is currently a work in progress, and not yet released. Great Tables must be installed from github in order to try it.
+:::
+
+
+## Kitchen sink
+
+Below is a big example that shows all possible `loc` specifiers being used.
+
+```{python}
+from great_tables import GT, exibble, loc, style
+
+# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12
+brewer_colors = [
+ "#a6cee3",
+ "#1f78b4",
+ "#b2df8a",
+ "#33a02c",
+ "#fb9a99",
+ "#e31a1c",
+ "#fdbf6f",
+ "#ff7f00",
+ "#cab2d6",
+ "#6a3d9a",
+ "#ffff99",
+ "#b15928",
+]
+
+c = iter(brewer_colors)
+
+gt = (
+ GT(exibble.loc[[0, 1, 4], ["num", "char", "fctr", "row", "group"]])
+ .tab_header("title", "subtitle")
+ .tab_stub(rowname_col="row", groupname_col="group")
+ .tab_source_note("yo")
+ .tab_spanner("spanner", ["char", "fctr"])
+ .tab_stubhead("stubhead")
+)
+
+(
+ gt.tab_style(style.fill(next(c)), loc.body())
+ # Columns -----------
+ # TODO: appears in browser, but not vs code
+ .tab_style(style.fill(next(c)), loc.column_labels(columns="num"))
+ .tab_style(style.fill(next(c)), loc.column_header())
+ .tab_style(style.fill(next(c)), loc.spanner_labels(ids=["spanner"]))
+ # Header -----------
+ .tab_style(style.fill(next(c)), loc.header())
+ .tab_style(style.fill(next(c)), loc.subtitle())
+ .tab_style(style.fill(next(c)), loc.title())
+ # Footer -----------
+ .tab_style(style.borders(weight="3px"), loc.source_notes())
+ .tab_style(style.fill(next(c)), loc.footer())
+ # Stub --------------
+ .tab_style(style.fill(next(c)), loc.row_groups())
+ .tab_style(style.borders(weight="3px"), loc.stub(rows=1))
+ .tab_style(style.fill(next(c)), loc.stub())
+ .tab_style(style.fill(next(c)), loc.stubhead())
+)
+```
+
+## Body
+
+```{python}
+gt.tab_style(style.fill("yellow"), loc.body())
+```
+
+## Column labels
+
+```{python}
+(
+ gt
+ .tab_style(style.fill("yellow"), loc.column_header())
+ .tab_style(style.fill("blue"), loc.column_labels(columns="num"))
+ .tab_style(style.fill("red"), loc.spanner_labels(ids=["spanner"]))
+)
+
+```
+
+
+
+## Header
+
+```{python}
+(
+ gt.tab_style(style.fill("yellow"), loc.header())
+ .tab_style(style.fill("blue"), loc.title())
+ .tab_style(style.fill("red"), loc.subtitle())
+)
+```
+
+## Footer
+
+```{python}
+(
+ gt.tab_style(
+ style.fill("yellow"),
+ loc.source_notes(),
+ ).tab_style(
+ style.borders(weight="3px"),
+ loc.footer(),
+ )
+)
+```
+
+## Stub
+
+```{python}
+(
+ gt.tab_style(style.fill("yellow"), loc.stub())
+ .tab_style(style.fill("blue"), loc.row_groups())
+ .tab_style(
+ style.borders(style="dashed", weight="3px", color="red"),
+ loc.stub(rows=[1]),
+ )
+)
+```
+
+## Stubhead
+
+```{python}
+gt.tab_style(style.fill("yellow"), loc.stubhead())
+```
diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py
index 58355a435..697c2404b 100644
--- a/great_tables/_gt_data.py
+++ b/great_tables/_gt_data.py
@@ -5,7 +5,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field, replace
from enum import Enum, auto
-from typing import Any, Callable, Tuple, TypeVar, overload, TYPE_CHECKING
+from typing import Any, Callable, Literal, Tuple, TypeVar, Union, overload, TYPE_CHECKING
from typing_extensions import Self, TypeAlias
@@ -28,6 +28,7 @@
if TYPE_CHECKING:
from ._helpers import Md, Html, UnitStr, Text
+ from ._locations import Loc
T = TypeVar("T")
@@ -610,7 +611,7 @@ def order_groups(self, group_order: RowGroups):
# TODO: validate
return self.__class__(self.rows, self.group_rows.reorder(group_order))
- def group_indices_map(self) -> list[tuple[int, str | None]]:
+ def group_indices_map(self) -> list[tuple[int, GroupRowInfo | None]]:
return self.group_rows.indices_map(len(self.rows))
def __iter__(self):
@@ -740,7 +741,7 @@ def reorder(self, group_ids: list[str | MISSING_GROUP]) -> Self:
return self.__class__(reordered)
- def indices_map(self, n: int) -> list[tuple[int, str | None]]:
+ def indices_map(self, n: int) -> list[tuple[int, GroupRowInfo]]:
"""Return pairs of row index, group label for all rows in data.
Note that when no groupings exist, n is used to return from range(n).
@@ -751,7 +752,7 @@ def indices_map(self, n: int) -> list[tuple[int, str | None]]:
if not len(self._d):
return [(ii, None) for ii in range(n)]
- return [(ind, info.defaulted_label()) for info in self for ind in info.indices]
+ return [(ind, info) for info in self for ind in info.indices]
# Spanners ----
@@ -852,7 +853,7 @@ class FootnotePlacement(Enum):
@dataclass(frozen=True)
class FootnoteInfo:
- locname: str | None = None
+ locname: Loc | None = None
grpname: str | None = None
colname: str | None = None
locnum: int | None = None
@@ -869,8 +870,7 @@ class FootnoteInfo:
@dataclass(frozen=True)
class StyleInfo:
- locname: str
- locnum: int
+ locname: Loc
grpname: str | None = None
colname: str | None = None
rownum: int | None = None
diff --git a/great_tables/_locations.py b/great_tables/_locations.py
index 617bd7ecf..3274c2450 100644
--- a/great_tables/_locations.py
+++ b/great_tables/_locations.py
@@ -1,16 +1,23 @@
from __future__ import annotations
import itertools
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from functools import singledispatch
-from typing import TYPE_CHECKING, Any, Callable, Literal
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, Union
from typing_extensions import TypeAlias
# note that types like Spanners are only used in annotations for concretes of the
# resolve generic, but we need to import at runtime, due to singledispatch looking
# up annotations
-from ._gt_data import ColInfoTypeEnum, FootnoteInfo, FootnotePlacement, GTData, Spanners, StyleInfo
+from ._gt_data import (
+ ColInfoTypeEnum,
+ FootnoteInfo,
+ FootnotePlacement,
+ GTData,
+ Spanners,
+ StyleInfo,
+)
from ._styles import CellStyle
from ._tbl_data import PlDataFrame, PlExpr, eval_select, eval_transform
@@ -35,6 +42,7 @@ class CellPos:
column: int
row: int
colname: str
+ rowname: str | None = None
@dataclass
@@ -43,42 +51,62 @@ class Loc:
@dataclass
-class LocTitle(Loc):
+class LocHeader(Loc):
"""A location for targeting the table title and subtitle."""
- groups: Literal["title", "subtitle"]
+
+@dataclass
+class LocTitle(Loc):
+ """A location for targeting the title."""
+
+
+@dataclass
+class LocSubTitle(Loc):
+ """A location for targeting the subtitle."""
@dataclass
class LocStubhead(Loc):
- groups: Literal["stubhead"] = "stubhead"
+ """A location for targeting the table stubhead and stubhead label."""
@dataclass
-class LocColumnSpanners(Loc):
- """A location for column spanners."""
+class LocStubheadLabel(Loc):
+ """A location for targetting the stubhead."""
+
- # TODO: these can also be tidy selectors
- ids: list[str]
+@dataclass
+class LocColumnHeader(Loc):
+ """A location for column spanners and column labels."""
@dataclass
class LocColumnLabels(Loc):
- # TODO: these can be tidyselectors
- columns: list[str]
+ columns: SelectExpr = None
@dataclass
-class LocRowGroups(Loc):
- # TODO: these can be tidyselectors
- groups: list[str]
+class LocSpannerLabels(Loc):
+ """A location for column spanners."""
+
+ ids: SelectExpr = None
@dataclass
class LocStub(Loc):
- # TODO: these can be tidyselectors
- # TODO: can this take integers?
- rows: list[str]
+ """A location for targeting the table stub, row group labels, summary labels, and body."""
+
+ rows: RowSelectExpr = None
+
+
+@dataclass
+class LocRowGroups(Loc):
+ rows: RowSelectExpr = None
+
+
+@dataclass
+class LocSummaryLabel(Loc):
+ rows: RowSelectExpr = None
@dataclass
@@ -108,6 +136,7 @@ class LocBody(Loc):
------
See [`GT.tab_style()`](`great_tables.GT.tab_style`).
"""
+
columns: SelectExpr = None
rows: RowSelectExpr = None
@@ -115,41 +144,23 @@ class LocBody(Loc):
@dataclass
class LocSummary(Loc):
# TODO: these can be tidyselectors
- groups: list[str]
- columns: list[str]
- rows: list[str]
-
-
-@dataclass
-class LocGrandSummary(Loc):
- # TODO: these can be tidyselectors
- columns: list[str]
- rows: list[str]
-
-
-@dataclass
-class LocStubSummary(Loc):
- # TODO: these can be tidyselectors
- groups: list[str]
- rows: list[str]
+ columns: SelectExpr = None
+ rows: RowSelectExpr = None
@dataclass
-class LocStubGrandSummary(Loc):
- rows: list[str]
+class LocFooter(Loc):
+ """A location for targeting the footer."""
@dataclass
class LocFootnotes(Loc):
- groups: Literal["footnotes"] = "footnotes"
+ """A location for targeting footnotes."""
@dataclass
class LocSourceNotes(Loc):
- # This dataclass in R has a `groups` field, which is a literal value.
- # In python, we can use an isinstance check to determine we're seeing an
- # instance of this class
- groups: Literal["source_notes"] = "source_notes"
+ """A location for targeting source notes."""
# Utils ================================================================================
@@ -289,17 +300,17 @@ def resolve_rows_i(
expr: list[str | int] = [expr]
if isinstance(data, GTData):
- if expr is None:
- if null_means == "everything":
- return [(row.rowname, ii) for ii, row in enumerate(data._stub)]
- else:
- return []
-
row_names = [row.rowname for row in data._stub]
else:
row_names = data
- if isinstance(expr, list):
+ if expr is None:
+ if null_means == "everything":
+ return [(row.rowname, ii) for ii, row in enumerate(data._stub)]
+ else:
+ return []
+
+ elif isinstance(expr, list):
# TODO: manually doing row selection here for now
target_names = set(x for x in expr if isinstance(x, str))
target_pos = set(
@@ -355,7 +366,7 @@ def resolve(loc: Loc, *args: Any, **kwargs: Any) -> Loc | list[CellPos]:
@resolve.register
-def _(loc: LocColumnSpanners, spanners: Spanners) -> LocColumnSpanners:
+def _(loc: LocSpannerLabels, spanners: Spanners) -> LocSpannerLabels:
# unique labels (with order preserved)
spanner_ids = [span.spanner_id for span in spanners]
@@ -363,7 +374,30 @@ def _(loc: LocColumnSpanners, spanners: Spanners) -> LocColumnSpanners:
resolved_spanners = [spanner_ids[idx] for idx in resolved_spanners_idx]
# Create a list object
- return LocColumnSpanners(ids=resolved_spanners)
+ return LocSpannerLabels(ids=resolved_spanners)
+
+
+@resolve.register
+def _(loc: LocColumnLabels, data: GTData) -> list[CellPos]:
+ cols = resolve_cols_i(data=data, expr=loc.columns)
+ cell_pos = [CellPos(col[1], 0, colname=col[0]) for col in cols]
+ return cell_pos
+
+
+@resolve.register
+def _(loc: LocRowGroups, data: GTData) -> set[int]:
+ # TODO: what are the rules for matching row groups?
+ # TODO: resolve_rows_i will match a list expr to row names (not group names)
+ group_pos = set(pos for _, pos in resolve_rows_i(data, loc.rows))
+ return list(group_pos)
+
+
+@resolve.register
+def _(loc: LocStub, data: GTData) -> set[int]:
+ # TODO: what are the rules for matching row groups?
+ rows = resolve_rows_i(data=data, expr=loc.rows)
+ cell_pos = set(row[1] for row in rows)
+ return cell_pos
@resolve.register
@@ -383,27 +417,114 @@ def _(loc: LocBody, data: GTData) -> list[CellPos]:
# Style generic ========================================================================
+# LocHeader
+# LocTitle
+# LocSubTitle
+# LocStubhead
+# LocStubheadLabel
+# LocColumnLabels
+# LocColumnLabel
+# LocSpannerLabel
+# LocStub
+# LocRowGroupLabel
+# LocRowLabel
+# LocSummaryLabel
+# LocBody
+# LocSummary
+# LocFooter
+# LocFootnotes
+# LocSourceNotes
+
+
@singledispatch
def set_style(loc: Loc, data: GTData, style: list[str]) -> GTData:
"""Set style for location."""
raise NotImplementedError(f"Unsupported location type: {type(loc)}")
+@set_style.register(LocHeader)
+@set_style.register(LocTitle)
+@set_style.register(LocSubTitle)
+@set_style.register(LocStubhead)
+@set_style.register(LocStubheadLabel)
+@set_style.register(LocColumnHeader)
+@set_style.register(LocFooter)
+@set_style.register(LocSourceNotes)
+def _(
+ loc: (
+ LocHeader
+ | LocTitle
+ | LocSubTitle
+ | LocStubhead
+ | LocStubheadLabel
+ | LocColumnHeader
+ | LocFooter
+ | LocSourceNotes
+ ),
+ data: GTData,
+ style: list[CellStyle],
+) -> GTData:
+ # validate ----
+ for entry in style:
+ entry._raise_if_requires_data(loc)
+
+ return data._replace(_styles=data._styles + [StyleInfo(locname=loc, styles=style)])
+
+
@set_style.register
-def _(loc: LocTitle, data: GTData, style: list[CellStyle]) -> GTData:
+def _(loc: LocColumnLabels, data: GTData, style: list[CellStyle]) -> GTData:
+ positions: list[CellPos] = resolve(loc, data)
+
+ # evaluate any column expressions in styles
+ styles = [entry._evaluate_expressions(data._tbl_data) for entry in style]
+
+ all_info: list[StyleInfo] = []
+ for col_pos in positions:
+ crnt_info = StyleInfo(
+ locname=loc,
+ colname=col_pos.colname,
+ rownum=col_pos.row,
+ styles=styles,
+ )
+ all_info.append(crnt_info)
+ return data._replace(_styles=data._styles + all_info)
+
+
+@set_style.register
+def _(loc: LocSpannerLabels, data: GTData, style: list[CellStyle]) -> GTData:
# validate ----
for entry in style:
entry._raise_if_requires_data(loc)
+ # TODO resolve
- # set ----
- if loc.groups == "title":
- info = StyleInfo(locname="title", locnum=1, styles=style)
- elif loc.groups == "subtitle":
- info = StyleInfo(locname="subtitle", locnum=2, styles=style)
- else:
- raise ValueError(f"Unknown title group: {loc.groups}")
+ new_loc = resolve(loc, data._spanners)
+ return data._replace(
+ _styles=data._styles + [StyleInfo(locname=new_loc, grpname=new_loc.ids, styles=style)]
+ )
- return data._styles.append(info)
+
+@set_style.register
+def _(loc: LocRowGroups, data: GTData, style: list[CellStyle]) -> GTData:
+ # validate ----
+ for entry in style:
+ entry._raise_if_requires_data(loc)
+
+ row_groups = resolve(loc, data)
+ return data._replace(
+ _styles=data._styles + [StyleInfo(locname=loc, grpname=row_groups, styles=style)]
+ )
+
+
+@set_style.register
+def _(loc: LocStub, data: GTData, style: list[CellStyle]) -> GTData:
+ # validate ----
+ for entry in style:
+ entry._raise_if_requires_data(loc)
+ # TODO resolve
+ cells = resolve(loc, data)
+
+ new_styles = [StyleInfo(locname=loc, rownum=rownum, styles=style) for rownum in cells]
+ return data._replace(_styles=data._styles + new_styles)
@set_style.register
@@ -417,7 +538,7 @@ def _(loc: LocBody, data: GTData, style: list[CellStyle]) -> GTData:
for col_pos in positions:
row_styles = [entry._from_row(data._tbl_data, col_pos.row) for entry in style_ready]
crnt_info = StyleInfo(
- locname="data", locnum=5, colname=col_pos.colname, rownum=col_pos.row, styles=row_styles
+ locname=loc, colname=col_pos.colname, rownum=col_pos.row, styles=row_styles
)
all_info.append(crnt_info)
@@ -436,21 +557,11 @@ def set_footnote(loc: Loc, data: GTData, footnote: str, placement: PlacementOpti
@set_footnote.register(type(None))
def _(loc: None, data: GTData, footnote: str, placement: PlacementOptions) -> GTData:
place = FootnotePlacement[placement]
- info = FootnoteInfo(locname="none", locnum=0, footnotes=[footnote], placement=place)
+ info = FootnoteInfo(locname="none", footnotes=[footnote], placement=place)
return data._replace(_footnotes=data._footnotes + [info])
@set_footnote.register
def _(loc: LocTitle, data: GTData, footnote: str, placement: PlacementOptions) -> GTData:
- # TODO: note that footnote here is annotated as a string, but I think that in R it
- # can be a list of strings.
- place = FootnotePlacement[placement]
- if loc.groups == "title":
- info = FootnoteInfo(locname="title", locnum=1, footnotes=[footnote], placement=place)
- elif loc.groups == "subtitle":
- info = FootnoteInfo(locname="subtitle", locnum=2, footnotes=[footnote], placement=place)
- else:
- raise ValueError(f"Unknown title group: {loc.groups}")
-
- return data._replace(_footnotes=data._footnotes + [info])
+ raise NotImplementedError()
diff --git a/great_tables/_modify_rows.py b/great_tables/_modify_rows.py
index 013fe458e..72e553091 100644
--- a/great_tables/_modify_rows.py
+++ b/great_tables/_modify_rows.py
@@ -15,8 +15,12 @@ def row_group_order(self: GTSelf, groups: RowGroups) -> GTSelf:
def _remove_from_body_styles(styles: Styles, column: str) -> Styles:
+ # TODO: refactor
+ from ._utils_render_html import _is_loc
+ from ._locations import LocBody
+
new_styles = [
- info for info in styles if not (info.locname == "data" and info.colname == column)
+ info for info in styles if not (_is_loc(info.locname, LocBody) and info.colname == column)
]
return new_styles
diff --git a/great_tables/_styles.py b/great_tables/_styles.py
index c14cb1937..867c3f5e1 100644
--- a/great_tables/_styles.py
+++ b/great_tables/_styles.py
@@ -125,6 +125,14 @@ def _raise_if_requires_data(self, loc: Loc):
)
+@dataclass
+class CellStyleCss(CellStyle):
+ rule: str
+
+ def _to_html_style(self):
+ return self.rule
+
+
@dataclass
class CellStyleText(CellStyle):
"""A style specification for cell text.
diff --git a/great_tables/_utils_render_html.py b/great_tables/_utils_render_html.py
index 4f8489a6a..4abd55ec7 100644
--- a/great_tables/_utils_render_html.py
+++ b/great_tables/_utils_render_html.py
@@ -1,19 +1,48 @@
from __future__ import annotations
-from itertools import chain
+from itertools import chain, groupby
+from math import isnan
from typing import Any, cast
from great_tables._spanners import spanners_print_matrix
from htmltools import HTML, TagList, css, tags
-from ._gt_data import GTData
+from ._gt_data import GTData, Styles, GroupRowInfo
from ._tbl_data import _get_cell, cast_frame_to_string, n_rows, replace_null_frame
from ._text import _process_text, _process_text_id
from ._utils import heading_has_subtitle, heading_has_title, seq_groups
+from . import _locations as loc
-def create_heading_component_h(data: GTData) -> str:
+def _is_loc(loc: str | loc.Loc, cls: type[loc.Loc]):
+ if isinstance(loc, str):
+ return loc == cls.groups
+
+ return isinstance(loc, cls)
+
+def _flatten_styles(styles: Styles, wrap: bool = False) -> str:
+ # flatten all StyleInfo.styles lists
+ style_entries = list(chain(*[x.styles for x in styles]))
+ rendered_styles = [el._to_html_style() for el in style_entries]
+
+ # TODO dedupe rendered styles in sequence
+
+ if wrap:
+ if rendered_styles:
+ # return style html attribute
+ return f' style="{" ".join(rendered_styles)}"'
+ # if no rendered styles, just return a blank
+ return ""
+ if rendered_styles:
+ # return space-separated list of rendered styles
+ return " ".join(rendered_styles)
+ # if not wrapping the styles for html element,
+ # return None so htmltools omits a style attribute
+ return None
+
+
+def create_heading_component_h(data: GTData) -> str:
title = data._heading.title
subtitle = data._heading.subtitle
@@ -31,6 +60,13 @@ def create_heading_component_h(data: GTData) -> str:
title = _process_text(title)
subtitle = _process_text(subtitle)
+ # Filter list of StyleInfo for the various header components
+ styles_header = [x for x in data._styles if _is_loc(x.locname, loc.LocHeader)]
+ styles_title = [x for x in data._styles if _is_loc(x.locname, loc.LocTitle)]
+ styles_subtitle = [x for x in data._styles if _is_loc(x.locname, loc.LocSubTitle)]
+ title_style = _flatten_styles(styles_header + styles_title, wrap=True)
+ subtitle_style = _flatten_styles(styles_header + styles_subtitle, wrap=True)
+
# Get the effective number of columns, which is number of columns
# that will finally be rendered accounting for the stub layout
n_cols_total = data._boxhead._get_effective_number_of_columns(
@@ -40,15 +76,15 @@ def create_heading_component_h(data: GTData) -> str:
if has_subtitle:
heading = f"""
- {title} |
+ {title} |
- {subtitle} |
+ {subtitle} |
"""
else:
heading = f"""
- {title} |
+ {title} |
"""
return heading
@@ -67,8 +103,6 @@ def create_columns_component_h(data: GTData) -> str:
# Get necessary data objects for composing the column labels and spanners
stubh = data._stubhead
- # TODO: skipping styles for now
- # styles_tbl = dt_styles_get(data = data)
boxhead = data._boxhead
# TODO: The body component of the table is only needed for determining RTL alignment
@@ -97,13 +131,11 @@ def create_columns_component_h(data: GTData) -> str:
# Get the column headings
headings_info = boxhead._get_default_columns()
- # TODO: Skipping styles for now
- # Get the style attrs for the stubhead label
- # stubhead_style_attrs = subset(styles_tbl, locname == "stubhead")
- # Get the style attrs for the spanner column headings
- # spanner_style_attrs = subset(styles_tbl, locname == "columns_groups")
- # Get the style attrs for the spanner column headings
- # column_style_attrs = subset(styles_tbl, locname == "columns_columns")
+ # Filter list of StyleInfo for the various stubhead and column labels components
+ styles_stubhead = [x for x in data._styles if _is_loc(x.locname, loc.LocStubhead)]
+ styles_column_labels = [x for x in data._styles if _is_loc(x.locname, loc.LocColumnHeader)]
+ styles_spanner_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSpannerLabels)]
+ styles_column_label = [x for x in data._styles if _is_loc(x.locname, loc.LocColumnLabels)]
# If columns are present in the stub, then replace with a set stubhead label or nothing
if len(stub_layout) > 0 and stubh is not None:
@@ -124,18 +156,13 @@ def create_columns_component_h(data: GTData) -> str:
if spanner_row_count == 0:
# Create the cell for the stubhead label
if len(stub_layout) > 0:
- stubhead_style = None
- # FIXME: Ignore styles for now
- # if stubhead_style_attrs is not None and len(stubhead_style_attrs) > 0:
- # stubhead_style = stubhead_style_attrs[0].html_style
-
table_col_headings.append(
tags.th(
HTML(_process_text(stub_label)),
class_=f"gt_col_heading gt_columns_bottom_border gt_{stubhead_label_alignment}",
rowspan="1",
colspan=len(stub_layout),
- style=stubhead_style,
+ style=_flatten_styles(styles_stubhead),
scope="colgroup" if len(stub_layout) > 1 else "col",
id=_process_text_id(stub_label),
)
@@ -143,13 +170,8 @@ def create_columns_component_h(data: GTData) -> str:
# Create the headings in the case where there are no spanners at all -------------------------
for info in headings_info:
- # NOTE: Ignore styles for now
- # styles_column = subset(column_style_attrs, colnum == i)
- #
- # Convert the code above this comment from R to valid python
- # if len(styles_column) > 0:
- # column_style = styles_column[0].html_style
- column_style = None
+ # Filter by column label / id, join with overall column labels style
+ styles_i = [x for x in styles_column_label if x.colname == info.var]
table_col_headings.append(
tags.th(
@@ -157,16 +179,16 @@ def create_columns_component_h(data: GTData) -> str:
class_=f"gt_col_heading gt_columns_bottom_border gt_{info.defaulted_align}",
rowspan=1,
colspan=1,
- style=column_style,
+ style=_flatten_styles(styles_column_labels + styles_i),
scope="col",
id=_process_text_id(info.column_label),
)
)
# Join the cells into a string and begin each with a newline
- th_cells = "\n" + "\n".join([" " + str(tag) for tag in table_col_headings]) + "\n"
+ # th_cells = "\n" + "\n".join([" " + str(tag) for tag in table_col_headings]) + "\n"
- table_col_headings = tags.tr(HTML(th_cells), class_="gt_col_headings")
+ table_col_headings = tags.tr(*table_col_headings, class_="gt_col_headings")
#
# Create the spanners and column labels in the case where there *are* spanners -------------
@@ -196,20 +218,13 @@ def create_columns_component_h(data: GTData) -> str:
# Create the cell for the stubhead label
if len(stub_layout) > 0:
- # NOTE: Ignore styles for now
- # if len(stubhead_style_attrs) > 0:
- # stubhead_style = stubhead_style_attrs.html_style
- # else:
- # stubhead_style = None
- stubhead_style = None
-
level_1_spanners.append(
tags.th(
HTML(_process_text(stub_label)),
class_=f"gt_col_heading gt_columns_bottom_border gt_{str(stubhead_label_alignment)}",
rowspan=2,
colspan=len(stub_layout),
- style=stubhead_style,
+ style=_flatten_styles(styles_stubhead),
scope="colgroup" if len(stub_layout) > 1 else "col",
id=_process_text_id(stub_label),
)
@@ -229,14 +244,8 @@ def create_columns_component_h(data: GTData) -> str:
for ii, (span_key, h_info) in enumerate(zip(spanner_col_names, headings_info)):
if spanner_ids[level_1_index][span_key] is None:
- # NOTE: Ignore styles for now
- # styles_heading = filter(
- # lambda x: x.get('locname') == "columns_columns" and x.get('colname') == headings_vars[i],
- # styles_tbl if 'styles_tbl' in locals() else []
- # )
- #
- # heading_style = next(styles_heading, {}).get('html_style', None)
- heading_style = None
+ # Filter by column label / id, join with overall column labels style
+ styles_i = [x for x in styles_column_label if x.colname == h_info.var]
# Get the alignment values for the first set of column labels
first_set_alignment = h_info.defaulted_align
@@ -248,7 +257,7 @@ def create_columns_component_h(data: GTData) -> str:
class_=f"gt_col_heading gt_columns_bottom_border gt_{str(first_set_alignment)}",
rowspan=2,
colspan=1,
- style=heading_style,
+ style=_flatten_styles(styles_column_labels + styles_i),
scope="col",
id=_process_text_id(h_info.column_label),
)
@@ -258,21 +267,14 @@ def create_columns_component_h(data: GTData) -> str:
# If colspans[i] == 0, it means that a previous cell's
# `colspan` will cover us
if colspans[ii] > 0:
- # NOTE: Ignore styles for now
- # FIXME: this needs to be rewritten
- # styles_spanners = filter(
- # spanner_style_attrs,
- # locname == "columns_groups",
- # grpname == spanner_ids[level_1_index, ][i]
- # )
- #
- # spanner_style =
- # if (nrow(styles_spanners) > 0) {
- # styles_spanners$html_style
- # } else {
- # NULL
- # }
- spanner_style = None
+ # Filter by column label / id, join with overall column labels style
+ # TODO check this filter logic
+ styles_i = [
+ x
+ for x in styles_spanner_label
+ # TODO: refactor use of set
+ if set(x.grpname) & set([spanner_ids_level_1_index[ii]])
+ ]
level_1_spanners.append(
tags.th(
@@ -283,7 +285,7 @@ def create_columns_component_h(data: GTData) -> str:
class_="gt_center gt_columns_top_border gt_column_spanner_outer",
rowspan=1,
colspan=colspans[ii],
- style=spanner_style,
+ style=_flatten_styles(styles_column_labels + styles_i),
scope="colgroup" if colspans[ii] > 1 else "col",
id=_process_text_id(spanner_ids_level_1_index[ii]),
)
@@ -301,18 +303,9 @@ def create_columns_component_h(data: GTData) -> str:
spanned_column_labels = []
for j in range(len(remaining_headings)):
- # Skip styles for now
- # styles_remaining = styles_tbl[
- # (styles_tbl["locname"] == "columns_columns") &
- # (styles_tbl["colname"] == remaining_headings[j])
- # ]
- #
- # remaining_style = (
- # styles_remaining["html_style"].values[0]
- # if len(styles_remaining) > 0
- # else None
- # )
- remaining_style = None
+ # Filter by column label / id, join with overall column labels style
+ # TODO check this filter logic
+ styles_i = [x for x in styles_column_label if x.colname == remaining_headings[j]]
remaining_alignment = boxhead._get_boxhead_get_alignment_by_var(
var=remaining_headings[j]
@@ -324,7 +317,7 @@ def create_columns_component_h(data: GTData) -> str:
class_=f"gt_col_heading gt_columns_bottom_border gt_{remaining_alignment}",
rowspan=1,
colspan=1,
- style=remaining_style,
+ style=_flatten_styles(styles_column_labels + styles_i),
scope="col",
id=_process_text_id(remaining_headings_labels[j]),
)
@@ -359,18 +352,14 @@ def create_columns_component_h(data: GTData) -> str:
for colspan, span_label in zip(colspans, spanners_row.values()):
if colspan > 0:
- # Skip styles for now
- # styles_spanners = styles_tbl[
- # (styles_tbl["locname"] == "columns_groups") &
- # (styles_tbl["grpname"] in spanners_vars)
- # ]
- #
- # spanner_style = (
- # styles_spanners["html_style"].values[0]
- # if len(styles_spanners) > 0
- # else None
- # )
- spanner_style = None
+ # Filter by column label / id, join with overall column labels style
+ # TODO check this filter logic
+ styles_i = [
+ x
+ for x in styles_column_label
+ # TODO: refactor use of set
+ if set(x.grpname) & set([colspan, span_label])
+ ]
if span_label:
span = tags.span(
@@ -386,7 +375,7 @@ def create_columns_component_h(data: GTData) -> str:
class_="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer",
rowspan=1,
colspan=colspan,
- style=spanner_style,
+ style=_flatten_styles(styles_column_labels + styles_i),
scope="colgroup" if colspan > 1 else "col",
)
)
@@ -400,6 +389,8 @@ def create_columns_component_h(data: GTData) -> str:
rowspan=1,
colspan=len(stub_layout),
scope="colgroup" if len(stub_layout) > 1 else "col",
+ # TODO check if ok to just use base styling?
+ style=_flatten_styles(styles_column_labels),
),
)
@@ -409,6 +400,8 @@ def create_columns_component_h(data: GTData) -> str:
tags.tr(
level_i_spanners,
class_="gt_col_headings gt_spanner_row",
+ # TODO check if ok to just use base styling?
+ style=_flatten_styles(styles_column_labels),
)
),
)
@@ -417,7 +410,7 @@ def create_columns_component_h(data: GTData) -> str:
higher_spanner_rows,
table_col_headings,
)
- return str(table_col_headings)
+ return table_col_headings
def create_body_component_h(data: GTData) -> str:
@@ -426,8 +419,15 @@ def create_body_component_h(data: GTData) -> str:
_str_orig_data = cast_frame_to_string(data._tbl_data)
tbl_data = replace_null_frame(data._body.body, _str_orig_data)
- # Filter list of StyleInfo to only those that apply to the body (where locname="data")
- styles_body = [x for x in data._styles if x.locname == "data"]
+ # Filter list of StyleInfo to only those that apply to the stub
+ styles_row_group_label = [x for x in data._styles if _is_loc(x.locname, loc.LocRowGroups)]
+ styles_row_label = [x for x in data._styles if _is_loc(x.locname, loc.LocStub)]
+ styles_summary_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSummaryLabel)]
+
+ # Filter list of StyleInfo to only those that apply to the body
+ styles_cells = [x for x in data._styles if _is_loc(x.locname, loc.LocBody)]
+ # styles_body = [x for x in data._styles if _is_loc(x.locname, loc.LocBody2)]
+ # styles_summary = [x for x in data._styles if _is_loc(x.locname, loc.LocSummary)]
# Get the default column vars
column_vars = data._boxhead._get_default_columns()
@@ -453,11 +453,11 @@ def create_body_component_h(data: GTData) -> str:
body_rows: list[str] = []
# iterate over rows (ordered by groupings)
- prev_group_label = None
+ prev_group_info = None
- ordered_index = data._stub.group_indices_map()
+ ordered_index: list[tuple[int, GroupRowInfo]] = data._stub.group_indices_map()
- for i, group_label in ordered_index:
+ for i, group_info in ordered_index:
# For table striping we want to add a striping CSS class to the even-numbered
# rows in the rendered table; to target these rows, determine if `i` in the current
@@ -466,27 +466,28 @@ def create_body_component_h(data: GTData) -> str:
body_cells: list[str] = []
+ # Create table row specifically for group (if applicable)
if has_stub_column and has_groups and not has_two_col_stub:
colspan_value = data._boxhead._get_effective_number_of_columns(
stub=data._stub, options=data._options
)
- # Generate a row that contains the row group label (this spans the entire row) but
- # only if `i` indicates there should be a row group label
- if group_label != prev_group_label:
+ # Only create if this is the first row of data within the group
+ if group_info is not prev_group_info:
+ group_label = group_info.defaulted_label()
group_class = (
"gt_empty_group_heading" if group_label == "" else "gt_group_heading_row"
)
+ _styles = [style for style in styles_row_group_label if i in style.grpname]
+ group_styles = _flatten_styles(_styles, wrap=True)
group_row = f""" |
- {group_label} |
+ {group_label} |
"""
- prev_group_label = group_label
-
body_rows.append(group_row)
- # Create a single cell and append result to `body_cells`
+ # Create row cells
for colinfo in column_vars:
cell_content: Any = _get_cell(tbl_data, i, colinfo.var)
cell_str: str = str(cell_content)
@@ -502,17 +503,8 @@ def create_body_component_h(data: GTData) -> str:
cell_alignment = colinfo.defaulted_align
# Get the style attributes for the current cell by filtering the
- # `styles_body` list for the current row and column
- styles_i = [x for x in styles_body if x.rownum == i and x.colname == colinfo.var]
-
- # Develop the `style` attribute for the current cell
- if len(styles_i) > 0:
- # flatten all StyleInfo.styles lists
- style_entries = list(chain(*[x.styles for x in styles_i]))
- rendered_styles = [el._to_html_style() for el in style_entries]
- cell_styles = f'style="{" ".join(rendered_styles)}"' + " "
- else:
- cell_styles = ""
+ # `styles_cells` list for the current row and column
+ _body_styles = [x for x in styles_cells if x.rownum == i and x.colname == colinfo.var]
if is_stub_cell:
@@ -520,6 +512,8 @@ def create_body_component_h(data: GTData) -> str:
classes = ["gt_row", "gt_left", "gt_stub"]
+ _rowname_styles = [x for x in styles_row_label if x.rownum == i]
+
if table_stub_striped and odd_i_row:
classes.append("gt_striped")
@@ -529,17 +523,24 @@ def create_body_component_h(data: GTData) -> str:
classes = ["gt_row", f"gt_{cell_alignment}"]
+ _rowname_styles = []
+
if table_body_striped and odd_i_row:
classes.append("gt_striped")
# Ensure that `classes` becomes a space-separated string
classes = " ".join(classes)
+ cell_styles = _flatten_styles(
+ _body_styles + _rowname_styles,
+ wrap=True,
+ )
body_cells.append(
- f""" <{el_name} {cell_styles}class="{classes}">{cell_str}{el_name}>"""
+ f""" <{el_name}{cell_styles} class="{classes}">{cell_str}{el_name}>"""
)
- prev_group_label = group_label
+ prev_group_info = group_info
+
body_rows.append(" \n" + "\n".join(body_cells) + "\n
")
all_body_rows = "\n".join(body_rows)
@@ -552,6 +553,10 @@ def create_body_component_h(data: GTData) -> str:
def create_source_notes_component_h(data: GTData) -> str:
source_notes = data._source_notes
+ # Filter list of StyleInfo to only those that apply to the source notes
+ styles_footer = [x for x in data._styles if _is_loc(x.locname, loc.LocFooter)]
+ styles_source_notes = [x for x in data._styles if _is_loc(x.locname, loc.LocSourceNotes)]
+
# If there are no source notes, then return an empty string
if source_notes == []:
return ""
@@ -573,13 +578,14 @@ def create_source_notes_component_h(data: GTData) -> str:
source_notes_tr: list[str] = []
+ _styles = _flatten_styles(styles_footer + styles_source_notes, wrap=True)
for note in source_notes:
note_str = _process_text(note)
source_notes_tr.append(
f"""
- {note_str} |
+ {note_str} |
"""
)
@@ -618,6 +624,9 @@ def create_source_notes_component_h(data: GTData) -> str:
def create_footnotes_component_h(data: GTData):
+ # Filter list of StyleInfo to only those that apply to the footnotes
+ styles_footnotes = [x for x in data._styles if _is_loc(x.locname, loc.LocFootnotes)]
+
return ""
diff --git a/great_tables/loc.py b/great_tables/loc.py
index eec0149b7..e463ab132 100644
--- a/great_tables/loc.py
+++ b/great_tables/loc.py
@@ -1,9 +1,42 @@
from __future__ import annotations
from ._locations import (
- LocBody as body,
- LocStub as stub,
+ # Header ----
+ LocHeader as header,
+ LocTitle as title,
+ LocSubTitle as subtitle,
+ #
+ # Stubhead ----
+ LocStubhead as stubhead,
+ #
+ # Column Labels ----
+ LocColumnHeader as column_header,
+ LocSpannerLabels as spanner_labels,
LocColumnLabels as column_labels,
+ #
+ # Stub ----
+ LocStub as stub,
+ LocRowGroups as row_groups,
+ #
+ # Body ----
+ LocBody as body,
+ #
+ # Footer ----
+ LocFooter as footer,
+ LocSourceNotes as source_notes,
)
-__all__ = ("body", "stub", "column_labels")
+__all__ = (
+ "header",
+ "title",
+ "subtitle",
+ "stubhead",
+ "column_header",
+ "spanner_labels",
+ "column_labels",
+ "stub",
+ "row_groups",
+ "body",
+ "footer",
+ "source_notes",
+)
diff --git a/great_tables/style.py b/great_tables/style.py
index e6b4c480e..7bd85d96e 100644
--- a/great_tables/style.py
+++ b/great_tables/style.py
@@ -4,6 +4,7 @@
CellStyleText as text,
CellStyleFill as fill,
CellStyleBorders as borders,
+ CellStyleCss as css,
)
-__all__ = ("text", "fill", "borders")
+__all__ = ("text", "fill", "borders", "css")
diff --git a/tests/__snapshots__/test_utils_render_html.ambr b/tests/__snapshots__/test_utils_render_html.ambr
index f671a3d3a..77598d1a8 100644
--- a/tests/__snapshots__/test_utils_render_html.ambr
+++ b/tests/__snapshots__/test_utils_render_html.ambr
@@ -35,6 +35,54 @@
'''
# ---
+# name: test_loc_kitchen_sink
+ '''
+
+
+
+
+ title |
+
+
+ subtitle |
+
+
+ stubhead |
+ num |
+
+ spanner
+ |
+
+
+ char |
+ fctr |
+
+
+
+
+ grp_a |
+
+
+ row_1 |
+ 0.1111 |
+ apricot |
+ one |
+
+
+
+
+
+ yo |
+
+
+
+
+
+
+
+
+ '''
+# ---
# name: test_multiple_spanners_pads_for_stubhead_label
'''
diff --git a/tests/test_gt_data.py b/tests/test_gt_data.py
index ee89b3d70..9feea05d0 100644
--- a/tests/test_gt_data.py
+++ b/tests/test_gt_data.py
@@ -31,7 +31,8 @@ def test_stub_order_groups():
stub2 = stub.order_groups(["c", "a", "b"])
assert stub2.group_ids == ["c", "a", "b"]
- assert stub2.group_indices_map() == [(3, "c"), (1, "a"), (0, "b"), (2, "b")]
+ indice_labels = [(ii, info.defaulted_label()) for ii, info in stub2.group_indices_map()]
+ assert indice_labels == [(3, "c"), (1, "a"), (0, "b"), (2, "b")]
def test_boxhead_reorder():
diff --git a/tests/test_locations.py b/tests/test_locations.py
index 55006b41e..cd10f7940 100644
--- a/tests/test_locations.py
+++ b/tests/test_locations.py
@@ -1,12 +1,13 @@
import pandas as pd
import polars as pl
+import polars.selectors as cs
import pytest
from great_tables import GT
from great_tables._gt_data import Spanners
from great_tables._locations import (
CellPos,
LocBody,
- LocColumnSpanners,
+ LocSpannerLabels,
LocTitle,
resolve,
resolve_cols_i,
@@ -116,6 +117,9 @@ def test_resolve_rows_i_raises(bad_expr):
assert "a callable that takes a DataFrame and returns a boolean Series" in expected
+# Resolve Loc tests --------------------------------------------------------------------------------
+
+
def test_resolve_loc_body():
gt = GT(pd.DataFrame({"x": [1, 2], "y": [3, 4]}))
@@ -132,25 +136,41 @@ def test_resolve_loc_body():
assert pos.colname == "x"
-def test_resolve_column_spanners_simple():
+@pytest.mark.xfail
+def test_resolve_loc_spanners_label_single():
+ spanners = Spanners.from_ids(["a", "b"])
+ loc = LocSpannerLabels(ids="a")
+
+ new_loc = resolve(loc, spanners)
+
+ assert new_loc.ids == ["a"]
+
+
+@pytest.mark.parametrize(
+ "expr",
+ [
+ ["a", "c"],
+ pytest.param(cs.by_name("a", "c"), marks=pytest.mark.xfail),
+ ],
+)
+def test_resolve_loc_spanners_label(expr):
# note that this essentially a no-op
ids = ["a", "b", "c"]
spanners = Spanners.from_ids(ids)
- loc = LocColumnSpanners(ids=["a", "c"])
+ loc = LocSpannerLabels(ids=expr)
new_loc = resolve(loc, spanners)
- assert new_loc == loc
assert new_loc.ids == ["a", "c"]
-def test_resolve_column_spanners_error_missing():
+def test_resolve_loc_spanner_label_error_missing():
# note that this essentially a no-op
ids = ["a", "b", "c"]
spanners = Spanners.from_ids(ids)
- loc = LocColumnSpanners(ids=["a", "d"])
+ loc = LocSpannerLabels(ids=["a", "d"])
with pytest.raises(ValueError):
resolve(loc, spanners)
@@ -190,7 +210,7 @@ def test_set_style_loc_body_from_column(expr):
def test_set_style_loc_title_from_column_error(snapshot):
df = pd.DataFrame({"x": [1, 2], "color": ["red", "blue"]})
gt_df = GT(df)
- loc = LocTitle("title")
+ loc = LocTitle()
style = CellStyleText(color=FromColumn("color"))
with pytest.raises(TypeError) as exc_info:
diff --git a/tests/test_utils_render_html.py b/tests/test_utils_render_html.py
index da0241a75..0d84ee314 100644
--- a/tests/test_utils_render_html.py
+++ b/tests/test_utils_render_html.py
@@ -29,7 +29,7 @@ def assert_rendered_columns(snapshot, gt):
built = gt._build_data("html")
columns = create_columns_component_h(built)
- assert snapshot == columns
+ assert snapshot == str(columns)
def assert_rendered_body(snapshot, gt):
@@ -191,3 +191,52 @@ def test_multiple_spanners_pads_for_stubhead_label(snapshot):
)
assert_rendered_columns(snapshot, gt)
+
+
+# Location style rendering -------------------------------------------------------------------------
+# these tests focus on location classes being correctly picked up
+def test_loc_column_labels():
+ gt = GT(pl.DataFrame({"x": [1], "y": [2]}))
+
+ new_gt = gt.tab_style(style.fill("yellow"), loc.column_labels(columns=["x"]))
+ el = create_columns_component_h(new_gt._build_data("html"))
+
+ assert el.name == "tr"
+ assert el.children[0].attrs["style"] == "background-color: yellow;"
+ assert "style" not in el.children[1].attrs
+
+
+def test_loc_kitchen_sink(snapshot):
+ gt = (
+ GT(exibble.loc[[0], ["num", "char", "fctr", "row", "group"]])
+ .tab_header("title", "subtitle")
+ .tab_stub(rowname_col="row", groupname_col="group")
+ .tab_source_note("yo")
+ .tab_spanner("spanner", ["char", "fctr"])
+ .tab_stubhead("stubhead")
+ )
+
+ new_gt = (
+ gt.tab_style(style.css("BODY"), loc.body())
+ # Columns -----------
+ .tab_style(style.css("COLUMN_LABEL"), loc.column_labels(columns="num"))
+ .tab_style(style.css("COLUMN_HEADER"), loc.column_header())
+ .tab_style(style.css("SPANNER_LABEL"), loc.spanner_labels(ids=["spanner"]))
+ # Header -----------
+ .tab_style(style.css("HEADER"), loc.header())
+ .tab_style(style.css("SUBTITLE"), loc.subtitle())
+ .tab_style(style.css("TITLE"), loc.title())
+ # Footer -----------
+ .tab_style(style.css("FOOTER"), loc.footer())
+ .tab_style(style.css("SOURCE_NOTES"), loc.source_notes())
+ # .tab_style(style.css("AAA"), loc.footnotes())
+ # Stub --------------
+ .tab_style(style.css("GROUP_LABEL"), loc.row_groups())
+ .tab_style(style.css("STUB"), loc.stub())
+ .tab_style(style.css("ROW_LABEL"), loc.stub(rows=[0]))
+ .tab_style(style.css("STUBHEAD"), loc.stubhead())
+ )
+
+ html = new_gt.as_raw_html()
+ cleaned = html[html.index("