From 1bb6cadb336919f20dc299740332540204fdd8ab Mon Sep 17 00:00:00 2001 From: "a.gazzaev" Date: Fri, 12 Jul 2024 00:26:21 +0300 Subject: [PATCH 1/2] add PickAndPlace training --- examples/PickAndPlace_train.ipynb | 107 ++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 examples/PickAndPlace_train.ipynb diff --git a/examples/PickAndPlace_train.ipynb b/examples/PickAndPlace_train.ipynb new file mode 100644 index 00000000..0a1cddb3 --- /dev/null +++ b/examples/PickAndPlace_train.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install stable-baselines3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "from stable_baselines3 import DDPG, SAC, HerReplayBuffer\n", + "import panda_gym\n", + "\n", + "env = gym.make(\"PandaPickAndPlace\", render_mode=\"human\")\n", + "\n", + "model = SAC(\n", + " policy=\"MultiInputPolicy\", \n", + " env=env, \n", + " replay_buffer_class=HerReplayBuffer, \n", + " verbose=1, \n", + " buffer_size=1000000,\n", + " replay_buffer_kwargs=dict(\n", + " n_sampled_goal=4,\n", + " goal_selection_strategy='future',\n", + " ),\n", + " gamma=0.95,\n", + " learning_starts=1000,\n", + " train_freq=1\n", + " )\n", + "\n", + "model.learn(total_timesteps=4000000)\n", + "model.save(\"./her_robot\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pick and place" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "from stable_baselines3 import SAC, HerReplayBuffer\n", + "import panda_gym\n", + "\n", + "\n", + "from time import sleep\n", + "\n", + "env = gym.make(\"PandaPush-v3\", render_mode=\"human\")\n", + "\n", + "model = SAC.load(\"./her_robot\", env=env)\n", + "vec_env = model.get_env()\n", + "obs = vec_env.reset()\n", + "for i in range(10000):\n", + " sleep(0.08)\n", + " action, _states = model.predict(obs, deterministic=True)\n", + " obs, reward, done, info = vec_env.step(action)\n", + " vec_env.render()\n", + " if done:\n", + " obs = vec_env.reset()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From de8fa59cc13e5ab6f14c0421365e5fa499f4bcdd Mon Sep 17 00:00:00 2001 From: "a.gazzaev" Date: Fri, 12 Jul 2024 00:32:52 +0300 Subject: [PATCH 2/2] add PickAndPlace training --- examples/PickAndPlace_train.ipynb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/PickAndPlace_train.ipynb b/examples/PickAndPlace_train.ipynb index 0a1cddb3..f180572d 100644 --- a/examples/PickAndPlace_train.ipynb +++ b/examples/PickAndPlace_train.ipynb @@ -24,6 +24,7 @@ "source": [ "import gymnasium as gym\n", "from stable_baselines3 import DDPG, SAC, HerReplayBuffer\n", + "\n", "import panda_gym\n", "\n", "env = gym.make(\"PandaPickAndPlace\", render_mode=\"human\")\n", @@ -51,7 +52,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Pick and place" + "### Evaluate" ] }, { @@ -63,11 +64,9 @@ "import gymnasium as gym\n", "from stable_baselines3 import SAC, HerReplayBuffer\n", "import panda_gym\n", - "\n", - "\n", "from time import sleep\n", "\n", - "env = gym.make(\"PandaPush-v3\", render_mode=\"human\")\n", + "env = gym.make(\"PandaPickAndPlace\", render_mode=\"human\")\n", "\n", "model = SAC.load(\"./her_robot\", env=env)\n", "vec_env = model.get_env()\n", @@ -77,8 +76,9 @@ " action, _states = model.predict(obs, deterministic=True)\n", " obs, reward, done, info = vec_env.step(action)\n", " vec_env.render()\n", + "\n", " if done:\n", - " obs = vec_env.reset()\n" + " obs = vec_env.reset()" ] } ],