From dae25efb72ddcb584eb7d1c1020b9b5c1ead974a Mon Sep 17 00:00:00 2001 From: Alexander Miller Date: Fri, 8 Feb 2019 17:21:29 -0500 Subject: [PATCH] batch / output objects --- parlai/core/torch_agent.py | 129 ++++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 59 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 0530d4f94a2..420650137fc 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -17,7 +17,7 @@ """ from torch import optim -from collections import deque, namedtuple +from collections import deque import json import random import numpy as np @@ -25,7 +25,7 @@ from parlai.core.build_data import modelzoo_path from parlai.core.dict import DictionaryAgent from parlai.core.utils import ( - set_namedtuple_defaults, argsort, padded_tensor, warn_once, round_sigfigs + AttrDict, argsort, padded_tensor, warn_once, round_sigfigs ) from parlai.core.distributed_utils import is_primary_worker @@ -35,93 +35,104 @@ raise ImportError('Need to install Pytorch: go to pytorch.org') -Batch = namedtuple('Batch', [ - 'text_vec', 'text_lengths', 'label_vec', 'label_lengths', 'labels', - 'valid_indices', 'candidates', 'candidate_vecs', 'image', - 'memory_vecs', 'observations' -]) -set_namedtuple_defaults(Batch, default=None) -Batch.__doc__ = """ -Batch is a namedtuple containing data being sent to an agent. +class Batch(AttrDict): + """ + Batch is a namedtuple containing data being sent to an agent. -This is the input type of the train_step and eval_step functions. -Agents can override the batchify function to return an extended namedtuple -with additional fields if they would like, though we recommend calling the -parent function to set up these fields as a base. + This is the input type of the train_step and eval_step functions. + Agents can override the batchify function to return an extended namedtuple + with additional fields if they would like, though we recommend calling the + parent function to set up these fields as a base. -.. py:attribute:: text_vec + .. py:attribute:: text_vec - bsz x seqlen tensor containing the parsed text data. + bsz x seqlen tensor containing the parsed text data. -.. py:attribute:: text_lengths + .. py:attribute:: text_lengths - list of length bsz containing the lengths of the text in same order as - text_vec; necessary for pack_padded_sequence. + list of length bsz containing the lengths of the text in same order as + text_vec; necessary for pack_padded_sequence. -.. py:attribute:: label_vec + .. py:attribute:: label_vec - bsz x seqlen tensor containing the parsed label (one per batch row). + bsz x seqlen tensor containing the parsed label (one per batch row). -.. py:attribute:: label_lengths + .. py:attribute:: label_lengths - list of length bsz containing the lengths of the labels in same order as - label_vec. + list of length bsz containing the lengths of the labels in same order as + label_vec. -.. py:attribute:: labels + .. py:attribute:: labels - list of length bsz containing the selected label for each batch row (some - datasets have multiple labels per input example). + list of length bsz containing the selected label for each batch row (some + datasets have multiple labels per input example). -.. py:attribute:: valid_indices + .. py:attribute:: valid_indices - list of length bsz containing the original indices of each example in the - batch. we use these to map predictions back to their proper row, since e.g. - we may sort examples by their length or some examples may be invalid. + list of length bsz containing the original indices of each example in the + batch. we use these to map predictions back to their proper row, since e.g. + we may sort examples by their length or some examples may be invalid. -.. py:attribute:: candidates + .. py:attribute:: candidates - list of lists of text. outer list has size bsz, inner lists vary in size - based on the number of candidates for each row in the batch. + list of lists of text. outer list has size bsz, inner lists vary in size + based on the number of candidates for each row in the batch. -.. py:attribute:: candidate_vecs + .. py:attribute:: candidate_vecs - list of lists of tensors. outer list has size bsz, inner lists vary in size - based on the number of candidates for each row in the batch. + list of lists of tensors. outer list has size bsz, inner lists vary in size + based on the number of candidates for each row in the batch. -.. py:attribute:: image + .. py:attribute:: image - list of image features in the format specified by the --image-mode arg. + list of image features in the format specified by the --image-mode arg. -.. py:attribute:: memory_vecs + .. py:attribute:: memory_vecs - list of lists of tensors. outer list has size bsz, inner lists vary based - on the number of memories for each row in the batch. these memories are - generated by splitting the input text on newlines, with the last line put - in the text field and the remaining put in this one. + list of lists of tensors. outer list has size bsz, inner lists vary based + on the number of memories for each row in the batch. these memories are + generated by splitting the input text on newlines, with the last line put + in the text field and the remaining put in this one. -.. py:attribute:: observations + .. py:attribute:: observations - the original observations in the batched order -""" + the original observations in the batched order + """ + def __init__(self, text_vec=None, text_lengths=None, + label_vec=None, label_lengths=None, labels=None, + valid_indices=None, + candidates=None, candidate_vecs=None, + image=None, memory_vecs=None, observations=None, + **kwargs): + super().__init__( + text_vec=text_vec, text_lengths=text_lengths, + label_vec=label_vec, label_lengths=label_lengths, labels=labels, + valid_indices=valid_indices, + candidates=candidates, candidate_vecs=candidate_vecs, + image=image, memory_vecs=memory_vecs, observations=observations, + **kwargs) + + +class Output(AttrDict): + """ + Output is a namedtuple containing agent predictions. -Output = namedtuple('Output', ['text', 'text_candidates']) -set_namedtuple_defaults(Output, default=None) -Output.__doc__ = """ -Output is a namedtuple containing agent predictions. + This is the expected return type of the train_step and eval_step functions, + though agents can choose to return None if they do not want to answer. -This is the expected return type of the train_step and eval_step functions, -though agents can choose to return None if they do not want to answer. + .. py:attribute:: text -.. py:attribute:: text + list of strings of length bsz containing the predictions of the model - list of strings of length bsz containing the predictions of the model + .. py:attribute:: text_candidates -.. py:attribute:: text_candidates + list of lists of length bsz containing ranked predictions of the model. + each sub-list is an ordered ranking of strings, of variable length. + """ - list of lists of length bsz containing ranked predictions of the model. - each sub-list is an ordered ranking of strings, of variable length. -""" + def __init__(self, text=None, text_candidates=None, **kwargs): + super().__init__(text=text, text_candidates=text_candidates, **kwargs) class TorchAgent(Agent):