Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add call override for use_running_average #683

Merged

Conversation

jheek
Copy link
Member

@jheek jheek commented Nov 27, 2020

No description provided.

@jheek jheek requested a review from levskaya November 27, 2020 09:57
@google-cla google-cla bot added the cla: yes label Nov 27, 2020
@avital
Copy link
Contributor

avital commented Nov 30, 2020

Should we just get rid of the use_running_average attribute entirely?

@jheek
Copy link
Member Author

jheek commented Nov 30, 2020

That was my first guess but I left it there because we use this very handy pattern in ImageNet:

norm = partial(nn.BatchNorm, use_running_average=not train, momentum=, eps=, ..).

Typically you have 1 template for a batchnorm layer throughout a network with a bunch of hyper params. This way you don't have to propagate train all throughout the model.

@avital
Copy link
Contributor

avital commented Dec 1, 2020

Hmm. I haven't thought about this too carefully, but I'm uneasy about having two different way to specify use_running_average for Batch Norm. That seems not so clean, as well as a footgun for users ("but I specified use_running_average! why isn't it using my setting?")

If the partial use-case is strong, maybe that's an argument against this PR. What is the argument in favor of it? (I didn't see any motivation in the description of this PR, maybe we should make sure to have those)

@jheek
Copy link
Member Author

jheek commented Dec 1, 2020

If you specify a BatchNorm in setup you have to pass use_running_average in call because you don't know whether train is True/False.

I don't think it's a footgun because it is usually passed as use_running_average=not train and you either can't pass it because you don't have train yet or you define it inside compact and you can pick either.

But it is definitely a little tricky. Partially applying modules kind of blur the line between hyper parameters and call arguments.

@avital
Copy link
Contributor

avital commented Dec 2, 2020

I see. The goal is to support top-level modules that: (1) use setup and (2) where the training code only instantiates one module instance for both train and test time, e.g.

class MyModule(nn.Module):
  def setup(self):
    self.bn1 = BatchNorm(...)

  def __call__(self, x, is_training):
    return self.bn1(x)  # how do you use `is_training` here?

So in the case where someone develops entirely with submodules defined in setup, does this imply that all attributes end up needing to be overridable in __call__ (for all layers?)


