Skip to content

Commit

Permalink
Fixing a few more bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 5, 2024
1 parent 7b1ad16 commit 87e2c83
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions examples/high_order_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def run_language_interpolation(cfg: DictConfig):
if cfg.net.model_type in [
"high_order_transformer",
"high_order_input_transformer",
"dual_convolution"
]:
print('Using transformer dataloader')
# dataset_generator is only one type so using the default
datamodule = TransformerDataModule(
characters_per_feature=cfg.data.characters_per_feature,
Expand Down
7 changes: 4 additions & 3 deletions language_interpolation/dual_convolutional_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
super().__init__()

self._out_width = out_width
self.device = device

self.input_layer = HighOrderMLP(
layer_type="continuous",
Expand Down Expand Up @@ -68,7 +69,7 @@ def forward(self, x: Tensor):
if val.shape[1] % 2 == 1:
# Add padding to the end, hope this doesn't bust anything
val = torch.cat(
[val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1
[val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1, device=self.device
)

valshape = val.shape
Expand All @@ -91,12 +92,12 @@ def __init__(
segments: int = None,
device: str = "cpu",
):

super().__init__()
self.dual_layer = DualConvolutionalLayer(
n=n,
in_width=in_width,
out_width=embedding_dimension,
hidden_layer=hidden_layers,
hidden_layers=hidden_layers,
hidden_width=hidden_width,
in_segments=in_segments,
segments=segments,
Expand Down
2 changes: 1 addition & 1 deletion language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def forward(self, x) :
max_context=cfg.data.max_features,
non_linearity=torch.nn.ReLU(),
)
elif cfg.net.model_type == "dual_convolutional_network":
elif cfg.net.model_type == "dual_convolution":
model = DualConvolutionNetwork(
n=cfg.net.n,
in_width=cfg.net.in_width,
Expand Down

0 comments on commit 87e2c83

Please sign in to comment.