Skip to content

Commit

Permalink
chnage numeric check logic
Browse files Browse the repository at this point in the history
  • Loading branch information
alantian committed Jun 15, 2022
1 parent ddf0d80 commit 399fccf
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions evojax/algo/cma_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class CMA_ES_JAX(NEAlgorithm):
logger:
A logging.Logger instance (optional).
If not specified, a new one will be created.
enable_numeric_check:
A bool indicating whether to enable numeric check (optional).
If True, a numeric check for mean and standard deviation is enabled.
The default value is False.
"""

def __init__(
Expand All @@ -104,6 +108,7 @@ def __init__(
seed: Optional[int] = 0,
cov: Optional[jnp.ndarray] = None,
logger: logging.Logger = None,
enable_numeric_check: Optional[bool] = False,
):
if mean is None:
mean = jnp.zeros(param_size)
Expand All @@ -114,7 +119,14 @@ def __init__(
f"In this case, mean (whose shape is {mean.shape}) must have a dimension of (param_size, )" \
f" (i.e. {(param_size, )}), which is not true."
mean = ensure_jnp(mean)
mean_max = ensure_jnp(_MEAN_MAX_X64 if jax.config.jax_enable_x64 else _MEAN_MAX_X32)

if enable_numeric_check:
mean_max = ensure_jnp(_MEAN_MAX_X64 if jax.config.jax_enable_x64 else _MEAN_MAX_X32)
sigma_max = ensure_jnp(_SIGMA_MAX_X64 if jax.config.jax_enable_x64 else _SIGMA_MAX_X32)
else:
mean_max = ensure_jnp(jnp.inf)
sigma_max = ensure_jnp(jnp.inf)

assert jnp.all(
jnp.abs(mean) < mean_max
), f"Abs of all elements of mean vector must be less than {mean_max}"
Expand Down Expand Up @@ -207,7 +219,7 @@ def __init__(
1.0 / (21.0 * (n_dim ** 2))
),
weights=weights,
sigma_max=ensure_jnp(_SIGMA_MAX_X64 if jax.config.jax_enable_x64 else _SIGMA_MAX_X32),
sigma_max=sigma_max,
)

# evolution path (state)
Expand Down

0 comments on commit 399fccf

Please sign in to comment.