-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhance] Support reading class_weight from file in loss functions to help MMDet3D #513
Conversation
Codecov Report
@@ Coverage Diff @@
## master #513 +/- ##
==========================================
+ Coverage 86.56% 86.58% +0.02%
==========================================
Files 99 99
Lines 5164 5172 +8
Branches 836 838 +2
==========================================
+ Hits 4470 4478 +8
Misses 535 535
Partials 159 159
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
This PR is ready for review @xvjiarui. Thanks! |
… help MMDet3D (open-mmlab#513) * support reading class_weight from file in loss function * add unit test of loss with class_weight from file * minor fix * move get_class_weight to utils
* Add `init_weights` method to `FlaxMixin` * Rn `random_state` -> `shape_state` * `PRNGKey(0)` for `jax.eval_shape` * No allow mismatched sizes * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * docstring diffusers Co-authored-by: Suraj Patil <surajp815@gmail.com>
* First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. * Remove FlaxUNet2DConfig class. * ignore_for_config non-config args. * Implement `FlaxModelMixin` * Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. * Import `FlaxUNet2DConditionModel` if flax is available. * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu> * Fix typo in transformer block. * make style * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Style: remove empty line. * Apply style (after upgrading black from pinned version) * Remove some commented code and unused imports. * Add init_weights (not yet in use until open-mmlab#513). * Trickle down deterministic to blocks. * Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. * Flax UNet docstrings, default props as in PyTorch. * Fix minor typos in PyTorch docstrings. * Use FlaxUNet2DConditionOutput as output from UNet. * make style Co-authored-by: Mishig Davaadorj <dmishig@gmail.com> Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* resolve comments * update changelog * update changelog * add master * fix changelog * remove FAQ
Hi, I am Ziyi, a developer from MMDet3D. Currently we are implementing 3D segmentation models, which require using loss functions from MMSeg (e.g. CrossEntropy). It's very common to use class_weight in 3D Seg, and in some scenarios we need to load class_weight from files. So I modify the code and support this feature. Hope you can approve it. Thanks.