Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Batch & Output objects #1437

Merged
merged 1 commit into from
Feb 12, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 70 additions & 59 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
"""

from torch import optim
from collections import deque, namedtuple
from collections import deque
import json
import random
import numpy as np
from parlai.core.agents import Agent
from parlai.core.build_data import modelzoo_path
from parlai.core.dict import DictionaryAgent
from parlai.core.utils import (
set_namedtuple_defaults, argsort, padded_tensor, warn_once, round_sigfigs
AttrDict, argsort, padded_tensor, warn_once, round_sigfigs
)
from parlai.core.distributed_utils import is_primary_worker

Expand All @@ -35,93 +35,104 @@
raise ImportError('Need to install Pytorch: go to pytorch.org')


Batch = namedtuple('Batch', [
'text_vec', 'text_lengths', 'label_vec', 'label_lengths', 'labels',
'valid_indices', 'candidates', 'candidate_vecs', 'image',
'memory_vecs', 'observations'
])
set_namedtuple_defaults(Batch, default=None)
Batch.__doc__ = """
Batch is a namedtuple containing data being sent to an agent.
class Batch(AttrDict):
"""
Batch is a namedtuple containing data being sent to an agent.

This is the input type of the train_step and eval_step functions.
Agents can override the batchify function to return an extended namedtuple
with additional fields if they would like, though we recommend calling the
parent function to set up these fields as a base.
This is the input type of the train_step and eval_step functions.
Agents can override the batchify function to return an extended namedtuple
with additional fields if they would like, though we recommend calling the
parent function to set up these fields as a base.

.. py:attribute:: text_vec
.. py:attribute:: text_vec

bsz x seqlen tensor containing the parsed text data.
bsz x seqlen tensor containing the parsed text data.

.. py:attribute:: text_lengths
.. py:attribute:: text_lengths

list of length bsz containing the lengths of the text in same order as
text_vec; necessary for pack_padded_sequence.
list of length bsz containing the lengths of the text in same order as
text_vec; necessary for pack_padded_sequence.

.. py:attribute:: label_vec
.. py:attribute:: label_vec

bsz x seqlen tensor containing the parsed label (one per batch row).
bsz x seqlen tensor containing the parsed label (one per batch row).

.. py:attribute:: label_lengths
.. py:attribute:: label_lengths

list of length bsz containing the lengths of the labels in same order as
label_vec.
list of length bsz containing the lengths of the labels in same order as
label_vec.

.. py:attribute:: labels
.. py:attribute:: labels

list of length bsz containing the selected label for each batch row (some
datasets have multiple labels per input example).
list of length bsz containing the selected label for each batch row (some
datasets have multiple labels per input example).

.. py:attribute:: valid_indices
.. py:attribute:: valid_indices

list of length bsz containing the original indices of each example in the
batch. we use these to map predictions back to their proper row, since e.g.
we may sort examples by their length or some examples may be invalid.
list of length bsz containing the original indices of each example in the
batch. we use these to map predictions back to their proper row, since e.g.
we may sort examples by their length or some examples may be invalid.

.. py:attribute:: candidates
.. py:attribute:: candidates

list of lists of text. outer list has size bsz, inner lists vary in size
based on the number of candidates for each row in the batch.
list of lists of text. outer list has size bsz, inner lists vary in size
based on the number of candidates for each row in the batch.

.. py:attribute:: candidate_vecs
.. py:attribute:: candidate_vecs

list of lists of tensors. outer list has size bsz, inner lists vary in size
based on the number of candidates for each row in the batch.
list of lists of tensors. outer list has size bsz, inner lists vary in size
based on the number of candidates for each row in the batch.

.. py:attribute:: image
.. py:attribute:: image

list of image features in the format specified by the --image-mode arg.
list of image features in the format specified by the --image-mode arg.

.. py:attribute:: memory_vecs
.. py:attribute:: memory_vecs

list of lists of tensors. outer list has size bsz, inner lists vary based
on the number of memories for each row in the batch. these memories are
generated by splitting the input text on newlines, with the last line put
in the text field and the remaining put in this one.
list of lists of tensors. outer list has size bsz, inner lists vary based
on the number of memories for each row in the batch. these memories are
generated by splitting the input text on newlines, with the last line put
in the text field and the remaining put in this one.

.. py:attribute:: observations
.. py:attribute:: observations

the original observations in the batched order
"""
the original observations in the batched order
"""

def __init__(self, text_vec=None, text_lengths=None,
label_vec=None, label_lengths=None, labels=None,
valid_indices=None,
candidates=None, candidate_vecs=None,
image=None, memory_vecs=None, observations=None,
**kwargs):
super().__init__(
text_vec=text_vec, text_lengths=text_lengths,
label_vec=label_vec, label_lengths=label_lengths, labels=labels,
valid_indices=valid_indices,
candidates=candidates, candidate_vecs=candidate_vecs,
image=image, memory_vecs=memory_vecs, observations=observations,
**kwargs)


class Output(AttrDict):
"""
Output is a namedtuple containing agent predictions.

Output = namedtuple('Output', ['text', 'text_candidates'])
set_namedtuple_defaults(Output, default=None)
Output.__doc__ = """
Output is a namedtuple containing agent predictions.
This is the expected return type of the train_step and eval_step functions,
though agents can choose to return None if they do not want to answer.

This is the expected return type of the train_step and eval_step functions,
though agents can choose to return None if they do not want to answer.
.. py:attribute:: text

.. py:attribute:: text
list of strings of length bsz containing the predictions of the model

list of strings of length bsz containing the predictions of the model
.. py:attribute:: text_candidates

.. py:attribute:: text_candidates
list of lists of length bsz containing ranked predictions of the model.
each sub-list is an ordered ranking of strings, of variable length.
"""

list of lists of length bsz containing ranked predictions of the model.
each sub-list is an ordered ranking of strings, of variable length.
"""
def __init__(self, text=None, text_candidates=None, **kwargs):
super().__init__(text=text, text_candidates=text_candidates, **kwargs)


class TorchAgent(Agent):
Expand Down