Skip to content

Commit

Permalink
feat(D3 plugin): improve categories using (#273)
Browse files Browse the repository at this point in the history
* feat(D3 plugin): improve categories using

* fix: fix prepareCategoricalScatterData

* fix: review fixes
  • Loading branch information
korvin89 authored Sep 6, 2023
1 parent 1dd8522 commit eabf291
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 77 deletions.
7 changes: 3 additions & 4 deletions src/plugins/d3/__stories__/bar-x/category.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ const Template: Story = () => {
visible: true,
data: [
{
category: 'A',
label: 10,
x: 'A',
y: 100,
},
{
category: 'B',
label: 12,
x: 'B',
y: 80,
},
],
Expand All @@ -39,8 +39,7 @@ const Template: Story = () => {
visible: true,
data: [
{
category: 'C',
x: 95.5,
x: 'C',
y: 120,
},
],
Expand Down
25 changes: 18 additions & 7 deletions src/plugins/d3/__stories__/scatter/LinearCategories.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,23 @@ const shapeScatterSeriesData = (args: {data: Record<string, any>[]; groupBy: str
acc[seriesName] = [];
}

acc[seriesName].push({
x: d[map.x],
y: d[map.y],
radius: random(3, 6),
...(map.category && {category: d[map.category]}),
});
const categoriesType = map.categoriesType as 'x' | 'y' | 'none' | undefined;
const isCategorical = categoriesType === 'x' || categoriesType === 'y';

if (isCategorical && map.category) {
acc[seriesName].push({
x: d[map.x],
y: d[map.y],
radius: random(3, 6),
[map.categoriesType]: d[map.category],
});
} else if (!isCategorical) {
acc[seriesName].push({
x: d[map.x],
y: d[map.y],
radius: random(3, 6),
});
}

return acc;
}, {});
Expand Down Expand Up @@ -133,7 +144,7 @@ const Template: Story = () => {
const shapedScatterSeriesData = shapeScatterSeriesData({
data: penguins,
groupBy,
map: {x, y, category},
map: {x, y, category, categoriesType},
});
const shapedScatterSeries = shapeScatterSeries(shapedScatterSeriesData);
const data = shapeScatterChartData(shapedScatterSeries, categoriesType, categories);
Expand Down
36 changes: 25 additions & 11 deletions src/plugins/d3/renderer/components/Tooltip/DefaultContent.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import React from 'react';
import get from 'lodash/get';

import type {
ScatterSeriesData,
BarXSeriesData,
TooltipHoveredData,
} from '../../../../../types/widget-data';
import type {ChartKitWidgetSeriesData, TooltipHoveredData} from '../../../../../types/widget-data';

import type {PreparedAxis} from '../../hooks';
import {getDataCategoryValue} from '../../utils';

type Props = {
hovered: TooltipHoveredData;
xAxis: PreparedAxis;
yAxis: PreparedAxis;
};

const getXRowData = (xAxis: PreparedAxis, data: ChartKitWidgetSeriesData) => {
const categories = get(xAxis, 'categories', [] as string[]);

return xAxis.type === 'category'
? getDataCategoryValue({axisDirection: 'x', categories, data})
: (data as {x: number}).x;
};

const getYRowData = (yAxis: PreparedAxis, data: ChartKitWidgetSeriesData) => {
const categories = get(yAxis, 'categories', [] as string[]);

return yAxis.type === 'category'
? getDataCategoryValue({axisDirection: 'y', categories, data})
: (data as {y: number}).y;
};

export const DefaultContent = ({hovered, xAxis, yAxis}: Props) => {
const {data, series} = hovered;

switch (series.type) {
case 'scatter': {
const scatterData = data as ScatterSeriesData;
const xRow = xAxis.type === 'category' ? scatterData.category : scatterData.x;
const yRow = yAxis.type === 'category' ? scatterData.category : scatterData.y;
const xRow = getXRowData(xAxis, data);
const yRow = getYRowData(yAxis, data);

return (
<div>
<div>
Expand All @@ -36,9 +50,9 @@ export const DefaultContent = ({hovered, xAxis, yAxis}: Props) => {
);
}
case 'bar-x': {
const barXData = data as BarXSeriesData;
const xRow = xAxis.type === 'category' ? barXData.category : barXData.x;
const yRow = yAxis.type === 'category' ? barXData.category : barXData.y;
const xRow = getXRowData(xAxis, data);
const yRow = getYRowData(yAxis, data);

return (
<div>
<div>{xRow}</div>
Expand Down
38 changes: 27 additions & 11 deletions src/plugins/d3/renderer/hooks/useAxisScales/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import get from 'lodash/get';
import type {ChartOptions} from '../useChartOptions/types';
import {
getOnlyVisibleSeries,
getDataCategoryValue,
getDomainDataYBySeries,
isAxisRelatedSeries,
getDomainDataXBySeries,
isAxisRelatedSeries,
isSeriesWithCategoryValues,
} from '../../utils';
import type {AxisDirection} from '../../utils';
import {PreparedSeries} from '../useSeries/types';

export type ChartScale =
Expand All @@ -35,10 +37,22 @@ const isNumericalArrayData = (data: unknown[]): data is number[] => {
return data.every((d) => typeof d === 'number' || d === null);
};

const filterCategoriesByVisibleSeries = (categories: string[], series: PreparedSeries[]) => {
const filterCategoriesByVisibleSeries = (args: {
axisDirection: AxisDirection;
categories: string[];
series: PreparedSeries[];
}) => {
const {axisDirection, categories, series} = args;

return categories.filter((category) => {
return series.some((s) => {
return isSeriesWithCategoryValues(s) && s.data.some((d) => d.category === category);
return (
isSeriesWithCategoryValues(s) &&
s.data.some((d) => {
const dataCategory = getDataCategoryValue({axisDirection, categories, data: d});
return dataCategory === category;
})
);
});
});
};
Expand Down Expand Up @@ -75,10 +89,11 @@ const createScales = (args: Args) => {
}
case 'category': {
if (xCategories) {
const filteredCategories = filterCategoriesByVisibleSeries(
xCategories,
visibleSeries,
);
const filteredCategories = filterCategoriesByVisibleSeries({
axisDirection: 'x',
categories: xCategories,
series: visibleSeries,
});
xScale = scaleBand().domain(filteredCategories).range([0, boundsWidth]);
}

Expand Down Expand Up @@ -122,10 +137,11 @@ const createScales = (args: Args) => {
}
case 'category': {
if (yCategories) {
const filteredCategories = filterCategoriesByVisibleSeries(
yCategories,
visibleSeries,
);
const filteredCategories = filterCategoriesByVisibleSeries({
axisDirection: 'y',
categories: yCategories,
series: visibleSeries,
});
yScale = scaleBand().domain(filteredCategories).range([boundsHeight, 0]);
}

Expand Down
21 changes: 14 additions & 7 deletions src/plugins/d3/renderer/hooks/useShapes/bar-x.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import React from 'react';
import {ChartOptions} from '../useChartOptions/types';
import {ChartScale} from '../useAxisScales';
import {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import {BarXSeriesData} from '../../../../../types/widget-data';
import {group, pointer, select} from 'd3';
import type {ScaleBand, ScaleLinear, ScaleTime} from 'd3';
import get from 'lodash/get';

import type {BarXSeriesData} from '../../../../../types/widget-data';
import {block} from '../../../../../utils/cn';
import {group, pointer, ScaleBand, ScaleLinear, ScaleTime, select} from 'd3';
import {PreparedBarXSeries} from '../useSeries/types';

import {getDataCategoryValue} from '../../utils';
import type {ChartScale} from '../useAxisScales';
import type {ChartOptions} from '../useChartOptions/types';
import type {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import type {PreparedBarXSeries} from '../useSeries/types';

const DEFAULT_BAR_RECT_WIDTH = 50;
const DEFAULT_LINEAR_BAR_RECT_WIDTH = 20;
Expand Down Expand Up @@ -44,8 +49,10 @@ const getRectProperties = (args: {
if (xAxis.type === 'category') {
const xBandScale = xScale as ScaleBand<string>;
const maxWidth = xBandScale.bandwidth() - MIN_RECT_GAP;
const categories = get(xAxis, 'categories', [] as string[]);
const dataCategory = getDataCategoryValue({axisDirection: 'x', categories, data: point});
width = Math.min(maxWidth, DEFAULT_BAR_RECT_WIDTH);
cx = (xBandScale(point.category as string) || 0) + xBandScale.step() / 2 - width / 2;
cx = (xBandScale(dataCategory) || 0) + xBandScale.step() / 2 - width / 2;
} else {
const xLinearScale = xScale as ScaleLinear<number, number> | ScaleTime<number, number>;
const [min, max] = xLinearScale.domain();
Expand Down
48 changes: 22 additions & 26 deletions src/plugins/d3/renderer/hooks/useShapes/scatter.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import React from 'react';
import {pointer, select} from 'd3';
import type {ScaleBand, ScaleLinear, ScaleTime} from 'd3';
import React from 'react';
import {ChartOptions} from '../useChartOptions/types';
import {ChartScale} from '../useAxisScales';
import {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import {ScatterSeries, ScatterSeriesData} from '../../../../../types/widget-data';
import get from 'lodash/get';

import type {ScatterSeries, ScatterSeriesData} from '../../../../../types/widget-data';
import {block} from '../../../../../utils/cn';

import {getDataCategoryValue} from '../../utils';
import type {ChartScale} from '../useAxisScales';
import type {PreparedAxis} from '../useChartOptions/types';
import type {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';

type ScatterSeriesShapeProps = {
top: number;
left: number;
series: ScatterSeries;
xAxis: ChartOptions['xAxis'];
xAxis: PreparedAxis;
xScale: ChartScale;
yAxis: ChartOptions['yAxis'];
yAxis: PreparedAxis[];
yScale: ChartScale;
svgContainer: SVGSVGElement | null;
onSeriesMouseMove?: OnSeriesMouseMove;
Expand All @@ -23,26 +27,20 @@ type ScatterSeriesShapeProps = {
const b = block('d3-scatter');
const DEFAULT_SCATTER_POINT_RADIUS = 4;

const prepareCategoricalScatterData = (data: ScatterSeriesData[]) => {
return data.filter((d) => typeof d.category === 'string');
};

const prepareLinearScatterData = (data: ScatterSeriesData[]) => {
return data.filter((d) => typeof d.x === 'number' && typeof d.y === 'number');
};

const getCxAttr = (args: {
point: ScatterSeriesData;
xAxis: ChartOptions['xAxis'];
xScale: ChartScale;
}) => {
const getCxAttr = (args: {point: ScatterSeriesData; xAxis: PreparedAxis; xScale: ChartScale}) => {
const {point, xAxis, xScale} = args;

let cx: number;

if (xAxis.type === 'category') {
const xBandScale = xScale as ScaleBand<string>;
cx = (xBandScale(point.category as string) || 0) + xBandScale.step() / 2;
const categories = get(xAxis, 'categories', [] as string[]);
const dataCategory = getDataCategoryValue({axisDirection: 'x', categories, data: point});
cx = (xBandScale(dataCategory) || 0) + xBandScale.step() / 2;
} else {
const xLinearScale = xScale as ScaleLinear<number, number> | ScaleTime<number, number>;
cx = xLinearScale(point.x as number);
Expand All @@ -51,18 +49,16 @@ const getCxAttr = (args: {
return cx;
};

const getCyAttr = (args: {
point: ScatterSeriesData;
yAxis: ChartOptions['yAxis'];
yScale: ChartScale;
}) => {
const getCyAttr = (args: {point: ScatterSeriesData; yAxis: PreparedAxis; yScale: ChartScale}) => {
const {point, yAxis, yScale} = args;

let cy: number;

if (yAxis[0].type === 'category') {
if (yAxis.type === 'category') {
const yBandScale = yScale as ScaleBand<string>;
cy = (yBandScale(point.category as string) || 0) + yBandScale.step() / 2;
const categories = get(yAxis, 'categories', [] as string[]);
const dataCategory = getDataCategoryValue({axisDirection: 'y', categories, data: point});
cy = (yBandScale(dataCategory) || 0) + yBandScale.step() / 2;
} else {
const yLinearScale = yScale as ScaleLinear<number, number> | ScaleTime<number, number>;
cy = yLinearScale(point.y as number);
Expand Down Expand Up @@ -95,7 +91,7 @@ export function ScatterSeriesShape(props: ScatterSeriesShapeProps) {
svgElement.selectAll('*').remove();
const preparedData =
xAxis.type === 'category' || yAxis[0]?.type === 'category'
? prepareCategoricalScatterData(series.data)
? series.data
: prepareLinearScatterData(series.data);

svgElement
Expand All @@ -107,7 +103,7 @@ export function ScatterSeriesShape(props: ScatterSeriesShapeProps) {
.attr('fill', (d) => d.color || series.color || '')
.attr('r', (d) => d.radius || DEFAULT_SCATTER_POINT_RADIUS)
.attr('cx', (d) => getCxAttr({point: d, xAxis, xScale}))
.attr('cy', (d) => getCyAttr({point: d, yAxis, yScale}))
.attr('cy', (d) => getCyAttr({point: d, yAxis: yAxis[0], yScale}))
.on('mousemove', (e, d) => {
const [x, y] = pointer(e, svgContainer);
onSeriesMouseMove?.({
Expand Down
Loading

0 comments on commit eabf291

Please sign in to comment.