-
Notifications
You must be signed in to change notification settings - Fork 131
/
conda_uint8_patch.sh
19 lines (16 loc) · 1.97 KB
/
conda_uint8_patch.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# This patch fixes two bugs in ray when using a conda environment.
# Find Python folder name so that this patch can run correctly on different versions of Python.
python_loc_name=$(which python)
python_version=python$(python -c "import sys; print(sys.version_info[0], sys.version_info[1], sep='.')")
python_folder_name="${python_loc_name::-10}"lib/${python_version}
echo $python_folder_name
# Apply patches for https://github.com/ray-project/ray/issues/7946
sed -i '119s/tf.float32/tf.uint8/' "$python_folder_name"/site-packages/ray/rllib/policy/dynamic_tf_policy.py # Hardcoded observation space to uint8.
sed -i '76s/np.float32/np.uint8/' "$python_folder_name"/site-packages/ray/rllib/models/preprocessors.py # Same as above.
sed -i '231s/np.zeros(self.shape)/np.zeros(self.shape, dtype=self.observation_space.dtype)/' "$python_folder_name"/site-packages/ray/rllib/models/preprocessors.py # Change observation shape to what we actually provide
sed -i '214s/tf.int64/action_space.dtype/' "$python_folder_name"/site-packages/ray/rllib/models/catalog.py # Change action shape to what we actually provide
sed -i '56s/tf.math.argmax(self.inputs, axis=1)/tf.math.argmax(self.inputs, axis=1, output_type=tf.int32)/' "$python_folder_name"/site-packages/ray/rllib/models/tf/tf_action_dist.py # Actions should not sample at int64, int32 is the lowest that multinomial takes
sed -i '84s/tf.multinomial(self.inputs, 1)/tf.multinomial(self.inputs, 1, output_dtype=tf.int32)/' "$python_folder_name"/site-packages/ray/rllib/models/tf/tf_action_dist.py # Same as above
sed -i '656i\ actions = np.array(actions, dtype=policy.action_space.dtype)' "$python_folder_name"/site-packages/ray/rllib/evaluation/sampler.py # Insert action to uint8 conversion to save even more memory
# Apply patch for https://github.com/ray-project/ray/pull/8491 (fixed in ray 0.8.6, remove this when upgrading to ray >= 0.8.6)
sed -i '164i\ return self.sess.run(self.variables)' "$python_folder_name"/site-packages/ray/experimental/tf_utils.py