Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Observed_masks not behaving as expected #28914

Closed
2 of 4 tasks
Tracked by #33345
dparr005 opened this issue Feb 7, 2024 · 18 comments
Closed
2 of 4 tasks
Tracked by #33345

Observed_masks not behaving as expected #28914

dparr005 opened this issue Feb 7, 2024 · 18 comments

Comments

@dparr005
Copy link

dparr005 commented Feb 7, 2024

System Info

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
use_cpu: false

Who can help?

@pacman100 @muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am doing TimeSeriesTransformerForPrediction but I am getting the following error when trying to train the model.

torch/nn/parallel/distributed.py", line 1026, in forward
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 2: 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
 In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

Expected behavior

For context, this issue is happening for TimeSeriesTransformerForPrediction.

From what I can tell, it is only happening when there are 0's in the beginning of the past_values segments. I believe that the error is getting thrown because the past_observed_mask is putting a corresponding 0 for all the 0's before any non-zero value (see pictures below). I would like the algorithm to learn/train on the 0's, since they are indeed 0's and not NaN or missing values (as the 0 in the past_observed_mask description would infer).

image
image

When I take the advice of the error message and set the find_unused_parameters=True, I get the following error:

ValueError: Expected parameter df (Tensor of shape (256, 45)) of distribution Chi2() to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values.

Can someone please advice how to fix this issue?

@mueller
Copy link

mueller commented Feb 7, 2024

Sorry, wrong mueller. Can't really help you😄 But thanks for suggesting me!

@dparr005
Copy link
Author

dparr005 commented Feb 7, 2024

Oops, sorry. I hope I fixed it now.

@ArthurZucker
Copy link
Collaborator

cc @niels as well 🤗

@amyeroberts
Copy link
Collaborator

Hi @dparr005, thanks for opening this issue!

Could you share a minimal code snippet to reproduce the error?

cc @kashif

@kashif
Copy link
Contributor

kashif commented Feb 8, 2024

checking thanks!

@kashif
Copy link
Contributor

kashif commented Feb 8, 2024

@dparr005 can you try running the training on a single GPU to see the issue? and since your data has somewhat sane magnitudes, perhaps also set your scaling=None in the config

@dparr005
Copy link
Author

dparr005 commented Feb 8, 2024

I just used the basis from: https://huggingface.co/blog/time-series-transformers with my own data. I am not sure what portion of the code you are asking for.

In addition, I was able to run the code on a single GPU already (using a local Jupyter Notebook). But when I run it on the HPC cluster using multi-GPU, it does not work. My hypothesis is that somehow it is seeding the samples differently perhaps and that is why it runs on Jupyter Notebook and not using the multi-GPU configuration.

@dparr005
Copy link
Author

dparr005 commented Feb 8, 2024

Actually, the error more looks like the following:

future_observed_mask=batch["future_observed_mask"].to(device)
packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
accelerate/utils/operations.py", line 553, in forward
    return model_forward(*args, **kwargs)
packages/accelerate/utils/operations.py", line 541, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
packages/torch/nn/parallel/distributed.py", line 1026, in forward
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 

If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 0: 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

@kashif
Copy link
Contributor

kashif commented Feb 8, 2024

i see @dparr005 so the issue is multi-gpu training... perhaps i need to gather the losses etc., I have a multi-gpu setup now so can finally test

@dparr005
Copy link
Author

dparr005 commented Feb 8, 2024

