Skip to content

Commit

Permalink
Allow any torch version; Fixed device propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
famura authored Jul 8, 2024
1 parent 759e37f commit 082ce51
Show file tree
Hide file tree
Showing 8 changed files with 1,899 additions and 1,444 deletions.
8 changes: 5 additions & 3 deletions neuralfields/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ def __init__(
self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3

# Initialize the weights values the same way PyTorch does.
new_weight_init = torch.zeros(self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size)
new_weight_init = torch.zeros(
self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size, device=device
)
nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5))

# Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here).
self.weight = nn.Parameter(new_weight_init, requires_grad=True)
self.weight = nn.Parameter(new_weight_init)

def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights,
Expand All @@ -228,7 +230,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
3-dim output tensor just like for [Conv1d][torch.nn.Conv1d].
"""
# Reconstruct symmetric weights for convolution (original size).
mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype)
mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype, device=self.weight.device)

# Loop over input channels.
for i in range(self.orig_weight_shape[1]):
Expand Down
3 changes: 2 additions & 1 deletion neuralfields/neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
potentials_init=potentials_init,
input_embedding=input_embedding,
output_embedding=output_embedding,
device=device,
)

# Create the custom convolution layer that models the interconnection of neurons, i.e., their potentials.
Expand All @@ -113,7 +114,7 @@ def __init__(
stride=1,
dilation=1,
groups=1,
# device=device,
device=device,
dtype=dtype,
)
init_param_(self.conv_layer, **init_param_kwargs)
Expand Down
28 changes: 17 additions & 11 deletions neuralfields/potential_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
output_size: Optional[int] = None,
input_embedding: Optional[nn.Module] = None,
output_embedding: Optional[nn.Module] = None,
device: Union[str, torch.device] = "cpu",
):
"""
Args:
Expand All @@ -53,6 +54,7 @@ def __init__(
output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
By default, a [linear layer][torch.nn.Linear] without biases is used.
device: Device to move this module to (after initialization).
"""
# Call torch.nn.Module's constructor.
super().__init__()
Expand All @@ -65,38 +67,42 @@ def __init__(
self.input_size = input_size
self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer
self.output_size = self._hidden_size if output_size is None else output_size
self._stimuli_external = torch.zeros(self.hidden_size)
self._stimuli_internal = torch.zeros(self.hidden_size)
self._stimuli_external = torch.zeros(self.hidden_size, device=device)
self._stimuli_internal = torch.zeros(self.hidden_size, device=device)

# Create the common layers.
self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False)
self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False)

# Initialize the values of the potentials.
if potentials_init is not None:
self._potentials_init = potentials_init.detach().clone()
self._potentials_init = potentials_init.detach().clone().to(device=device)
else:
if activation_nonlin is torch.sigmoid:
self._potentials_init = -7 * torch.ones(1, self.hidden_size)
self._potentials_init = -7 * torch.ones(1, self.hidden_size, device=device)
else:
self._potentials_init = torch.zeros(1, self.hidden_size)
self._potentials_init = torch.zeros(1, self.hidden_size, device=device)

# Initialize the potentials' resting level, i.e., the asymptotic level without stimuli.
self.resting_level = nn.Parameter(torch.randn(self.hidden_size), requires_grad=True)
self.resting_level = nn.Parameter(torch.randn(self.hidden_size, device=device))

# Initialize the potential dynamics' time constant.
self.tau_learnable = tau_learnable
self._log_tau_init = torch.log(torch.as_tensor(tau_init, dtype=torch.get_default_dtype()).reshape(-1))
self._log_tau_init = torch.log(
torch.as_tensor(tau_init, device=device, dtype=torch.get_default_dtype()).reshape(-1)
)
if self.tau_learnable:
self._log_tau = nn.Parameter(self._log_tau_init, requires_grad=True)
self._log_tau = nn.Parameter(self._log_tau_init)
else:
self._log_tau = self._log_tau_init

# Initialize the potential dynamics' cubic decay.
self.kappa_learnable = kappa_learnable
self._log_kappa_init = torch.log(torch.as_tensor(kappa_init, dtype=torch.get_default_dtype()).reshape(-1))
self._log_kappa_init = torch.log(
torch.as_tensor(kappa_init, device=device, dtype=torch.get_default_dtype()).reshape(-1)
)
if self.kappa_learnable:
self._log_kappa = nn.Parameter(self._log_kappa_init, requires_grad=True)
self._log_kappa = nn.Parameter(self._log_kappa_init)
else:
self._log_kappa = self._log_kappa_init

Expand Down Expand Up @@ -182,7 +188,7 @@ def init_hidden(
if potentials_init is None:
if batch_size is None:
return self._potentials_init.view(-1)
return self._potentials_init.repeat(batch_size, 1)
return self._potentials_init.repeat(batch_size, 1).to(device=self.device)

return potentials_init.to(device=self.device)

Expand Down
19 changes: 8 additions & 11 deletions neuralfields/simple_neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def __init__(
potentials_init=potentials_init,
input_embedding=input_embedding,
output_embedding=output_embedding,
device=device,
)

# Create the layer that converts the activations of the previous time step into potentials (internal stimulus).
Expand All @@ -345,9 +346,9 @@ def __init__(
self.capacity_learnable = capacity_learnable
if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
if _is_iterable(activation_nonlin):
self._init_capacity(activation_nonlin[0])
self._init_capacity(activation_nonlin[0], device)
else:
self._init_capacity(activation_nonlin) # type: ignore[arg-type]
self._init_capacity(activation_nonlin, device) # type: ignore[arg-type]
else:
self._log_capacity = None

Expand All @@ -360,27 +361,23 @@ def __init__(
# Move the complete model to the given device.
self.to(device=device)

def _init_capacity(self, activation_nonlin: ActivationFunction) -> None:
def _init_capacity(self, activation_nonlin: ActivationFunction, device: Union[str, torch.device]) -> None:
"""Initialize the value of the capacity parameter $C$ depending on the activation function.
Args:
activation_nonlin: Nonlinear activation function used.
"""
if activation_nonlin is torch.sigmoid:
# sigmoid(7.) approx 0.999
self._log_capacity_init = torch.log(torch.tensor([7.0], dtype=torch.get_default_dtype()))
self._log_capacity_init = torch.log(torch.tensor([7.0], device=device, dtype=torch.get_default_dtype()))
self._log_capacity = (
nn.Parameter(self._log_capacity_init, requires_grad=True)
if self.capacity_learnable
else self._log_capacity_init
nn.Parameter(self._log_capacity_init) if self.capacity_learnable else self._log_capacity_init
)
elif activation_nonlin is torch.tanh:
# tanh(3.8) approx 0.999
self._log_capacity_init = torch.log(torch.tensor([3.8], dtype=torch.get_default_dtype()))
self._log_capacity_init = torch.log(torch.tensor([3.8], device=device, dtype=torch.get_default_dtype()))
self._log_capacity = (
nn.Parameter(self._log_capacity_init, requires_grad=True)
if self.capacity_learnable
else self._log_capacity_init
nn.Parameter(self._log_capacity_init) if self.capacity_learnable else self._log_capacity_init
)
else:
raise ValueError(
Expand Down
Loading

0 comments on commit 082ce51

Please sign in to comment.