From 24593521ec4d25f2a6ed0b928221e1098811fa59 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 11 May 2022 20:41:56 -0400 Subject: [PATCH 1/6] Add wandb support --- examples/train.py | 73 ++++++++---- poetry.lock | 295 +++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 3 files changed, 342 insertions(+), 27 deletions(-) diff --git a/examples/train.py b/examples/train.py index 804f981..885c703 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,5 +1,6 @@ import os import random +import time import numpy as np import tqdm @@ -29,6 +30,9 @@ 'Number of training steps to start training.') flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.') flags.DEFINE_boolean('save_video', False, 'Save videos during evaluation.') +flags.DEFINE_boolean('track', False, 'Track experiments with Weights and Biases.') +flags.DEFINE_string('wandb_project_name', "jaxrl", "The wandb's project name.") +flags.DEFINE_string('wandb_entity', None, "the entity (team) of wandb's project") config_flags.DEFINE_config_file( 'config', 'configs/sac_default.py', @@ -37,8 +41,26 @@ def main(_): + kwargs = dict(FLAGS.config) + algo = kwargs.pop('algo') + run_name = f"{FLAGS.env_name}__{algo}__{FLAGS.seed}__{int(time.time())}" + print(FLAGS) + if FLAGS.track: + import wandb + + + wandb.init( + project=FLAGS.wandb_project_name, + entity=FLAGS.wandb_entity, + sync_tensorboard=True, + config=FLAGS, + name=run_name, + monitor_gym=True, + save_code=True, + ) + summary_writer = SummaryWriter( - os.path.join(FLAGS.save_dir, 'tb', str(FLAGS.seed))) + os.path.join(FLAGS.save_dir, run_name)) if FLAGS.save_video: video_train_folder = os.path.join(FLAGS.save_dir, 'video', 'train') @@ -53,8 +75,7 @@ def main(_): np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) - kwargs = dict(FLAGS.config) - algo = kwargs.pop('algo') + replay_buffer_size = kwargs.pop('replay_buffer_size') if algo == 'sac': agent = SACLearner(FLAGS.seed, @@ -115,29 +136,29 @@ def main(_): info['is_success'], info['total']['timesteps']) - if i >= FLAGS.start_training: - for _ in range(FLAGS.updates_per_step): - batch = replay_buffer.sample(FLAGS.batch_size) - update_info = agent.update(batch) - - if i % FLAGS.log_interval == 0: - for k, v in update_info.items(): - summary_writer.add_scalar(f'training/{k}', v, i) - summary_writer.flush() - - if i % FLAGS.eval_interval == 0: - eval_stats = evaluate(agent, eval_env, FLAGS.eval_episodes) - - for k, v in eval_stats.items(): - summary_writer.add_scalar(f'evaluation/average_{k}s', v, - info['total']['timesteps']) - summary_writer.flush() - - eval_returns.append( - (info['total']['timesteps'], eval_stats['return'])) - np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'), - eval_returns, - fmt=['%d', '%.1f']) + # if i >= FLAGS.start_training: + # for _ in range(FLAGS.updates_per_step): + # batch = replay_buffer.sample(FLAGS.batch_size) + # update_info = agent.update(batch) + + # if i % FLAGS.log_interval == 0: + # for k, v in update_info.items(): + # summary_writer.add_scalar(f'training/{k}', v, i) + # summary_writer.flush() + + # if i % FLAGS.eval_interval == 0: + # eval_stats = evaluate(agent, eval_env, FLAGS.eval_episodes) + + # for k, v in eval_stats.items(): + # summary_writer.add_scalar(f'evaluation/average_{k}s', v, + # info['total']['timesteps']) + # summary_writer.flush() + + # eval_returns.append( + # (info['total']['timesteps'], eval_stats['return'])) + # np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'), + # eval_returns, + # fmt=['%d', '%.1f']) if __name__ == '__main__': diff --git a/poetry.lock b/poetry.lock index 2d43c9c..5a5603e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -204,6 +204,17 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "docker-pycreds" +version = "0.4.0" +description = "Python bindings for the docker credentials store API" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +six = ">=1.4.0" + [[package]] name = "fasteners" version = "0.17.3" @@ -304,6 +315,28 @@ requests = {version = "*", extras = ["socks"]} six = "*" tqdm = "*" +[[package]] +name = "gitdb" +version = "4.0.9" +description = "Git Object Database" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +smmap = ">=3.0.1,<6" + +[[package]] +name = "gitpython" +version = "3.1.27" +description = "GitPython is a python library used to interact with Git repositories" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +gitdb = ">=4.0.1,<5" + [[package]] name = "glfw" version = "2.5.3" @@ -628,6 +661,14 @@ python-versions = ">=3.6" [package.dependencies] pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" +[[package]] +name = "pathtools" +version = "0.1.2" +description = "File system general utilities" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "pillow" version = "9.1.0" @@ -640,6 +681,20 @@ python-versions = ">=3.7" docs = ["olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinx-rtd-theme (>=1.0)", "sphinxext-opengraph"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +[[package]] +name = "promise" +version = "2.3" +description = "Promises/A+ implementation for Python" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +six = "*" + +[package.extras] +test = ["pytest (>=2.7.3)", "pytest-cov", "coveralls", "futures", "pytest-benchmark", "mock"] + [[package]] name = "protobuf" version = "3.20.1" @@ -648,6 +703,17 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "psutil" +version = "5.9.0" +description = "Cross-platform lib for process and system monitoring in Python." +category = "main" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.extras] +test = ["ipaddress", "mock", "unittest2", "enum34", "pywin32", "wmi"] + [[package]] name = "pybullet" version = "3.2.4" @@ -737,6 +803,47 @@ python-versions = ">=3.8,<3.11" [package.dependencies] numpy = ">=1.17.3,<1.25.0" +[[package]] +name = "sentry-sdk" +version = "1.5.12" +description = "Python client for Sentry (https://sentry.io)" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.10.0" + +[package.extras] +aiohttp = ["aiohttp (>=3.5)"] +beam = ["apache-beam (>=2.12)"] +bottle = ["bottle (>=0.12.13)"] +celery = ["celery (>=3)"] +chalice = ["chalice (>=1.16.0)"] +django = ["django (>=1.8)"] +falcon = ["falcon (>=1.4)"] +flask = ["flask (>=0.11)", "blinker (>=1.1)"] +httpx = ["httpx (>=0.16.0)"] +pure_eval = ["pure-eval", "executing", "asttokens"] +pyspark = ["pyspark (>=2.4.4)"] +quart = ["quart (>=0.16.1)", "blinker (>=1.1)"] +rq = ["rq (>=0.6)"] +sanic = ["sanic (>=0.8)"] +sqlalchemy = ["sqlalchemy (>=1.2)"] +tornado = ["tornado (>=5)"] + +[[package]] +name = "setproctitle" +version = "1.2.3" +description = "A Python module to customize the process title" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +test = ["pytest"] + [[package]] name = "setuptools-scm" version = "6.4.2" @@ -753,6 +860,14 @@ tomli = ">=1.0.0" test = ["pytest (>=6.2)", "virtualenv (>20)"] toml = ["setuptools (>=42)"] +[[package]] +name = "shortuuid" +version = "1.0.9" +description = "A generator library for concise, unambiguous and URL-safe UUIDs." +category = "main" +optional = false +python-versions = ">=3.5" + [[package]] name = "six" version = "1.16.0" @@ -761,6 +876,14 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +[[package]] +name = "smmap" +version = "5.0.0" +description = "A pure Python implementation of a sliding window memory map manager" +category = "main" +optional = false +python-versions = ">=3.6" + [[package]] name = "soupsieve" version = "2.3.2.post1" @@ -865,6 +988,40 @@ brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "wandb" +version = "0.12.16" +description = "A CLI and library for interacting with the Weights and Biases API." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +Click = ">=7.0,<8.0.0 || >8.0.0" +docker-pycreds = ">=0.4.0" +GitPython = ">=1.0.0" +pathtools = "*" +promise = ">=2.0,<3" +protobuf = ">=3.12.0" +psutil = ">=5.0.0" +python-dateutil = ">=2.6.1" +PyYAML = "*" +requests = ">=2.0.0,<3" +sentry-sdk = ">=1.0.0" +setproctitle = "*" +shortuuid = ">=0.5.0" +six = ">=1.13.0" + +[package.extras] +aws = ["boto3"] +azure = ["azure-storage-blob"] +gcp = ["google-cloud-storage"] +grpc = ["grpcio (>=1.27.2)"] +kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"] +launch = ["nbconvert", "nbformat", "chardet", "iso8601", "typing-extensions", "boto3", "google-cloud-storage", "kubernetes"] +media = ["numpy", "moviepy", "pillow", "bokeh", "soundfile", "plotly", "rdkit-pypi"] +sweeps = ["numpy (>=1.15,<1.21)", "scipy (>=1.5.4)", "pyyaml", "scikit-learn (==0.24.1)", "jsonschema (>=3.2.0)", "jsonref (>=0.2)", "pydantic (>=1.8.2)"] + [[package]] name = "zipp" version = "3.8.0" @@ -880,7 +1037,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.10" -content-hash = "2757c70a507147b8d1a637f6f8aa1e91a47dc268066751bec38193d1ff4c82b3" +content-hash = "15bbd0af8944dfdd62769f50bbf0930af30a922d246ad93d00d9857b1306bd1c" [metadata.files] absl-py = [ @@ -1050,6 +1207,10 @@ dm-tree = [ {file = "dm_tree-0.1.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1379a02df36e2bbff9819ceafa55ccd436b15af398803f781f372f8ead7ed871"}, {file = "dm_tree-0.1.7-cp39-cp39-win_amd64.whl", hash = "sha256:3ca0a58e219b7b0bc201fea4679971188d0a9028a2543c16803a84e8f8c7eb2c"}, ] +docker-pycreds = [ + {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"}, + {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"}, +] fasteners = [ {file = "fasteners-0.17.3-py3-none-any.whl", hash = "sha256:cae0772df265923e71435cc5057840138f4e8b6302f888a567d06ed8e1cbca03"}, {file = "fasteners-0.17.3.tar.gz", hash = "sha256:a9a42a208573d4074c77d041447336cf4e3c1389a256fd3e113ef59cf29b7980"}, @@ -1080,6 +1241,14 @@ gast = [ gdown = [ {file = "gdown-4.4.0.tar.gz", hash = "sha256:18fc3a4da4a2273deb7aa29c7486be4df3919d904158ad6a6a3e25c8115470d7"}, ] +gitdb = [ + {file = "gitdb-4.0.9-py3-none-any.whl", hash = "sha256:8033ad4e853066ba6ca92050b9df2f89301b8fc8bf7e9324d412a63f8bf1a8fd"}, + {file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"}, +] +gitpython = [ + {file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"}, + {file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"}, +] glfw = [ {file = "glfw-2.5.3-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:797528a7e433c1684cad2807838824e73c0c1fd8e12b01ee0b22ad203f46bf87"}, {file = "glfw-2.5.3-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_11_0_arm64.whl", hash = "sha256:1763af8d00e1b1517fccceb768e5470ce754a993bd8020878dca444f9523726b"}, @@ -1415,6 +1584,9 @@ packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] +pathtools = [ + {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, +] pillow = [ {file = "Pillow-9.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:af79d3fde1fc2e33561166d62e3b63f0cc3e47b5a3a2e5fea40d4917754734ea"}, {file = "Pillow-9.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:55dd1cf09a1fd7c7b78425967aacae9b0d70125f7d3ab973fadc7b5abc3de652"}, @@ -1455,6 +1627,9 @@ pillow = [ {file = "Pillow-9.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8d79c6f468215d1a8415aa53d9868a6b40c4682165b8cb62a221b1baa47db458"}, {file = "Pillow-9.1.0.tar.gz", hash = "sha256:f401ed2bbb155e1ade150ccc63db1a4f6c1909d3d378f7d1235a44e90d75fb97"}, ] +promise = [ + {file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"}, +] protobuf = [ {file = "protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996"}, {file = "protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3"}, @@ -1481,6 +1656,35 @@ protobuf = [ {file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"}, {file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"}, ] +psutil = [ + {file = "psutil-5.9.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:55ce319452e3d139e25d6c3f85a1acf12d1607ddedea5e35fb47a552c051161b"}, + {file = "psutil-5.9.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:7336292a13a80eb93c21f36bde4328aa748a04b68c13d01dfddd67fc13fd0618"}, + {file = "psutil-5.9.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:cb8d10461c1ceee0c25a64f2dd54872b70b89c26419e147a05a10b753ad36ec2"}, + {file = "psutil-5.9.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:7641300de73e4909e5d148e90cc3142fb890079e1525a840cf0dfd39195239fd"}, + {file = "psutil-5.9.0-cp27-none-win32.whl", hash = "sha256:ea42d747c5f71b5ccaa6897b216a7dadb9f52c72a0fe2b872ef7d3e1eacf3ba3"}, + {file = "psutil-5.9.0-cp27-none-win_amd64.whl", hash = "sha256:ef216cc9feb60634bda2f341a9559ac594e2eeaadd0ba187a4c2eb5b5d40b91c"}, + {file = "psutil-5.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90a58b9fcae2dbfe4ba852b57bd4a1dded6b990a33d6428c7614b7d48eccb492"}, + {file = "psutil-5.9.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff0d41f8b3e9ebb6b6110057e40019a432e96aae2008951121ba4e56040b84f3"}, + {file = "psutil-5.9.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:742c34fff804f34f62659279ed5c5b723bb0195e9d7bd9907591de9f8f6558e2"}, + {file = "psutil-5.9.0-cp310-cp310-win32.whl", hash = "sha256:8293942e4ce0c5689821f65ce6522ce4786d02af57f13c0195b40e1edb1db61d"}, + {file = "psutil-5.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:9b51917c1af3fa35a3f2dabd7ba96a2a4f19df3dec911da73875e1edaf22a40b"}, + {file = "psutil-5.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3d00a664e31921009a84367266b35ba0aac04a2a6cad09c550a89041034d19a0"}, + {file = "psutil-5.9.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7779be4025c540d1d65a2de3f30caeacc49ae7a2152108adeaf42c7534a115ce"}, + {file = "psutil-5.9.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5"}, + {file = "psutil-5.9.0-cp37-cp37m-win32.whl", hash = "sha256:df2c8bd48fb83a8408c8390b143c6a6fa10cb1a674ca664954de193fdcab36a9"}, + {file = "psutil-5.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1d7b433519b9a38192dfda962dd8f44446668c009833e1429a52424624f408b4"}, + {file = "psutil-5.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c3400cae15bdb449d518545cbd5b649117de54e3596ded84aacabfbb3297ead2"}, + {file = "psutil-5.9.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2237f35c4bbae932ee98902a08050a27821f8f6dfa880a47195e5993af4702d"}, + {file = "psutil-5.9.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1070a9b287846a21a5d572d6dddd369517510b68710fca56b0e9e02fd24bed9a"}, + {file = "psutil-5.9.0-cp38-cp38-win32.whl", hash = "sha256:76cebf84aac1d6da5b63df11fe0d377b46b7b500d892284068bacccf12f20666"}, + {file = "psutil-5.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:3151a58f0fbd8942ba94f7c31c7e6b310d2989f4da74fcbf28b934374e9bf841"}, + {file = "psutil-5.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:539e429da49c5d27d5a58e3563886057f8fc3868a5547b4f1876d9c0f007bccf"}, + {file = "psutil-5.9.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58c7d923dc209225600aec73aa2c4ae8ea33b1ab31bc11ef8a5933b027476f07"}, + {file = "psutil-5.9.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3611e87eea393f779a35b192b46a164b1d01167c9d323dda9b1e527ea69d697d"}, + {file = "psutil-5.9.0-cp39-cp39-win32.whl", hash = "sha256:4e2fb92e3aeae3ec3b7b66c528981fd327fb93fd906a77215200404444ec1845"}, + {file = "psutil-5.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:7d190ee2eaef7831163f254dc58f6d2e2a22e27382b936aab51c835fc080c3d3"}, + {file = "psutil-5.9.0.tar.gz", hash = "sha256:869842dbd66bb80c3217158e629d6fceaecc3a3166d3d1faee515b05dd26ca25"}, +] pybullet = [ {file = "pybullet-3.2.4-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:064d269dda3a5112c0baa4e9d620ffa430155b0df4b32a7741489eb77ca8b503"}, {file = "pybullet-3.2.4-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cc29d139c0129baf76f2149f9028b68b31e260bdb37ec8a61345cac41e632dfe"}, @@ -1577,14 +1781,99 @@ scipy = [ {file = "scipy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:bb7088e89cd751acf66195d2f00cf009a1ea113f3019664032d9075b1e727b6c"}, {file = "scipy-1.8.0.tar.gz", hash = "sha256:31d4f2d6b724bc9a98e527b5849b8a7e589bf1ea630c33aa563eda912c9ff0bd"}, ] +sentry-sdk = [ + {file = "sentry-sdk-1.5.12.tar.gz", hash = "sha256:259535ba66933eacf85ab46524188c84dcb4c39f40348455ce15e2c0aca68863"}, + {file = "sentry_sdk-1.5.12-py2.py3-none-any.whl", hash = "sha256:778b53f0a6c83b1ee43d3b7886318ba86d975e686cb2c7906ccc35b334360be1"}, +] +setproctitle = [ + {file = "setproctitle-1.2.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0a668acec8b61a971de54bc4c733869ea7b0eb1348eae5a32b9477f788908e5c"}, + {file = "setproctitle-1.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52265182fe5ac237d179d8e949248d307882a2e6ec7f189c8dac1c9d1b3631fa"}, + {file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71d00ef63a1f78e13c236895badac77b6c8503377467b9c1a4f81fe729d16e03"}, + {file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82a49aaf440232c762539ab3737b5174d31aba0141fd4bf4d8739c28d18624"}, + {file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:791bed39e4ecbdd008b64999a60c9cc560d17b3836ca0c27cd4708e8e1bcf495"}, + {file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8e4da68d4d4ba46d4c5db6ae5eb61b11de9c520f25ae8334570f4d0018a8611"}, + {file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:47f97f591ea2335b7d35f5e9ad7d806385338182dc6de5732d091e9c70ed1cc0"}, + {file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:501c084cf3df7d848e91c97d4f8c44d799ba545858a79c6960326ce6f285b4e4"}, + {file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a39b30d7400c0d50941fe19e1fe0b7d35676186fec4d9c010129ac91b883fd26"}, + {file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b213376fc779c0e1a4b60008f3fd03f74e9baa9665db37fa6646e98d31baa6d8"}, + {file = "setproctitle-1.2.3-cp310-cp310-win32.whl", hash = "sha256:e24fa9251cc22ddb88ef183070063fdca826c9636381f1c4fb9d2a1dccb7c2a4"}, + {file = "setproctitle-1.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:3b1883ccdbee624386dc046cfbcd80c4e75e24c478f35627984a79892e088b88"}, + {file = "setproctitle-1.2.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f9cf1098205c23fbcaaaef798afaff714fa9ffadf24166f5e85e6d16b9ef82a1"}, + {file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a546cd2dfaecb227d24122257b98b2e062762871888835c7b608f1c41c3a77ad"}, + {file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e40c35564081983eab6a07f9eb5693867bc447b0edf9c61b69446223d6593814"}, + {file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d083cae02e344e760bd21c28d591ac5f7ddbd6e1a0ecba62092ae724abd5c28"}, + {file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2fa9f4b382a6cf88f2f345044d0916a92f37cac21355585bd14bc7ee91af187"}, + {file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:38855b06a124361dc73c198853dee3f2b775531c4f4b7472f0e3d441192b3d8a"}, + {file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:a81067bdc015fee1cc148c79b346f24fdad1224a8898b4239c7cbdee1add8a60"}, + {file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:409a39f92e123be061626fdfd3e76625b04db103479bb4ba1c85b587db0b9498"}, + {file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:a993610383028f093112dce7f77b262e88fce9d70127535fcdc78953179857e8"}, + {file = "setproctitle-1.2.3-cp36-cp36m-win32.whl", hash = "sha256:4eed53c12146de5df959d84384ffc2774651cab406ee4854e12728cf0eee5297"}, + {file = "setproctitle-1.2.3-cp36-cp36m-win_amd64.whl", hash = "sha256:335750c9eb5b18326a138a09266862a52b4f474277c3e410b419bea9a1df8bee"}, + {file = "setproctitle-1.2.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7a72bbe53191fbe574c94c0f8b9451dce535b398b7c47ce2e26e21d55eaa1d7e"}, + {file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5464e6812d050c986e6e9b97d54ab88c23dbe9d81151a2fa10b48bb5133a1e2c"}, + {file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec7c3a27460ae7811e868e5494e3d8aee5012912744c48fa2d80b5e614b1b972"}, + {file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01cef383afc7ea7a3b1696818c8712029bf2f1d64f5d4777dbaf0166becf2c00"}, + {file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54c7315e53b49ef2227d47a75c3d28c4c51ea9ee46a066460732c0d0f8e605a7"}, + {file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0b444ed4051161a3b0a85dec2bb9b50922f37c75f5fb86f7784b235cf6754336"}, + {file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:be0b46beeb1c92450079a7f30a025d69b63fd6a5de040ebc478fd6e6bf3b63fc"}, + {file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:60f7a2f5da36a3075dda7edbee2173be5b765b0460b8d401ee01a11f68dee1d2"}, + {file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:138bfa853e607f06d95b0f253e9152b32a00af3d0dbec96abf0871236a483932"}, + {file = "setproctitle-1.2.3-cp37-cp37m-win32.whl", hash = "sha256:e80fc59739a738b5c67afbbb9d1c238aa47b6d290c2ada872b15c819350ec5f8"}, + {file = "setproctitle-1.2.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a912df3f065572cef211e9ed9f157a0dd2bd73d150281f18f00728afa1b1e5d2"}, + {file = "setproctitle-1.2.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d45dbe4171f8c27a515ecb4562f4cd9ef67d98474bea18e0c14dfbdc2b225050"}, + {file = "setproctitle-1.2.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9d905ac84dde5227de6516ec08639759f99684148bb88ba05f4cbdaebff5d69"}, + {file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f272b84d79bbe15af26ecf6f7c129bbe642f628866c9253659cdb519216f138f"}, + {file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc586f002fd5dd8695718e22a83771fd9f744f081a2b8e614bf6b5f44135964a"}, + {file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4051c3a3b07f8a4cca205cd45366a22f322da2f26491c0d6b313a10f8c77b734"}, + {file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25538341e56f9e75e9759229ff674282dccb5b1ce79a974f968d36208d465674"}, + {file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fdb2231db176e0848b757fc5d9bed08bc8a498b5b9abb8b640f39e9720f309fc"}, + {file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0670f2130a7ca0e167d3d5a7c8e3c707340b8693d6af7416ff55c18ab2a0a43f"}, + {file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:9a92978030616f5e20617b7b832efee398df82072b7239c53db41c8026f5fe55"}, + {file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:28e0df80d5069586a08a3cb463fb23503a37cbb805826ef93164bc4bfb5f35b9"}, + {file = "setproctitle-1.2.3-cp38-cp38-win32.whl", hash = "sha256:35b869e416a105c59133a48b569c6e808159485d916f55e80c7394a42667a386"}, + {file = "setproctitle-1.2.3-cp38-cp38-win_amd64.whl", hash = "sha256:f47f6704880869d8e8f52efac2f2f60f5ed4cb9662b98fc1c7e916eefe76e61d"}, + {file = "setproctitle-1.2.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ccb0b5334dbf248f7504d88b5e9e9a09a0da119eeafacd6f7247f7c055443522"}, + {file = "setproctitle-1.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:14641a4ec2f2110cf4afc666eaecc82ba67814e927e02647fa1f4cf74476e752"}, + {file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4a3cb19346a0cd680617742f5e39fdd14596f6fd91d6c9038272663e37441b4"}, + {file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2ac0ebd9c63c3d19f768966be2f771bf088bc7373c63ed6fcbb3444a30d0f62"}, + {file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32a84cc309b9e595f06a55bec2fa335a23c307a55d2989864b60ecd71ea87897"}, + {file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f55493c987935fa540ef9ffb7ee7db03b4a18a9d5cc103681e2e6a6dfbd7054"}, + {file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f2a137984d3436f13e4bf7c8ca6f6f292df119c009c5e39556cabba4f4bfbf92"}, + {file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f06ff922254023eaabef6af6631f89e5f2f420cf0112865d57d7703f933d4e9f"}, + {file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:eb06c1086cf8c8cf12ce45a02450befcb408dfd646d0ccb47d388fd6e73c333a"}, + {file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2c8c245e08f6a296fdaa1b36894ec40e20464a4fc6458e6178c8d55a2f83457a"}, + {file = "setproctitle-1.2.3-cp39-cp39-win32.whl", hash = "sha256:21d6e064b8fee4e58eb00cdd8771c638de1bc30bb6c02d0208af9ca0a1c00898"}, + {file = "setproctitle-1.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:efb3001fd9e71d3ae939d826bf436f0446fd30a6ac01e0ce08cd7eb55ee5ac57"}, + {file = "setproctitle-1.2.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3dbe87e76197f9a303451512088c18c96f09a6fc4f871a92e5bd695f46f94a26"}, + {file = "setproctitle-1.2.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b207de9e4f4aa5265b36dd826a1f6ef6566b064a042033bd7447efb7e9a7664"}, + {file = "setproctitle-1.2.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48ac48a94040ef21be37366cbc8270fcba2ca103d6c64da6099d5a7b034f72d0"}, + {file = "setproctitle-1.2.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:9fb5d2e66f94eebc3d06cda9e71a3fffef24c5273971180a4b5628a37fae05a5"}, + {file = "setproctitle-1.2.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:423f8a6d8116acf975ebf93d6b5c4a752f7d2039fa9aafe175a62de86e17016e"}, + {file = "setproctitle-1.2.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c0be45535e934deab3aa72ed1a8487174af4ea12cec124478c68a312e1c8b13"}, + {file = "setproctitle-1.2.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65a9384cafdfed98f91416e93705ad08f049c298afcb9c515882beba23153bd0"}, + {file = "setproctitle-1.2.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d312a170f539895c8093b5e68ba126aa131c9f0d00f6360410db27ec50bf7afa"}, + {file = "setproctitle-1.2.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c93a2272740e60cddf59d3e1d35dbb89fcc3676f5ca9618bb4e6ae9633fdf13c"}, + {file = "setproctitle-1.2.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76f59444a25fb42ca07f53a4474b1545d97a06f016e6c6b8246eee5b146820b5"}, + {file = "setproctitle-1.2.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06aab65e68163ead9d046b452dd9ad1fc6834ce6bde490f63fdce3be53e9cc73"}, + {file = "setproctitle-1.2.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:97accd117392b1e57e09888792750c403d7729b7e4b193005178b3736b325ea0"}, + {file = "setproctitle-1.2.3.tar.gz", hash = "sha256:ecf28b1c07a799d76f4326e508157b71aeda07b84b90368ea451c0710dbd32c0"}, +] setuptools-scm = [ {file = "setuptools_scm-6.4.2-py3-none-any.whl", hash = "sha256:acea13255093849de7ccb11af9e1fb8bde7067783450cee9ef7a93139bddf6d4"}, {file = "setuptools_scm-6.4.2.tar.gz", hash = "sha256:6833ac65c6ed9711a4d5d2266f8024cfa07c533a0e55f4c12f6eff280a5a9e30"}, ] +shortuuid = [ + {file = "shortuuid-1.0.9-py3-none-any.whl", hash = "sha256:b2bb9eb7773170e253bb7ba25971023acb473517a8b76803d9618668cb1dd46f"}, + {file = "shortuuid-1.0.9.tar.gz", hash = "sha256:459f12fa1acc34ff213b1371467c0325169645a31ed989e268872339af7563d5"}, +] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +smmap = [ + {file = "smmap-5.0.0-py3-none-any.whl", hash = "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94"}, + {file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"}, +] soupsieve = [ {file = "soupsieve-2.3.2.post1-py3-none-any.whl", hash = "sha256:3b2503d3c7084a42b1ebd08116e5f81aadfaea95863628c80a3b774a11b7c759"}, {file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"}, @@ -1619,6 +1908,10 @@ urllib3 = [ {file = "urllib3-1.26.9-py2.py3-none-any.whl", hash = "sha256:44ece4d53fb1706f667c9bd1c648f5469a2ec925fcf3a776667042d645472c14"}, {file = "urllib3-1.26.9.tar.gz", hash = "sha256:aabaf16477806a5e1dd19aa41f8c2b7950dd3c746362d7e3223dbe6de6ac448e"}, ] +wandb = [ + {file = "wandb-0.12.16-py2.py3-none-any.whl", hash = "sha256:ed7782dadfb5bc457998eccd995f88ae564cdf2a36b12024e4a5d9a47b1b84e8"}, + {file = "wandb-0.12.16.tar.gz", hash = "sha256:a738b5eb61081fa96fc2e16ffaf6dbde67b78f973ff45bda61ed93659ca09912"}, +] zipp = [ {file = "zipp-3.8.0-py3-none-any.whl", hash = "sha256:c4f6e5bbf48e74f7a38e7cc5b0480ff42b0ae5178957d564d18932525d5cf099"}, {file = "zipp-3.8.0.tar.gz", hash = "sha256:56bf8aadb83c24db6c4b577e13de374ccfb67da2078beba1d037c17980bf43ad"}, diff --git a/pyproject.toml b/pyproject.toml index 1d39a41..40ec181 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ imageio-ffmpeg = "^0.4.7" mujoco-py = "^2.1.2" d4rl = {git = "https://github.com/ikostrikov/d4rl.git"} dm-control = "^1.0.2" +wandb = "^0.12.16" [tool.poetry.dev-dependencies] From d7c25d7c8898a9442eba53122c4daceca8c2c182 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 11 May 2022 20:44:22 -0400 Subject: [PATCH 2/6] Add documentation --- examples/README.md | 7 +++++++ examples/train.py | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index a734e36..2d89480 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,6 +6,13 @@ OpenAI Gym MuJoCo tasks python train.py --env_name=HalfCheetah-v2 --save_dir=./tmp/ ``` +Experiment tracking with Weights and Biases + +```bash +python train.py --env_name=HalfCheetah-v2 --save_dir=./tmp/ --track +``` + + DeepMind Control suite (--env-name=domain-task) ```bash diff --git a/examples/train.py b/examples/train.py index 885c703..d1ea2af 100644 --- a/examples/train.py +++ b/examples/train.py @@ -47,7 +47,6 @@ def main(_): print(FLAGS) if FLAGS.track: import wandb - wandb.init( project=FLAGS.wandb_project_name, From 59b0bce989eb283fe689b216043b05a81a257bd4 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 11 May 2022 20:45:40 -0400 Subject: [PATCH 3/6] Revert changes --- examples/train.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/examples/train.py b/examples/train.py index d1ea2af..e784796 100644 --- a/examples/train.py +++ b/examples/train.py @@ -135,29 +135,29 @@ def main(_): info['is_success'], info['total']['timesteps']) - # if i >= FLAGS.start_training: - # for _ in range(FLAGS.updates_per_step): - # batch = replay_buffer.sample(FLAGS.batch_size) - # update_info = agent.update(batch) - - # if i % FLAGS.log_interval == 0: - # for k, v in update_info.items(): - # summary_writer.add_scalar(f'training/{k}', v, i) - # summary_writer.flush() - - # if i % FLAGS.eval_interval == 0: - # eval_stats = evaluate(agent, eval_env, FLAGS.eval_episodes) - - # for k, v in eval_stats.items(): - # summary_writer.add_scalar(f'evaluation/average_{k}s', v, - # info['total']['timesteps']) - # summary_writer.flush() - - # eval_returns.append( - # (info['total']['timesteps'], eval_stats['return'])) - # np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'), - # eval_returns, - # fmt=['%d', '%.1f']) + if i >= FLAGS.start_training: + for _ in range(FLAGS.updates_per_step): + batch = replay_buffer.sample(FLAGS.batch_size) + update_info = agent.update(batch) + + if i % FLAGS.log_interval == 0: + for k, v in update_info.items(): + summary_writer.add_scalar(f'training/{k}', v, i) + summary_writer.flush() + + if i % FLAGS.eval_interval == 0: + eval_stats = evaluate(agent, eval_env, FLAGS.eval_episodes) + + for k, v in eval_stats.items(): + summary_writer.add_scalar(f'evaluation/average_{k}s', v, + info['total']['timesteps']) + summary_writer.flush() + + eval_returns.append( + (info['total']['timesteps'], eval_stats['return'])) + np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'), + eval_returns, + fmt=['%d', '%.1f']) if __name__ == '__main__': From b7dd8dab1c085306d500e5baa6bc2a0423f33eb4 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 11 May 2022 20:46:34 -0400 Subject: [PATCH 4/6] Quick fix --- examples/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train.py b/examples/train.py index e784796..eb6c0e1 100644 --- a/examples/train.py +++ b/examples/train.py @@ -44,7 +44,6 @@ def main(_): kwargs = dict(FLAGS.config) algo = kwargs.pop('algo') run_name = f"{FLAGS.env_name}__{algo}__{FLAGS.seed}__{int(time.time())}" - print(FLAGS) if FLAGS.track: import wandb From b1fa9dc99957fcc09d236b0e544c0a6053aee663 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 11 May 2022 20:47:18 -0400 Subject: [PATCH 5/6] Add gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index b6e4761..e32ee43 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +*.tfevents.* +tmp +wandb + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 42dd5c505fbec94a49e6814b4b670adc3dada7f0 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 12 May 2022 22:13:50 -0400 Subject: [PATCH 6/6] Use `gym.wrappers.RecordVideo` --- jaxrl/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jaxrl/utils.py b/jaxrl/utils.py index a1ed250..fc8e8fd 100644 --- a/jaxrl/utils.py +++ b/jaxrl/utils.py @@ -5,7 +5,6 @@ from gym.wrappers.pixel_observation import PixelObservationWrapper from jaxrl import wrappers -from jaxrl.wrappers import VideoRecorder def make_env(env_name: str, @@ -44,7 +43,7 @@ def make_env(env_name: str, env = RescaleAction(env, -1.0, 1.0) if save_folder is not None: - env = VideoRecorder(env, save_folder=save_folder) + env = gym.wrappers.RecordVideo(env, save_folder) if from_pixels: if env_name in env_ids: