Skip to content

Commit

Permalink
Get rid of gin dependency and make policy base class abstract.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708966070
  • Loading branch information
jaindeepali authored and copybara-github committed Dec 23, 2024
1 parent c0618ec commit dc480a3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
6 changes: 0 additions & 6 deletions iris/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import copy
from typing import Any, Dict, Optional, Sequence, Union
from absl import logging
import gin
import gym
from gym import spaces
from gym.spaces import utils
Expand Down Expand Up @@ -106,7 +105,6 @@ def state(self, new_state: Dict[str, Any]) -> None:
pass


@gin.configurable
class MeanStdBuffer(Buffer):
"""Collect stats for calculating mean and std online."""

Expand Down Expand Up @@ -283,7 +281,6 @@ def state(self, state: Dict[str, np.ndarray]) -> None:
self._state = state.copy()


@gin.configurable
class NoNormalizer(Normalizer):
"""No Normalization applied to input."""

Expand All @@ -300,7 +297,6 @@ def __call__(
return value


@gin.configurable
class ActionRangeDenormalizer(Normalizer):
"""Actions mapped to given range from [-1, 1]."""

Expand Down Expand Up @@ -341,7 +337,6 @@ def __call__(
return action


@gin.configurable
class ObservationRangeNormalizer(Normalizer):
"""Observations mapped from given range to [-1, 1]."""

Expand Down Expand Up @@ -383,7 +378,6 @@ def __call__(
return observation


@gin.configurable
class RunningMeanStdNormalizer(Normalizer):
"""Standardize observations with mean and std calculated online."""

Expand Down
13 changes: 9 additions & 4 deletions iris/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

"""Policy class for computing action from weights and observation vector."""

import abc
from typing import Dict, Union

import gym
from gym.spaces import utils
import numpy as np


class BasePolicy(object):
class BasePolicy(abc.ABC):
"""Base policy class for reinforcement learning."""

def __init__(self, ob_space: gym.Space, ac_space: gym.Space) -> None:
Expand Down Expand Up @@ -55,23 +56,27 @@ def set_iteration(self, value: int | None):
self._iteration = value

def update_weights(self, new_weights: np.ndarray) -> None:
"""Updates the flat weights vector."""
self._weights[:] = new_weights[:]

def get_weights(self) -> np.ndarray:
"""Returns the flat weights vector."""
return self._weights

def get_representation_weights(self):
"""Returns the flat representation weights vector."""
return self._representation_weights

def update_representation_weights(
self, new_representation_weights: np.ndarray) -> None:
"""Updates the flat representation weights vector."""
self._representation_weights[:] = new_representation_weights[:]

@abc.abstractmethod
def reset(self):
pass
"""Resets the internal policy state."""

@abc.abstractmethod
def act(self, ob: Union[np.ndarray, Dict[str, np.ndarray]]
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""Maps the observation to action."""
raise NotImplementedError(
"Should be implemented in derived classes for specific policies.")
5 changes: 1 addition & 4 deletions requirements-rl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,4 @@ tf-agents # NOTE: Requires tensorflow>=2.15.0 for TFP compatibility.
jax # Use latest version.
jaxlib # Use latest version.
flax # Use latest version.
tensorflow # TODO(team): Resolve version conflicts.

# Configuration + Experimentation
gin-config>=0.5.0
tensorflow # TODO(team): Resolve version conflicts.

0 comments on commit dc480a3

Please sign in to comment.