diff --git a/packages/grid/_modules_/grid/hooks/features/columnMenu/useGridColumnMenu.ts b/packages/grid/_modules_/grid/hooks/features/columnMenu/useGridColumnMenu.ts index 0576c6c94947e..35a340bb21978 100644 --- a/packages/grid/_modules_/grid/hooks/features/columnMenu/useGridColumnMenu.ts +++ b/packages/grid/_modules_/grid/hooks/features/columnMenu/useGridColumnMenu.ts @@ -10,6 +10,7 @@ import { useGridApiEventHandler, } from '../../utils'; import { gridColumnMenuSelector } from './columnMenuSelector'; +import { GridColumnMenuApi } from '../../../models'; /** * @requires useGridColumnResize (event) @@ -22,8 +23,11 @@ export const useGridColumnMenu = (apiRef: GridApiRef): void => { const [, setGridState, forceUpdate] = useGridState(apiRef); const columnMenu = useGridSelector(apiRef, gridColumnMenuSelector); - const showColumnMenu = React.useCallback( - (field: string) => { + /** + * API METHODS + */ + const showColumnMenu = React.useCallback( + (field) => { const shouldUpdate = setGridState((state) => { if (state.columnMenu.open && state.columnMenu.field === field) { return state; @@ -44,7 +48,7 @@ export const useGridColumnMenu = (apiRef: GridApiRef): void => { [apiRef, forceUpdate, logger, setGridState], ); - const hideColumnMenu = React.useCallback(() => { + const hideColumnMenu = React.useCallback(() => { const shouldUpdate = setGridState((state) => { if (!state.columnMenu.open && state.columnMenu.field === undefined) { return state; @@ -62,8 +66,8 @@ export const useGridColumnMenu = (apiRef: GridApiRef): void => { } }, [forceUpdate, logger, setGridState]); - const toggleColumnMenu = React.useCallback( - (field: string) => { + const toggleColumnMenu = React.useCallback( + (field) => { logger.debug('Toggle Column Menu'); if (!columnMenu.open || columnMenu.field !== field) { showColumnMenu(field); @@ -74,16 +78,17 @@ export const useGridColumnMenu = (apiRef: GridApiRef): void => { [logger, showColumnMenu, hideColumnMenu, columnMenu], ); - useGridApiMethod( - apiRef, - { - showColumnMenu, - hideColumnMenu, - toggleColumnMenu, - }, - 'ColumnMenuApi', - ); + const columnMenuApi: GridColumnMenuApi = { + showColumnMenu, + hideColumnMenu, + toggleColumnMenu, + }; + + useGridApiMethod(apiRef, columnMenuApi, 'GridColumnMenuApi'); + /** + * EVENTS + */ useGridApiEventHandler(apiRef, GridEvents.columnResizeStart, hideColumnMenu); useGridApiEventHandler(apiRef, GridEvents.rowsScroll, hideColumnMenu); }; diff --git a/packages/grid/_modules_/grid/hooks/features/columns/useGridColumns.ts b/packages/grid/_modules_/grid/hooks/features/columns/useGridColumns.ts index b763a3ad155d8..20a21f8b60fde 100644 --- a/packages/grid/_modules_/grid/hooks/features/columns/useGridColumns.ts +++ b/packages/grid/_modules_/grid/hooks/features/columns/useGridColumns.ts @@ -71,6 +71,9 @@ export function useGridColumns( [logger, setGridState, forceUpdate, apiRef], ); + /** + * API METHODS + */ const getColumn = React.useCallback( (field) => gridColumnLookupSelector(apiRef.current.state)[field], [apiRef], @@ -189,7 +192,7 @@ export function useGridColumns( [apiRef, logger], ); - const colApi: GridColumnApi = { + const columnApi: GridColumnApi = { getColumn, getAllColumns, getColumnIndex, @@ -203,28 +206,11 @@ export function useGridColumns( setColumnWidth, }; - useGridApiMethod(apiRef, colApi, 'ColApi'); - - // The effect do not track any value defined synchronously during the 1st render by hooks called after `useGridColumns` - // As a consequence, the state generated by the 1st run of this useEffect will always be equal to the initialization one - const isFirstRender = React.useRef(true); - React.useEffect(() => { - if (isFirstRender.current) { - isFirstRender.current = false; - return; - } - - logger.info(`GridColumns have changed, new length ${props.columns.length}`); - - const columnsState = createColumnsState({ - apiRef, - columnsTypes, - columnsToUpsert: props.columns, - reset: true, - }); - setGridColumnsState(columnsState); - }, [logger, apiRef, setGridColumnsState, props.columns, columnsTypes]); + useGridApiMethod(apiRef, columnApi, 'GridColumnApi'); + /** + * EVENTS + */ const handlePreProcessorRegister = React.useCallback< GridEventListener >( @@ -259,10 +245,32 @@ export function useGridColumns( useGridApiEventHandler(apiRef, GridEvents.preProcessorRegister, handlePreProcessorRegister); useGridApiEventHandler(apiRef, GridEvents.viewportInnerSizeChange, handleGridSizeChange); - // Grid Option Handlers useGridApiOptionHandler( apiRef, GridEvents.columnVisibilityChange, props.onColumnVisibilityChange, ); + + /** + * EFFECTS + */ + // The effect do not track any value defined synchronously during the 1st render by hooks called after `useGridColumns` + // As a consequence, the state generated by the 1st run of this useEffect will always be equal to the initialization one + const isFirstRender = React.useRef(true); + React.useEffect(() => { + if (isFirstRender.current) { + isFirstRender.current = false; + return; + } + + logger.info(`GridColumns have changed, new length ${props.columns.length}`); + + const columnsState = createColumnsState({ + apiRef, + columnsTypes, + columnsToUpsert: props.columns, + reset: true, + }); + setGridColumnsState(columnsState); + }, [logger, apiRef, setGridColumnsState, props.columns, columnsTypes]); } diff --git a/packages/grid/_modules_/grid/hooks/features/filter/gridFilterUtils.ts b/packages/grid/_modules_/grid/hooks/features/filter/gridFilterUtils.ts new file mode 100644 index 0000000000000..b935229b13326 --- /dev/null +++ b/packages/grid/_modules_/grid/hooks/features/filter/gridFilterUtils.ts @@ -0,0 +1,104 @@ +import { + GridApiRef, + GridFilterItem, + GridFilterModel, + GridLinkOperator, + GridRowId, +} from '../../../models'; + +type GridFilterItemApplier = (rowId: GridRowId) => boolean; + +/** + * Adds default values to the optional fields of a filter items. + * @param {GridFilterItem} item The raw filter item. + * @param {GridApiRef} apiRef The API of the grid. + * @return {GridFilterItem} The clean filter item with an uniq ID and an always-defined operatorValue. + * TODO: Make the typing reflect the different between GridFilterInputItem and GridFilterItem. + */ +export const cleanFilterItem = (item: GridFilterItem, apiRef: GridApiRef) => { + const cleanItem: GridFilterItem = { ...item }; + + if (cleanItem.id == null) { + cleanItem.id = Math.round(Math.random() * 1e5); + } + + if (cleanItem.operatorValue == null) { + // we select a default operator + const column = apiRef.current.getColumn(cleanItem.columnField); + cleanItem.operatorValue = column && column!.filterOperators![0].value!; + } + + return cleanItem; +}; + +/** + * Generates a method to easily check if a row is matching the current filter model. + * @param {GridFilterModel} filterModel The model with which we want to filter the rows. + * @param {GridApiRef} apiRef The API of the grid. + * @returns {GridFilterItemApplier | null} A method that checks if a row is matching the current filter model. If `null`, we consider that all the rows are matching the filters. + */ +export const buildAggregatedFilterApplier = ( + filterModel: GridFilterModel, + apiRef: GridApiRef, +): GridFilterItemApplier | null => { + const { items, linkOperator = GridLinkOperator.And } = filterModel; + + const getFilterCallbackFromItem = (filterItem: GridFilterItem): GridFilterItemApplier | null => { + if (!filterItem.columnField || !filterItem.operatorValue) { + return null; + } + + const column = apiRef.current.getColumn(filterItem.columnField); + if (!column) { + return null; + } + + const parsedValue = column.valueParser + ? column.valueParser(filterItem.value) + : filterItem.value; + const newFilterItem: GridFilterItem = { ...filterItem, value: parsedValue }; + + const filterOperators = column.filterOperators; + if (!filterOperators?.length) { + throw new Error(`MUI: No filter operators found for column '${column.field}'.`); + } + + const filterOperator = filterOperators.find( + (operator) => operator.value === newFilterItem.operatorValue, + )!; + if (!filterOperator) { + throw new Error( + `MUI: No filter operator found for column '${column.field}' and operator value '${newFilterItem.operatorValue}'.`, + ); + } + + const applyFilterOnRow = filterOperator.getApplyFilterFn(newFilterItem, column)!; + if (typeof applyFilterOnRow !== 'function') { + return null; + } + + return (rowId: GridRowId) => { + const cellParams = apiRef.current.getCellParams(rowId, newFilterItem.columnField!); + + return applyFilterOnRow(cellParams); + }; + }; + + const appliers = items + .map(getFilterCallbackFromItem) + .filter((callback): callback is GridFilterItemApplier => !!callback); + + if (appliers.length === 0) { + return null; + } + + return (rowId: GridRowId) => { + // Return `false` as soon as we have a failing filter + if (linkOperator === GridLinkOperator.And) { + return appliers.every((applier) => applier(rowId)); + } + + // Return `true` as soon as we have a passing filter + return appliers.some((applier) => applier(rowId)); + }; +}; diff --git a/packages/grid/_modules_/grid/hooks/features/filter/useGridFilter.ts b/packages/grid/_modules_/grid/hooks/features/filter/useGridFilter.ts index 56508e1211143..e9c42209df19f 100644 --- a/packages/grid/_modules_/grid/hooks/features/filter/useGridFilter.ts +++ b/packages/grid/_modules_/grid/hooks/features/filter/useGridFilter.ts @@ -4,7 +4,7 @@ import { GridComponentProps } from '../../../GridComponentProps'; import { GridApiRef } from '../../../models/api/gridApiRef'; import { GridFilterApi } from '../../../models/api/gridFilterApi'; import { GridFeatureModeConstant } from '../../../models/gridFeatureMode'; -import { GridFilterItem, GridLinkOperator } from '../../../models/gridFilterItem'; +import { GridFilterItem } from '../../../models/gridFilterItem'; import { GridRowId, GridRowModel } from '../../../models/gridRows'; import { useGridApiEventHandler } from '../../utils/useGridApiEventHandler'; import { useGridApiMethod } from '../../utils/useGridApiMethod'; @@ -24,8 +24,7 @@ import { useFirstRender } from '../../utils/useFirstRender'; import { gridRowIdsSelector, gridRowGroupingNameSelector } from '../rows'; import { GridPreProcessingGroup } from '../../core/preProcessing'; import { useGridRegisterFilteringMethod } from './useGridRegisterFilteringMethod'; - -type GridFilterItemApplier = (rowId: GridRowId) => boolean; +import { buildAggregatedFilterApplier, cleanFilterItem } from './gridFilterUtils'; const checkFilterModelValidity = (model: GridFilterModel) => { if (model.items.length > 1) { @@ -87,77 +86,8 @@ export const useGridFilter = ( changeEvent: GridEvents.filterModelChange, }); - const buildAggregatedFilterApplier = React.useCallback( - (filterModel: GridFilterModel): GridFilterItemApplier | null => { - const { items, linkOperator = GridLinkOperator.And } = filterModel; - - const getFilterCallbackFromItem = ( - filterItem: GridFilterItem, - ): GridFilterItemApplier | null => { - if (!filterItem.columnField || !filterItem.operatorValue) { - return null; - } - - const column = apiRef.current.getColumn(filterItem.columnField); - if (!column) { - return null; - } - - const parsedValue = column.valueParser - ? column.valueParser(filterItem.value) - : filterItem.value; - const newFilterItem: GridFilterItem = { ...filterItem, value: parsedValue }; - - const filterOperators = column.filterOperators; - if (!filterOperators?.length) { - throw new Error(`MUI: No filter operators found for column '${column.field}'.`); - } - - const filterOperator = filterOperators.find( - (operator) => operator.value === newFilterItem.operatorValue, - )!; - if (!filterOperator) { - throw new Error( - `MUI: No filter operator found for column '${column.field}' and operator value '${newFilterItem.operatorValue}'.`, - ); - } - - const applyFilterOnRow = filterOperator.getApplyFilterFn(newFilterItem, column)!; - if (typeof applyFilterOnRow !== 'function') { - return null; - } - - return (rowId: GridRowId) => { - const cellParams = apiRef.current.getCellParams(rowId, newFilterItem.columnField!); - - return applyFilterOnRow(cellParams); - }; - }; - - const appliers = items - .map(getFilterCallbackFromItem) - .filter((callback): callback is GridFilterItemApplier => !!callback); - - if (appliers.length === 0) { - return null; - } - - return (rowId: GridRowId) => { - // Return `false` as soon as we have a failing filter - if (linkOperator === GridLinkOperator.And) { - return appliers.every((applier) => applier(rowId)); - } - - // Return `true` as soon as we have a passing filter - return appliers.some((applier) => applier(rowId)); - }; - }, - [apiRef], - ); - /** - * Generate the `visibleRowsLookup` and `visibleDescendantsCountLookup` for the current `filterModel` - * If the tree is not flat, we have to create the lookups even with "server" filtering or 0 filter item to remove to collapsed rows. + * API METHODS */ const applyFilters = React.useCallback(() => { setGridState((state) => { @@ -170,7 +100,7 @@ export const useGridFilter = ( const filterModel = gridFilterModelSelector(state); const isRowMatchingFilters = props.filterMode === GridFeatureModeConstant.client - ? buildAggregatedFilterApplier(filterModel) + ? buildAggregatedFilterApplier(filterModel, apiRef) : null; lastFilteringMethodApplied.current = filteringMethod; @@ -188,33 +118,14 @@ export const useGridFilter = ( }); apiRef.current.publishEvent(GridEvents.visibleRowsSet); forceUpdate(); - }, [apiRef, setGridState, forceUpdate, props.filterMode, buildAggregatedFilterApplier]); - - const cleanFilterItem = React.useCallback( - (item: GridFilterItem) => { - const cleanItem: GridFilterItem = { ...item }; - - if (cleanItem.id == null) { - cleanItem.id = Math.round(Math.random() * 1e5); - } - - if (cleanItem.operatorValue == null) { - // we select a default operator - const column = apiRef.current.getColumn(cleanItem.columnField); - cleanItem.operatorValue = column && column!.filterOperators![0].value!; - } - - return cleanItem; - }, - [apiRef], - ); + }, [apiRef, setGridState, forceUpdate, props.filterMode]); const upsertFilterItem = React.useCallback( (item) => { const filterModel = gridFilterModelSelector(apiRef.current.state); const items = [...filterModel.items]; const itemIndex = items.findIndex((filterItem) => filterItem.id === item.id); - const newItem = cleanFilterItem(item); + const newItem = cleanFilterItem(item, apiRef); if (itemIndex === -1) { items.push(newItem); } else { @@ -222,7 +133,7 @@ export const useGridFilter = ( } apiRef.current.setFilterModel({ ...filterModel, items }); }, - [apiRef, cleanFilterItem], + [apiRef], ); const deleteFilterItem = React.useCallback( @@ -254,11 +165,11 @@ export const useGridFilter = ( if (filterItemOnTarget) { newFilterItems = filterItemsWithValue; } else if (props.disableMultipleColumnsFiltering) { - newFilterItems = [cleanFilterItem({ columnField: targetColumnField })]; + newFilterItems = [cleanFilterItem({ columnField: targetColumnField }, apiRef)]; } else { newFilterItems = [ ...filterItemsWithValue, - cleanFilterItem({ columnField: targetColumnField }), + cleanFilterItem({ columnField: targetColumnField }, apiRef), ]; } @@ -269,7 +180,7 @@ export const useGridFilter = ( } apiRef.current.showPreferences(GridPreferencePanelsValue.filters); }, - [apiRef, logger, cleanFilterItem, props.disableMultipleColumnsFiltering], + [apiRef, logger, props.disableMultipleColumnsFiltering], ); const hideFilterPanel = React.useCallback(() => { @@ -320,20 +231,18 @@ export const useGridFilter = ( return new Map(visibleSortedRows.map((row) => [row.id, row.model])); }, [apiRef]); - useGridApiMethod( - apiRef, - { - setFilterLinkOperator, - unstable_applyFilters: applyFilters, - deleteFilterItem, - upsertFilterItem, - setFilterModel, - showFilterPanel, - hideFilterPanel, - getVisibleRowModels, - }, - 'FilterApi', - ); + const filterApi: GridFilterApi = { + setFilterLinkOperator, + unstable_applyFilters: applyFilters, + deleteFilterItem, + upsertFilterItem, + setFilterModel, + showFilterPanel, + hideFilterPanel, + getVisibleRowModels, + }; + + useGridApiMethod(apiRef, filterApi, 'GridFilterApi'); /** * PRE-PROCESSING diff --git a/packages/grid/_modules_/grid/hooks/features/pagination/useGridPage.ts b/packages/grid/_modules_/grid/hooks/features/pagination/useGridPage.ts index 713bd1eab9228..2d2b5998c5f45 100644 --- a/packages/grid/_modules_/grid/hooks/features/pagination/useGridPage.ts +++ b/packages/grid/_modules_/grid/hooks/features/pagination/useGridPage.ts @@ -65,8 +65,11 @@ export const useGridPage = ( changeEvent: GridEvents.pageChange, }); + /** + * API METHODS + */ const setPage = React.useCallback( - (page: number) => { + (page) => { logger.debug(`Setting page to ${page}`); setGridState((state) => ({ @@ -81,47 +84,53 @@ export const useGridPage = ( [setGridState, forceUpdate, logger], ); - React.useEffect(() => { + const pageApi: GridPageApi = { + setPage, + }; + + useGridApiMethod(apiRef, pageApi, 'GridPageApi'); + + /** + * EVENTS + */ + const handlePageSizeChange: GridEventListener = (pageSize) => { setGridState((state) => { - const rowCount = props.rowCount !== undefined ? props.rowCount : visibleTopLevelRowCount; - const pageCount = getPageCount(rowCount, state.pagination.pageSize); - const page = props.page == null ? state.pagination.page : props.page; + const pageCount = getPageCount(state.pagination.rowCount, pageSize); return { ...state, pagination: applyValidPage({ ...state.pagination, - page, - rowCount, pageCount, + page: state.pagination.page, }), }; }); + forceUpdate(); - }, [setGridState, forceUpdate, visibleTopLevelRowCount, props.rowCount, props.page, apiRef]); + }; - const handlePageSizeChange: GridEventListener = (pageSize) => { + useGridApiEventHandler(apiRef, GridEvents.pageSizeChange, handlePageSizeChange); + + /** + * EFFECTS + */ + React.useEffect(() => { setGridState((state) => { - const pageCount = getPageCount(state.pagination.rowCount, pageSize); + const rowCount = props.rowCount !== undefined ? props.rowCount : visibleTopLevelRowCount; + const pageCount = getPageCount(rowCount, state.pagination.pageSize); + const page = props.page == null ? state.pagination.page : props.page; return { ...state, pagination: applyValidPage({ ...state.pagination, + page, + rowCount, pageCount, - page: state.pagination.page, }), }; }); - forceUpdate(); - }; - - useGridApiEventHandler(apiRef, GridEvents.pageSizeChange, handlePageSizeChange); - - const pageApi: GridPageApi = { - setPage, - }; - - useGridApiMethod(apiRef, pageApi, 'GridPageApi'); + }, [setGridState, forceUpdate, visibleTopLevelRowCount, props.rowCount, props.page, apiRef]); }; diff --git a/packages/grid/_modules_/grid/hooks/features/pagination/useGridPageSize.ts b/packages/grid/_modules_/grid/hooks/features/pagination/useGridPageSize.ts index 8660397fef835..2adf1d1c8a734 100644 --- a/packages/grid/_modules_/grid/hooks/features/pagination/useGridPageSize.ts +++ b/packages/grid/_modules_/grid/hooks/features/pagination/useGridPageSize.ts @@ -40,8 +40,11 @@ export const useGridPageSize = ( changeEvent: GridEvents.pageSizeChange, }); - const setPageSize = React.useCallback( - (pageSize: number) => { + /** + * API METHODS + */ + const setPageSize = React.useCallback( + (pageSize) => { if (pageSize === gridPageSizeSelector(apiRef.current.state)) { return; } @@ -60,18 +63,15 @@ export const useGridPageSize = ( [apiRef, setGridState, forceUpdate, logger], ); - React.useEffect(() => { - if (props.pageSize != null && !props.autoPageSize) { - apiRef.current.setPageSize(props.pageSize); - } - }, [apiRef, props.autoPageSize, props.pageSize]); - const pageSizeApi: GridPageSizeApi = { setPageSize, }; useGridApiMethod(apiRef, pageSizeApi, 'GridPageSizeApi'); + /** + * EVENTS + */ const handleUpdateAutoPageSize = React.useCallback(() => { const dimensions = apiRef.current.getRootDimensions(); if (!props.autoPageSize || !dimensions) { @@ -84,9 +84,18 @@ export const useGridPageSize = ( apiRef.current.setPageSize(maximumPageSizeWithoutScrollBar); }, [apiRef, props.autoPageSize, rowHeight]); + useGridApiEventHandler(apiRef, GridEvents.viewportInnerSizeChange, handleUpdateAutoPageSize); + + /** + * EFFECTS + */ + React.useEffect(() => { + if (props.pageSize != null && !props.autoPageSize) { + apiRef.current.setPageSize(props.pageSize); + } + }, [apiRef, props.autoPageSize, props.pageSize]); + React.useEffect(() => { handleUpdateAutoPageSize(); }, [handleUpdateAutoPageSize]); - - useGridApiEventHandler(apiRef, GridEvents.viewportInnerSizeChange, handleUpdateAutoPageSize); }; diff --git a/packages/grid/_modules_/grid/hooks/features/selection/useGridSelection.ts b/packages/grid/_modules_/grid/hooks/features/selection/useGridSelection.ts index c7bb68ed68cf8..a252ffacd4503 100644 --- a/packages/grid/_modules_/grid/hooks/features/selection/useGridSelection.ts +++ b/packages/grid/_modules_/grid/hooks/features/selection/useGridSelection.ts @@ -21,8 +21,7 @@ import { gridVisibleSortedRowIdsSelector } from '../filter/gridFilterSelector'; import { GRID_CHECKBOX_SELECTION_COL_DEF, GridColDef } from '../../../models'; import { getDataGridUtilityClass } from '../../../gridClasses'; import { useGridStateInit } from '../../utils/useGridStateInit'; -import { useFirstRender } from '../../utils/useFirstRender'; -import { GridPreProcessingGroup } from '../../core/preProcessing'; +import { GridPreProcessingGroup, useGridRegisterPreProcessor } from '../../core/preProcessing'; import { GridCellModes } from '../../../models/gridEditRowModel'; import { GridColumnsRawState } from '../columns/gridColumnsState'; import { isKeyboardEvent } from '../../../utils/keyboardUtils'; @@ -97,6 +96,79 @@ export const useGridSelection = ( const canHaveMultipleSelection = !disableMultipleSelection || checkboxSelection; + const expandRowRangeSelection = React.useCallback( + (id: GridRowId) => { + let endId = id; + const startId = lastRowToggled.current ?? id; + const isSelected = apiRef.current.isRowSelected(id); + if (isSelected) { + const visibleRowIds = gridVisibleSortedRowIdsSelector(apiRef.current.state); + const startIndex = visibleRowIds.findIndex((rowId) => rowId === startId); + const endIndex = visibleRowIds.findIndex((rowId) => rowId === endId); + if (startIndex > endIndex) { + endId = visibleRowIds[endIndex + 1]; + } else { + endId = visibleRowIds[endIndex - 1]; + } + } + + lastRowToggled.current = id; + + apiRef.current.selectRowRange({ startId, endId }, !isSelected); + }, + [apiRef], + ); + + /** + * PRE-PROCESSING + */ + const updateSelectionColumn = React.useCallback( + (columnsState: GridColumnsRawState) => { + const selectionColumn: GridColDef = { + ...GRID_CHECKBOX_SELECTION_COL_DEF, + cellClassName: classes.cellCheckbox, + headerClassName: classes.columnHeaderCheckbox, + headerName: apiRef.current.getLocaleText('checkboxSelectionHeaderName'), + }; + + const shouldHaveSelectionColumn = props.checkboxSelection; + const haveSelectionColumn = columnsState.lookup[selectionColumn.field] != null; + + if (shouldHaveSelectionColumn && !haveSelectionColumn) { + columnsState.lookup[selectionColumn.field] = selectionColumn; + columnsState.all = [selectionColumn.field, ...columnsState.all]; + } else if (!shouldHaveSelectionColumn && haveSelectionColumn) { + delete columnsState.lookup[selectionColumn.field]; + columnsState.all = columnsState.all.filter((field) => field !== selectionColumn.field); + } + + return columnsState; + }, + [apiRef, classes, props.checkboxSelection], + ); + + useGridRegisterPreProcessor(apiRef, GridPreProcessingGroup.hydrateColumns, updateSelectionColumn); + + /** + * API METHODS + */ + const setSelectionModel = React.useCallback( + (model) => { + const currentModel = gridSelectionStateSelector(apiRef.current.state); + if (currentModel !== model) { + logger.debug(`Setting selection model`); + setGridState((state) => ({ ...state, selection: model })); + forceUpdate(); + } + }, + [apiRef, setGridState, forceUpdate, logger], + ); + + const isRowSelected = React.useCallback( + (id) => gridSelectionStateSelector(apiRef.current.state).includes(id), + [apiRef], + ); + const getSelectedRows = React.useCallback( () => selectedGridRowsSelector(apiRef.current.state), [apiRef], @@ -196,46 +268,20 @@ export const useGridSelection = ( [apiRef, logger], ); - const expandRowRangeSelection = React.useCallback( - (id: GridRowId) => { - let endId = id; - const startId = lastRowToggled.current ?? id; - const isSelected = apiRef.current.isRowSelected(id); - if (isSelected) { - const visibleRowIds = gridVisibleSortedRowIdsSelector(apiRef.current.state); - const startIndex = visibleRowIds.findIndex((rowId) => rowId === startId); - const endIndex = visibleRowIds.findIndex((rowId) => rowId === endId); - if (startIndex > endIndex) { - endId = visibleRowIds[endIndex + 1]; - } else { - endId = visibleRowIds[endIndex - 1]; - } - } - - lastRowToggled.current = id; - - apiRef.current.selectRowRange({ startId, endId }, !isSelected); - }, - [apiRef], - ); - - const setSelectionModel = React.useCallback( - (model) => { - const currentModel = gridSelectionStateSelector(apiRef.current.state); - if (currentModel !== model) { - logger.debug(`Setting selection model`); - setGridState((state) => ({ ...state, selection: model })); - forceUpdate(); - } - }, - [apiRef, setGridState, forceUpdate, logger], - ); + const selectionApi: GridSelectionApi = { + selectRow, + selectRows, + selectRowRange, + setSelectionModel, + getSelectedRows, + isRowSelected, + }; - const isRowSelected = React.useCallback( - (id) => gridSelectionStateSelector(apiRef.current.state).includes(id), - [apiRef], - ); + useGridApiMethod(apiRef, selectionApi, 'GridSelectionApi'); + /** + * EVENTS + */ const removeOutdatedSelection = React.useCallback(() => { const currentSelection = gridSelectionStateSelector(apiRef.current.state); const rowsLookup = gridRowsLookupSelector(apiRef.current.state); @@ -384,16 +430,9 @@ export const useGridSelection = ( useGridApiEventHandler(apiRef, GridEvents.cellMouseDown, preventSelectionOnShift); useGridApiEventHandler(apiRef, GridEvents.cellKeyDown, handleCellKeyDown); - const selectionApi: GridSelectionApi = { - selectRow, - selectRows, - selectRowRange, - setSelectionModel, - getSelectedRows, - isRowSelected, - }; - useGridApiMethod(apiRef, selectionApi, 'GridSelectionApi'); - + /** + * EFFECTS + */ React.useEffect(() => { if (propSelectionModel !== undefined) { apiRef.current.setSelectionModel(propSelectionModel); @@ -419,47 +458,4 @@ export const useGridSelection = ( } } }, [apiRef, isRowSelectable, isStateControlled]); - - const updateColumnsPreProcessing = React.useCallback(() => { - const updateCheckboxColumn = (columnsState: GridColumnsRawState) => { - const selectionColumn: GridColDef = { - ...GRID_CHECKBOX_SELECTION_COL_DEF, - cellClassName: classes.cellCheckbox, - headerClassName: classes.columnHeaderCheckbox, - headerName: apiRef.current.getLocaleText('checkboxSelectionHeaderName'), - }; - - const shouldHaveSelectionColumn = props.checkboxSelection; - const haveSelectionColumn = columnsState.lookup[selectionColumn.field] != null; - - if (shouldHaveSelectionColumn && !haveSelectionColumn) { - columnsState.lookup[selectionColumn.field] = selectionColumn; - columnsState.all = [selectionColumn.field, ...columnsState.all]; - } else if (!shouldHaveSelectionColumn && haveSelectionColumn) { - delete columnsState.lookup[selectionColumn.field]; - columnsState.all = columnsState.all.filter((field) => field !== selectionColumn.field); - } - - return columnsState; - }; - - apiRef.current.unstable_registerPreProcessor( - GridPreProcessingGroup.hydrateColumns, - 'selection', - updateCheckboxColumn, - ); - }, [apiRef, props.checkboxSelection, classes]); - - useFirstRender(() => { - updateColumnsPreProcessing(); - }); - - const isFirstRender = React.useRef(true); - React.useEffect(() => { - if (isFirstRender.current) { - isFirstRender.current = false; - return; - } - updateColumnsPreProcessing(); - }, [updateColumnsPreProcessing]); };