Yes. That was my other hypothesis, that somehow the code is expecting a gather statement (from all GPU's once a single training epoch is done) before going to the next epoch.

What do you need from me to test this hypothesis?

@dparr005
Copy link
Author

I just ran the given code (from gitlab) in a multiple GPU environment but it gives the same type of errors.

The distribution environment is:

partialState:  Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 3
Process index: 0
Local process index: 0
Device: cuda:0

I converted the Jupyter Notebook into a .py script and am calling it from SLURM. Via:
crun -p ~/envs/ts_tnn python -m accelerate.commands.launch --config_file config.yaml --num_processes=3 --multi_gpu fromGithub.py

Can anyone help me? It seems to me that it is an issue with implementing accelerate.

@muellerzr
Copy link
Contributor

What's your exact code look like here?

@dparr005
Copy link
Author

from data_utils import *
from datasets import load_dataset
from functools import partial
from gluonts.time_feature import get_lags_for_frequency
from gluonts.time_feature import time_features_from_frequency_str
from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerForPrediction
from accelerate import Accelerator
from torch.optim import AdamW
from evaluate import load
from gluonts.time_feature import get_seasonality
import matplotlib.dates as mdates
from accelerate.state import PartialState



def getModel(prediction_length, lags_sequence, time_features_len, train_dataset_len):
    config = TimeSeriesTransformerConfig(
        prediction_length=prediction_length,
        # context length:
        context_length=prediction_length * 2,
        # lags coming from helper given the freq:
        lags_sequence=lags_sequence,
        # we'll add 2 time features ("month of year" and "age", see further):
        num_time_features=time_features_len + 1,
        # we have a single static categorical feature, namely time series ID:
        num_static_categorical_features=1,
        # it has 366 possible values:
        cardinality=[train_dataset_len],
        # the model will learn an embedding of size 2 for each of the 366 possible values:
        embedding_dimension=[2],

        # transformer params:
        encoder_layers=4,
        decoder_layers=4,
        d_model=32,
    )
    model = TimeSeriesTransformerForPrediction(config)
    return config, model



    


def main():
    ### set up data
    dataset = load_dataset("monash_tsf", "tourism_monthly")
    train_example = dataset["train"][0]
    validation_example = dataset["validation"][0]
    freq = "1M"
    prediction_length = 24

    assert len(train_example["target"]) + prediction_length == len(
        validation_example["target"]
    )
    train_dataset = dataset["train"]
    test_dataset = dataset["test"]
    
    ### make sure that the data is in the correct form
    train_dataset.set_transform(partial(transform_start_field, freq=freq))
    test_dataset.set_transform(partial(transform_start_field, freq=freq))
    lags_sequence = get_lags_for_frequency(freq)
    time_features = time_features_from_frequency_str(freq)
    config, model = getModel(prediction_length, lags_sequence, len(time_features), len(train_dataset))
    
    # get data loaders:
    train_dataloader = create_train_dataloader(
        config=config,
        freq=freq,
        data=train_dataset,
        batch_size=256,
        num_batches_per_epoch=100,
    )

    test_dataloader = create_backtest_dataloader(
        config=config,
        freq=freq,
        data=test_dataset,
        batch_size=64,
    )
    
    ### Init accelerator
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)
    ps = PartialState()
    print("partialState: ", ps)
    optimizer = AdamW(model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)

    model, optimizer, train_dataloader = accelerator.prepare(
        model,
        optimizer,
        train_dataloader,
    )

    model.train()
    for epoch in range(40):
        for idx, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            outputs = model(
                static_categorical_features=batch["static_categorical_features"].to(device)
                if config.num_static_categorical_features > 0
                else None,
                static_real_features=batch["static_real_features"].to(device)
                if config.num_static_real_features > 0
                else None,
                past_time_features=batch["past_time_features"].to(device),
                past_values=batch["past_values"].to(device),
                future_time_features=batch["future_time_features"].to(device),
                future_values=batch["future_values"].to(device),
                past_observed_mask=batch["past_observed_mask"].to(device),
                future_observed_mask=batch["future_observed_mask"].to(device),
            )
            loss = outputs.loss

            # Backpropagation
            accelerator.backward(loss)
            optimizer.step()
            if idx % 100 == 0:
                print("loss:", loss.item())
                
    ### Inference
    model.eval()

    forecasts = []

    for batch in test_dataloader:
        outputs = model.generate(
            static_categorical_features=batch["static_categorical_features"].to(device)
            if config.num_static_categorical_features > 0
            else None,
            static_real_features=batch["static_real_features"].to(device)
            if config.num_static_real_features > 0
            else None,
            past_time_features=batch["past_time_features"].to(device),
            past_values=batch["past_values"].to(device),
            future_time_features=batch["future_time_features"].to(device),
            past_observed_mask=batch["past_observed_mask"].to(device),
        )
        forecasts.append(outputs.sequences.cpu().numpy())
    forecasts = np.vstack(forecasts)
    mase_metric = load("evaluate-metric/mase")
    smape_metric = load("evaluate-metric/smape")

    forecast_median = np.median(forecasts, 1)

    mase_metrics = []
    smape_metrics = []
    for item_id, ts in enumerate(test_dataset):
        training_data = ts["target"][:-prediction_length]
        ground_truth = ts["target"][-prediction_length:]
        mase = mase_metric.compute(
            predictions=forecast_median[item_id],
            references=np.array(ground_truth),
            training=np.array(training_data),
            periodicity=get_seasonality(freq),
        )
        mase_metrics.append(mase["mase"])

        smape = smape_metric.compute(
            predictions=forecast_median[item_id],
            references=np.array(ground_truth),
        )
        smape_metrics.append(smape["smape"])
        
    ### print results of the evaluation
    print(f"MASE: {np.mean(mase_metrics)}")
    print(f"sMAPE: {np.mean(smape_metrics)}")
    
    plt.scatter(mase_metrics, smape_metrics, alpha=0.3)
    plt.xlabel("MASE")
    plt.ylabel("sMAPE")
    plt.savefig("figures/github_results.pdf")
    plt.show()
    
    
    def plot(ts_index):
        fig, ax = plt.subplots()

        index = pd.period_range(
            start=test_dataset[ts_index][FieldName.START],
            periods=len(test_dataset[ts_index][FieldName.TARGET]),
            freq=freq,
        ).to_timestamp()

        # Major ticks every half year, minor ticks every month,
        ax.xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1, 7)))
        ax.xaxis.set_minor_locator(mdates.MonthLocator())

        ax.plot(
            index[-2 * prediction_length :],
            test_dataset[ts_index]["target"][-2 * prediction_length :],
            label="actual",
        )

        plt.plot(
            index[-prediction_length:],
            np.median(forecasts[ts_index], axis=0),
            label="median",
        )

        plt.fill_between(
            index[-prediction_length:],
            forecasts[ts_index].mean(0) - forecasts[ts_index].std(axis=0),
            forecasts[ts_index].mean(0) + forecasts[ts_index].std(axis=0),
            alpha=0.3,
            interpolate=True,
            label="+/- 1-std",
        )
        plt.legend()
        plt.show()
    
