Skip to content

Commit

Permalink
[PPO] fix corner cases with PPO batch size and forward_batch_size (#…
Browse files Browse the repository at this point in the history
…563)

* fix corner cases PPO

* forward contrib credits from initial contribution

* forward contrib credits from initial discussions

---------

Co-authored-by: 1485840691-eng <1485840691-eng@users.noreply.github.com>
Co-authored-by: shubhlohiya <shubhlohiya@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 28, 2023
1 parent 3b0a1b5 commit 1b46c61
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import os
import time
import typing
Expand Down Expand Up @@ -898,7 +899,7 @@ def batched_forward_pass(
all_masks = []
all_values = []

for i in range(int(bs / fbs)):
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
Expand All @@ -915,7 +916,7 @@ def batched_forward_pass(
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]

for j in range(fbs):
for j in range(len(query_batch)):
if self.is_encoder_decoder:
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
start = 1
Expand Down

0 comments on commit 1b46c61

Please sign in to comment.