diff --git a/docs/_ext/flax_module.py b/docs/_ext/flax_module.py new file mode 100644 index 0000000000..ece47caf3f --- /dev/null +++ b/docs/_ext/flax_module.py @@ -0,0 +1,76 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sphinx directive for visualizing Flax modules. + +Use directive as follows: + +.. flax_module:: + :module: flax.linen + :class: Dense + +""" + +from docutils import nodes +from docutils.parsers.rst import directives +from docutils.statemachine import ViewList + +import sphinx +from sphinx.util.docutils import SphinxDirective +from docs.conf_sphinx_patch import generate_autosummary_content +import sphinx.ext.autosummary.generate as ag +import importlib + + +def render_module(modname: str, qualname: str, app): + parent = importlib.import_module(modname) + obj = getattr(parent, qualname) + template = ag.AutosummaryRenderer(app) + template_name = "flax_module" + imported_members = False + recursive = False + context = {} + return generate_autosummary_content( + qualname, obj, parent, template, template_name, imported_members, + app, recursive, context, modname, qualname) + +class FlaxModuleDirective(SphinxDirective): + has_content = True + option_spec = { + 'module': directives.unchanged, + 'class': directives.unchanged, + } + + def run(self): + module_template = render_module( + self.options['module'], self.options['class'], self.env.app + ) + module_template = module_template.splitlines() + + # Create a container for the rendered nodes + container_node = nodes.container() + self.content = ViewList(module_template, self.content.parent) + self.state.nested_parse(self.content, self.content_offset, container_node) + + return [container_node] + + +def setup(app): + app.add_directive('flax_module', FlaxModuleDirective) + + return { + 'version': sphinx.__display_version__, + 'parallel_read_safe': True, + 'parallel_write_safe': True, + } diff --git a/docs/_templates/autosummary/flax_module.rst b/docs/_templates/autosummary/flax_module.rst index 5f51933cff..b3f605222a 100644 --- a/docs/_templates/autosummary/flax_module.rst +++ b/docs/_templates/autosummary/flax_module.rst @@ -1,4 +1,4 @@ -{{ fullname | escape | underline}} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} diff --git a/docs/api_reference/flax.linen.rst b/docs/api_reference/flax.linen.rst deleted file mode 100644 index 126b6113fa..0000000000 --- a/docs/api_reference/flax.linen.rst +++ /dev/null @@ -1,278 +0,0 @@ - -flax.linen package -================== - -.. currentmodule:: flax.linen - -Linen is the Flax Module system. Read more about our design goals in the `Linen README `_. - - - -Module ------------------------- - -.. autoclass:: Module - :members: setup, variable, param, bind, unbind, apply, init, init_with_output, make_rng, sow, variables, Variable, __setattr__, tabulate, is_initializing, perturb - -Init/Apply ------------------------- - -.. currentmodule:: flax.linen -.. autofunction:: apply -.. autofunction:: init -.. autofunction:: init_with_output - -Variable dictionary ----------------------- - -.. automodule:: flax.core.variables -.. autoclass:: Variable - - -Compact methods ----------------------- - -.. currentmodule:: flax.linen -.. autofunction:: compact - - -No wrap methods ----------------------- - -.. currentmodule:: flax.linen -.. autofunction:: nowrap - - -Profiling ----------------------- - -.. automodule:: flax.linen -.. currentmodule:: flax.linen - -.. autosummary:: - :toctree: _autosummary - - enable_named_call - disable_named_call - override_named_call - - -Inspection ----------------------- - -.. automodule:: flax.linen -.. currentmodule:: flax.linen - -.. autosummary:: - :toctree: _autosummary - - tabulate - - -Transformations ----------------------- - -.. automodule:: flax.linen.transforms -.. currentmodule:: flax.linen - -.. autosummary:: - :toctree: _autosummary - - vmap - scan - jit - remat - remat_scan - map_variables - jvp - vjp - custom_vjp - while_loop - cond - switch - - -SPMD ----------------------- - -.. automodule:: flax.linen.spmd -.. currentmodule:: flax.linen - -.. autosummary:: - :toctree: _autosummary - - Partitioned - with_partitioning - get_partition_spec - get_sharding - LogicallyPartitioned - logical_axis_rules - set_logical_axis_rules - get_logical_axis_rules - logical_to_mesh_axes - logical_to_mesh - logical_to_mesh_sharding - with_logical_constraint - with_logical_partitioning - - -Linear Modules ------------------------- - -.. autosummary:: - :toctree: _autosummary - :template: flax_module - - Dense - DenseGeneral - Conv - ConvTranspose - ConvLocal - Embed - - -Normalization ------------------------- - -.. autosummary:: - :toctree: _autosummary - :template: flax_module - - BatchNorm - LayerNorm - GroupNorm - - -Pooling ------------------------- - -.. autosummary:: - :toctree: _autosummary - - max_pool - avg_pool - pool - - -Activation functions ------------------------- - -.. automodule:: flax.linen.activation -.. currentmodule:: flax.linen.activation - -.. autosummary:: - :toctree: _autosummary - - PReLU - celu - elu - gelu - glu - hard_sigmoid - hard_silu - hard_swish - hard_tanh - leaky_relu - log_sigmoid - log_softmax - logsumexp - one_hot - relu - relu6 as relu6, - selu - sigmoid - silu - soft_sign - softmax - softplus - standardize - swish - tanh - - -Initializers ------------------------- - -.. automodule:: flax.linen.initializers -.. currentmodule:: flax.linen.initializers - -.. autosummary:: - :toctree: _autosummary - - constant - delta_orthogonal - glorot_normal - glorot_uniform - he_normal - he_uniform - kaiming_normal - kaiming_uniform - lecun_normal - lecun_uniform - normal - ones - ones_init - orthogonal - uniform - standardize - variance_scaling - xavier_normal - xavier_uniform - zeros - zeros_init - - -Combinators ------------------------- - -.. currentmodule:: flax.linen - -.. autosummary:: - :toctree: _autosummary - :template: flax_module - - Sequential - - -Attention primitives ------------------------- - -.. autosummary:: - :toctree: _autosummary - - dot_product_attention_weights - dot_product_attention - make_attention_mask - make_causal_mask - -.. autosummary:: - :toctree: _autosummary - :template: flax_module - - SelfAttention - MultiHeadDotProductAttention - - -Stochastic ------------------------- - -.. autosummary:: - :toctree: _autosummary - :template: flax_module - - Dropout - - -RNN primitives ------------------------- - -.. autosummary:: - :toctree: _autosummary - :template: flax_module - - LSTMCell - OptimizedLSTMCell - GRUCell - RNNCellBase - RNN - Bidirectional diff --git a/docs/api_reference/flax.linen/activation_functions.rst b/docs/api_reference/flax.linen/activation_functions.rst new file mode 100644 index 0000000000..4acd47db9e --- /dev/null +++ b/docs/api_reference/flax.linen/activation_functions.rst @@ -0,0 +1,63 @@ + +Activation functions +------------------------ + +.. automodule:: flax.linen.activation +.. currentmodule:: flax.linen.activation + +.. autofunction:: PReLU +.. autofunction:: celu +.. autofunction:: elu +.. autofunction:: gelu +.. autofunction:: glu +.. autofunction:: hard_sigmoid +.. autofunction:: hard_silu +.. autofunction:: hard_swish +.. autofunction:: hard_tanh +.. autofunction:: leaky_relu +.. autofunction:: log_sigmoid +.. autofunction:: log_softmax +.. autofunction:: logsumexp +.. autofunction:: one_hot +.. autofunction:: relu +.. autofunction:: relu6 as relu6, +.. autofunction:: selu +.. autofunction:: sigmoid +.. autofunction:: silu +.. autofunction:: soft_sign +.. autofunction:: softmax +.. autofunction:: softplus +.. autofunction:: standardize +.. autofunction:: swish +.. autofunction:: tanh + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + PReLU + celu + elu + gelu + glu + hard_sigmoid + hard_silu + hard_swish + hard_tanh + leaky_relu + log_sigmoid + log_softmax + logsumexp + one_hot + relu + relu6 as relu6, + selu + sigmoid + silu + soft_sign + softmax + softplus + standardize + swish + tanh \ No newline at end of file diff --git a/docs/api_reference/flax.linen/decorators.rst b/docs/api_reference/flax.linen/decorators.rst new file mode 100644 index 0000000000..5dc8ae6c3b --- /dev/null +++ b/docs/api_reference/flax.linen/decorators.rst @@ -0,0 +1,15 @@ +Decorators +---------------------- + +.. currentmodule:: flax.linen + +.. autofunction:: compact +.. autofunction:: nowrap + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + compact + nowrap \ No newline at end of file diff --git a/docs/api_reference/flax.linen/index.rst b/docs/api_reference/flax.linen/index.rst new file mode 100644 index 0000000000..e54bd1ca00 --- /dev/null +++ b/docs/api_reference/flax.linen/index.rst @@ -0,0 +1,20 @@ + +flax.linen +========== + +Linen is the Flax Module system. Read more about our design goals in the `Linen README `_. + +.. toctree:: + :maxdepth: 2 + + module + init_apply + layers + activation_functions + initializers + transformations + inspection + variable + spmd + decorators + profiling \ No newline at end of file diff --git a/docs/api_reference/flax.linen/init_apply.rst b/docs/api_reference/flax.linen/init_apply.rst new file mode 100644 index 0000000000..364f8d8478 --- /dev/null +++ b/docs/api_reference/flax.linen/init_apply.rst @@ -0,0 +1,19 @@ + +Init/Apply +============== + +.. currentmodule:: flax.linen + +.. autofunction:: apply +.. autofunction:: init +.. autofunction:: init_with_output + + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + apply + init + init_with_output \ No newline at end of file diff --git a/docs/api_reference/flax.linen/initializers.rst b/docs/api_reference/flax.linen/initializers.rst new file mode 100644 index 0000000000..1f59e162e2 --- /dev/null +++ b/docs/api_reference/flax.linen/initializers.rst @@ -0,0 +1,54 @@ +Initializers +------------------------ + +.. automodule:: flax.linen.initializers +.. currentmodule:: flax.linen.initializers + +.. autofunction:: constant +.. autofunction:: delta_orthogonal +.. autofunction:: glorot_normal +.. autofunction:: glorot_uniform +.. autofunction:: he_normal +.. autofunction:: he_uniform +.. autofunction:: kaiming_normal +.. autofunction:: kaiming_uniform +.. autofunction:: lecun_normal +.. autofunction:: lecun_uniform +.. autofunction:: normal +.. autofunction:: ones +.. autofunction:: ones_init +.. autofunction:: orthogonal +.. autofunction:: uniform +.. autofunction:: standardize +.. autofunction:: variance_scaling +.. autofunction:: xavier_normal +.. autofunction:: xavier_uniform +.. autofunction:: zeros +.. autofunction:: zeros_init + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + constant + delta_orthogonal + glorot_normal + glorot_uniform + he_normal + he_uniform + kaiming_normal + kaiming_uniform + lecun_normal + lecun_uniform + normal + ones + ones_init + orthogonal + uniform + standardize + variance_scaling + xavier_normal + xavier_uniform + zeros + zeros_init \ No newline at end of file diff --git a/docs/api_reference/flax.linen/inspection.rst b/docs/api_reference/flax.linen/inspection.rst new file mode 100644 index 0000000000..6c627bf9ba --- /dev/null +++ b/docs/api_reference/flax.linen/inspection.rst @@ -0,0 +1,14 @@ + +Inspection +---------------------- + +.. currentmodule:: flax.linen + +.. autofunction:: tabulate + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + tabulate diff --git a/docs/api_reference/flax.linen/layers.rst b/docs/api_reference/flax.linen/layers.rst new file mode 100644 index 0000000000..e84c4d0a93 --- /dev/null +++ b/docs/api_reference/flax.linen/layers.rst @@ -0,0 +1,149 @@ +Layers +====== + +.. currentmodule:: flax.linen + +Linear Modules +------------------------ + +.. flax_module:: + :module: flax.linen + :class: Dense + +.. flax_module:: + :module: flax.linen + :class: DenseGeneral + +.. flax_module:: + :module: flax.linen + :class: Conv + +.. flax_module:: + :module: flax.linen + :class: ConvTranspose + +.. flax_module:: + :module: flax.linen + :class: ConvLocal + +.. flax_module:: + :module: flax.linen + :class: Embed + +Pooling +------------------------ + +.. autofunction:: max_pool +.. autofunction:: avg_pool +.. autofunction:: pool + +Normalization +------------------------ + +.. flax_module:: + :module: flax.linen + :class: BatchNorm + +.. flax_module:: + :module: flax.linen + :class: LayerNorm + +.. flax_module:: + :module: flax.linen + :class: GroupNorm + + +Combinators +------------------------ + +.. flax_module:: + :module: flax.linen + :class: Sequential + +Stochastic +------------------------ + +.. flax_module:: + :module: flax.linen + :class: Dropout + +Attention +------------------------ + +.. flax_module:: + :module: flax.linen + :class: SelfAttention + +.. flax_module:: + :module: flax.linen + :class: MultiHeadDotProductAttention + +.. autofunction:: dot_product_attention_weights +.. autofunction:: dot_product_attention +.. autofunction:: make_attention_mask +.. autofunction:: make_causal_mask + +Recurrent +------------------------ + +.. flax_module:: + :module: flax.linen + :class: RNNCellBase + +.. flax_module:: + :module: flax.linen + :class: LSTMCell + +.. flax_module:: + :module: flax.linen + :class: OptimizedLSTMCell + +.. flax_module:: + :module: flax.linen + :class: GRUCell + +.. flax_module:: + :module: flax.linen + :class: RNN + +.. flax_module:: + :module: flax.linen + :class: Bidirectional + + +**Summary** + +.. autosummary:: + :toctree: _autosummary + :template: flax_module + + Dense + DenseGeneral + Conv + ConvTranspose + ConvLocal + Embed + BatchNorm + LayerNorm + GroupNorm + Sequential + Dropout + SelfAttention + MultiHeadDotProductAttention + RNNCellBase + LSTMCell + OptimizedLSTMCell + GRUCell + RNN + Bidirectional + +.. autosummary:: + :toctree: _autosummary + + max_pool + avg_pool + pool + dot_product_attention_weights + dot_product_attention + make_attention_mask + make_causal_mask \ No newline at end of file diff --git a/docs/api_reference/flax.linen/module.rst b/docs/api_reference/flax.linen/module.rst new file mode 100644 index 0000000000..8760fbe0e2 --- /dev/null +++ b/docs/api_reference/flax.linen/module.rst @@ -0,0 +1,8 @@ +Module +------------------------ + +.. automodule:: flax.linen +.. currentmodule:: flax.linen + +.. autoclass:: Module + :members: setup, variable, param, bind, unbind, apply, init, init_with_output, make_rng, sow, variables, Variable, __setattr__, tabulate, is_initializing, perturb \ No newline at end of file diff --git a/docs/api_reference/flax.linen/profiling.rst b/docs/api_reference/flax.linen/profiling.rst new file mode 100644 index 0000000000..d3e91c44a5 --- /dev/null +++ b/docs/api_reference/flax.linen/profiling.rst @@ -0,0 +1,17 @@ +Profiling +---------------------- + +.. currentmodule:: flax.linen + +.. autofunction:: enable_named_call +.. autofunction:: disable_named_call +.. autofunction:: override_named_call + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + enable_named_call + disable_named_call + override_named_call \ No newline at end of file diff --git a/docs/api_reference/flax.linen/spmd.rst b/docs/api_reference/flax.linen/spmd.rst new file mode 100644 index 0000000000..2f510402ca --- /dev/null +++ b/docs/api_reference/flax.linen/spmd.rst @@ -0,0 +1,40 @@ + +SPMD +---------------------- + +.. automodule:: flax.linen.spmd +.. currentmodule:: flax.linen + +.. autofunction:: Partitioned +.. autofunction:: with_partitioning +.. autofunction:: get_partition_spec +.. autofunction:: get_sharding +.. autofunction:: LogicallyPartitioned +.. autofunction:: logical_axis_rules +.. autofunction:: set_logical_axis_rules +.. autofunction:: get_logical_axis_rules +.. autofunction:: logical_to_mesh_axes +.. autofunction:: logical_to_mesh +.. autofunction:: logical_to_mesh_sharding +.. autofunction:: with_logical_constraint +.. autofunction:: with_logical_partitioning + +**Summary** + +.. autosummary:: + :toctree: _autosummary + + Partitioned + with_partitioning + get_partition_spec + get_sharding + LogicallyPartitioned + logical_axis_rules + set_logical_axis_rules + get_logical_axis_rules + logical_to_mesh_axes + logical_to_mesh + logical_to_mesh_sharding + with_logical_constraint + with_logical_partitioning + diff --git a/docs/api_reference/flax.linen/transformations.rst b/docs/api_reference/flax.linen/transformations.rst new file mode 100644 index 0000000000..82bf21234a --- /dev/null +++ b/docs/api_reference/flax.linen/transformations.rst @@ -0,0 +1,34 @@ +Transformations +---------------------- + +.. automodule:: flax.linen.transforms +.. currentmodule:: flax.linen + +.. autofunction:: vmap +.. autofunction:: scan +.. autofunction:: jit +.. autofunction:: remat +.. autofunction:: remat_scan +.. autofunction:: map_variables +.. autofunction:: jvp +.. autofunction:: vjp +.. autofunction:: custom_vjp +.. autofunction:: while_loop +.. autofunction:: cond +.. autofunction:: switch + +.. autosummary:: + :toctree: _autosummary + + vmap + scan + jit + remat + remat_scan + map_variables + jvp + vjp + custom_vjp + while_loop + cond + switch \ No newline at end of file diff --git a/docs/api_reference/flax.linen/variable.rst b/docs/api_reference/flax.linen/variable.rst new file mode 100644 index 0000000000..675e98868e --- /dev/null +++ b/docs/api_reference/flax.linen/variable.rst @@ -0,0 +1,14 @@ + +Variable dictionary +---------------------- + +.. automodule:: flax.core.variables +.. autoclass:: flax.linen.Variable + +**Summary** + +.. currentmodule:: flax.linen +.. autosummary:: + :toctree: _autosummary + + Variable \ No newline at end of file diff --git a/docs/api_reference/index.rst b/docs/api_reference/index.rst index 03c3de9153..fc4c797d79 100644 --- a/docs/api_reference/index.rst +++ b/docs/api_reference/index.rst @@ -2,9 +2,9 @@ API Reference ============= .. toctree:: - :maxdepth: 3 + :maxdepth: 4 - flax.linen + flax.linen/index flax.serialization flax.core.frozen_dict flax.struct diff --git a/docs/conf.py b/docs/conf.py index 75e0dab71b..1114cbabdb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -60,6 +60,7 @@ 'sphinx.ext.viewcode', 'myst_nb', 'codediff', + 'flax_module', 'sphinx_design', ]