main()

@dparr005
Copy link
Author

This is the code from data_utils.py. The above code was fromGithub.py

from datasets import DatasetDict
from gluonts.itertools import Map
from datasets import Dataset, Features, Value, Sequence
from gluonts.dataset.pandas import PandasDataset
import datasets
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
from functools import lru_cache
from functools import partial
from gluonts.time_feature import get_lags_for_frequency
from gluonts.time_feature import time_features_from_frequency_str
from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerForPrediction
import os
from os.path import exists
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch
import torch.distributed as dist
from accelerate import Accelerator
from torch.optim import AdamW
import sys
from gluonts.time_feature import (
    time_features_from_frequency_str,
    TimeFeature,
    get_lags_for_frequency,
)
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
    AddAgeFeature,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AsNumpyArray,
    Chain,
    ExpectedNumInstanceSampler,
    InstanceSplitter,
    RemoveFields,
    SelectFields,
    SetField,
    TestSplitSampler,
    Transformation,
    ValidationSplitSampler,
    VstackFeatures,
    RenameFields,
)

from transformers import PretrainedConfig
from gluonts.transform.sampler import InstanceSampler
from typing import Optional
from typing import Iterable
import torch
from gluonts.itertools import Cyclic, Cached
from gluonts.dataset.loader import as_stacked_batches
import matplotlib.dates as mdates


def getRank():
    try:
        local_rank = int(os.environ["LOCAL_RANK"])
    except KeyError:
        local_rank = 0
    return local_rank

