diff --git a/docs/_quarto.yml b/docs/_quarto.yml index d69673cc4..08ecac6b3 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -46,6 +46,7 @@ website: - get-started/column-selection.qmd - get-started/row-selection.qmd - get-started/nanoplots.qmd + - get-started/targeted-styles.qmd format: html: diff --git a/docs/get-started/targeted-styles.qmd b/docs/get-started/targeted-styles.qmd new file mode 100644 index 000000000..a3394ec75 --- /dev/null +++ b/docs/get-started/targeted-styles.qmd @@ -0,0 +1,131 @@ +--- +title: Targeted styles +jupyter: python3 +--- + +In [Styling the Table Body](./basic-styling), we discussed styling table data with `.tab_style()`. +In this article we'll cover how the same method can be used to style many other parts of the table, like the header, specific spanner labels, the footer, and more. + +:::{.callout-warning} +This feature is currently a work in progress, and not yet released. Great Tables must be installed from github in order to try it. +::: + + +## Kitchen sink + +Below is a big example that shows all possible `loc` specifiers being used. + +```{python} +from great_tables import GT, exibble, loc, style + +# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 +brewer_colors = [ + "#a6cee3", + "#1f78b4", + "#b2df8a", + "#33a02c", + "#fb9a99", + "#e31a1c", + "#fdbf6f", + "#ff7f00", + "#cab2d6", + "#6a3d9a", + "#ffff99", + "#b15928", +] + +c = iter(brewer_colors) + +gt = ( + GT(exibble.loc[[0, 1, 4], ["num", "char", "fctr", "row", "group"]]) + .tab_header("title", "subtitle") + .tab_stub(rowname_col="row", groupname_col="group") + .tab_source_note("yo") + .tab_spanner("spanner", ["char", "fctr"]) + .tab_stubhead("stubhead") +) + +( + gt.tab_style(style.fill(next(c)), loc.body()) + # Columns ----------- + # TODO: appears in browser, but not vs code + .tab_style(style.fill(next(c)), loc.column_labels(columns="num")) + .tab_style(style.fill(next(c)), loc.column_header()) + .tab_style(style.fill(next(c)), loc.spanner_labels(ids=["spanner"])) + # Header ----------- + .tab_style(style.fill(next(c)), loc.header()) + .tab_style(style.fill(next(c)), loc.subtitle()) + .tab_style(style.fill(next(c)), loc.title()) + # Footer ----------- + .tab_style(style.borders(weight="3px"), loc.source_notes()) + .tab_style(style.fill(next(c)), loc.footer()) + # Stub -------------- + .tab_style(style.fill(next(c)), loc.row_groups()) + .tab_style(style.borders(weight="3px"), loc.stub(rows=1)) + .tab_style(style.fill(next(c)), loc.stub()) + .tab_style(style.fill(next(c)), loc.stubhead()) +) +``` + +## Body + +```{python} +gt.tab_style(style.fill("yellow"), loc.body()) +``` + +## Column labels + +```{python} +( + gt + .tab_style(style.fill("yellow"), loc.column_header()) + .tab_style(style.fill("blue"), loc.column_labels(columns="num")) + .tab_style(style.fill("red"), loc.spanner_labels(ids=["spanner"])) +) + +``` + + + +## Header + +```{python} +( + gt.tab_style(style.fill("yellow"), loc.header()) + .tab_style(style.fill("blue"), loc.title()) + .tab_style(style.fill("red"), loc.subtitle()) +) +``` + +## Footer + +```{python} +( + gt.tab_style( + style.fill("yellow"), + loc.source_notes(), + ).tab_style( + style.borders(weight="3px"), + loc.footer(), + ) +) +``` + +## Stub + +```{python} +( + gt.tab_style(style.fill("yellow"), loc.stub()) + .tab_style(style.fill("blue"), loc.row_groups()) + .tab_style( + style.borders(style="dashed", weight="3px", color="red"), + loc.stub(rows=[1]), + ) +) +``` + +## Stubhead + +```{python} +gt.tab_style(style.fill("yellow"), loc.stubhead()) +``` diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index 58355a435..697c2404b 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field, replace from enum import Enum, auto -from typing import Any, Callable, Tuple, TypeVar, overload, TYPE_CHECKING +from typing import Any, Callable, Literal, Tuple, TypeVar, Union, overload, TYPE_CHECKING from typing_extensions import Self, TypeAlias @@ -28,6 +28,7 @@ if TYPE_CHECKING: from ._helpers import Md, Html, UnitStr, Text + from ._locations import Loc T = TypeVar("T") @@ -610,7 +611,7 @@ def order_groups(self, group_order: RowGroups): # TODO: validate return self.__class__(self.rows, self.group_rows.reorder(group_order)) - def group_indices_map(self) -> list[tuple[int, str | None]]: + def group_indices_map(self) -> list[tuple[int, GroupRowInfo | None]]: return self.group_rows.indices_map(len(self.rows)) def __iter__(self): @@ -740,7 +741,7 @@ def reorder(self, group_ids: list[str | MISSING_GROUP]) -> Self: return self.__class__(reordered) - def indices_map(self, n: int) -> list[tuple[int, str | None]]: + def indices_map(self, n: int) -> list[tuple[int, GroupRowInfo]]: """Return pairs of row index, group label for all rows in data. Note that when no groupings exist, n is used to return from range(n). @@ -751,7 +752,7 @@ def indices_map(self, n: int) -> list[tuple[int, str | None]]: if not len(self._d): return [(ii, None) for ii in range(n)] - return [(ind, info.defaulted_label()) for info in self for ind in info.indices] + return [(ind, info) for info in self for ind in info.indices] # Spanners ---- @@ -852,7 +853,7 @@ class FootnotePlacement(Enum): @dataclass(frozen=True) class FootnoteInfo: - locname: str | None = None + locname: Loc | None = None grpname: str | None = None colname: str | None = None locnum: int | None = None @@ -869,8 +870,7 @@ class FootnoteInfo: @dataclass(frozen=True) class StyleInfo: - locname: str - locnum: int + locname: Loc grpname: str | None = None colname: str | None = None rownum: int | None = None diff --git a/great_tables/_locations.py b/great_tables/_locations.py index 617bd7ecf..3274c2450 100644 --- a/great_tables/_locations.py +++ b/great_tables/_locations.py @@ -1,16 +1,23 @@ from __future__ import annotations import itertools -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, Union from typing_extensions import TypeAlias # note that types like Spanners are only used in annotations for concretes of the # resolve generic, but we need to import at runtime, due to singledispatch looking # up annotations -from ._gt_data import ColInfoTypeEnum, FootnoteInfo, FootnotePlacement, GTData, Spanners, StyleInfo +from ._gt_data import ( + ColInfoTypeEnum, + FootnoteInfo, + FootnotePlacement, + GTData, + Spanners, + StyleInfo, +) from ._styles import CellStyle from ._tbl_data import PlDataFrame, PlExpr, eval_select, eval_transform @@ -35,6 +42,7 @@ class CellPos: column: int row: int colname: str + rowname: str | None = None @dataclass @@ -43,42 +51,62 @@ class Loc: @dataclass -class LocTitle(Loc): +class LocHeader(Loc): """A location for targeting the table title and subtitle.""" - groups: Literal["title", "subtitle"] + +@dataclass +class LocTitle(Loc): + """A location for targeting the title.""" + + +@dataclass +class LocSubTitle(Loc): + """A location for targeting the subtitle.""" @dataclass class LocStubhead(Loc): - groups: Literal["stubhead"] = "stubhead" + """A location for targeting the table stubhead and stubhead label.""" @dataclass -class LocColumnSpanners(Loc): - """A location for column spanners.""" +class LocStubheadLabel(Loc): + """A location for targetting the stubhead.""" + - # TODO: these can also be tidy selectors - ids: list[str] +@dataclass +class LocColumnHeader(Loc): + """A location for column spanners and column labels.""" @dataclass class LocColumnLabels(Loc): - # TODO: these can be tidyselectors - columns: list[str] + columns: SelectExpr = None @dataclass -class LocRowGroups(Loc): - # TODO: these can be tidyselectors - groups: list[str] +class LocSpannerLabels(Loc): + """A location for column spanners.""" + + ids: SelectExpr = None @dataclass class LocStub(Loc): - # TODO: these can be tidyselectors - # TODO: can this take integers? - rows: list[str] + """A location for targeting the table stub, row group labels, summary labels, and body.""" + + rows: RowSelectExpr = None + + +@dataclass +class LocRowGroups(Loc): + rows: RowSelectExpr = None + + +@dataclass +class LocSummaryLabel(Loc): + rows: RowSelectExpr = None @dataclass @@ -108,6 +136,7 @@ class LocBody(Loc): ------ See [`GT.tab_style()`](`great_tables.GT.tab_style`). """ + columns: SelectExpr = None rows: RowSelectExpr = None @@ -115,41 +144,23 @@ class LocBody(Loc): @dataclass class LocSummary(Loc): # TODO: these can be tidyselectors - groups: list[str] - columns: list[str] - rows: list[str] - - -@dataclass -class LocGrandSummary(Loc): - # TODO: these can be tidyselectors - columns: list[str] - rows: list[str] - - -@dataclass -class LocStubSummary(Loc): - # TODO: these can be tidyselectors - groups: list[str] - rows: list[str] + columns: SelectExpr = None + rows: RowSelectExpr = None @dataclass -class LocStubGrandSummary(Loc): - rows: list[str] +class LocFooter(Loc): + """A location for targeting the footer.""" @dataclass class LocFootnotes(Loc): - groups: Literal["footnotes"] = "footnotes" + """A location for targeting footnotes.""" @dataclass class LocSourceNotes(Loc): - # This dataclass in R has a `groups` field, which is a literal value. - # In python, we can use an isinstance check to determine we're seeing an - # instance of this class - groups: Literal["source_notes"] = "source_notes" + """A location for targeting source notes.""" # Utils ================================================================================ @@ -289,17 +300,17 @@ def resolve_rows_i( expr: list[str | int] = [expr] if isinstance(data, GTData): - if expr is None: - if null_means == "everything": - return [(row.rowname, ii) for ii, row in enumerate(data._stub)] - else: - return [] - row_names = [row.rowname for row in data._stub] else: row_names = data - if isinstance(expr, list): + if expr is None: + if null_means == "everything": + return [(row.rowname, ii) for ii, row in enumerate(data._stub)] + else: + return [] + + elif isinstance(expr, list): # TODO: manually doing row selection here for now target_names = set(x for x in expr if isinstance(x, str)) target_pos = set( @@ -355,7 +366,7 @@ def resolve(loc: Loc, *args: Any, **kwargs: Any) -> Loc | list[CellPos]: @resolve.register -def _(loc: LocColumnSpanners, spanners: Spanners) -> LocColumnSpanners: +def _(loc: LocSpannerLabels, spanners: Spanners) -> LocSpannerLabels: # unique labels (with order preserved) spanner_ids = [span.spanner_id for span in spanners] @@ -363,7 +374,30 @@ def _(loc: LocColumnSpanners, spanners: Spanners) -> LocColumnSpanners: resolved_spanners = [spanner_ids[idx] for idx in resolved_spanners_idx] # Create a list object - return LocColumnSpanners(ids=resolved_spanners) + return LocSpannerLabels(ids=resolved_spanners) + + +@resolve.register +def _(loc: LocColumnLabels, data: GTData) -> list[CellPos]: + cols = resolve_cols_i(data=data, expr=loc.columns) + cell_pos = [CellPos(col[1], 0, colname=col[0]) for col in cols] + return cell_pos + + +@resolve.register +def _(loc: LocRowGroups, data: GTData) -> set[int]: + # TODO: what are the rules for matching row groups? + # TODO: resolve_rows_i will match a list expr to row names (not group names) + group_pos = set(pos for _, pos in resolve_rows_i(data, loc.rows)) + return list(group_pos) + + +@resolve.register +def _(loc: LocStub, data: GTData) -> set[int]: + # TODO: what are the rules for matching row groups? + rows = resolve_rows_i(data=data, expr=loc.rows) + cell_pos = set(row[1] for row in rows) + return cell_pos @resolve.register @@ -383,27 +417,114 @@ def _(loc: LocBody, data: GTData) -> list[CellPos]: # Style generic ======================================================================== +# LocHeader +# LocTitle +# LocSubTitle +# LocStubhead +# LocStubheadLabel +# LocColumnLabels +# LocColumnLabel +# LocSpannerLabel +# LocStub +# LocRowGroupLabel +# LocRowLabel +# LocSummaryLabel +# LocBody +# LocSummary +# LocFooter +# LocFootnotes +# LocSourceNotes + + @singledispatch def set_style(loc: Loc, data: GTData, style: list[str]) -> GTData: """Set style for location.""" raise NotImplementedError(f"Unsupported location type: {type(loc)}") +@set_style.register(LocHeader) +@set_style.register(LocTitle) +@set_style.register(LocSubTitle) +@set_style.register(LocStubhead) +@set_style.register(LocStubheadLabel) +@set_style.register(LocColumnHeader) +@set_style.register(LocFooter) +@set_style.register(LocSourceNotes) +def _( + loc: ( + LocHeader + | LocTitle + | LocSubTitle + | LocStubhead + | LocStubheadLabel + | LocColumnHeader + | LocFooter + | LocSourceNotes + ), + data: GTData, + style: list[CellStyle], +) -> GTData: + # validate ---- + for entry in style: + entry._raise_if_requires_data(loc) + + return data._replace(_styles=data._styles + [StyleInfo(locname=loc, styles=style)]) + + @set_style.register -def _(loc: LocTitle, data: GTData, style: list[CellStyle]) -> GTData: +def _(loc: LocColumnLabels, data: GTData, style: list[CellStyle]) -> GTData: + positions: list[CellPos] = resolve(loc, data) + + # evaluate any column expressions in styles + styles = [entry._evaluate_expressions(data._tbl_data) for entry in style] + + all_info: list[StyleInfo] = [] + for col_pos in positions: + crnt_info = StyleInfo( + locname=loc, + colname=col_pos.colname, + rownum=col_pos.row, + styles=styles, + ) + all_info.append(crnt_info) + return data._replace(_styles=data._styles + all_info) + + +@set_style.register +def _(loc: LocSpannerLabels, data: GTData, style: list[CellStyle]) -> GTData: # validate ---- for entry in style: entry._raise_if_requires_data(loc) + # TODO resolve - # set ---- - if loc.groups == "title": - info = StyleInfo(locname="title", locnum=1, styles=style) - elif loc.groups == "subtitle": - info = StyleInfo(locname="subtitle", locnum=2, styles=style) - else: - raise ValueError(f"Unknown title group: {loc.groups}") + new_loc = resolve(loc, data._spanners) + return data._replace( + _styles=data._styles + [StyleInfo(locname=new_loc, grpname=new_loc.ids, styles=style)] + ) - return data._styles.append(info) + +@set_style.register +def _(loc: LocRowGroups, data: GTData, style: list[CellStyle]) -> GTData: + # validate ---- + for entry in style: + entry._raise_if_requires_data(loc) + + row_groups = resolve(loc, data) + return data._replace( + _styles=data._styles + [StyleInfo(locname=loc, grpname=row_groups, styles=style)] + ) + + +@set_style.register +def _(loc: LocStub, data: GTData, style: list[CellStyle]) -> GTData: + # validate ---- + for entry in style: + entry._raise_if_requires_data(loc) + # TODO resolve + cells = resolve(loc, data) + + new_styles = [StyleInfo(locname=loc, rownum=rownum, styles=style) for rownum in cells] + return data._replace(_styles=data._styles + new_styles) @set_style.register @@ -417,7 +538,7 @@ def _(loc: LocBody, data: GTData, style: list[CellStyle]) -> GTData: for col_pos in positions: row_styles = [entry._from_row(data._tbl_data, col_pos.row) for entry in style_ready] crnt_info = StyleInfo( - locname="data", locnum=5, colname=col_pos.colname, rownum=col_pos.row, styles=row_styles + locname=loc, colname=col_pos.colname, rownum=col_pos.row, styles=row_styles ) all_info.append(crnt_info) @@ -436,21 +557,11 @@ def set_footnote(loc: Loc, data: GTData, footnote: str, placement: PlacementOpti @set_footnote.register(type(None)) def _(loc: None, data: GTData, footnote: str, placement: PlacementOptions) -> GTData: place = FootnotePlacement[placement] - info = FootnoteInfo(locname="none", locnum=0, footnotes=[footnote], placement=place) + info = FootnoteInfo(locname="none", footnotes=[footnote], placement=place) return data._replace(_footnotes=data._footnotes + [info]) @set_footnote.register def _(loc: LocTitle, data: GTData, footnote: str, placement: PlacementOptions) -> GTData: - # TODO: note that footnote here is annotated as a string, but I think that in R it - # can be a list of strings. - place = FootnotePlacement[placement] - if loc.groups == "title": - info = FootnoteInfo(locname="title", locnum=1, footnotes=[footnote], placement=place) - elif loc.groups == "subtitle": - info = FootnoteInfo(locname="subtitle", locnum=2, footnotes=[footnote], placement=place) - else: - raise ValueError(f"Unknown title group: {loc.groups}") - - return data._replace(_footnotes=data._footnotes + [info]) + raise NotImplementedError() diff --git a/great_tables/_modify_rows.py b/great_tables/_modify_rows.py index 013fe458e..72e553091 100644 --- a/great_tables/_modify_rows.py +++ b/great_tables/_modify_rows.py @@ -15,8 +15,12 @@ def row_group_order(self: GTSelf, groups: RowGroups) -> GTSelf: def _remove_from_body_styles(styles: Styles, column: str) -> Styles: + # TODO: refactor + from ._utils_render_html import _is_loc + from ._locations import LocBody + new_styles = [ - info for info in styles if not (info.locname == "data" and info.colname == column) + info for info in styles if not (_is_loc(info.locname, LocBody) and info.colname == column) ] return new_styles diff --git a/great_tables/_styles.py b/great_tables/_styles.py index c14cb1937..867c3f5e1 100644 --- a/great_tables/_styles.py +++ b/great_tables/_styles.py @@ -125,6 +125,14 @@ def _raise_if_requires_data(self, loc: Loc): ) +@dataclass +class CellStyleCss(CellStyle): + rule: str + + def _to_html_style(self): + return self.rule + + @dataclass class CellStyleText(CellStyle): """A style specification for cell text. diff --git a/great_tables/_utils_render_html.py b/great_tables/_utils_render_html.py index 4f8489a6a..4abd55ec7 100644 --- a/great_tables/_utils_render_html.py +++ b/great_tables/_utils_render_html.py @@ -1,19 +1,48 @@ from __future__ import annotations -from itertools import chain +from itertools import chain, groupby +from math import isnan from typing import Any, cast from great_tables._spanners import spanners_print_matrix from htmltools import HTML, TagList, css, tags -from ._gt_data import GTData +from ._gt_data import GTData, Styles, GroupRowInfo from ._tbl_data import _get_cell, cast_frame_to_string, n_rows, replace_null_frame from ._text import _process_text, _process_text_id from ._utils import heading_has_subtitle, heading_has_title, seq_groups +from . import _locations as loc -def create_heading_component_h(data: GTData) -> str: +def _is_loc(loc: str | loc.Loc, cls: type[loc.Loc]): + if isinstance(loc, str): + return loc == cls.groups + + return isinstance(loc, cls) + +def _flatten_styles(styles: Styles, wrap: bool = False) -> str: + # flatten all StyleInfo.styles lists + style_entries = list(chain(*[x.styles for x in styles])) + rendered_styles = [el._to_html_style() for el in style_entries] + + # TODO dedupe rendered styles in sequence + + if wrap: + if rendered_styles: + # return style html attribute + return f' style="{" ".join(rendered_styles)}"' + # if no rendered styles, just return a blank + return "" + if rendered_styles: + # return space-separated list of rendered styles + return " ".join(rendered_styles) + # if not wrapping the styles for html element, + # return None so htmltools omits a style attribute + return None + + +def create_heading_component_h(data: GTData) -> str: title = data._heading.title subtitle = data._heading.subtitle @@ -31,6 +60,13 @@ def create_heading_component_h(data: GTData) -> str: title = _process_text(title) subtitle = _process_text(subtitle) + # Filter list of StyleInfo for the various header components + styles_header = [x for x in data._styles if _is_loc(x.locname, loc.LocHeader)] + styles_title = [x for x in data._styles if _is_loc(x.locname, loc.LocTitle)] + styles_subtitle = [x for x in data._styles if _is_loc(x.locname, loc.LocSubTitle)] + title_style = _flatten_styles(styles_header + styles_title, wrap=True) + subtitle_style = _flatten_styles(styles_header + styles_subtitle, wrap=True) + # Get the effective number of columns, which is number of columns # that will finally be rendered accounting for the stub layout n_cols_total = data._boxhead._get_effective_number_of_columns( @@ -40,15 +76,15 @@ def create_heading_component_h(data: GTData) -> str: if has_subtitle: heading = f""" - {title} + {title} - {subtitle} + {subtitle} """ else: heading = f""" - {title} + {title} """ return heading @@ -67,8 +103,6 @@ def create_columns_component_h(data: GTData) -> str: # Get necessary data objects for composing the column labels and spanners stubh = data._stubhead - # TODO: skipping styles for now - # styles_tbl = dt_styles_get(data = data) boxhead = data._boxhead # TODO: The body component of the table is only needed for determining RTL alignment @@ -97,13 +131,11 @@ def create_columns_component_h(data: GTData) -> str: # Get the column headings headings_info = boxhead._get_default_columns() - # TODO: Skipping styles for now - # Get the style attrs for the stubhead label - # stubhead_style_attrs = subset(styles_tbl, locname == "stubhead") - # Get the style attrs for the spanner column headings - # spanner_style_attrs = subset(styles_tbl, locname == "columns_groups") - # Get the style attrs for the spanner column headings - # column_style_attrs = subset(styles_tbl, locname == "columns_columns") + # Filter list of StyleInfo for the various stubhead and column labels components + styles_stubhead = [x for x in data._styles if _is_loc(x.locname, loc.LocStubhead)] + styles_column_labels = [x for x in data._styles if _is_loc(x.locname, loc.LocColumnHeader)] + styles_spanner_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSpannerLabels)] + styles_column_label = [x for x in data._styles if _is_loc(x.locname, loc.LocColumnLabels)] # If columns are present in the stub, then replace with a set stubhead label or nothing if len(stub_layout) > 0 and stubh is not None: @@ -124,18 +156,13 @@ def create_columns_component_h(data: GTData) -> str: if spanner_row_count == 0: # Create the cell for the stubhead label if len(stub_layout) > 0: - stubhead_style = None - # FIXME: Ignore styles for now - # if stubhead_style_attrs is not None and len(stubhead_style_attrs) > 0: - # stubhead_style = stubhead_style_attrs[0].html_style - table_col_headings.append( tags.th( HTML(_process_text(stub_label)), class_=f"gt_col_heading gt_columns_bottom_border gt_{stubhead_label_alignment}", rowspan="1", colspan=len(stub_layout), - style=stubhead_style, + style=_flatten_styles(styles_stubhead), scope="colgroup" if len(stub_layout) > 1 else "col", id=_process_text_id(stub_label), ) @@ -143,13 +170,8 @@ def create_columns_component_h(data: GTData) -> str: # Create the headings in the case where there are no spanners at all ------------------------- for info in headings_info: - # NOTE: Ignore styles for now - # styles_column = subset(column_style_attrs, colnum == i) - # - # Convert the code above this comment from R to valid python - # if len(styles_column) > 0: - # column_style = styles_column[0].html_style - column_style = None + # Filter by column label / id, join with overall column labels style + styles_i = [x for x in styles_column_label if x.colname == info.var] table_col_headings.append( tags.th( @@ -157,16 +179,16 @@ def create_columns_component_h(data: GTData) -> str: class_=f"gt_col_heading gt_columns_bottom_border gt_{info.defaulted_align}", rowspan=1, colspan=1, - style=column_style, + style=_flatten_styles(styles_column_labels + styles_i), scope="col", id=_process_text_id(info.column_label), ) ) # Join the cells into a string and begin each with a newline - th_cells = "\n" + "\n".join([" " + str(tag) for tag in table_col_headings]) + "\n" + # th_cells = "\n" + "\n".join([" " + str(tag) for tag in table_col_headings]) + "\n" - table_col_headings = tags.tr(HTML(th_cells), class_="gt_col_headings") + table_col_headings = tags.tr(*table_col_headings, class_="gt_col_headings") # # Create the spanners and column labels in the case where there *are* spanners ------------- @@ -196,20 +218,13 @@ def create_columns_component_h(data: GTData) -> str: # Create the cell for the stubhead label if len(stub_layout) > 0: - # NOTE: Ignore styles for now - # if len(stubhead_style_attrs) > 0: - # stubhead_style = stubhead_style_attrs.html_style - # else: - # stubhead_style = None - stubhead_style = None - level_1_spanners.append( tags.th( HTML(_process_text(stub_label)), class_=f"gt_col_heading gt_columns_bottom_border gt_{str(stubhead_label_alignment)}", rowspan=2, colspan=len(stub_layout), - style=stubhead_style, + style=_flatten_styles(styles_stubhead), scope="colgroup" if len(stub_layout) > 1 else "col", id=_process_text_id(stub_label), ) @@ -229,14 +244,8 @@ def create_columns_component_h(data: GTData) -> str: for ii, (span_key, h_info) in enumerate(zip(spanner_col_names, headings_info)): if spanner_ids[level_1_index][span_key] is None: - # NOTE: Ignore styles for now - # styles_heading = filter( - # lambda x: x.get('locname') == "columns_columns" and x.get('colname') == headings_vars[i], - # styles_tbl if 'styles_tbl' in locals() else [] - # ) - # - # heading_style = next(styles_heading, {}).get('html_style', None) - heading_style = None + # Filter by column label / id, join with overall column labels style + styles_i = [x for x in styles_column_label if x.colname == h_info.var] # Get the alignment values for the first set of column labels first_set_alignment = h_info.defaulted_align @@ -248,7 +257,7 @@ def create_columns_component_h(data: GTData) -> str: class_=f"gt_col_heading gt_columns_bottom_border gt_{str(first_set_alignment)}", rowspan=2, colspan=1, - style=heading_style, + style=_flatten_styles(styles_column_labels + styles_i), scope="col", id=_process_text_id(h_info.column_label), ) @@ -258,21 +267,14 @@ def create_columns_component_h(data: GTData) -> str: # If colspans[i] == 0, it means that a previous cell's # `colspan` will cover us if colspans[ii] > 0: - # NOTE: Ignore styles for now - # FIXME: this needs to be rewritten - # styles_spanners = filter( - # spanner_style_attrs, - # locname == "columns_groups", - # grpname == spanner_ids[level_1_index, ][i] - # ) - # - # spanner_style = - # if (nrow(styles_spanners) > 0) { - # styles_spanners$html_style - # } else { - # NULL - # } - spanner_style = None + # Filter by column label / id, join with overall column labels style + # TODO check this filter logic + styles_i = [ + x + for x in styles_spanner_label + # TODO: refactor use of set + if set(x.grpname) & set([spanner_ids_level_1_index[ii]]) + ] level_1_spanners.append( tags.th( @@ -283,7 +285,7 @@ def create_columns_component_h(data: GTData) -> str: class_="gt_center gt_columns_top_border gt_column_spanner_outer", rowspan=1, colspan=colspans[ii], - style=spanner_style, + style=_flatten_styles(styles_column_labels + styles_i), scope="colgroup" if colspans[ii] > 1 else "col", id=_process_text_id(spanner_ids_level_1_index[ii]), ) @@ -301,18 +303,9 @@ def create_columns_component_h(data: GTData) -> str: spanned_column_labels = [] for j in range(len(remaining_headings)): - # Skip styles for now - # styles_remaining = styles_tbl[ - # (styles_tbl["locname"] == "columns_columns") & - # (styles_tbl["colname"] == remaining_headings[j]) - # ] - # - # remaining_style = ( - # styles_remaining["html_style"].values[0] - # if len(styles_remaining) > 0 - # else None - # ) - remaining_style = None + # Filter by column label / id, join with overall column labels style + # TODO check this filter logic + styles_i = [x for x in styles_column_label if x.colname == remaining_headings[j]] remaining_alignment = boxhead._get_boxhead_get_alignment_by_var( var=remaining_headings[j] @@ -324,7 +317,7 @@ def create_columns_component_h(data: GTData) -> str: class_=f"gt_col_heading gt_columns_bottom_border gt_{remaining_alignment}", rowspan=1, colspan=1, - style=remaining_style, + style=_flatten_styles(styles_column_labels + styles_i), scope="col", id=_process_text_id(remaining_headings_labels[j]), ) @@ -359,18 +352,14 @@ def create_columns_component_h(data: GTData) -> str: for colspan, span_label in zip(colspans, spanners_row.values()): if colspan > 0: - # Skip styles for now - # styles_spanners = styles_tbl[ - # (styles_tbl["locname"] == "columns_groups") & - # (styles_tbl["grpname"] in spanners_vars) - # ] - # - # spanner_style = ( - # styles_spanners["html_style"].values[0] - # if len(styles_spanners) > 0 - # else None - # ) - spanner_style = None + # Filter by column label / id, join with overall column labels style + # TODO check this filter logic + styles_i = [ + x + for x in styles_column_label + # TODO: refactor use of set + if set(x.grpname) & set([colspan, span_label]) + ] if span_label: span = tags.span( @@ -386,7 +375,7 @@ def create_columns_component_h(data: GTData) -> str: class_="gt_center gt_columns_bottom_border gt_columns_top_border gt_column_spanner_outer", rowspan=1, colspan=colspan, - style=spanner_style, + style=_flatten_styles(styles_column_labels + styles_i), scope="colgroup" if colspan > 1 else "col", ) ) @@ -400,6 +389,8 @@ def create_columns_component_h(data: GTData) -> str: rowspan=1, colspan=len(stub_layout), scope="colgroup" if len(stub_layout) > 1 else "col", + # TODO check if ok to just use base styling? + style=_flatten_styles(styles_column_labels), ), ) @@ -409,6 +400,8 @@ def create_columns_component_h(data: GTData) -> str: tags.tr( level_i_spanners, class_="gt_col_headings gt_spanner_row", + # TODO check if ok to just use base styling? + style=_flatten_styles(styles_column_labels), ) ), ) @@ -417,7 +410,7 @@ def create_columns_component_h(data: GTData) -> str: higher_spanner_rows, table_col_headings, ) - return str(table_col_headings) + return table_col_headings def create_body_component_h(data: GTData) -> str: @@ -426,8 +419,15 @@ def create_body_component_h(data: GTData) -> str: _str_orig_data = cast_frame_to_string(data._tbl_data) tbl_data = replace_null_frame(data._body.body, _str_orig_data) - # Filter list of StyleInfo to only those that apply to the body (where locname="data") - styles_body = [x for x in data._styles if x.locname == "data"] + # Filter list of StyleInfo to only those that apply to the stub + styles_row_group_label = [x for x in data._styles if _is_loc(x.locname, loc.LocRowGroups)] + styles_row_label = [x for x in data._styles if _is_loc(x.locname, loc.LocStub)] + styles_summary_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSummaryLabel)] + + # Filter list of StyleInfo to only those that apply to the body + styles_cells = [x for x in data._styles if _is_loc(x.locname, loc.LocBody)] + # styles_body = [x for x in data._styles if _is_loc(x.locname, loc.LocBody2)] + # styles_summary = [x for x in data._styles if _is_loc(x.locname, loc.LocSummary)] # Get the default column vars column_vars = data._boxhead._get_default_columns() @@ -453,11 +453,11 @@ def create_body_component_h(data: GTData) -> str: body_rows: list[str] = [] # iterate over rows (ordered by groupings) - prev_group_label = None + prev_group_info = None - ordered_index = data._stub.group_indices_map() + ordered_index: list[tuple[int, GroupRowInfo]] = data._stub.group_indices_map() - for i, group_label in ordered_index: + for i, group_info in ordered_index: # For table striping we want to add a striping CSS class to the even-numbered # rows in the rendered table; to target these rows, determine if `i` in the current @@ -466,27 +466,28 @@ def create_body_component_h(data: GTData) -> str: body_cells: list[str] = [] + # Create table row specifically for group (if applicable) if has_stub_column and has_groups and not has_two_col_stub: colspan_value = data._boxhead._get_effective_number_of_columns( stub=data._stub, options=data._options ) - # Generate a row that contains the row group label (this spans the entire row) but - # only if `i` indicates there should be a row group label - if group_label != prev_group_label: + # Only create if this is the first row of data within the group + if group_info is not prev_group_info: + group_label = group_info.defaulted_label() group_class = ( "gt_empty_group_heading" if group_label == "" else "gt_group_heading_row" ) + _styles = [style for style in styles_row_group_label if i in style.grpname] + group_styles = _flatten_styles(_styles, wrap=True) group_row = f""" - {group_label} + {group_label} """ - prev_group_label = group_label - body_rows.append(group_row) - # Create a single cell and append result to `body_cells` + # Create row cells for colinfo in column_vars: cell_content: Any = _get_cell(tbl_data, i, colinfo.var) cell_str: str = str(cell_content) @@ -502,17 +503,8 @@ def create_body_component_h(data: GTData) -> str: cell_alignment = colinfo.defaulted_align # Get the style attributes for the current cell by filtering the - # `styles_body` list for the current row and column - styles_i = [x for x in styles_body if x.rownum == i and x.colname == colinfo.var] - - # Develop the `style` attribute for the current cell - if len(styles_i) > 0: - # flatten all StyleInfo.styles lists - style_entries = list(chain(*[x.styles for x in styles_i])) - rendered_styles = [el._to_html_style() for el in style_entries] - cell_styles = f'style="{" ".join(rendered_styles)}"' + " " - else: - cell_styles = "" + # `styles_cells` list for the current row and column + _body_styles = [x for x in styles_cells if x.rownum == i and x.colname == colinfo.var] if is_stub_cell: @@ -520,6 +512,8 @@ def create_body_component_h(data: GTData) -> str: classes = ["gt_row", "gt_left", "gt_stub"] + _rowname_styles = [x for x in styles_row_label if x.rownum == i] + if table_stub_striped and odd_i_row: classes.append("gt_striped") @@ -529,17 +523,24 @@ def create_body_component_h(data: GTData) -> str: classes = ["gt_row", f"gt_{cell_alignment}"] + _rowname_styles = [] + if table_body_striped and odd_i_row: classes.append("gt_striped") # Ensure that `classes` becomes a space-separated string classes = " ".join(classes) + cell_styles = _flatten_styles( + _body_styles + _rowname_styles, + wrap=True, + ) body_cells.append( - f""" <{el_name} {cell_styles}class="{classes}">{cell_str}""" + f""" <{el_name}{cell_styles} class="{classes}">{cell_str}""" ) - prev_group_label = group_label + prev_group_info = group_info + body_rows.append(" \n" + "\n".join(body_cells) + "\n ") all_body_rows = "\n".join(body_rows) @@ -552,6 +553,10 @@ def create_body_component_h(data: GTData) -> str: def create_source_notes_component_h(data: GTData) -> str: source_notes = data._source_notes + # Filter list of StyleInfo to only those that apply to the source notes + styles_footer = [x for x in data._styles if _is_loc(x.locname, loc.LocFooter)] + styles_source_notes = [x for x in data._styles if _is_loc(x.locname, loc.LocSourceNotes)] + # If there are no source notes, then return an empty string if source_notes == []: return "" @@ -573,13 +578,14 @@ def create_source_notes_component_h(data: GTData) -> str: source_notes_tr: list[str] = [] + _styles = _flatten_styles(styles_footer + styles_source_notes, wrap=True) for note in source_notes: note_str = _process_text(note) source_notes_tr.append( f""" - {note_str} + {note_str} """ ) @@ -618,6 +624,9 @@ def create_source_notes_component_h(data: GTData) -> str: def create_footnotes_component_h(data: GTData): + # Filter list of StyleInfo to only those that apply to the footnotes + styles_footnotes = [x for x in data._styles if _is_loc(x.locname, loc.LocFootnotes)] + return "" diff --git a/great_tables/loc.py b/great_tables/loc.py index eec0149b7..e463ab132 100644 --- a/great_tables/loc.py +++ b/great_tables/loc.py @@ -1,9 +1,42 @@ from __future__ import annotations from ._locations import ( - LocBody as body, - LocStub as stub, + # Header ---- + LocHeader as header, + LocTitle as title, + LocSubTitle as subtitle, + # + # Stubhead ---- + LocStubhead as stubhead, + # + # Column Labels ---- + LocColumnHeader as column_header, + LocSpannerLabels as spanner_labels, LocColumnLabels as column_labels, + # + # Stub ---- + LocStub as stub, + LocRowGroups as row_groups, + # + # Body ---- + LocBody as body, + # + # Footer ---- + LocFooter as footer, + LocSourceNotes as source_notes, ) -__all__ = ("body", "stub", "column_labels") +__all__ = ( + "header", + "title", + "subtitle", + "stubhead", + "column_header", + "spanner_labels", + "column_labels", + "stub", + "row_groups", + "body", + "footer", + "source_notes", +) diff --git a/great_tables/style.py b/great_tables/style.py index e6b4c480e..7bd85d96e 100644 --- a/great_tables/style.py +++ b/great_tables/style.py @@ -4,6 +4,7 @@ CellStyleText as text, CellStyleFill as fill, CellStyleBorders as borders, + CellStyleCss as css, ) -__all__ = ("text", "fill", "borders") +__all__ = ("text", "fill", "borders", "css") diff --git a/tests/__snapshots__/test_utils_render_html.ambr b/tests/__snapshots__/test_utils_render_html.ambr index f671a3d3a..77598d1a8 100644 --- a/tests/__snapshots__/test_utils_render_html.ambr +++ b/tests/__snapshots__/test_utils_render_html.ambr @@ -35,6 +35,54 @@ ''' # --- +# name: test_loc_kitchen_sink + ''' + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
title
subtitle
stubheadnum + spanner +
charfctr
grp_a
row_10.1111apricotone
yo
+ + + + ''' +# --- # name: test_multiple_spanners_pads_for_stubhead_label ''' diff --git a/tests/test_gt_data.py b/tests/test_gt_data.py index ee89b3d70..9feea05d0 100644 --- a/tests/test_gt_data.py +++ b/tests/test_gt_data.py @@ -31,7 +31,8 @@ def test_stub_order_groups(): stub2 = stub.order_groups(["c", "a", "b"]) assert stub2.group_ids == ["c", "a", "b"] - assert stub2.group_indices_map() == [(3, "c"), (1, "a"), (0, "b"), (2, "b")] + indice_labels = [(ii, info.defaulted_label()) for ii, info in stub2.group_indices_map()] + assert indice_labels == [(3, "c"), (1, "a"), (0, "b"), (2, "b")] def test_boxhead_reorder(): diff --git a/tests/test_locations.py b/tests/test_locations.py index 55006b41e..cd10f7940 100644 --- a/tests/test_locations.py +++ b/tests/test_locations.py @@ -1,12 +1,13 @@ import pandas as pd import polars as pl +import polars.selectors as cs import pytest from great_tables import GT from great_tables._gt_data import Spanners from great_tables._locations import ( CellPos, LocBody, - LocColumnSpanners, + LocSpannerLabels, LocTitle, resolve, resolve_cols_i, @@ -116,6 +117,9 @@ def test_resolve_rows_i_raises(bad_expr): assert "a callable that takes a DataFrame and returns a boolean Series" in expected +# Resolve Loc tests -------------------------------------------------------------------------------- + + def test_resolve_loc_body(): gt = GT(pd.DataFrame({"x": [1, 2], "y": [3, 4]})) @@ -132,25 +136,41 @@ def test_resolve_loc_body(): assert pos.colname == "x" -def test_resolve_column_spanners_simple(): +@pytest.mark.xfail +def test_resolve_loc_spanners_label_single(): + spanners = Spanners.from_ids(["a", "b"]) + loc = LocSpannerLabels(ids="a") + + new_loc = resolve(loc, spanners) + + assert new_loc.ids == ["a"] + + +@pytest.mark.parametrize( + "expr", + [ + ["a", "c"], + pytest.param(cs.by_name("a", "c"), marks=pytest.mark.xfail), + ], +) +def test_resolve_loc_spanners_label(expr): # note that this essentially a no-op ids = ["a", "b", "c"] spanners = Spanners.from_ids(ids) - loc = LocColumnSpanners(ids=["a", "c"]) + loc = LocSpannerLabels(ids=expr) new_loc = resolve(loc, spanners) - assert new_loc == loc assert new_loc.ids == ["a", "c"] -def test_resolve_column_spanners_error_missing(): +def test_resolve_loc_spanner_label_error_missing(): # note that this essentially a no-op ids = ["a", "b", "c"] spanners = Spanners.from_ids(ids) - loc = LocColumnSpanners(ids=["a", "d"]) + loc = LocSpannerLabels(ids=["a", "d"]) with pytest.raises(ValueError): resolve(loc, spanners) @@ -190,7 +210,7 @@ def test_set_style_loc_body_from_column(expr): def test_set_style_loc_title_from_column_error(snapshot): df = pd.DataFrame({"x": [1, 2], "color": ["red", "blue"]}) gt_df = GT(df) - loc = LocTitle("title") + loc = LocTitle() style = CellStyleText(color=FromColumn("color")) with pytest.raises(TypeError) as exc_info: diff --git a/tests/test_utils_render_html.py b/tests/test_utils_render_html.py index da0241a75..0d84ee314 100644 --- a/tests/test_utils_render_html.py +++ b/tests/test_utils_render_html.py @@ -29,7 +29,7 @@ def assert_rendered_columns(snapshot, gt): built = gt._build_data("html") columns = create_columns_component_h(built) - assert snapshot == columns + assert snapshot == str(columns) def assert_rendered_body(snapshot, gt): @@ -191,3 +191,52 @@ def test_multiple_spanners_pads_for_stubhead_label(snapshot): ) assert_rendered_columns(snapshot, gt) + + +# Location style rendering ------------------------------------------------------------------------- +# these tests focus on location classes being correctly picked up +def test_loc_column_labels(): + gt = GT(pl.DataFrame({"x": [1], "y": [2]})) + + new_gt = gt.tab_style(style.fill("yellow"), loc.column_labels(columns=["x"])) + el = create_columns_component_h(new_gt._build_data("html")) + + assert el.name == "tr" + assert el.children[0].attrs["style"] == "background-color: yellow;" + assert "style" not in el.children[1].attrs + + +def test_loc_kitchen_sink(snapshot): + gt = ( + GT(exibble.loc[[0], ["num", "char", "fctr", "row", "group"]]) + .tab_header("title", "subtitle") + .tab_stub(rowname_col="row", groupname_col="group") + .tab_source_note("yo") + .tab_spanner("spanner", ["char", "fctr"]) + .tab_stubhead("stubhead") + ) + + new_gt = ( + gt.tab_style(style.css("BODY"), loc.body()) + # Columns ----------- + .tab_style(style.css("COLUMN_LABEL"), loc.column_labels(columns="num")) + .tab_style(style.css("COLUMN_HEADER"), loc.column_header()) + .tab_style(style.css("SPANNER_LABEL"), loc.spanner_labels(ids=["spanner"])) + # Header ----------- + .tab_style(style.css("HEADER"), loc.header()) + .tab_style(style.css("SUBTITLE"), loc.subtitle()) + .tab_style(style.css("TITLE"), loc.title()) + # Footer ----------- + .tab_style(style.css("FOOTER"), loc.footer()) + .tab_style(style.css("SOURCE_NOTES"), loc.source_notes()) + # .tab_style(style.css("AAA"), loc.footnotes()) + # Stub -------------- + .tab_style(style.css("GROUP_LABEL"), loc.row_groups()) + .tab_style(style.css("STUB"), loc.stub()) + .tab_style(style.css("ROW_LABEL"), loc.stub(rows=[0])) + .tab_style(style.css("STUBHEAD"), loc.stubhead()) + ) + + html = new_gt.as_raw_html() + cleaned = html[html.index("