The footgun note was more about something like the following (which we could solve by having the default attribute value for use_running_average be a sentinel object like UNDECIDED, and then only let you specify a value in __call__ if the instance attribute value is UNDECIDED:

class MyModule(nn.Module):
  def setup(self):
    # a user may look at this and be surprised that the batch norm ends up not
    # using the running average (especially if the module declaration is long and it's
    # not easy to find where `self.bn1` is used (or maybe in multiple places)
    self.bn1 = BatchNorm(use_running_average=True)

  def __call__(self, x):
    return self.bn1(x, use_running_average=False)

@avital
Copy link
Contributor

avital commented Dec 2, 2020

BTW somehow PyTorch only lets you define the equivalent attribute in the constructor to BatchNorm (https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d) -- and they "define all submodules during setup" (in Flax speak). How do they get around this problem? Do people just create two different module instances at top level training code for Train vs Test time?

@jheek
Copy link
Member Author

jheek commented Dec 2, 2020

Pytorch has the implicit training=True/False argument in every Module. So it will still only track it if you are in training mode

@avital
Copy link
Contributor

avital commented Dec 2, 2020

Oh wow. They ask you to globally, mutably, set your module into eval mode (https://discuss.pytorch.org/t/example-on-how-to-use-batch-norm/216/19)...

@avital
Copy link
Contributor

avital commented Dec 2, 2020

I would like for us to get additional opinions about this before we go down this route. As I wrote in #683 (comment), I am afraid that more and more attributes will end up be such that you want to be able to override them this way, and this pattern infests our module implementations.

If the reasoning is solely to let folks define a single module instance at their top-level that they can use for train or eval, then maybe there's a simpler solution which is to define, at the top-level:

def MyModule(nn.Module):
  train: bool
  num_features: int

  def setup(self):
    self.bn1 = BatchNorm(use_running_average=not self.train, [...])

  def __call__(self, x):
    [...]
    x = self.bn1(x)

MyModule = ResNet.partial(num_features=1000)
MyModuleTrain = MyModule(train=True)
MyModuleEval = MyModule(train=False)

@jheek
Copy link
Member Author

jheek commented Dec 2, 2020

This way we go back to the hyper parameters are just partially applied arguments pattern.
Both in the original abstraction & linen we could a lot of negative feedback on that.

@avital
Copy link
Contributor

avital commented Dec 2, 2020

I just think we'll also get negative feedback about this override pattern (I myself don't like it). So if we are going down this path I want to hear opinions of more folks, showing concretely the two options on the table.

@jheek jheek self-assigned this Dec 3, 2020
@rwightman
Copy link

rwightman commented Dec 10, 2020

@jheek @avital as someone who's new to JAX / FLAX, after getting used to things, I think either approach can work I just feel it should be consistent. Consistent in that I'd like to pass my train/not train state either exclusively through call args or exclusively through dataclass attr but not a mix of both as is the case between BatchNorm vs Dropout right now.

If for some reason it is deemed that stochastic layers like dropout are different, and deterministic/non det state may change with more granularity, couldn't that also apply to batchnorm in some usecases where the phases of using running stats vs not might be finer grained than train / not train?

Essentially, why this (where I've wound up right now).

def __call__(x, training: bool)
  ...
  x = self.norm_layer(name='bn', use_running_average=not training)(x)
  x = self.act_fn(x)
  x = x.mean((1, 2))
  x = Dropout(rate=self.drop_rate)(x, determinisitc=not training)
  ...

and not either of these (ignoring potential use of partial binding on purpose)

def __call__(x, training: bool)
  ...
  x = self.norm_layer(name='bn')(x, use_running_average=not training)
  x = self.act_fn(x)
  x = x.mean((1, 2))
  x = Dropout(rate=self.drop_rate)(x, determinisitc=training)
  ...

def __call__(x)
  ...
  x = self.norm_layer(name='bn', use_running_average=not self.train)(x)
  x = self.act_fn(x)
  x = x.mean((1, 2))
  x = Dropout(rate=self.drop_rate, determinisitc=not self.train)(x)
  ...

EDIT: I guess I should add, given a choice, I'd pass training state, or other state that might change with some phase of training exclusively via arg of call and not as module attr requiring a different module instantiation.

Aside from train / not train, what if you wanted to freeze a subset of BN layers to use running stats at some epoch (in both train + eval), scale a dropout rate with epoch (% of training), disable a stochastic layer (ie drop block) after a certain epoch.... should the preferred pattern be to instantiate a new model with each phase transition that has updated attr or pass the relevant bools/scaling factors via call args? In any case it seems like it should be consistent and not a mix.

@jheek jheek force-pushed the use_running_average-on-__call__ branch from e04449c to 5e1a7c2 Compare December 11, 2020 13:49
Copy link
Contributor

@avital avital left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jheek thanks for writing up https://github.com/google/flax/pull/683/files/5e1a7c24787f498cbd24f81d3516ef3a52e2adb8#diff-a4e69f05db605d788793032faac416daf270becbb42ef8abd8f08c69946bd373.

  1. I suggested changes to the language to make the different terms "parameter" "property" "hyperparameter" "attribute" "argument" more precise, and other clarifications to improve readability. PTAL.
  2. I think perhaps we should make these live in Read The Docs? And add a notes/ directory with a README that just points there?

@@ -0,0 +1,89 @@
# Design note: dealing with Module paramaters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Design note: dealing with Module paramaters
# Design Note: Dealing with Module Parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of the word "parameters" is confusing here because Modules have self.param, self.params which mean something different.

Can we call these "module arguments" or "module hyperparameters"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, maybe it's "module properties"? Because it's something that ends up either being an argument or hyperparameter, so we need a third word. And it can't be parameter.


@nn.compact
def __call__(self, x, deterministic=None):
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What purpose does the first 'deterministic' string argument serve?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's for a helpful error message. I should maybe make that optional though.

@levskaya
Copy link
Collaborator

@rwightman - for myself I tend to agree with you that I prefer consistency in applying this sort of thing, and in practice so far haven't found myself wanting to exploit a finer degree of granularity for things like dropout / batchnorm, but I'm curious if there are usecases for such that demand it that I'm unaware of.

@jheek @avital - In a lot of recent model development for large language models I have adopted the pattern of defining several top-level configurations for a model: TrainModel / EvalModel / PredictModel that sets these train/eval/infer-time options across the model. (I do so by typically aggregating "global" model options in a single dataclass that's threaded through the entirety of the "user-defined" portions of the model and actually have separate train/eval/predict configs, but this is an implementation detail.)

@avital
Copy link
Contributor

avital commented Dec 12, 2020

If we follow @levskaya 's proposal to encourage multiple top-level modules, e.g. TrainModel, EvalModel, PredictModel, then this proposal isn't necessary and we can go to using attributes for everything other than the actual input to a layer. This has the added benefit of having "one way to do things".

@rwightman
Copy link

rwightman commented Dec 12, 2020

@avital One way to do things, or at least a canonical way certainly makes for less confusion.

@levskaya Are any of those 'large language models', or models using that pattern open-source and available to look at? Helpful to digest more examples...

For the cases I'm thinking of where one might definitely need to pass something through call args to decay dropout rates, they are dropout(ish) layers that'd need their own impl anyways. For the majority of use cases I can think of with a normal dropout / BN layers the 'phases' are fairly coarse, even if they don't align with train/not train -- different top-level modules would suffice.

One thing I did notice with existing layers / examples is that the train, use_running_avg, deterministic arguments or attr often have default values. I'd recommend they don't for the simple reason that it can prevent a lot of hard to track down bugs. One big gotcha in PyTorch that has generated a lot of forum / issue tracker activity over the years is forgetting to flip the train/eval state of your models. Requiring the explicit setting or binding of those values can prevent bugs where all (or more insidiously, just a few) such layers in the model are in the wrong mode. I feel the verbosity vs ARRG tradeoff there is worth it.

@avital avital added this to the Improve Linen milestone Dec 12, 2020
@jheek
Copy link
Member Author

jheek commented Dec 15, 2020

I like the dataclass config pattern especially for larger models.

I'm not sure about having multiple top level models though. It's not great for the functional training loop patterns because you cannot easily write your training+eval loop with just an apply function now (important for interop with other Jax libraries).

The second issue is interactive mode and composability. What if I want to dropout or use_running_average one some inputs but not others (let's say in a GAN discriminator)? Or how would I switch the mode if I have an interactive module instance?

@rwightman I fully agree that there should be no default for these mode like arguments.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 13, 2021

Hey everybody :-),
could we maybe use this PR to add deterministic as an input param to MultiHeadDotProductAttention's __call__ method?

We could replace this line:

def __call__(self,

by

  def __call__(self,
               inputs_q: Array,
               inputs_kv: Array,
               mask: Optional[Array] = None
               deterministic: Optional[bool] = None):
  ...
  deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)

?

It would solve the problem described in #821 we're having at the moment for Flax Transformers :-)

@BertrandRdp BertrandRdp added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Jan 14, 2021
@jheek jheek force-pushed the use_running_average-on-__call__ branch 3 times, most recently from e973657 to e662a1b Compare February 3, 2021 13:35
@codecov-io
Copy link

Codecov Report

Merging #683 (e662a1b) into master (278df51) will increase coverage by 0.01%.
The diff coverage is 90.90%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #683      +/-   ##
==========================================
+ Coverage   81.22%   81.24%   +0.01%     
==========================================
  Files          57       57              
  Lines        4533     4548      +15     
==========================================
+ Hits         3682     3695      +13     
- Misses        851      853       +2     
Impacted Files Coverage Δ
flax/linen/module.py 94.83% <77.77%> (-0.48%) ⬇️
flax/linen/attention.py 96.26% <100.00%> (ø)
flax/linen/normalization.py 89.28% <100.00%> (+0.19%) ⬆️
flax/linen/stochastic.py 92.30% <100.00%> (+1.39%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 278df51...e662a1b. Read the comment docs.

@jheek jheek force-pushed the use_running_average-on-__call__ branch from e662a1b to cd8c6b2 Compare February 3, 2021 14:18
@jheek jheek requested a review from avital February 3, 2021 15:34
@jheek jheek force-pushed the use_running_average-on-__call__ branch from cd8c6b2 to 5959ce8 Compare February 3, 2021 15:55
@jheek jheek force-pushed the use_running_average-on-__call__ branch from 5959ce8 to 655d74a Compare February 3, 2021 16:18
Copy link
Contributor

@avital avital left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few small comments/suggestions, but generally LGTM!

jheek and others added 2 commits February 4, 2021 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) pull ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants