Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax/Flax pretraining of wav2vec2 #19588

Closed
aapot opened this issue Oct 13, 2022 · 34 comments
Closed

Jax/Flax pretraining of wav2vec2 #19588

aapot opened this issue Oct 13, 2022 · 34 comments

Comments

@aapot
Copy link

aapot commented Oct 13, 2022

There is Jax/Flax based script available for pretraining wav2vec2 here. I have been trying to pretrain new wav2vec2 model for Finnish on TPUs using that script but it seems impossible to get the model train properly. I know the script is under research_projects so I am wondering if anyone has been able to succesfully pretrain wav2vec2 models with it? Or if anyone has made own updates to the script to fix potential problems?

For me, it looks like the codevector_perplexity will always collapse to value of 2 and stay there which I believe is not a good thing. Also, the constrastive loss is usually very unstable. I attached the image below showcasing those issues. In addition, I have tried pretraining the wav2vec2 with the official fairseq implementation where the training looks to be working fine without those issues. So I believe HF Jax/Flax implementation is broken somehow.

image

Also, I think the HF Jax/Flax wav2vec2 implementation is not fully on par with the HF PyTorch wav2vec2 implementation. For example, I noticed this comment by @patrickvonplaten #14471 (comment) and I think the comment's point number 1 is not implemented in the Jax/Flax version. Also, on Pytorch wav2vec2 pretraining PR comment #13877 (comment) gradient scaling is implemented to avoid issues on multiple devices training. I wonder if same would be needed for Jax/Flax script when training on 8 TPU cores? I tried implementing those myself but then I found this script where @patrickvonplaten seemed to have already implemented the first point number 1: https://huggingface.co/patrickvonplaten/wav2vec2-german-flax/blob/main/run_wav2vec2_pretrain_flax.py

Anyhow, even with those potential fixes I haven't been able to get the training work properly. That's really pity since the Jax/Flax training would be really great when using TPUs.

@LysandreJik
Copy link
Member

cc @sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Oct 14, 2022

Hey @aapot! Cool to see you're trying out pre-training of Wav2Vec2 in Flax on Finnish 🇫🇮

Indeed, the script is under 'research projects' as it remains unsolved. Pre-training of Wav2Vec2 models is notoriously difficult due to issues with stability, giving rise to the phenomena you've experienced such as code vector collapse and unstable contrastive loss. AFAIK there isn't a working implementation for training Transformers W2V2 models in Flax, which makes it an interesting topic to pursue!

You've done a great job at digging through issues and PRs to find the aforementioned points! Both the points you've raised look to be missing from the Flax Wav2Vec2 script. Did you try gradient scaling in your experiments?

One thing we can try is running the PyTorch and Flax scripts step-by-step in parallel and inspecting where they diverge. We can do this with a tiny dummy model (hf-internal-testing/tiny-random-wav2vec2 for instance) to make it fast to debug and the same training inputs. When we've identified a divergence between the two we can fix the Flax script by porting the corresponding PyTorch code. LMK if you'd be interested in doing this and I can provide further pointers!

@patrickvonplaten
Copy link
Contributor

Just for reference, I never got Wav2Vec2 to work in JAX, but it should def be possible (didn't spent too much time on it)

@aapot
Copy link
Author

aapot commented Oct 15, 2022

@sanchit-gandhi yep this would be interesting to get working! Yes, I also tried gradient scaling like it was implemented in the PyTorch pretrain script (basically multiply gradients with (num devices / total samples)) without luck.

I'd be interested in putting some time into fixing this so feel free to provide further pointers. Training of these ASR and NLP models for Finnish is a free time hobby project with @R4ZZ3 so cannot promise anything yet but let's get this fixed 🤗

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Oct 17, 2022

Awesome @aapot, that's great to hear!

Essentially what you want to do is run the PyTorch script and Flax script with identical args (for the model, data and training args). In doing this, the PyTorch and Flax models should receive identical inputs, and thus should compute identical losses if the training scripts are the same.

What you want to then do is compare the outputs of the PT and Flax training scripts after each step of pre-training:

  1. First check that the data collators are identical by inspecting the returned elements of the batch ("input_values", "attention_mask", "mask_time_indices") -> we need to make sure the inputs to the models are the same before we can assess the model outputs
  2. Check Gumbel temp is the same
  3. Check outputs of the models are the same (projected_quantized_states, projected_states, codevector_perplexity)
  4. Check contrastive loss is the same
  5. Check diversity loss is the same -> once all the losses match then we can move onto making sure the gradients and updates are the same (easier to verify, and very much likely to be the case if the losses are the same)

It's likely the bug in the Flax script lies in 3, 4 or 5! Once you identify where the losses deviate, you can dig deeper into the code for PT and Flax and try to find the line(s) of code where the functionality is different.

How you debug this is up to you. To make this quick and easy, I'd recommend using a dummy model (hf-internal-testing/tiny-random-wav2vec2) and a dummy dataset (hf-internal-testing/librispeech_asr_dummy) -> in total this is about 10MB of downloaded data and the script should run very fast. I'd also first run training on CPU only for both PT and Flax, such that the number of devices are fixed equal to one (no gradient scaling effects).

For comparing the outputs, you can either run the scripts side-by-side and print intermediate values, or combine them into a single notebook and print cell outputs after each step (I can give you a template for this if you want to use a notebook). Print statements are easy to use, but don't provide much detail other than numeric values. What I'd do is first add print statements for each of the items listed in 1-5 to quickly see which values match ✅ and which values don't ❌. After that you can go deeper with either: more print statements, breakpoints (ipdb), or a debugger.

It might take a bit of time to establish a good set-up for debugging quickly, but once you've got this set-up it should be a case of finding where the losses are different and then fixing for Flax! You might also need to disable shuffling of the training dataset to make sure the training inputs are passed in the same way to PT as Flax.

These should make for good starting points (haven't tried them, but they're similar to the configs I use for debugging ASR fine-tuning):

PT

python run_wav2vec2_pretraining.py \
	--dataset_name="hf-internal-testing/librispeech_asr_dummy" \
	--dataset_config_names="clean" \
	--train_split_name="validation" \
	--model_name_or_path="hf-internal-testing/tiny-random-wav2vec2" \
	--output_dir="./" \
	--max_train_steps="10" \
	--num_warmup_steps="2" \
	--learning_rate="0.005" \
	--logging_steps="1" \
	--save_strategy="no" \
	--per_device_train_batch_size="8" \
        --do_train

Flax

JAX_PLATFORM_NAME=cpu python run_wav2vec2_pretrain_flax.py \
	--dataset_name="hf-internal-testing/librispeech_asr_dummy" \
	--dataset_config_names="clean" \
	--train_split_name="validation" \
	--model_name_or_path="hf-internal-testing/tiny-random-wav2vec2" \
	--output_dir="./" \
	--max_train_steps="10" \
	--num_warmup_steps="2" \
	--learning_rate="0.005" \
	--logging_steps="1" \
	--save_strategy="no" \
	--per_device_train_batch_size="8" \
        --do_train

@aapot
Copy link
Author

aapot commented Oct 18, 2022

Thanks for those pointers @sanchit-gandhi, sounds reasonable! I'll start digging into this soon, will keep you updated here.

@Aaryan369
Copy link

Hi,

For me, it looks like the codevector_perplexity will always collapse to value of 2 and stay there which I believe is not a good thing. Also, the constrastive loss is usually very unstable.

I tried pre-training the jax wav2vec2 model on my own data and I came across similar problems. Tried with multiple huge chunks of my own dataset and the perplexity always collapsed to 2 while the loss fluctuated a lot. I also noticed that across all my datasets the eval loss was always 0.09969.

So, if I finetune this pretrained model, will it give any good results?

Also do you guys have any code to fine-tune this pretrained model that I can use?

@R4ZZ3
Copy link

R4ZZ3 commented Oct 20, 2022

For finetuning we have used these resources as base:
https://huggingface.co/blog/fine-tune-wav2vec2-english
https://huggingface.co/blog/wav2vec2-with-ngram

Also we are trying out going to try out these. We just need to fix some of our datasets before that as we have lover case material. Luckily we have trained T5 model for casing + punctuation correction. https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#sequence-to-sequence

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Oct 20, 2022

Hey @Aaryan369! Thanks for sharing your experience - it seems like there's an inherent bug in the JAX pre-training implementation with how the loss terms are computed leading to code vector perplexity collapse and unstable loss.

You can certainly try fine-tuning a pre-trained Wav2Vec2 model. If your fine-tuning data is in-domain with the pre-training you can expect good results with very little data - as little as 10 minutes as shown by the Wav2Vec2 paper!

If your fine-tuning data is more out-of-domain with the pre-training data, you can expect to require much more data to achieve good results. This is really on a case-by-case basis, so you'll have to make that decision based on what you know about your fine-tuning situation!

In terms of pre-trained models, there are English-only checkpoints:

And multilingual ones (https://huggingface.co/facebook/wav2vec2-large-xlsr-53 for example). The English-only ones will fare better for English speech tasks, and the multilingual ones for most others.

The resources @R4ZZ3 has kindly linked are perfect for fine-tuning in PyTorch. If you want to fine-tune in JAX, I'd advise you to try: https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_flax_speech_recognition_ctc.py This script closely resembles the PyTorch one in Transformers: https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition

It's on my list to add this JAX CTC fine-tuning script to Transformers over the coming weeks!

@peregilk
Copy link
Contributor

We are compiling a large speech corpus in Norwegian (100k+ hours). We expect it to be ready in roughly a month. Our plan is to pretrain a Wav2Vec2. We have access to TPUs through TRC and ideally we would like to train this in Flax instead of XLA/PT.

This is a high priority project for us, and I am happy to assist in both testing and debugging here.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Oct 28, 2022

100k is mega! Very excited to see how pre-training JAX Wav2Vec2 in Finnish goes with this much data. Just out of interest, are you set on producing a pre-trained checkpoint in Finnish? Or is the end goal downstream ASR? The multilingual Whisper models are pre-trained on 1066h of labelled Finnish audio-transcription data (out of 670,000h total). They get good results with zero-shot transfer learning (i.e. no fine-tuning) on Finnish Common Voice 9 (17.0% WER) and Finnish VoxPopuli (15.5% WER), c.f. Tables 11 and 12 from the Whisper paper. You could definitely improve upon these results with fine-tuning! Might be a faster route to a performant, downstream Finnish ASR model than Wav2Vec2 pre-training + fine-tuning?

@aapot
Copy link
Author

aapot commented Oct 29, 2022

Just a quick update so I finally had time to start the actual debugging. More info to follow soon.

@sanchit-gandhi
Copy link
Contributor

Great! Keep us posted!

@aapot
Copy link
Author

aapot commented Nov 2, 2022

Alright, here are some findings so far:
Step 1: mask_time_indices and sampled_negative_indices were not same with the PT implementation. Fixed that by pretty much just copying functions for those from PT to Flax.
Step 2: PT and Flax gumbel decay was using different step number, fixed that by deducting Flax step number by one for gumbel decaying. After that, gumbel temp seemed to remain same for about the first 5 steps, after that it started deviating tiny bit between Flax and PT which was weird. Although this probably is not our biggest problem at the moment.
Step 3: Comparing model outputs seemed bit hard, I guess because model weights are initialized differently at random?

I have found couple differences between Flax and PT model code so far:

  1. Flax was missing layerdrop functionality, fixed that.
  2. Flax and PT gumbel softmax was implemented differently. PT version uses hard=True option with torch.nn.functional.gumbel_softmax which results in returned samples as discretized one-hot vectors. Flax gumbel softmax implementation returns soft samples. I tried implement the hard option to Flax by copying it from PT code. My current implementation looks like this:
y_soft = nn.softmax((hidden_states + gumbels) / temperature)
index = y_soft.argmax(axis=-1)
y_hard = jnp.zeros_like(hidden_states).at[jnp.arange(len(hidden_states)), index].set(1.0)
codevector_probs = y_hard - y_soft + y_soft

when the PT code looks like this:

y_soft = gumbels.softmax(dim)
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft

At first, I also had the PT's y_soft.detach() implemented as codevector_probs = y_hard - jax.lax.stop_gradient(y_soft) + y_soft but I noticed it seemed to make the codevector collapse again. Without it, based on some testing the Flax codevector doesn't seem to collapse anymore (the codevector_perplexity is raising and staying on high level, not collapsing close to zero as originally). Although the model doesn't still seem to learn properly so I bet there still are more to investigate and fix. It also could be that my Flax gumbel softmax hard option is not yet implemented correctly.

In addition, I have made initial updates (some more smaller updates could still be made) to the run_wav2vec2_pretrain_flax.py script to make it more up to date and comparable to the PT run_wav2vec2_pretraining_no_trainer.py script. My updates are available here on my fork and branch: https://github.com/aapot/transformers/tree/w2v2-jax-flax-pretrain

@aapot
Copy link
Author

aapot commented Nov 4, 2022

Continuing with the updates:
Step 4. contrastive loss calculation is same with Flax and PT
Step 5. diversity loss calculation looks to be same but I'll verify that later

@peregilk
Copy link
Contributor

peregilk commented Nov 4, 2022

Really great work, @aapot. I do however understand that there are still some issues here (since the contrastive loss starts to increase after a while), and that the issue most likely is related to the Flax gumbel implementation. Any chance that anyone at 🤗 can take a look at that? What do you think @sanchit-gandhi @patrickvonplaten ?

When this is done Ill be glad to contribute with larger training, and finetuning/testing on downstream tasks.

@sanchit-gandhi
Copy link
Contributor

Comparing model outputs seemed bit hard, I guess because model weights are initialized differently at random?

You could hack into the code and load pre-trained weights! I'd recommend the checkpoint at https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2
PyTorch:

from transformers import Wav2Vec2ForPreTraining

model = Wav2Vec2ForPreTraining.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")

JAX:

from transformers import FlaxWav2Vec2ForPreTraining

model = FlaxWav2Vec2ForPreTraining.from_pretrained("hf-internal-testing/tiny-random-wav2vec2", from_pt=True)

=> this will initialise the models with the same weights!

From the PyTorch code, it seems as though we should break y_soft from the computation graph in the codevector_probs calculation. Maybe worth quickly double checking what they do in fairseq here as well? jax.lax.stop_gradient can be a bit fiddly but I think it's the best option for stoping the backprop for a variable.

Sounds like you're making good progress @aapot! Keep us posted with updates and questions, happy to help!

@sanchit-gandhi
Copy link
Contributor

Oh one more question! You're running both on CPU right? JAX will definitely diverge from PT on GPU/TPU due to differences in the matmul precision (c.f. jax-ml/jax#10413 (comment))

@aapot
Copy link
Author

aapot commented Nov 4, 2022

Thanks for the tips @sanchit-gandhi!

Actually I also had in mind to use pre-trained weights to compare model outputs that way, will try it soon.

Will also check fairseq implementation if that could reveal more stuff to fix.

Yup, I am running both Jax and PT on my local laptop with CPU when debugging.

@sanchit-gandhi
Copy link
Contributor

Okay great! Tiny pre-trained models on CPU is the way to go here!

@aapot
Copy link
Author

aapot commented Nov 8, 2022

After using pre-trained weights to continue pretraining for one more step with same input data, I think following is happening with model outputs:

  • projected_states has max difference of 0.276 (abs of Flax and PT matrices deducted from each other and max value of the deducted matrix)
  • projected_quantized_states has max difference of 0.588
  • codevector_perplexity is same

projected_quantized_states difference is due to the GumbelVectorQuantizer because its input extract_features from the wav2vec2 module is actually matching for Flax and PT. Maybe the difference happening in GumbelVectorQuantizer is because of randomized gumbel sampling?

In addition, I checked the fairseq gumbel softmax implementation and they are also using the PyTorch's torch.nn.functional.gumbel_softmax with the hard=True option. I am starting to think the main problem could be in this gumbel softmax implementation in Flax.

If someone could verify if using codevector_probs = y_hard - jax.lax.stop_gradient(y_soft) + y_soft version will make the codevector to collapse (perplexity) that would be great. For me, I think using codevector_probs = y_hard - y_soft + y_soft won't make it collapse but not sure if that's the correct approach either for implementing the gumbel softmax in Flax. For example, with the local Flax VS PT testing with PT the codevector perplexity starts to rise from ~100 to ~400 over 5 epochs of pretraining from scratch. With Flax without using jax.lax.stop_gradient the perplexity rises very similarly. But if I use jax.lax.stop_gradient the perplexity rises only to ~250. Sometime ago I tried test the same with real base-sized w2v2 Flax model to pretrain with Finnish data and with jax.lax.stop_gradient the codevector perplexity seemed to collapse totally quite early at the training.

@peregilk
Copy link
Contributor

peregilk commented Nov 8, 2022

Fantastic work @aapot!

I noticed the following comment in the pull from @patrickvonplaten to @ThomAub:
"PyTorch module to Flax? This might be a bit difficult and require some googling to see if others have already implement gumbel softmax in jax/Flax or not. If you could take a look at this, it would be very useful!" (#12271 (comment)).

May there be issues here?

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Nov 9, 2022

abs of Flax and PT matrices deducted from each other and max value of the deducted matrix

This is exactly the way we want to compute differences between PT and Flax in the projected states space 👌 For reference, a matching implementation should have a max abs difference of 1e-5.

Maybe the difference happening in GumbelVectorQuantizer is because of randomized gumbel sampling?

This seems logical! What I would do is dive into the GumbelVectorQuantizer and check the intermediate variables up to where the randomised sampling is performed. If they match up until the sampling that's a good sign. Forcing sampling between PT and Flax to be the same is a bit tricky... IMO we have two options:

  1. Pre-define a sequence of 'pseudo-random' matrices. Hard code these in PT and Flax (e.g. 3 matrices of the correct dimension, pre-defined elements, with the same elements used in PT and Flax). Replace the sampled matrix with one of our pre-defined matrices in the GumbelVectorQuantizer at each training step: this ensures the matrices are the same in PT and Flax.
  2. Temporarily use the PT implementation of the randomised Gumbel sampling in the Flax script such that the same seed is used and thus the same pseudo-random numbers. Will requires sampling a PyTorch tensor and then converting back to a jnp array.

Unfortunately, both of these methods are a bit hacky. The first might be easier IMO - you don't have to define it to be anything too crazy, and just 2 or 3 different matrices would do (we just need to verify the outputs are the same over 2 or 3 training steps).

I am starting to think the main problem could be in this gumbel softmax implementation in Flax.

Sounds like we're narrowing down!

Maybe we can try forcing the same Gumbel quantiser outputs and then experiment with / without stop gradient. The fact that fairseq and HF PT use y_hard suggests we should use stop gradient!

@sanchit-gandhi
Copy link
Contributor

Good point @peregilk - worth having a look to see if there are any OSS implementations of the Gumbel ops in JAX/Flax online! (as far as I'm aware there's not, but might be wrong!)

@peregilk
Copy link
Contributor

peregilk commented Dec 4, 2022

Please keep this issue open. It is still activity going on for solving this issue.

@huggingface huggingface deleted a comment from github-actions bot Dec 5, 2022
@sanchit-gandhi
Copy link
Contributor

Hope the analysis is going ok @aapot, think you're doing a great job here! Feel free to share any updates / ask questions, more than happy to help!

@aapot
Copy link
Author

aapot commented Dec 23, 2022

Hi @sanchit-gandhi, unfortunately I have been very busy the past month so haven't had time to investigate more about this jax gumbel quantizer. Now that the recent Hugging Face Whisper finetuning event is over (where I participated too), I'll get back to debugging this wav2vec2 pretraining after a short Christmas break :) In any case, I am planning to create PR of my current work even if the Gumbel quantizer would not get fixed because my current branch has pretty much updated the Wav2vec2 flax model and pretraining code implemetation up to date with the Pytorch version. But I hope we get the Gumbel part fixed too.

@sanchit-gandhi
Copy link
Contributor

Hey @aapot! Hope you had a nice Xmas break and that you enjoyed the Whisper event 🙂

Thanks for the update! Sounds good regarding opening a PR with the current changes - these are certainly welcome fixes! We can iterate on the PR to see if we can get the Gumbel part fixed too. Feel free to ping me here or on the new PR with questions / queries - more than happy to help and excited to see this one to completion!

@aapot
Copy link
Author

aapot commented Jan 10, 2023

@sanchit-gandhi quick update on the GumbelVectorQuantizer with the option 1 you mentioned earlier (replace gumbel sampled matrix with predefined matrix).
First, I checked that hidden_states inside GumbelVectorQuantizer just before the actual gumbel sampling had diff of 5.7e-07 between Flax and PT for the first training step so that looks good.
Next, I saved matrices of PT nn.functional.gumbel_softmax for the first three steps and then used them inside Flax GumbelVectorQuantizer for the first three steps. By doing that, model output's projected_quantized_states were actually the same between PT and Flax for the first training step (diff 0).
But for the second step, the projected_quantized_states diff already jumped to 0.3 (although the diff before the linear projection layer was 0.01 so the linear projection adds some diff to the projected_quantized_states output. For the second step, hidden_states also had diff of 0.38 inside GumbelVectorQuantizer.
For the third step diverging continues by having diff of 0.47 for projected_quantized_states (0.02 before linear projection), and diff of 0.43 for the hidden_states inside GumbelVectorQuantizer.

Any ideas how to proceed?

@sanchit-gandhi
Copy link
Contributor

Hey @aapot!

Thanks for the update - really cool to see the progress you're making here! Sounds like you've got a nice system going for debugging and comparing the PT-FX outputs!

That's great the hidden_states are equivalent before the Gumbel sampling ✅ And good to see that the codevectors had a diff of 0 - exactly what we wanted by forcing the sampled matrix! Was the codevector_perplexity also equivalent in this case?

From the last experiment, it sounds pretty likely that the nn.Module's are equivalent now between PT and Flax (we're getting the same tensors out when we override the Gumbel sampling step).

I would suggest we quickly verify that all the loss terms are equivalent with this non-deterministic set-up. If they match for the first training step, that's perfect, it means we should have a closely matching implementation.

Note that with our 'forced sampling' method, we can verify that we get the same losses between PT and Flax, but since we change how the code vectors are computed in Flax (by forcing the sampled Gumbel matrix) we can't expect the gradients to be correct - forcing the Gumbel sampling is going to mess-up the backprop, so anything after the first parameter update is going to be divergent.

So once we've verified that all the loss terms are the same (contrastive, diversity, total loss), I would re-instate stochastic sampling of the Gumbel matrix in Flax and see whether we can train a stable system!

How does that sound?

@aapot
Copy link
Author

aapot commented Jan 30, 2023

@sanchit-gandhi sounds reasonable! codevector_perplexity diff was in the range of 1e-4 with fixed gumbels. I also checked loss terms and constrast_loss diff is 2e-3, div_loss diff is 2e-7, and total_loss is 2e-3. What would the best way to try train a stable system next?

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Feb 2, 2023

Hey @aapot! Awesome - thanks for getting back with these results! Really enjoying hearing your updates here! Shall we double check that the code vector perplexity is being computed correctly? The diff for this value & the contrastive loss looks a little high for a dummy model (should be < 1e-5)! We can quickly check the code vector ppl function and verify that it matches PT (and correct if not!)

@huggingface huggingface deleted a comment from github-actions bot Mar 3, 2023
@huggingface huggingface deleted a comment from github-actions bot Mar 31, 2023
@sanchit-gandhi
Copy link
Contributor

Hey @aapot! We're really close here! Any chance you've had the opportunity to look into the code vector perplexity and the error propagation onto the contrastive loss? Once we're confident with these we can start scaling up to full training runs

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants