Skip to content

Commit

Permalink
Merge branch 'dev' into add_truncate_normal_distribution_for_torch_di…
Browse files Browse the repository at this point in the history
…stributions
  • Loading branch information
melopeo authored Aug 22, 2023
2 parents 0f5912b + 482e940 commit e3c8dcc
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 88 deletions.
70 changes: 3 additions & 67 deletions src/gluonts/torch/model/wavenet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
Expand All @@ -41,7 +41,7 @@
InstanceSplitter,
SetField,
RemoveFields,
SimpleTransformation,
QuantizeMeanScaled,
VstackFeatures,
Identity,
TestSplitSampler,
Expand All @@ -67,70 +67,6 @@
]


class QuantizeScaled(SimpleTransformation):
"""Rescale and quantize the target variable.
Requires `past_target_field`, and `future_target_field` to be present.
The mean absolute value of the past_target is used to rescale
past_target and future_target. Then the bin_edges are used to quantize
the rescaled target.
The calculated scale is stored in the `scale_field`.
Parameters
----------
bin_edges
The bin edges for quantization.
past_target_field, optional
The field name that contains `past_target`,
by default "past_target"
past_observed_values_field, optional
The field name that contains `past_observed_values`,
by default "past_observed_values"
future_target_field, optional
The field name that contains `future_target`,
by default "future_target"
scale_field, optional
The field name where scale will be stored,
by default "scale"
"""

@validated()
def __init__(
self,
bin_edges: List[float],
past_target_field: str = "past_target",
past_observed_values_field: str = "past_observed_values",
future_target_field: str = "future_target",
scale_field: str = "scale",
):
self.bin_edges = np.array(bin_edges)
self.future_target_field = future_target_field
self.past_target_field = past_target_field
self.past_observed_values_field = past_observed_values_field
self.scale_field = scale_field

def transform(self, data: DataEntry) -> DataEntry:
target = data[self.past_target_field]
weights = data.get(
self.past_observed_values_field, np.ones_like(target)
)
m = np.sum(np.abs(target) * weights) / np.sum(weights)
scale = m if m > 0 else 1.0
data[self.future_target_field] = np.digitize(
data[self.future_target_field] / scale,
bins=self.bin_edges,
right=False,
)
data[self.past_target_field] = np.digitize(
data[self.past_target_field] / scale,
bins=self.bin_edges,
right=False,
)
data[self.scale_field] = np.array([scale], dtype=np.float32)
return data


class WaveNetEstimator(PyTorchLightningEstimator):
@validated()
def __init__(
Expand Down Expand Up @@ -392,7 +328,7 @@ def _create_instance_splitter(self, mode: str):
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
) + QuantizeScaled(bin_edges=self.bin_edges)
) + QuantizeMeanScaled(bin_edges=self.bin_edges)

def create_training_data_loader(
self,
Expand Down
22 changes: 1 addition & 21 deletions src/gluonts/torch/model/wavenet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,7 @@

from gluonts.core.component import validated
from gluonts.torch.modules.feature import FeatureEmbedder


class LookupValues(nn.Module):
"""Lookup bin values from bin indices.
Parameters
----------
bin_values
Tensor of bin values with shape (num_bins, ).
"""

@validated()
def __init__(self, bin_values: torch.Tensor):
super().__init__()
self.register_buffer("bin_values", bin_values)

def forward(self, indices: torch.Tensor) -> torch.Tensor:
indices = torch.clamp(indices, 0, self.bin_values.shape[0] - 1)
return torch.index_select(
self.bin_values, 0, indices.reshape(-1)
).view_as(indices)
from gluonts.torch.modules.lookup_table import LookupValues


class CausalDilatedResidualLayer(nn.Module):
Expand Down
38 changes: 38 additions & 0 deletions src/gluonts/torch/modules/lookup_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import torch
import torch.nn as nn

from gluonts.core.component import validated


class LookupValues(nn.Module):
"""A lookup table mapping bin indices to values.
Parameters
----------
bin_values
Tensor of bin values with shape (num_bins, ).
"""

@validated()
def __init__(self, bin_values: torch.Tensor):
super().__init__()
self.register_buffer("bin_values", bin_values)

def forward(self, indices: torch.Tensor) -> torch.Tensor:
indices = torch.clamp(indices, 0, self.bin_values.shape[0] - 1)
return torch.index_select(
self.bin_values, 0, indices.reshape(-1)
).view_as(indices)
2 changes: 2 additions & 0 deletions src/gluonts/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"InstanceSplitter",
"ListFeatures",
"MapTransformation",
"QuantizeMeanScaled",
"RemoveFields",
"RenameFields",
"SampleTargetDim",
Expand Down Expand Up @@ -89,6 +90,7 @@
VstackFeatures,
cdf_to_gaussian_forward_transform,
Valmap,
QuantizeMeanScaled,
)
from .feature import (
AddAgeFeature,
Expand Down
64 changes: 64 additions & 0 deletions src/gluonts/transform/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,3 +904,67 @@ def flatmap_transform(
if len(times) > 0 or not self.drop_empty:
data[self.target_field] = [times, sizes]
yield data


class QuantizeMeanScaled(SimpleTransformation):
"""Rescale and quantize the target variable.
Requires `past_target_field`, and `future_target_field` to be present.
The mean absolute value of the past_target is used to rescale
past_target and future_target. Then the bin_edges are used to quantize
the rescaled target.
The calculated scale is stored in the `scale_field`.
Parameters
----------
bin_edges
The bin edges for quantization.
past_target_field, optional
The field name that contains `past_target`,
by default "past_target"
past_observed_values_field, optional
The field name that contains `past_observed_values`,
by default "past_observed_values"
future_target_field, optional
The field name that contains `future_target`,
by default "future_target"
scale_field, optional
The field name where scale will be stored,
by default "scale"
"""

@validated()
def __init__(
self,
bin_edges: List[float],
past_target_field: str = "past_target",
past_observed_values_field: str = "past_observed_values",
future_target_field: str = "future_target",
scale_field: str = "scale",
):
self.bin_edges = np.array(bin_edges)
self.future_target_field = future_target_field
self.past_target_field = past_target_field
self.past_observed_values_field = past_observed_values_field
self.scale_field = scale_field

def transform(self, data: DataEntry) -> DataEntry:
target = data[self.past_target_field]
weights = data.get(
self.past_observed_values_field, np.ones_like(target)
)
m = np.sum(np.abs(target) * weights) / np.sum(weights)
scale = m if m > 0 else 1.0
data[self.future_target_field] = np.digitize(
data[self.future_target_field] / scale,
bins=self.bin_edges,
right=False,
)
data[self.past_target_field] = np.digitize(
data[self.past_target_field] / scale,
bins=self.bin_edges,
right=False,
)
data[self.scale_field] = np.array([scale], dtype=np.float32)
return data

0 comments on commit e3c8dcc

Please sign in to comment.