diff --git a/matrax/__init__.py b/matrax/__init__.py index d493a6d..fff4b76 100644 --- a/matrax/__init__.py +++ b/matrax/__init__.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jumanji.registration import register +from jumanji.registration import make, register +from jumanji.version import __version__ +from matrax.env import MatrixGame from matrax.games import climbing_game, conflict_games, no_conflict_games, penalty_games +from matrax.types import Observation, State """Environment Registration""" diff --git a/matrax/env_test.py b/matrax/env_test.py index 934380b..6a3e17a 100644 --- a/matrax/env_test.py +++ b/matrax/env_test.py @@ -21,9 +21,8 @@ from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep -from matrax.env import MatrixGame +from matrax import MatrixGame, State from matrax.games import climbing_game -from matrax.types import State @pytest.fixture