Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix constant numbering in SLATE #3808

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions firedrake/slate/slac/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap

# Pick the constants associated with a Tensor()/TSFC kernel
tsfc_constants = tuple(tsfc_constants[i] for i in kinfo.constant_numbers)
kernel_data.extend([(c, c.name) for c in wrapper_constants if c in tsfc_constants])
kernel_data.extend([
(constant, constant_name)
for constant, constant_name in wrapper_constants
if constant in tsfc_constants
])
return kernel_data

def loopify_tsfc_kernel_data(self, kernel_data):
Expand Down Expand Up @@ -254,7 +258,10 @@ def collect_coefficients(self):

def collect_constants(self):
""" All constants of self.expression as a list """
return self.expression.constants()
return tuple(
(constant, f"c_{i}")
for i, constant in enumerate(self.expression.constants())
)

def initialise_terminals(self, var2terminal, coefficients):
""" Initilisation of the variables in which coefficients
Expand Down Expand Up @@ -361,9 +368,9 @@ def generate_wrapper_kernel_args(self, tensor2temp):
dtype=self.tsfc_parameters["scalar_type"])
args.append(kernel_args.CoefficientKernelArg(coeff_loopy_arg))

for constant in self.bag.constants:
for constant, constant_name in self.bag.constants:
constant_loopy_arg = loopy.GlobalArg(
constant.name,
constant_name,
shape=constant.dat.cdim,
dtype=self.tsfc_parameters["scalar_type"]
)
Expand Down
3 changes: 1 addition & 2 deletions firedrake/slate/slac/tsfc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None):
cxt_kernels = []
assert prefix is not None
for orig_it_type, integrals in transformed_integrals.items():
subkernel_prefix = prefix + "%s_to_" % orig_it_type
form = Form(integrals)
kernels = tsfc_compile(form,
subkernel_prefix,
f"{prefix}{orig_it_type}_to_",
parameters=tsfc_parameters,
split=False, diagonal=tensor.diagonal)

Expand Down
Loading