Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Limit torch_agent.Batch to tensors #3389

Merged
merged 33 commits into from
Mar 16, 2021
Merged

Limit torch_agent.Batch to tensors #3389

merged 33 commits into from
Mar 16, 2021

Conversation

stephenroller
Copy link
Contributor

@stephenroller stephenroller commented Jan 16, 2021

Patch description
Where possible, limit the attributes of our Batch object to only tensors. Additionally, delay cudafying until late and switch to using batch.batchsize for many operations.

This is in preparation of a fresh attempt of Background Preprocessing.

Testing steps
CI

@stephenroller stephenroller marked this pull request as ready for review February 26, 2021 13:50
Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet

| `ctpb` | Context tokens per batch |
| `ctps` | Context tokens per second |
| `ctrun` | Fraction of samples with some context truncation |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we have it be ctrunc; this reads to me like ct run which is a bit confusing

| `loss` | Loss |
| `lr` | The most recent learning rate applied |
| `ltpb` | Label tokens per batch |
| `ltps` | Label tokens per second |
| `ltrun` | Fraction of samples with some label truncation |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. totally fine to leave it if we want to keep abbreviations 5 chars or less

padded_context_vec,
torch.tensor(hist_lens, dtype=torch.long, device=self.device),
# sum here is list concat, not addition
context_vec, hist_lens_ = self._pad_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps unrelated, but padded tensor returning the lengths as simply the lengths of the input lists is not super intuitive, especially when an optional argument is the pad token; ideally you'd want to return the lengths of the unpadded input lists

is this why you recompute hist_lens below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk bro I'm just trying to get that one to work lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my b i didn't realize this was in the hred agent

@@ -114,18 +114,16 @@ def train_step(self, batch):
"""
Return confirmation of training.
"""
return Output(['Training {}!'.format(i) for i in range(len(batch.text_vec))])
return Output(['Training {}!'.format(i) for i in range(batch.batchsize)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can use f string here

i, batch.observations[i]['text']
)
for i in range(len(batch.text_vec))
'Evaluating {} (responding to {})!'.format(i, batch.text_vec.tolist())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -123,8 +113,12 @@ def __init__(
valid_indices=None,
candidates=None,
candidate_vecs=None,
reward=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this reward coming from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some teachers provide it now and we batchify it. It's used by Unlikelihood and others.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you include it in the Batch docstring?

"""
Move all tensors in the batch to a device.

Happens in place. Note that valid_indices and fields starting with an
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like these semantics. can we make that more clear in the Batch object description? specifically underscored fields being exempt

_context_original_length: Optional[torch.LongTensor]
_context_truncate_rate: Optional[torch.LongTensor]
_label_original_length: Optional[torch.LongTensor]
_label_truncate_rate: Optional[torch.LongTensor]

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we can warn here if we're passing in a non-tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For specifically these 4 or others? We have non-tensors (bools, Nones, etc).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait so... we're not fully getting rid of non-tensors here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I benchmarked them and they're not painful. Only the complex objects are.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how complex is complex? does this mean I can just make my own batch object and keep strings in there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I mean, we don't have any hard limitation. You can even do full observations if you want, you'll just pay a penalty with background workers. But any Batch object you manually want to add things to is still allowed.

I'll leave notes in batch descriptions

@@ -65,11 +65,9 @@ def atomic_save(state_dict: Any, path: str) -> None:
def padded_tensor(
items: List[Union[List[int], torch.LongTensor]],
pad_idx: int = 0,
use_cuda: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will almost certainly break some internal stuff but i suppose those can be dealt with individually

Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable - great to have this streamlining!

@@ -96,22 +88,20 @@ class Batch(AttrDict):

:param image:
list of image features in the format specified by the --image-mode arg.

:param observations:
the original observations in the batched order
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we might not need to define the new underscored args if we don't think the user should mess with them, but maybe we should add a sentence explaining generally what they are?

parlai/core/torch_agent.py Show resolved Hide resolved
parlai/core/torch_agent.py Show resolved Hide resolved
projects/style_gen/classifier.py Show resolved Hide resolved
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants