-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add call override for use_running_average #683
Add call override for use_running_average #683
Conversation
Should we just get rid of the |
That was my first guess but I left it there because we use this very handy pattern in ImageNet:
Typically you have 1 template for a batchnorm layer throughout a network with a bunch of hyper params. This way you don't have to propagate train all throughout the model. |
Hmm. I haven't thought about this too carefully, but I'm uneasy about having two different way to specify If the |
If you specify a BatchNorm in setup you have to pass I don't think it's a footgun because it is usually passed as But it is definitely a little tricky. Partially applying modules kind of blur the line between hyper parameters and call arguments. |
I see. The goal is to support top-level modules that: (1) use
So in the case where someone develops entirely with submodules defined in The footgun note was more about something like the following (which we could solve by having the default attribute value for
|
BTW somehow PyTorch only lets you define the equivalent attribute in the constructor to BatchNorm (https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d) -- and they "define all submodules during setup" (in Flax speak). How do they get around this problem? Do people just create two different module instances at top level training code for Train vs Test time? |
Pytorch has the implicit training=True/False argument in every |
Oh wow. They ask you to globally, mutably, set your module into eval mode (https://discuss.pytorch.org/t/example-on-how-to-use-batch-norm/216/19)... |
I would like for us to get additional opinions about this before we go down this route. As I wrote in #683 (comment), I am afraid that more and more attributes will end up be such that you want to be able to override them this way, and this pattern infests our module implementations. If the reasoning is solely to let folks define a single module instance at their top-level that they can use for train or eval, then maybe there's a simpler solution which is to define, at the top-level: def MyModule(nn.Module):
train: bool
num_features: int
def setup(self):
self.bn1 = BatchNorm(use_running_average=not self.train, [...])
def __call__(self, x):
[...]
x = self.bn1(x)
MyModule = ResNet.partial(num_features=1000)
MyModuleTrain = MyModule(train=True)
MyModuleEval = MyModule(train=False) |
This way we go back to the hyper parameters are just partially applied arguments pattern. |
I just think we'll also get negative feedback about this override pattern (I myself don't like it). So if we are going down this path I want to hear opinions of more folks, showing concretely the two options on the table. |
@jheek @avital as someone who's new to JAX / FLAX, after getting used to things, I think either approach can work I just feel it should be consistent. Consistent in that I'd like to pass my train/not train state either exclusively through call args or exclusively through dataclass attr but not a mix of both as is the case between BatchNorm vs Dropout right now. If for some reason it is deemed that stochastic layers like dropout are different, and deterministic/non det state may change with more granularity, couldn't that also apply to batchnorm in some usecases where the phases of using running stats vs not might be finer grained than train / not train? Essentially, why this (where I've wound up right now).
and not either of these (ignoring potential use of partial binding on purpose)
EDIT: I guess I should add, given a choice, I'd pass Aside from train / not train, what if you wanted to freeze a subset of BN layers to use running stats at some epoch (in both train + eval), scale a dropout rate with epoch (% of training), disable a stochastic layer (ie drop block) after a certain epoch.... should the preferred pattern be to instantiate a new model with each phase transition that has updated attr or pass the relevant bools/scaling factors via call args? In any case it seems like it should be consistent and not a mix. |
e04449c
to
5e1a7c2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jheek thanks for writing up https://github.com/google/flax/pull/683/files/5e1a7c24787f498cbd24f81d3516ef3a52e2adb8#diff-a4e69f05db605d788793032faac416daf270becbb42ef8abd8f08c69946bd373.
- I suggested changes to the language to make the different terms "parameter" "property" "hyperparameter" "attribute" "argument" more precise, and other clarifications to improve readability. PTAL.
- I think perhaps we should make these live in Read The Docs? And add a notes/ directory with a README that just points there?
docs/notes/parameters.md
Outdated
@@ -0,0 +1,89 @@ | |||
# Design note: dealing with Module paramaters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Design note: dealing with Module paramaters | |
# Design Note: Dealing with Module Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of the word "parameters" is confusing here because Modules have self.param
, self.params
which mean something different.
Can we call these "module arguments" or "module hyperparameters"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, maybe it's "module properties"? Because it's something that ends up either being an argument or hyperparameter, so we need a third word. And it can't be parameter.
docs/notes/parameters.md
Outdated
|
||
@nn.compact | ||
def __call__(self, x, deterministic=None): | ||
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What purpose does the first 'deterministic'
string argument serve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's for a helpful error message. I should maybe make that optional though.
@rwightman - for myself I tend to agree with you that I prefer consistency in applying this sort of thing, and in practice so far haven't found myself wanting to exploit a finer degree of granularity for things like dropout / batchnorm, but I'm curious if there are usecases for such that demand it that I'm unaware of. @jheek @avital - In a lot of recent model development for large language models I have adopted the pattern of defining several top-level configurations for a model: TrainModel / EvalModel / PredictModel that sets these train/eval/infer-time options across the model. (I do so by typically aggregating "global" model options in a single dataclass that's threaded through the entirety of the "user-defined" portions of the model and actually have separate train/eval/predict configs, but this is an implementation detail.) |
If we follow @levskaya 's proposal to encourage multiple top-level modules, e.g. |
@avital One way to do things, or at least a canonical way certainly makes for less confusion. @levskaya Are any of those 'large language models', or models using that pattern open-source and available to look at? Helpful to digest more examples... For the cases I'm thinking of where one might definitely need to pass something through call args to decay dropout rates, they are dropout(ish) layers that'd need their own impl anyways. For the majority of use cases I can think of with a normal dropout / BN layers the 'phases' are fairly coarse, even if they don't align with train/not train -- different top-level modules would suffice. One thing I did notice with existing layers / examples is that the |
I like the dataclass config pattern especially for larger models. I'm not sure about having multiple top level models though. It's not great for the functional training loop patterns because you cannot easily write your training+eval loop with just an apply function now (important for interop with other Jax libraries). The second issue is interactive mode and composability. What if I want to dropout or use_running_average one some inputs but not others (let's say in a GAN discriminator)? Or how would I switch the mode if I have an interactive module instance? @rwightman I fully agree that there should be no default for these mode like arguments. |
Hey everybody :-), We could replace this line: Line 154 in 36eaa1e
by def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None
deterministic: Optional[bool] = None):
...
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic) ? It would solve the problem described in #821 we're having at the moment for Flax Transformers :-) |
e973657
to
e662a1b
Compare
Codecov Report
@@ Coverage Diff @@
## master #683 +/- ##
==========================================
+ Coverage 81.22% 81.24% +0.01%
==========================================
Files 57 57
Lines 4533 4548 +15
==========================================
+ Hits 3682 3695 +13
- Misses 851 853 +2
Continue to review full report at Codecov.
|
e662a1b
to
cd8c6b2
Compare
cd8c6b2
to
5959ce8
Compare
5959ce8
to
655d74a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few small comments/suggestions, but generally LGTM!
Co-authored-by: Avital Oliver <avital@thewe.net>
No description provided.