Skip to content

Commit

Permalink
feat(fem): update Forms to accept chunk size.
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertZyy committed Dec 10, 2024
1 parent 68ed414 commit 4e9ac1d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 39 deletions.
24 changes: 13 additions & 11 deletions fealpy/fem/bilinear_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def check_space(self):
raise ValueError("Spaces should have the same dtype, "
f"but got {s0.ftype} and {s1.ftype}.")

def _scalar_assembly(self, retain_ints: bool, batch_size: int):
def _scalar_assembly(self):
self.check_space()
space = self._spaces
batch_size = self.batch_size
ugdof = space[0].number_of_global_dofs()
vgdof = space[1].number_of_global_dofs() if (len(space) > 1) else ugdof
init_value_shape = (0,) if (batch_size == 0) else (batch_size, 0)
Expand All @@ -56,12 +57,13 @@ def _scalar_assembly(self, retain_ints: bool, batch_size: int):
values = bm.empty(init_value_shape, dtype=space[0].ftype, device=bm.get_device(space[0])),
spshape = sparse_shape
)
for group in self.integrators.keys():
group_tensor, e2dofs = self._assembly_group(group, retain_ints)
ue2dof = e2dofs[0]
ve2dof = e2dofs[1] if (len(e2dofs) > 1) else ue2dof
# for group in self.integrators.keys():
# group_tensor, e2dofs = self._assembly_group(group, retain_ints)
for group_tensor, e2dofs_tuple in self.assembly_local_iterative():
ue2dof = e2dofs_tuple[0]
ve2dof = e2dofs_tuple[1] if (len(e2dofs_tuple) > 1) else ue2dof
local_shape = group_tensor.shape[-3:] # (NC, vldof, uldof)

if (batch_size > 0) and (group_tensor.ndim == 3): # Case: no batch dimension
group_tensor = bm.stack([group_tensor]*batch_size, axis=0)
I = bm.broadcast_to(ve2dof[:, :, None], local_shape)
Expand All @@ -73,12 +75,12 @@ def _scalar_assembly(self, retain_ints: bool, batch_size: int):
return M

@overload
def assembly(self, *, retain_ints: bool=False) -> CSRTensor: ...
def assembly(self) -> CSRTensor: ...
@overload
def assembly(self, *, format: Literal['coo'], retain_ints: bool=False) -> COOTensor: ...
def assembly(self, *, format: Literal['coo']) -> COOTensor: ...
@overload
def assembly(self, *, format: Literal['csr'], retain_ints: bool=False) -> CSRTensor: ...
def assembly(self, *, format='csr', retain_ints: bool=False):
def assembly(self, *, format: Literal['csr']) -> CSRTensor: ...
def assembly(self, *, format='csr'):
"""Assembly the bilinear form matrix.
Parameters:
Expand All @@ -88,7 +90,7 @@ def assembly(self, *, format='csr', retain_ints: bool=False):
Returns:
global_matrix (CSRTensor | COOTensor): Global sparse matrix shaped ([batch, ]gdof, gdof).
"""
M = self._scalar_assembly(retain_ints, self.batch_size)
M = self._scalar_assembly()
if getattr(self, '_transposed', False):
M = M.T

Expand Down
44 changes: 30 additions & 14 deletions fealpy/fem/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class Form(Generic[_I], ABC):
_spaces: Tuple[_FS, ...]
integrators: Dict[str, _I]
chunk_sizes: Dict[str, int]
batch_size: int
sparse_shape: Tuple[int, ...]

Expand All @@ -32,6 +33,7 @@ def __init__(self, *space, batch_size: int=0):
space = space[0]
self._spaces = space
self.integrators = {}
self.chunk_sizes = {}
self._cursor = 0
self.batch_size = batch_size

Expand Down Expand Up @@ -69,18 +71,19 @@ def space(self):
return self._spaces

@overload
def add_integrator(self: Self, I: _I, /, *, region: Optional[Index] = None, group: str = ...) -> Self: ...
def add_integrator(self: Self, I: _I, /, *, region: Optional[Index] = None, chunk_size: int = 0, group: str = ...) -> Self: ...
@overload
def add_integrator(self: Self, I: Sequence[_I], /, *, region: Optional[Index] = None, group: str = ...) -> Self: ...
def add_integrator(self: Self, I: Sequence[_I], /, *, region: Optional[Index] = None, chunk_size: int = 0, group: str = ...) -> Self: ...
@overload
def add_integrator(self: Self, *I: _I, region: Optional[Index] = None, group: str = ...) -> Self: ...
def add_integrator(self, *I, region: Optional[Index] = None, group=None):
def add_integrator(self: Self, *I: _I, region: Optional[Index] = None, chunk_size: int = 0, group: str = ...) -> Self: ...
def add_integrator(self, *I, region: Optional[Index] = None, chunk_size=0, group=None):
"""Add integrator(s) to the form.
Parameters:
*I (Integrator): The integrator(s) to add as a new group.
Also accepts sequence of integrators.
index (Index | None, optional):
chunk_size (int, optional):
group (str | None, optional): Name of the group. Defaults to None.
Returns:
Expand All @@ -100,7 +103,7 @@ def add_integrator(self, *I, region: Optional[Index] = None, group=None):
else:
I = GroupIntegrator(*I, region=region)

return self._add_integrator_impl(I, group)
return self._add_integrator_impl(I, group, chunk_size)

@overload
def __lshift__(self: Self, other: Integrator) -> Self: ...
Expand All @@ -110,9 +113,10 @@ def __lshift__(self, other):
else:
return NotImplemented

def _add_integrator_impl(self, I: _I, group: Optional[str] = None):
def _add_integrator_impl(self, I: _I, group: Optional[str] = None, chunk_size: int = 0):
group = f'_group_{self._cursor}' if group is None else group
self._cursor += 1
self.chunk_sizes[group] = chunk_size

if group in self.integrators:
self.integrators[group] += I
Expand All @@ -128,19 +132,30 @@ def _assembly_group(self, group: str, /, *args, **kwds):
etg = (etg, )
return integrator(self.space), etg

def assembly_local_iterative(self):
"""Assembly local matrix considering chunk size.
Yields local matrix and to_global_dof tuple."""
for key, int_ in self.integrators.items():
chunk_size = self.chunk_sizes[key]
if chunk_size == 0:
logger.debug(f"(ASSEMBLY LOCAL FULL) {key}")
yield self._assembly_group(key)
else:
logger.debug(f"(ASSEMBLY LOCAL ITER) {key}, {chunk_size} chunks")
yield from IntegralIter.split(int_, chunk_size)(self.space)


# An iteration util for the `_assembly_group` method.
class IntegralIter():
def __init__(self, integrator: Integrator, /, indices_or_segments: Union[Iterable[TensorLike], TensorLike]):
self.integrator = integrator
self.indices_or_segments = indices_or_segments

def get(self, space: Union[_FS, Tuple[_FS, ...]], region: Index):
self.integrator.set_region(region)
etg = self.integrator.to_global_dof(space)
def kernel(self, space: Union[_FS, Tuple[_FS, ...]], /, indices: Index):
etg = self.integrator.to_global_dof(space, indices=indices)
if not isinstance(etg, (tuple, list)):
etg = (etg, )
return self.integrator(space), etg
return self.integrator(space, indices=indices), etg

def __call__(self, spaces: Tuple[_FS, ...]):
if isinstance(self.indices_or_segments, TensorLike):
Expand All @@ -152,7 +167,7 @@ def __call__(self, spaces: Tuple[_FS, ...]):

def _call_impl_indices(self, spaces: Tuple[_FS, ...], /, indices: Iterable[TensorLike]):
for index in indices:
yield self.get(spaces, index)
yield self.kernel(spaces, index)

def _call_impl_segments(self, spaces: Tuple[_FS, ...], /, segments: TensorLike):
assert segments.ndim == 1
Expand All @@ -161,13 +176,14 @@ def _call_impl_segments(self, spaces: Tuple[_FS, ...], /, segments: TensorLike):
length = segments.shape[0] + 1

for i in range(length):
logger.debug(f"(ITERATION) {i}/{length}")
stop = segments[i] if (i + 1 < length) else None
slicing = slice(start, stop, 1)
yield self.get(spaces, slicing)
yield self.kernel(spaces, slicing)
start = stop

@classmethod
def split(cls, integrator: Integrator, /, index: TensorLike, chunk_size=0):
size = index.shape[0]
def split(cls, integrator: Integrator, /, chunk_size=0):
size = integrator.get_region().shape[0]
segments = bm.arange(chunk_size, size, chunk_size)
return cls(integrator, segments)
13 changes: 9 additions & 4 deletions fealpy/fem/integrator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

from typing import Union, Optional, Any, TypeVar, Tuple, List, Dict, Sequence
from typing import overload, Generic
import logging

from .. import logger
from ..typing import TensorLike, Index, CoefLike, _S
from ..backend import backend_manager as bm
from ..typing import TensorLike, Index, CoefLike
from ..functionspace.space import FunctionSpace as _FS
from ..utils import ftype_memory_size

Expand Down Expand Up @@ -175,10 +175,15 @@ def const(self, space: _SpaceGroup, /):
return ConstIntegrator(value, to_gdof)

def __call__(self, space: _SpaceGroup, /, indices: _OpIndex = None) -> TensorLike:
logger.debug(f"(INTEGRATOR RUN) {self.__class__.__name__}, on {space.__class__.__name__}")
meth = getattr(self, self._assembly_name_map[self._method], None)
if indices is None:
return meth(space) # Old API
return meth(space, indices=indices)
val = meth(space) # Old API
else:
val = meth(space, indices=indices)
if logger.level == logging._nameToLevel['INFO']:
logger.info(f"Local tensor sized {ftype_memory_size(val)} Mb.")
return val

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._method})"
Expand Down
19 changes: 9 additions & 10 deletions fealpy/fem/linear_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def check_space(self):
if len(self._spaces) != 1:
raise ValueError("LinearForm should have only one space.")

def _scalar_assembly(self, retain_ints: bool, batch_size: int):
def _scalar_assembly(self):
self.check_space()
space = self._spaces[0]
batch_size = self.batch_size
gdof = space.number_of_global_dofs()
init_value_shape = (0,) if (batch_size == 0) else (batch_size, 0)
sparse_shape = (gdof, )
Expand All @@ -46,25 +47,23 @@ def _scalar_assembly(self, retain_ints: bool, batch_size: int):
spshape = sparse_shape
)

for group in self.integrators.keys():
group_tensor, e2dofs = self._assembly_group(group, retain_ints)

for group_tensor, e2dofs_tuple in self.assembly_local_iterative():
if (batch_size > 0) and (group_tensor.ndim == 2):
group_tensor = bm.stack([group_tensor]*batch_size, axis=0)

indices = e2dofs[0].reshape(1, -1)
indices = e2dofs_tuple[0].reshape(1, -1)
group_tensor = bm.reshape(group_tensor, self._values_ravel_shape)
M = M.add(COOTensor(indices, group_tensor, sparse_shape))

return M

@overload
def assembly(self, *, retain_ints: bool=False) -> TensorLike: ...
def assembly(self) -> TensorLike: ...
@overload
def assembly(self, *, format: Literal['coo'], retain_ints: bool=False) -> COOTensor: ...
def assembly(self, *, format: Literal['coo']) -> COOTensor: ...
@overload
def assembly(self, *, format: Literal['dense'], retain_ints: bool=False) -> TensorLike: ...
def assembly(self, *, format='dense', retain_ints: bool=False):
def assembly(self, *, format: Literal['dense']) -> TensorLike: ...
def assembly(self, *, format='dense'):
"""Assembly the linear form vector.
Parameters:
Expand All @@ -74,7 +73,7 @@ def assembly(self, *, format='dense', retain_ints: bool=False):
Returns:
global_vector (COOTensor | TensorLike): Global sparse vector shaped ([batch, ]gdof).
"""
V = self._scalar_assembly(retain_ints, self.batch_size)
V = self._scalar_assembly()

if format == 'dense':
self._V = V.to_dense()
Expand Down

0 comments on commit 4e9ac1d

Please sign in to comment.