class ProcessStartField():
    ts_id = 0
    def __call__(self, data):
        data["start"] = data["start"].to_timestamp()
        data["feat_static_cat"] = [self.ts_id]
        self.ts_id += 1
        return data

@lru_cache(10_000)
def convert_to_pandas_period(date, freq):
    return pd.Period(date, freq)


def transform_start_field(batch, freq):
    batch["start"] = [convert_to_pandas_period(date, freq) for date in batch["start"]]
    return batch
def create_transformation(freq: str, config: PretrainedConfig) -> Transformation:
    remove_field_names = []
    if config.num_static_real_features == 0:
        remove_field_names.append(FieldName.FEAT_STATIC_REAL)
    if config.num_dynamic_real_features == 0:
        remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
    if config.num_static_categorical_features == 0:
        remove_field_names.append(FieldName.FEAT_STATIC_CAT)

    # a bit like torchvision.transforms.Compose
    return Chain(
        # step 1: remove static/dynamic fields if not specified
        [RemoveFields(field_names=remove_field_names)]
        # step 2: convert the data to NumPy (potentially not needed)
        + (
            [
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_CAT,
                    expected_ndim=1,
                    dtype=int,
                )
            ]
            if config.num_static_categorical_features > 0
            else []
        )
        + (
            [
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_REAL,
                    expected_ndim=1,
                )
            ]
            if config.num_static_real_features > 0
            else []
        )
        + [
            AsNumpyArray(
                field=FieldName.TARGET,
                # we expect an extra dim for the multivariate case:
                expected_ndim=1 if config.input_size == 1 else 2,
            ),
            # step 3: handle the NaN's by filling in the target with zero
            # and return the mask (which is in the observed values)
            # true for observed values, false for nan's
            # the decoder uses this mask (no loss is incurred for unobserved values)
            # see loss_weights inside the xxxForPrediction model
            AddObservedValuesIndicator(
                target_field=FieldName.TARGET,
                output_field=FieldName.OBSERVED_VALUES,
            ),
            # step 4: add temporal features based on freq of the dataset
            # month of year in the case when freq="M"
            # these serve as positional encodings
            AddTimeFeatures(
                start_field=FieldName.START,
                target_field=FieldName.TARGET,
                output_field=FieldName.FEAT_TIME,
                time_features=time_features_from_frequency_str(freq),
                pred_length=config.prediction_length,
            ),
            # step 5: add another temporal feature (just a single number)
            # tells the model where in the life the value of the time series is
            # sort of running counter
            AddAgeFeature(
                target_field=FieldName.TARGET,
                output_field=FieldName.FEAT_AGE,
                pred_length=config.prediction_length,
                log_scale=True,
            ),
            # step 6: vertically stack all the temporal features into the key FEAT_TIME
            VstackFeatures(
                output_field=FieldName.FEAT_TIME,
                input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
                + (
                    [FieldName.FEAT_DYNAMIC_REAL]
                    if config.num_dynamic_real_features > 0
                    else []
                ),
            ),
            # step 7: rename to match HuggingFace names
            RenameFields(
                mapping={
                    FieldName.FEAT_STATIC_CAT: "static_categorical_features",
                    FieldName.FEAT_STATIC_REAL: "static_real_features",
                    FieldName.FEAT_TIME: "time_features",
                    FieldName.TARGET: "values",
                    FieldName.OBSERVED_VALUES: "observed_mask",
                }
            ),
        ]
    )
def create_instance_splitter(
    config: PretrainedConfig,
    mode: str,
    train_sampler: Optional[InstanceSampler] = None,
    validation_sampler: Optional[InstanceSampler] = None,
) -> Transformation:
    assert mode in ["train", "validation", "test"]

    instance_sampler = {
        "train": train_sampler
        or ExpectedNumInstanceSampler(
            num_instances=1.0, min_future=config.prediction_length
        ),
        "validation": validation_sampler
        or ValidationSplitSampler(min_future=config.prediction_length),
        "test": TestSplitSampler(),
    }[mode]

    return InstanceSplitter(
        target_field="values",
        is_pad_field=FieldName.IS_PAD,
        start_field=FieldName.START,
        forecast_start_field=FieldName.FORECAST_START,
        instance_sampler=instance_sampler,
        past_length=config.context_length + max(config.lags_sequence),
        future_length=config.prediction_length,
        time_series_fields=["time_features", "observed_mask"],
    )

