diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py index 821594c4fd573..ca61709f9aa07 100644 --- a/ivy/func_wrapper.py +++ b/ivy/func_wrapper.py @@ -1585,6 +1585,8 @@ def func(x): target_backend is not None and ivy.backend != "" and ivy.current_backend_str() != target_backend.backend + # keras supports inputs instantiated with different backends + and ivy.current_backend_str() != "keras" ): raise ivy.utils.exceptions.IvyInvalidBackendException( "Operation not allowed. Array was instantiated with backend" diff --git a/ivy/functional/backends/keras/__init__.py b/ivy/functional/backends/keras/__init__.py new file mode 100644 index 0000000000000..f2f241e64acdc --- /dev/null +++ b/ivy/functional/backends/keras/__init__.py @@ -0,0 +1,294 @@ +# global +import os +import sys +import keras + + +backend = "keras" +backend_version = {"version": keras.__version__} +keras_backend = os.getenv("KERAS_BACKEND", default="tensorflow").lower() + +# local +import ivy +from ivy.func_wrapper import _dtype_from_version + + +# noinspection PyUnresolvedReferences +if not ivy.is_local(): + _module_in_memory = sys.modules[__name__] +else: + _module_in_memory = sys.modules[ivy.import_module_path].import_cache[__name__] + +use = ivy.utils.backend.ContextManager(_module_in_memory) + + +# wrap dunder methods of native tensors to return NotImplemented to prioritize Ivy array methods. +def dunder_wrapper(func): + def rep_method(*args, **kwargs): + for arg in args: + if ivy.is_ivy_array(arg): + return NotImplemented + return func(*args, **kwargs) + + return rep_method + + +# check for previously imported tensorflow modules +modules_to_patch = [] +tensors_to_patch = [] +tmp_globals = dict(globals()) +for name, value in tmp_globals.items(): + if value == "tensorflow.python.framework.ops.Tensor": + tensors_to_patch.append(name) + try: + if value.__name__ == "tensorflow": + modules_to_patch.append(name) + except AttributeError: + pass + +methods_to_patch = [ + "__add__", + "__sub__", + "__mul__", + "__div__", + "__truediv__", + "__floordiv__", + "__mod__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__ne__", + "__eq__", + "__and__", + "__or__", + "__xor__", + "__pow__", + "__matmul__", +] + +for module in modules_to_patch: + for method in methods_to_patch: + exec( + module + + ".Tensor." + + method + + " = dunder_wrapper(" + + module + + ".Tensor." + + method + + ")" + ) + +for tensor in tensors_to_patch: + for method in methods_to_patch: + exec(tensor + "." + method + " = dunder_wrapper(" + tensor + "." + method + ")") + + + +print('setting NativeArray keras') ########### +if keras_backend == "jax": + import jax + import jax.numpy as jnp + import jaxlib + + if jax.__version__ >= "0.4.1": + JaxArray = jax.Array + NativeArray = jax.Array + else: + JaxArray = jaxlib.xla_extension.DeviceArray + NativeArray = jaxlib.xla_extension.DeviceArray + + # noinspection PyUnresolvedReferences,PyProtectedMember + NativeDevice = jaxlib.xla_extension.Device + NativeDtype = jnp.dtype + NativeShape = tuple + + NativeSparseArray = None +elif keras_backend == "torch": + import torch + + NativeArray = torch.Tensor + NativeDevice = torch.device + NativeDtype = torch.dtype + NativeShape = torch.Size + NativeSparseArray = torch.Tensor +else: + import tensorflow as tf + from tensorflow.python.framework.dtypes import DType + from tensorflow.python.framework.tensor_shape import TensorShape + from tensorflow.python.types.core import Tensor + + NativeArray = Tensor + NativeDevice = str + NativeDtype = DType + NativeShape = TensorShape + NativeSparseArray = tf.SparseTensor + + +# devices +valid_devices = ("cpu", "gpu", "tpu") + +# native data types +native_int8 = tf.int8 +native_int16 = tf.int16 +native_int32 = tf.int32 +native_int64 = tf.int64 +native_uint8 = tf.uint8 +native_uint16 = tf.uint16 +native_uint32 = tf.uint32 +native_uint64 = tf.uint64 +native_bfloat16 = tf.bfloat16 +native_float16 = tf.float16 +native_float32 = tf.float32 +native_float64 = tf.float64 +native_complex64 = tf.complex64 +native_complex128 = tf.complex128 +native_double = native_float64 +native_bool = tf.bool + +# valid data types +# ToDo: Add complex dtypes to valid_dtypes and fix all resulting failures. + +# update these to add new dtypes +valid_dtypes = { + "3.4.1 and below": ( + ivy.int8, + ivy.int16, + ivy.int32, + ivy.int64, + ivy.uint8, + ivy.uint16, + ivy.uint32, + ivy.uint64, + ivy.bfloat16, + ivy.float16, + ivy.float32, + ivy.float64, + ivy.complex64, + ivy.complex128, + ivy.bool, + ) +} +valid_numeric_dtypes = { + "3.4.1 and below": ( + ivy.int8, + ivy.int16, + ivy.int32, + ivy.int64, + ivy.uint8, + ivy.uint16, + ivy.uint32, + ivy.uint64, + ivy.bfloat16, + ivy.float16, + ivy.float32, + ivy.float64, + ivy.complex64, + ivy.complex128, + ) +} +valid_int_dtypes = { + "3.4.1 and below": ( + ivy.int8, + ivy.int16, + ivy.int32, + ivy.int64, + ivy.uint8, + ivy.uint16, + ivy.uint32, + ivy.uint64, + ) +} +valid_float_dtypes = { + "3.4.1 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64) +} +valid_uint_dtypes = { + "3.4.1 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64) +} +valid_complex_dtypes = {"3.4.1 and below": (ivy.complex128,)} + +# leave these untouched +valid_dtypes = _dtype_from_version(valid_dtypes, backend_version) +valid_numeric_dtypes = _dtype_from_version(valid_numeric_dtypes, backend_version) +valid_int_dtypes = _dtype_from_version(valid_int_dtypes, backend_version) +valid_float_dtypes = _dtype_from_version(valid_float_dtypes, backend_version) +valid_uint_dtypes = _dtype_from_version(valid_uint_dtypes, backend_version) +valid_complex_dtypes = _dtype_from_version(valid_complex_dtypes, backend_version) + +# invalid data types +# update these to add new dtypes +invalid_dtypes = {"3.4.1 and below": ()} +invalid_numeric_dtypes = {"3.4.1 and below": ()} +invalid_int_dtypes = {"3.4.1 and below": ()} +invalid_float_dtypes = {"3.4.1 and below": ()} +invalid_uint_dtypes = {"3.4.1 and below": ()} +invalid_complex_dtypes = {"3.4.1 and below": ()} + +# leave these untouched +invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version) +invalid_numeric_dtypes = _dtype_from_version(invalid_numeric_dtypes, backend_version) +invalid_int_dtypes = _dtype_from_version(invalid_int_dtypes, backend_version) +invalid_float_dtypes = _dtype_from_version(invalid_float_dtypes, backend_version) +invalid_uint_dtypes = _dtype_from_version(invalid_uint_dtypes, backend_version) +invalid_complex_dtypes = _dtype_from_version(invalid_complex_dtypes, backend_version) + +native_inplace_support = False + +supports_gradients = True + + +def closest_valid_dtype(type=None, /, as_native=False): + if type is None: + type = ivy.default_dtype() + return ivy.as_ivy_dtype(type) if not as_native else ivy.as_native_dtype(type) + + +# local sub-modules +from . import activations +from .activations import * +from . import creation +from .creation import * +from . import data_type +from .data_type import * +from . import device +from .device import * +from . import elementwise +from .elementwise import * +from . import general +from .general import * +from . import gradients +from .gradients import * +from . import layers +from .layers import * +from . import linear_algebra as linalg +from .linear_algebra import * +from . import manipulation +from .manipulation import * +from . import random +from .random import * +from . import searching +from .searching import * +from . import set +from .set import * +from . import sorting +from .sorting import * +from . import statistical +from .statistical import * +from . import utility +from .utility import * +from . import experimental +from .experimental import * +from . import control_flow_ops +from .control_flow_ops import * + + +# sub-backends +from . import sub_backends +from .sub_backends import * + +from . import module +# from .module import Model + + +# NativeModule = Model diff --git a/ivy/functional/backends/keras/activations.py b/ivy/functional/backends/keras/activations.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/control_flow_ops.py b/ivy/functional/backends/keras/control_flow_ops.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/creation.py b/ivy/functional/backends/keras/creation.py new file mode 100644 index 0000000000000..b1b6d8e61eb53 --- /dev/null +++ b/ivy/functional/backends/keras/creation.py @@ -0,0 +1,182 @@ +import ivy +from .func_wrapper import use_keras_backend_framework + + +# Array API Standard # +# -------------------# + + +@use_keras_backend_framework +def arange( + start, + /, + stop=None, + step=1, + *, + dtype=None, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def asarray( + obj, + /, + *, + copy=None, + dtype=None, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def empty(shape, *, dtype, device=None, out=None): + pass + + +@use_keras_backend_framework +def empty_like(x, /, *, dtype, device=None, out=None): + pass + + +@use_keras_backend_framework +def eye( + n_rows, + n_cols=None, + /, + *, + k=0, + batch_shape=None, + dtype, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def to_dlpack(x, /, *, out=None): + pass + + +@use_keras_backend_framework +def from_dlpack(x, /, *, out=None): + pass + + +@use_keras_backend_framework +def full( + shape, + fill_value, + *, + dtype=None, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def full_like( + x, + /, + fill_value, + *, + dtype, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def linspace( + start, + stop, + /, + num, + *, + axis=None, + endpoint=True, + dtype, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def meshgrid(*arrays, sparse=False, indexing="xy", out=None): + pass + + +@use_keras_backend_framework +def ones(shape, *, dtype, device=None, out=None): + pass + + +@use_keras_backend_framework +def ones_like(x, /, *, dtype, device=None, out=None): + pass + + +@use_keras_backend_framework +def tril(x, /, *, k=0, out=None): + pass + + +@use_keras_backend_framework +def triu(x, /, *, k=0, out=None): + pass + + +@use_keras_backend_framework +def zeros(shape, *, dtype, device=None, out=None): + pass + + +@use_keras_backend_framework +def zeros_like(x, /, *, dtype, device=None, out=None): + pass + + +# Extra # +# ------# + + +array = asarray + + +@use_keras_backend_framework +def copy_array(x, *, to_ivy_array=True, out=None): + pass + + +@use_keras_backend_framework +def one_hot( + indices, + depth, + /, + *, + on_value=None, + off_value=None, + axis=None, + dtype=None, + device=None, + out=None, +): + pass + + +@use_keras_backend_framework +def frombuffer(buffer, dtype=float, count=-1, offset=0): + pass + + +@use_keras_backend_framework +def triu_indices(n_rows, n_cols=None, k=0, /, *, device=None): + pass diff --git a/ivy/functional/backends/keras/data_type.py b/ivy/functional/backends/keras/data_type.py new file mode 100644 index 0000000000000..27b9a5a528e63 --- /dev/null +++ b/ivy/functional/backends/keras/data_type.py @@ -0,0 +1,64 @@ +from .func_wrapper import use_keras_backend_framework + + +# Array API Standard # +# -------------------# + + +@use_keras_backend_framework +def astype(x, dtype, /, *, copy=True, out=None): + pass + + +@use_keras_backend_framework +def broadcast_arrays(*arrays): + pass + + +@use_keras_backend_framework +def broadcast_to(x, /, shape, *, out=None): + pass + + +@use_keras_backend_framework +def finfo(type, /): + pass + + +@use_keras_backend_framework +def iinfo(type, /): + pass + + +@use_keras_backend_framework +def result_type(*arrays_and_dtypes): + pass + + +# Extra # +# ------# + + +@use_keras_backend_framework +def as_ivy_dtype(dtype_in, /): + pass + + +@use_keras_backend_framework +def as_native_dtype(dtype_in): + pass + + +@use_keras_backend_framework +def dtype(x, *, as_native=False): + pass + + +@use_keras_backend_framework +def dtype_bits(dtype_in, /): + pass + + +@use_keras_backend_framework +def is_native_dtype(dtype_in, /): + pass diff --git a/ivy/functional/backends/keras/device.py b/ivy/functional/backends/keras/device.py new file mode 100644 index 0000000000000..b021c542d8d6a --- /dev/null +++ b/ivy/functional/backends/keras/device.py @@ -0,0 +1,60 @@ +"""Tensorflow device functions. + +Collection of TensorFlow general functions, wrapped to fit Ivy syntax +and signature. +""" + +# global +_round = round +import tensorflow as tf +from typing import Union, Optional + +# local +import ivy +from ivy.functional.ivy.device import Profiler as BaseProfiler +from .func_wrapper import use_keras_backend_framework + + +@use_keras_backend_framework +def dev(x, /, *, as_native=False): + pass + + +@use_keras_backend_framework +def to_device(x, device, /, *, stream=None, out=None): + pass + + +@use_keras_backend_framework +def as_ivy_dev(device, /): + pass + + +@use_keras_backend_framework +def as_native_dev(device, /): + pass + + +@use_keras_backend_framework +def clear_cached_mem_on_dev(device, /): + pass + + +@use_keras_backend_framework +def num_gpus(): + pass + + +@use_keras_backend_framework +def gpu_is_available(): + pass + + +@use_keras_backend_framework +def tpu_is_available(): + pass + + +@use_keras_backend_framework +def handle_soft_device_variable(*args, fn, **kwargs): + pass diff --git a/ivy/functional/backends/keras/elementwise.py b/ivy/functional/backends/keras/elementwise.py new file mode 100644 index 0000000000000..a3cd0a5a75696 --- /dev/null +++ b/ivy/functional/backends/keras/elementwise.py @@ -0,0 +1,821 @@ +# global +from typing import Union, Optional +import keras +import jax +import tensorflow as tf +import torch + +# local +import ivy +from ivy.func_wrapper import ( + with_unsupported_dtypes, + with_supported_dtypes, +) +from ivy import promote_types_of_inputs +from . import backend_version +from .func_wrapper import use_keras_backend_framework + + +@use_keras_backend_framework +def abs( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def acos( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def acosh( + x, + /, + *, + out=None +): + pass + + +@with_unsupported_dtypes({"3.4.1 and below": ("complex",)}, backend_version) +def add( + x1: Union[float, jax.Array, tf.Tensor, torch.Tensor], + x2: Union[float, jax.Array, tf.Tensor, torch.Tensor], + /, + *, + alpha=None, + out: Optional[Union[jax.Array, tf.Tensor, torch.Tensor]] = None, +) -> Union[jax.Array, tf.Tensor, torch.Tensor]: + x1, x2 = promote_types_of_inputs(x1, x2) + + if alpha is not None: + x2 = keras.ops.multiply(x2, alpha) + + ret = keras.ops.add(x1, x2) + + if ivy.exists(out): + ivy.inplace_update(out, ret) + return ret + + +add.support_native_out = True + + +@use_keras_backend_framework +def asin( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def asinh( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def atan( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def atan2( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def atanh( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def bitwise_and( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def bitwise_invert( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def bitwise_left_shift( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def bitwise_or( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def bitwise_right_shift( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def bitwise_xor( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def ceil( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def cos( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def cosh( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def divide( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def equal( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +@with_unsupported_dtypes({"2.15.0 and below": ("integer",)}, backend_version) +def exp( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def exp2( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def expm1( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def floor( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def floor_divide( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def fmin( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def greater( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def greater_equal( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def isfinite( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def isinf( + x, + /, + *, + detect_positive=True, + detect_negative=True, + out=None +): + pass + + +@use_keras_backend_framework +def isnan( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def lcm( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def less( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def less_equal( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def log( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def log10( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def log1p( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def log2( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def logaddexp( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def real( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def logaddexp2( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def logical_and( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def logical_not( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def logical_or( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def logical_xor( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def multiply( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def negative( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def not_equal( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def positive( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def pow( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def remainder( + x1, + x2, + /, + *, + modulus=True, + out=None +): + pass + + +@use_keras_backend_framework +def round( + x, + /, + *, + decimals=0, + out=None +): + pass + + +@use_keras_backend_framework +def sign( + x, + /, + *, + np_variant=True, + out=None +): + pass + + +@use_keras_backend_framework +def sin( + x, + /, + *, + out=None +): + return tf.sin(x) + + +@use_keras_backend_framework +def sinh( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def sqrt( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def square( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def subtract( + x1, + x2, + /, + *, + alpha=None, + out=None +): + pass + + +@use_keras_backend_framework +def tan( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def tanh( + x, + /, + *, + complex_mode="jax", + out=None +): + pass + + +@use_keras_backend_framework +def trapz( + y, + /, + *, + x=None, + dx=1.0, + axis=-1, + out=None +): + pass + + +@use_keras_backend_framework +def trunc( + x, + /, + *, + out=None +): + pass + + +# Extra # +# ------# + + +@use_keras_backend_framework +def erf( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def maximum( + x1, + x2, + /, + *, + use_where=True, + out=None +): + pass + + +@use_keras_backend_framework +def minimum( + x1, + x2, + /, + *, + use_where=True, + out=None +): + pass + + +@use_keras_backend_framework +def reciprocal( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def deg2rad( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def rad2deg( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def isreal( + x, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def fmod( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def gcd( + x1, + x2, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def angle( + input, + /, + *, + deg=None, + out=None +): + pass + + +@use_keras_backend_framework +def imag( + val, + /, + *, + out=None +): + pass + + +@use_keras_backend_framework +def nan_to_num( + x, + /, + *, + copy=True, + nan=0.0, + posinf=None, + neginf=None, + out=None +): + pass diff --git a/ivy/functional/backends/keras/experimental/__init__.py b/ivy/functional/backends/keras/experimental/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/func_wrapper.py b/ivy/functional/backends/keras/func_wrapper.py new file mode 100644 index 0000000000000..93f153a27f9bd --- /dev/null +++ b/ivy/functional/backends/keras/func_wrapper.py @@ -0,0 +1,36 @@ +import functools +import ivy +import ivy.functional.backends.keras as keras_backend +import ivy.functional.backends.jax as jax_backend +import ivy.functional.backends.tensorflow as tf_backend +import ivy.functional.backends.torch as torch_backend +import os +from typing import Callable + + +def use_keras_backend_framework(fn: Callable) -> Callable: + """ + Wraps the function such that it instead calls the equivalent function + from the ivy backend equivalent to the keras backend currently set. + """ + + @functools.wraps(fn) + def _use_keras_backend_framework(*args, **kwargs): + keras_backend = os.getenv("KERAS_BACKEND", default="tensorflow").lower() + assert keras_backend in ["jax", "tensorflow", "torch"] + + if keras_backend == "jax": + ivy_keras_backend = jax_backend + elif keras_backend == "tensorflow": + ivy_keras_backend = tf_backend + elif keras_backend == "torch": + ivy_keras_backend = torch_backend + else: + # default to tensorflow backend + # TODO: raise warning? + ivy_keras_backend = tf_backend + + backend_fn = getattr(ivy_keras_backend, fn.__name__) + return backend_fn(*args, **kwargs) + + return _use_keras_backend_framework diff --git a/ivy/functional/backends/keras/general.py b/ivy/functional/backends/keras/general.py new file mode 100644 index 0000000000000..633ba5a26e694 --- /dev/null +++ b/ivy/functional/backends/keras/general.py @@ -0,0 +1,143 @@ +# global +import keras +from typing import Union, Optional + +# local +import ivy +from . import backend_version +from ivy.functional.backends.jax import JaxArray, NativeArray +from .func_wrapper import use_keras_backend_framework + + +@use_keras_backend_framework +def is_native_array(x, /, *, exclusive=False): + pass + + +@use_keras_backend_framework +def array_equal(x0, x1, /): + pass + + +def container_types(): + return [] + + +def current_backend_str() -> str: + return "keras" + + +@use_keras_backend_framework +def get_item(x, /, query, *, copy: Optional[bool] = None): + pass + + +@use_keras_backend_framework +def to_numpy(x, /, *, copy=True): + pass + + +@use_keras_backend_framework +def to_scalar(x, /): + pass + + +@use_keras_backend_framework +def to_list(x, /): + pass + + +@use_keras_backend_framework +def gather(params, indices, /, *, axis=-1, batch_dims=0, out=None): + pass + + +@use_keras_backend_framework +def gather_nd(params, indices, /, *, batch_dims=0, out=None): + pass + + +@use_keras_backend_framework +def get_num_dims(x, /, *, as_array=False): + pass + + +@use_keras_backend_framework +def size(x, /): + pass + + +@use_keras_backend_framework +def inplace_arrays_supported(): + pass + + +@use_keras_backend_framework +def inplace_decrement(x, val): + pass + + +@use_keras_backend_framework +def inplace_increment(x, val): + pass + + +@use_keras_backend_framework +def inplace_update(x, val, /, *, ensure_in_backend=False, keep_input_dtype=False): + pass + + +@use_keras_backend_framework +def inplace_variables_supported(): + pass + + +@use_keras_backend_framework +def multiprocessing(context=None): + pass + + +@use_keras_backend_framework +def scatter_flat( + indices, + updates, + /, + *, + size=None, + reduction="sum", + out=None, +): + pass + + +@use_keras_backend_framework +def scatter_nd( + indices, + updates, + /, + shape=None, + *, + reduction="sum", + out=None, +): + pass + + +@use_keras_backend_framework +def shape(x, /, *, as_array=False): + pass + + +@use_keras_backend_framework +def vmap(func, in_axes=0, out_axes=0): + pass + + +@use_keras_backend_framework +def isin(elements, test_elements, /, *, assume_unique=False, invert=False): + pass + + +@use_keras_backend_framework +def itemsize(x): + pass diff --git a/ivy/functional/backends/keras/gradients.py b/ivy/functional/backends/keras/gradients.py new file mode 100644 index 0000000000000..129ccb4764ace --- /dev/null +++ b/ivy/functional/backends/keras/gradients.py @@ -0,0 +1,50 @@ +import ivy +from .func_wrapper import use_keras_backend_framework + + +@use_keras_backend_framework +def variable(x, /): + pass + + +@use_keras_backend_framework +def is_variable(x, /, *, exclusive=False): + pass + + +@use_keras_backend_framework +def variable_data(x, /): + pass + + +@use_keras_backend_framework +def execute_with_gradients( + func, + xs, + /, + *, + retain_grads=False, + xs_grad_idxs=((0,),), + ret_grad_idxs=((0,),), +): + pass + + +@use_keras_backend_framework +def value_and_grad(func): + pass + + +@use_keras_backend_framework +def stop_gradient(x, /, *, preserve_type=True, out=None): + pass + + +@use_keras_backend_framework +def jac(func): + pass + + +@use_keras_backend_framework +def grad(f, argnums=0): + pass diff --git a/ivy/functional/backends/keras/layers.py b/ivy/functional/backends/keras/layers.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/linear_algebra.py b/ivy/functional/backends/keras/linear_algebra.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/manipulation.py b/ivy/functional/backends/keras/manipulation.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/random.py b/ivy/functional/backends/keras/random.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/searching.py b/ivy/functional/backends/keras/searching.py new file mode 100644 index 0000000000000..f2af34d996346 --- /dev/null +++ b/ivy/functional/backends/keras/searching.py @@ -0,0 +1,74 @@ +import ivy +from ivy.func_wrapper import with_unsupported_dtypes +from . import backend_version +from .func_wrapper import use_keras_backend_framework + + +# Array API Standard # +# ------------------ # + + +@use_keras_backend_framework +def argmax( + x, + /, + *, + axis=None, + keepdims=False, + dtype=None, + select_last_index=False, + out=None, +): + pass + + +@use_keras_backend_framework +def argmin( + x, + /, + *, + axis=None, + keepdims=False, + dtype=None, + select_last_index=False, + out=None, +): + pass + + +@use_keras_backend_framework +def nonzero( + x, + /, + *, + as_tuple=True, + size=None, + fill_value=0, +): + pass + + +@use_keras_backend_framework +def where( + condition, + x1, + x2, + /, + *, + out=None, +): + pass + + +# Extra # +# ----- # + + +@use_keras_backend_framework +def argwhere( + x, + /, + *, + out=None, +): + pass diff --git a/ivy/functional/backends/keras/set.py b/ivy/functional/backends/keras/set.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/sorting.py b/ivy/functional/backends/keras/sorting.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/statistical.py b/ivy/functional/backends/keras/statistical.py new file mode 100644 index 0000000000000..3a3d4c7c9f199 --- /dev/null +++ b/ivy/functional/backends/keras/statistical.py @@ -0,0 +1,59 @@ +import ivy +from .func_wrapper import use_keras_backend_framework + + +# Array API Standard # +# -------------------# + + +@use_keras_backend_framework +def min(x, /, *, axis=None, keepdims=False, initial=None, where=None, out=None): + pass + +@use_keras_backend_framework +def max(x, /, *, axis=None, keepdims=False, out=None): + pass + + +@use_keras_backend_framework +def mean(x, /, axis=None, keepdims=False, *, dtype=None, out=None): + pass + + +@use_keras_backend_framework +def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + pass + + +@use_keras_backend_framework +def std(x, /, *, axis=None, correction=0.0, keepdims=False, out=None): + pass + + +@use_keras_backend_framework +def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + pass + + +@use_keras_backend_framework +def var(x, /, *, axis=None, correction=0.0, keepdims=False, out=None): + pass + + +# Extra # +# ------# + + +@use_keras_backend_framework +def cumprod(x, /, *, axis=0, exclusive=False, reverse=False, dtype=None, out=None): + pass + + +@use_keras_backend_framework +def cumsum(x, axis=0, exclusive=False, reverse=False, *, dtype=None, out=None): + pass + + +@use_keras_backend_framework +def einsum(equation, *operands, out=None): + pass diff --git a/ivy/functional/backends/keras/sub_backends/__init__.py b/ivy/functional/backends/keras/sub_backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy/functional/backends/keras/utility.py b/ivy/functional/backends/keras/utility.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy_tests/test_ivy/conftest.py b/ivy_tests/test_ivy/conftest.py index 13731b6f5eb4e..2649f4325ff04 100644 --- a/ivy_tests/test_ivy/conftest.py +++ b/ivy_tests/test_ivy/conftest.py @@ -90,7 +90,7 @@ def pytest_configure(config): if not no_mp: # we go multiprocessing, if multiversion - known_backends = {"tensorflow", "torch", "jax"} + known_backends = {"tensorflow", "torch", "jax", "keras"} found_backends = set() for fw in backend_strs: if "/" in fw: diff --git a/ivy_tests/test_ivy/helpers/available_frameworks.py b/ivy_tests/test_ivy/helpers/available_frameworks.py index 8bcd2d793a1b2..486bc50690197 100644 --- a/ivy_tests/test_ivy/helpers/available_frameworks.py +++ b/ivy_tests/test_ivy/helpers/available_frameworks.py @@ -6,7 +6,14 @@ def _available_frameworks(path="/opt/fw/"): ret = [] - for backend in ["numpy", "jax", "tensorflow", "torch", "paddle"]: + for backend in [ + "numpy", + "jax", + "keras", + "paddle", + "tensorflow", + "torch", + ]: if find_spec(backend) is not None: ret.append(backend) elif os.path.exists(f"{path}{backend}"): diff --git a/ivy_tests/test_ivy/helpers/globals.py b/ivy_tests/test_ivy/helpers/globals.py index d1a4fb6e83791..faa5dd93f80e7 100644 --- a/ivy_tests/test_ivy/helpers/globals.py +++ b/ivy_tests/test_ivy/helpers/globals.py @@ -11,6 +11,7 @@ available_frameworks = [ "numpy", "jax", + "keras", "tensorflow", "torch", "paddle", @@ -30,6 +31,7 @@ mod_backend = { "numpy": None, "jax": None, + "keras": None, "tensorflow": None, "torch": None, "paddle": None,