diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index dec58914f3..5b3c309ec4 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -1380,6 +1380,7 @@ def _set_projections_tracincp_fast_rand_proj( 1 ] # this is the dimension of the input of the last fully-connected layer device = batch_jacobians.device + dtype = batch_jacobians.dtype # choose projection if needed # without projection, the dimension of the intermediate quantities returned @@ -1409,8 +1410,8 @@ def _set_projections_tracincp_fast_rand_proj( ) projection_quantities = jacobian_projection.to( - device - ), layer_input_projection.to(device) + device=device, dtype=dtype + ), layer_input_projection.to(device=device, dtype=dtype) return projection_quantities diff --git a/tests/influence/_utils/common.py b/tests/influence/_utils/common.py index dbfc0de550..26a5146785 100644 --- a/tests/influence/_utils/common.py +++ b/tests/influence/_utils/common.py @@ -190,35 +190,42 @@ def get_random_model_and_data( BasicLinearNet(in_features, hidden_nodes, out_features) if not unpack_inputs else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs) - ) + ).double() num_checkpoints = 5 for i in range(num_checkpoints): - net.linear1.weight.data = torch.normal(3, 4, (hidden_nodes, in_features)) - net.linear2.weight.data = torch.normal(5, 6, (out_features, hidden_nodes)) + net.linear1.weight.data = torch.normal( + 3, 4, (hidden_nodes, in_features) + ).double() + net.linear2.weight.data = torch.normal( + 5, 6, (out_features, hidden_nodes) + ).double() if unpack_inputs: net.pre.weight.data = torch.normal( 3, 4, (in_features, in_features * num_inputs) ) + if hasattr(net, "pre"): + net.pre.weight.data = net.pre.weight.data.double() checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"]) net_adjusted = _wrap_model_in_dataparallel(net) if use_gpu else net torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) num_samples = 50 num_train = 32 - all_labels = torch.normal(1, 2, (num_samples, out_features)) + all_labels = torch.normal(1, 2, (num_samples, out_features)).double() train_labels = all_labels[:num_train] test_labels = all_labels[num_train:] if unpack_inputs: all_samples = [ - torch.normal(0, 1, (num_samples, in_features)) for _ in range(num_inputs) + torch.normal(0, 1, (num_samples, in_features)).double() + for _ in range(num_inputs) ] train_samples = [ts[:num_train] for ts in all_samples] test_samples = [ts[num_train:] for ts in all_samples] else: - all_samples = torch.normal(0, 1, (num_samples, in_features)) + all_samples = torch.normal(0, 1, (num_samples, in_features)).double() train_samples = all_samples[:num_train] test_samples = all_samples[num_train:]