Skip to content

Commit

Permalink
Fix constant numbering in SLATE (#3808)
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward authored Oct 21, 2024
1 parent e27b702 commit bf67f79
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
15 changes: 11 additions & 4 deletions firedrake/slate/slac/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap

# Pick the constants associated with a Tensor()/TSFC kernel
tsfc_constants = tuple(tsfc_constants[i] for i in kinfo.constant_numbers)
kernel_data.extend([(c, c.name) for c in wrapper_constants if c in tsfc_constants])
kernel_data.extend([
(constant, constant_name)
for constant, constant_name in wrapper_constants
if constant in tsfc_constants
])
return kernel_data

def loopify_tsfc_kernel_data(self, kernel_data):
Expand Down Expand Up @@ -254,7 +258,10 @@ def collect_coefficients(self):

def collect_constants(self):
""" All constants of self.expression as a list """
return self.expression.constants()
return tuple(
(constant, f"c_{i}")
for i, constant in enumerate(self.expression.constants())
)

def initialise_terminals(self, var2terminal, coefficients):
""" Initilisation of the variables in which coefficients
Expand Down Expand Up @@ -361,9 +368,9 @@ def generate_wrapper_kernel_args(self, tensor2temp):
dtype=self.tsfc_parameters["scalar_type"])
args.append(kernel_args.CoefficientKernelArg(coeff_loopy_arg))

for constant in self.bag.constants:
for constant, constant_name in self.bag.constants:
constant_loopy_arg = loopy.GlobalArg(
constant.name,
constant_name,
shape=constant.dat.cdim,
dtype=self.tsfc_parameters["scalar_type"]
)
Expand Down
3 changes: 1 addition & 2 deletions firedrake/slate/slac/tsfc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None):
cxt_kernels = []
assert prefix is not None
for orig_it_type, integrals in transformed_integrals.items():
subkernel_prefix = prefix + "%s_to_" % orig_it_type
form = Form(integrals)
kernels = tsfc_compile(form,
subkernel_prefix,
f"{prefix}{orig_it_type}_to_",
parameters=tsfc_parameters,
split=False, diagonal=tensor.diagonal)

Expand Down

0 comments on commit bf67f79

Please sign in to comment.