Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
famura committed Jul 8, 2024
1 parent c95122c commit 47101de
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 31 deletions.
8 changes: 5 additions & 3 deletions neuralfields/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ def __init__(
self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3

# Initialize the weights values the same way PyTorch does.
new_weight_init = torch.zeros(self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size)
new_weight_init = torch.zeros(
self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size, device=device
)
nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5))

# Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here).
self.weight = nn.Parameter(new_weight_init, requires_grad=True)
self.weight = nn.Parameter(new_weight_init)

def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights,
Expand All @@ -228,7 +230,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
3-dim output tensor just like for [Conv1d][torch.nn.Conv1d].
"""
# Reconstruct symmetric weights for convolution (original size).
mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype)
mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype, device=self.weight.device)

# Loop over input channels.
for i in range(self.orig_weight_shape[1]):
Expand Down
3 changes: 2 additions & 1 deletion neuralfields/neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
potentials_init=potentials_init,
input_embedding=input_embedding,
output_embedding=output_embedding,
device=device,
)

# Create the custom convolution layer that models the interconnection of neurons, i.e., their potentials.
Expand All @@ -113,7 +114,7 @@ def __init__(
stride=1,
dilation=1,
groups=1,
# device=device,
device=device,
dtype=dtype,
)
init_param_(self.conv_layer, **init_param_kwargs)
Expand Down
26 changes: 16 additions & 10 deletions neuralfields/potential_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
output_size: Optional[int] = None,
input_embedding: Optional[nn.Module] = None,
output_embedding: Optional[nn.Module] = None,
device: Union[str, torch.device] = "cpu",
):
"""
Args:
Expand All @@ -53,6 +54,7 @@ def __init__(
output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
By default, a [linear layer][torch.nn.Linear] without biases is used.
device: Device to move this module to (after initialization).
"""
# Call torch.nn.Module's constructor.
super().__init__()
Expand All @@ -65,38 +67,42 @@ def __init__(
self.input_size = input_size
self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer
self.output_size = self._hidden_size if output_size is None else output_size
self._stimuli_external = torch.zeros(self.hidden_size)
self._stimuli_internal = torch.zeros(self.hidden_size)
self._stimuli_external = torch.zeros(self.hidden_size, device=device)
self._stimuli_internal = torch.zeros(self.hidden_size, device=device)

# Create the common layers.
self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False)
self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False)

# Initialize the values of the potentials.
if potentials_init is not None:
self._potentials_init = potentials_init.detach().clone()
self._potentials_init = potentials_init.detach().clone().to(device=device)
else:
if activation_nonlin is torch.sigmoid:
self._potentials_init = -7 * torch.ones(1, self.hidden_size)
self._potentials_init = -7 * torch.ones(1, self.hidden_size, device=device)
else:
self._potentials_init = torch.zeros(1, self.hidden_size)
self._potentials_init = torch.zeros(1, self.hidden_size, device=device)

# Initialize the potentials' resting level, i.e., the asymptotic level without stimuli.
self.resting_level = nn.Parameter(torch.randn(self.hidden_size), requires_grad=True)
self.resting_level = nn.Parameter(torch.randn(self.hidden_size, device=device))

# Initialize the potential dynamics' time constant.
self.tau_learnable = tau_learnable
self._log_tau_init = torch.log(torch.as_tensor(tau_init, dtype=torch.get_default_dtype()).reshape(-1))
self._log_tau_init = torch.log(
torch.as_tensor(tau_init, device=device, dtype=torch.get_default_dtype()).reshape(-1)
)
if self.tau_learnable:
self._log_tau = nn.Parameter(self._log_tau_init, requires_grad=True)
self._log_tau = nn.Parameter(self._log_tau_init)
else:
self._log_tau = self._log_tau_init

# Initialize the potential dynamics' cubic decay.
self.kappa_learnable = kappa_learnable
self._log_kappa_init = torch.log(torch.as_tensor(kappa_init, dtype=torch.get_default_dtype()).reshape(-1))
self._log_kappa_init = torch.log(
torch.as_tensor(kappa_init, device=device, dtype=torch.get_default_dtype()).reshape(-1)
)
if self.kappa_learnable:
self._log_kappa = nn.Parameter(self._log_kappa_init, requires_grad=True)
self._log_kappa = nn.Parameter(self._log_kappa_init)
else:
self._log_kappa = self._log_kappa_init

Expand Down
19 changes: 8 additions & 11 deletions neuralfields/simple_neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def __init__(
potentials_init=potentials_init,
input_embedding=input_embedding,
output_embedding=output_embedding,
device=device,
)

# Create the layer that converts the activations of the previous time step into potentials (internal stimulus).
Expand All @@ -345,9 +346,9 @@ def __init__(
self.capacity_learnable = capacity_learnable
if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
if _is_iterable(activation_nonlin):
self._init_capacity(activation_nonlin[0])
self._init_capacity(activation_nonlin[0], device)
else:
self._init_capacity(activation_nonlin) # type: ignore[arg-type]
self._init_capacity(activation_nonlin, device) # type: ignore[arg-type]
else:
self._log_capacity = None

Expand All @@ -360,27 +361,23 @@ def __init__(
# Move the complete model to the given device.
self.to(device=device)

def _init_capacity(self, activation_nonlin: ActivationFunction) -> None:
def _init_capacity(self, activation_nonlin: ActivationFunction, device: Union[str, torch.device]) -> None:
"""Initialize the value of the capacity parameter $C$ depending on the activation function.
Args:
activation_nonlin: Nonlinear activation function used.
"""
if activation_nonlin is torch.sigmoid:
# sigmoid(7.) approx 0.999
self._log_capacity_init = torch.log(torch.tensor([7.0], dtype=torch.get_default_dtype()))
self._log_capacity_init = torch.log(torch.tensor([7.0], device=device, dtype=torch.get_default_dtype()))
self._log_capacity = (
nn.Parameter(self._log_capacity_init, requires_grad=True)
if self.capacity_learnable
else self._log_capacity_init
nn.Parameter(self._log_capacity_init) if self.capacity_learnable else self._log_capacity_init
)
elif activation_nonlin is torch.tanh:
# tanh(3.8) approx 0.999
self._log_capacity_init = torch.log(torch.tensor([3.8], dtype=torch.get_default_dtype()))
self._log_capacity_init = torch.log(torch.tensor([3.8], device=device, dtype=torch.get_default_dtype()))
self._log_capacity = (
nn.Parameter(self._log_capacity_init, requires_grad=True)
if self.capacity_learnable
else self._log_capacity_init
nn.Parameter(self._log_capacity_init) if self.capacity_learnable else self._log_capacity_init
)
else:
raise ValueError(
Expand Down
18 changes: 12 additions & 6 deletions tests/test_simple_neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ def test_simple_neural_fields(
# Get and set the parameters.
param_vec = snf.param_values
assert isinstance(param_vec, torch.Tensor)
new_param_vec = param_vec + torch.randn_like(param_vec)
new_param_vec = param_vec + torch.randn_like(param_vec, device=device)
snf.param_values = new_param_vec
assert torch.allclose(snf.param_values, new_param_vec)

# Compute dp/dt.
for _ in range(10):
p_dot = snf.potentials_dot(potentials=torch.randn(hidden_size), stimuli=torch.randn(hidden_size))
p_dot = snf.potentials_dot(
potentials=torch.randn(hidden_size, device=device), stimuli=torch.randn(hidden_size, device=device)
)
assert isinstance(p_dot, torch.Tensor)
assert p_dot.shape == (hidden_size,)
assert isinstance(snf.stimuli_internal, torch.Tensor)
Expand All @@ -100,7 +102,7 @@ def test_simple_neural_fields(
# Compute the unbatched forward pass.
hidden = None
for _ in range(5):
outputs, hidden_next = snf.forward_one_step(inputs=torch.randn(input_size), hidden=hidden)
outputs, hidden_next = snf.forward_one_step(inputs=torch.randn(input_size, device=device), hidden=hidden)
hidden = hidden_next.clone()
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == (1, output_size or snf.hidden_size)
Expand All @@ -110,16 +112,20 @@ def test_simple_neural_fields(
# Compute the batched forward pass.
hidden = None
for _ in range(5):
outputs, hidden_next = snf.forward_one_step(inputs=torch.randn(batch_size, input_size), hidden=hidden)
outputs, hidden_next = snf.forward_one_step(
inputs=torch.randn(batch_size, input_size, device=device), hidden=hidden
)
hidden = hidden_next.clone()
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == (batch_size, output_size or snf.hidden_size)
assert isinstance(hidden_next, torch.Tensor)
assert hidden_next.shape == (batch_size, snf.hidden_size)

# Evaluate a time series of inputs.
for hidden in (None, torch.randn(batch_size, hidden_size)):
output_seq, hidden_seq = snf.forward(inputs=torch.randn(batch_size, len_input_seq, input_size), hidden=hidden)
for hidden in (None, torch.randn(batch_size, hidden_size, device=device)):
output_seq, hidden_seq = snf.forward(
inputs=torch.randn(batch_size, len_input_seq, input_size, device=device), hidden=hidden
)
assert isinstance(output_seq, torch.Tensor)
assert output_seq.shape == (batch_size, len_input_seq, output_size or snf.hidden_size)
assert isinstance(hidden_seq, torch.Tensor)
Expand Down

0 comments on commit 47101de

Please sign in to comment.