diff --git a/api_gen.py b/api_gen.py index c3f1b9c80b7..28fac8fa4f1 100644 --- a/api_gen.py +++ b/api_gen.py @@ -47,8 +47,6 @@ def create_legacy_directory(package_dir): ) with open(os.path.join(api_dir, "__init__.py"), "w") as f: f.write(init_file) - with open(os.path.join(package_dir, "__init__.py"), "w") as f: - f.write(init_file) # Remove the import of `_tf_keras` in `keras/_tf_keras/keras/__init__.py` init_file = init_file.replace("from keras.api import _tf_keras\n", "\n") with open(os.path.join(tf_keras_dirpath, "__init__.py"), "w") as f: @@ -128,6 +126,40 @@ def export_version_string(api_init_fname): f.write(contents) +def update_package_init(init_fname): + contents = """ +# Import everything from /api/ into keras. +from keras.api import * # noqa: F403 +from keras.api import __version__ # Import * ignores names start with "_". + +import os + +# Add everything in /api/ to the module search path. +__path__.append(os.path.join(os.path.dirname(__file__), "api")) # noqa: F405 + +# Don't pollute namespace. +del os + +# Never autocomplete `.src` or `.api` on an imported keras object. +def __dir__(): + keys = dict.fromkeys((globals().keys())) + keys.pop("src") + keys.pop("api") + return list(keys) + + +# Don't import `.src` or `.api` during `from keras import *`. +__all__ = [ + name + for name in globals().keys() + if not (name.startswith("_") or name in ("src", "api")) +]""" + with open(init_fname) as f: + init_contents = f.read() + with open(init_fname, "w") as f: + f.write(init_contents.replace("\nfrom keras import api", contents)) + + def build(): # Backup the `keras/__init__.py` and restore it on error in api gen. root_path = os.path.dirname(os.path.abspath(__file__)) @@ -149,6 +181,8 @@ def build(): namex.generate_api_files( "keras", code_directory="src", target_directory="api" ) + # Creates `keras/__init__.py` importing from `keras/api` + update_package_init(build_init_fname) # Add __version__ to keras package export_version_string(build_api_init_fname) # Creates `_tf_keras` with full keras API diff --git a/keras/__init__.py b/keras/__init__.py index 1750a42e869..6276b51e1f8 100644 --- a/keras/__init__.py +++ b/keras/__init__.py @@ -4,52 +4,30 @@ since your modifications would be overwritten. """ -from keras.api import _tf_keras -from keras.api import activations -from keras.api import applications -from keras.api import backend -from keras.api import callbacks -from keras.api import config -from keras.api import constraints -from keras.api import datasets -from keras.api import distribution -from keras.api import dtype_policies -from keras.api import export -from keras.api import initializers -from keras.api import layers -from keras.api import legacy -from keras.api import losses -from keras.api import metrics -from keras.api import mixed_precision -from keras.api import models -from keras.api import ops -from keras.api import optimizers -from keras.api import preprocessing -from keras.api import quantizers -from keras.api import random -from keras.api import regularizers -from keras.api import saving -from keras.api import tree -from keras.api import utils -from keras.src.backend.common.keras_tensor import KerasTensor -from keras.src.backend.common.stateless_scope import StatelessScope -from keras.src.backend.exports import Variable -from keras.src.backend.exports import device -from keras.src.backend.exports import name_scope -from keras.src.dtype_policies.dtype_policy import DTypePolicy -from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy -from keras.src.initializers.initializer import Initializer -from keras.src.layers.core.input_layer import Input -from keras.src.layers.input_spec import InputSpec -from keras.src.layers.layer import Layer -from keras.src.losses.loss import Loss -from keras.src.metrics.metric import Metric -from keras.src.models.model import Model -from keras.src.models.sequential import Sequential -from keras.src.ops.function import Function -from keras.src.ops.operation import Operation -from keras.src.optimizers.optimizer import Optimizer -from keras.src.quantizers.quantizers import Quantizer -from keras.src.regularizers.regularizers import Regularizer -from keras.src.version import __version__ -from keras.src.version import version +import os + +# Import everything from /api/ into keras. +from keras.api import * # noqa: F403 +from keras.api import __version__ # Import * ignores names start with "_". + +# Add everything in /api/ to the module search path. +__path__.append(os.path.join(os.path.dirname(__file__), "api")) # noqa: F405 + +# Don't pollute namespace. +del os + + +# Never autocomplete `.src` or `.api` on an imported keras object. +def __dir__(): + keys = dict.fromkeys((globals().keys())) + keys.pop("src") + keys.pop("api") + return list(keys) + + +# Don't import `.src` or `.api` during `from keras import *`. +__all__ = [ + name + for name in globals().keys() + if not (name.startswith("_") or name in ("src", "api")) +]