Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any torch version; Fixed device propagation #1

Merged
merged 8 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions neuralfields/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ def __init__(
self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3

# Initialize the weights values the same way PyTorch does.
new_weight_init = torch.zeros(self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size)
new_weight_init = torch.zeros(
self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size, device=device
)
nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5))

# Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here).
self.weight = nn.Parameter(new_weight_init, requires_grad=True)
self.weight = nn.Parameter(new_weight_init)

def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights,
Expand All @@ -228,7 +230,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
3-dim output tensor just like for [Conv1d][torch.nn.Conv1d].
"""
# Reconstruct symmetric weights for convolution (original size).
mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype)
mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype, device=self.weight.device)

# Loop over input channels.
for i in range(self.orig_weight_shape[1]):
Expand Down
3 changes: 2 additions & 1 deletion neuralfields/neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
potentials_init=potentials_init,
input_embedding=input_embedding,
output_embedding=output_embedding,
device=device,
)

# Create the custom convolution layer that models the interconnection of neurons, i.e., their potentials.
Expand All @@ -113,7 +114,7 @@ def __init__(
stride=1,
dilation=1,
groups=1,
# device=device,
device=device,
dtype=dtype,
)
init_param_(self.conv_layer, **init_param_kwargs)
Expand Down
28 changes: 17 additions & 11 deletions neuralfields/potential_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
output_size: Optional[int] = None,
input_embedding: Optional[nn.Module] = None,
output_embedding: Optional[nn.Module] = None,
device: Union[str, torch.device] = "cpu",
):
"""
Args:
Expand All @@ -53,6 +54,7 @@ def __init__(
output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
By default, a [linear layer][torch.nn.Linear] without biases is used.
device: Device to move this module to (after initialization).
"""
# Call torch.nn.Module's constructor.
super().__init__()
Expand All @@ -65,38 +67,42 @@ def __init__(
self.input_size = input_size
self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer
self.output_size = self._hidden_size if output_size is None else output_size
self._stimuli_external = torch.zeros(self.hidden_size)
self._stimuli_internal = torch.zeros(self.hidden_size)
self._stimuli_external = torch.zeros(self.hidden_size, device=device)
self._stimuli_internal = torch.zeros(self.hidden_size, device=device)

# Create the common layers.
self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False)
self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False)

# Initialize the values of the potentials.
if potentials_init is not None:
self._potentials_init = potentials_init.detach().clone()
self._potentials_init = potentials_init.detach().clone().to(device=device)
else:
if activation_nonlin is torch.sigmoid:
self._potentials_init = -7 * torch.ones(1, self.hidden_size)
self._potentials_init = -7 * torch.ones(1, self.hidden_size, device=device)
else:
self._potentials_init = torch.zeros(1, self.hidden_size)
self._potentials_init = torch.zeros(1, self.hidden_size, device=device)

# Initialize the potentials' resting level, i.e., the asymptotic level without stimuli.
self.resting_level = nn.Parameter(torch.randn(self.hidden_size), requires_grad=True)
self.resting_level = nn.Parameter(torch.randn(self.hidden_size, device=device))

# Initialize the potential dynamics' time constant.
self.tau_learnable = tau_learnable
self._log_tau_init = torch.log(torch.as_tensor(tau_init, dtype=torch.get_default_dtype()).reshape(-1))
self._log_tau_init = torch.log(
torch.as_tensor(tau_init, device=device, dtype=torch.get_default_dtype()).reshape(-1)
)
if self.tau_learnable:
self._log_tau = nn.Parameter(self._log_tau_init, requires_grad=True)
self._log_tau = nn.Parameter(self._log_tau_init)
else:
self._log_tau = self._log_tau_init

# Initialize the potential dynamics' cubic decay.
self.kappa_learnable = kappa_learnable
self._log_kappa_init = torch.log(torch.as_tensor(kappa_init, dtype=torch.get_default_dtype()).reshape(-1))
self._log_kappa_init = torch.log(
torch.as_tensor(kappa_init, device=device, dtype=torch.get_default_dtype()).reshape(-1)
)
if self.kappa_learnable:
self._log_kappa = nn.Parameter(self._log_kappa_init, requires_grad=True)
self._log_kappa = nn.Parameter(self._log_kappa_init)
else:
self._log_kappa = self._log_kappa_init

Expand Down Expand Up @@ -182,7 +188,7 @@ def init_hidden(
if potentials_init is None:
if batch_size is None:
return self._potentials_init.view(-1)
return self._potentials_init.repeat(batch_size, 1)
return self._potentials_init.repeat(batch_size, 1).to(device=self.device)

return potentials_init.to(device=self.device)

Expand Down
19 changes: 8 additions & 11 deletions neuralfields/simple_neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def __init__(
potentials_init=potentials_init,
input_embedding=input_embedding,
output_embedding=output_embedding,
device=device,
)

# Create the layer that converts the activations of the previous time step into potentials (internal stimulus).
Expand All @@ -345,9 +346,9 @@ def __init__(
self.capacity_learnable = capacity_learnable
if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
if _is_iterable(activation_nonlin):
self._init_capacity(activation_nonlin[0])
self._init_capacity(activation_nonlin[0], device)
else:
self._init_capacity(activation_nonlin) # type: ignore[arg-type]
self._init_capacity(activation_nonlin, device) # type: ignore[arg-type]
else:
self._log_capacity = None

Expand All @@ -360,27 +361,23 @@ def __init__(
# Move the complete model to the given device.
self.to(device=device)

def _init_capacity(self, activation_nonlin: ActivationFunction) -> None:
def _init_capacity(self, activation_nonlin: ActivationFunction, device: Union[str, torch.device]) -> None:
"""Initialize the value of the capacity parameter $C$ depending on the activation function.

Args:
activation_nonlin: Nonlinear activation function used.
"""
if activation_nonlin is torch.sigmoid:
# sigmoid(7.) approx 0.999
self._log_capacity_init = torch.log(torch.tensor([7.0], dtype=torch.get_default_dtype()))
self._log_capacity_init = torch.log(torch.tensor([7.0], device=device, dtype=torch.get_default_dtype()))
self._log_capacity = (
nn.Parameter(self._log_capacity_init, requires_grad=True)
if self.capacity_learnable
else self._log_capacity_init
nn.Parameter(self._log_capacity_init) if self.capacity_learnable else self._log_capacity_init
)
elif activation_nonlin is torch.tanh:
# tanh(3.8) approx 0.999
self._log_capacity_init = torch.log(torch.tensor([3.8], dtype=torch.get_default_dtype()))
self._log_capacity_init = torch.log(torch.tensor([3.8], device=device, dtype=torch.get_default_dtype()))
self._log_capacity = (
nn.Parameter(self._log_capacity_init, requires_grad=True)
if self.capacity_learnable
else self._log_capacity_init
nn.Parameter(self._log_capacity_init) if self.capacity_learnable else self._log_capacity_init
)
else:
raise ValueError(
Expand Down
Loading
Loading