def create_train_dataloader(
    config: PretrainedConfig,
    freq,
    data,
    batch_size: int,
    num_batches_per_epoch: int,
    shuffle_buffer_length: Optional[int] = None,
    cache_data: bool = True,
    **kwargs,
) -> Iterable:
    PREDICTION_INPUT_NAMES = [
        "past_time_features",
        "past_values",
        "past_observed_mask",
        "future_time_features",
    ]
    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")

    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")

    TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
        "future_values",
        "future_observed_mask",
    ]

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data, is_train=True)
    if cache_data:
        transformed_data = Cached(transformed_data)

    # we initialize a Training instance
    instance_splitter = create_instance_splitter(config, "train")

    # the instance splitter will sample a window of
    # context length + lags + prediction length (from the 366 possible transformed time series)
    # randomly from within the target time series and return an iterator.
    stream = Cyclic(transformed_data).stream()
    training_instances = instance_splitter.apply(stream)
    
    return as_stacked_batches(
        training_instances,
        batch_size=batch_size,
        shuffle_buffer_length=shuffle_buffer_length,
        field_names=TRAINING_INPUT_NAMES,
        output_type=torch.tensor,
        num_batches_per_epoch=num_batches_per_epoch,
    )

def create_backtest_dataloader(
    config: PretrainedConfig,
    freq,
    data,
    batch_size: int,
    **kwargs,
):
    PREDICTION_INPUT_NAMES = [
        "past_time_features",
        "past_values",
        "past_observed_mask",
        "future_time_features",
    ]
    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")

    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data)

    # We create a Validation Instance splitter which will sample the very last
    # context window seen during training only for the encoder.
    instance_sampler = create_instance_splitter(config, "validation")

    # we apply the transformations in train mode
    testing_instances = instance_sampler.apply(transformed_data, is_train=True)
    
    return as_stacked_batches(
        testing_instances,
        batch_size=batch_size,
        output_type=torch.tensor,
        field_names=PREDICTION_INPUT_NAMES,
    )

def create_test_dataloader(
    config: PretrainedConfig,
    freq,
    data,
    batch_size: int,
    **kwargs,
):
    PREDICTION_INPUT_NAMES = [
        "past_time_features",
        "past_values",
        "past_observed_mask",
        "future_time_features",
    ]
    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")

    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data, is_train=False)

    # We create a test Instance splitter to sample the very last
    # context window from the dataset provided.
    instance_sampler = create_instance_splitter(config, "test")

    # We apply the transformations in test mode
    testing_instances = instance_sampler.apply(transformed_data, is_train=False)
    
    return as_stacked_batches(
        testing_instances,
        batch_size=batch_size,
        output_type=torch.tensor,
        field_names=PREDICTION_INPUT_NAMES,
    )

@dparr005
Copy link
Author

As a reminder, the code is taken from the following github code. It works as a Jupyter Notebook but not as a python script launched via SLURM.

@huggingface huggingface deleted a comment from github-actions bot Mar 11, 2024
@huggingface huggingface deleted a comment from github-actions bot Apr 5, 2024
@huggingface huggingface deleted a comment from github-actions bot Apr 30, 2024
@huggingface huggingface deleted a comment from github-actions bot May 28, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @muellerzr

@huggingface huggingface deleted a comment from github-actions bot Jun 23, 2024
@huggingface huggingface deleted a comment from github-actions bot Jul 18, 2024
@huggingface huggingface deleted a comment from github-actions bot Aug 15, 2024
@amyeroberts
Copy link
Collaborator

Another ping @muellerzr

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants