Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

bb2 #3790

Merged
merged 1 commit into from
Jul 16, 2021
Merged

bb2 #3790

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

See https://arxiv.org/abs/2007.01282
"""
from copy import deepcopy
import torch
from typing import Tuple, Union, Optional, List, Dict, Any

from parlai.core.dict import DictionaryAgent
from parlai.core.opt import Opt
from parlai.agents.transformer.transformer import TransformerGeneratorModel

from parlai.agents.rag.args import RetrieverType
from parlai.agents.rag.modules import RagModel, Document, T5RagModel
from parlai.agents.rag.rag import RagAgent
from parlai.agents.rag.model_types import RagToken, get_forced_decoder_inputs
from parlai.utils.typing import TShared


class Fid(RagToken):
Expand Down Expand Up @@ -200,6 +203,98 @@ def build_model(self) -> FidModel:
return model


RETRIEVER_DOC_LEN_TOKENS = 256


class SearchQueryFiDAgent(FidAgent):
@classmethod
def add_cmdline_args(cls, parser, partial_opt=None):
super().add_cmdline_args(parser, partial_opt=partial_opt)
group = parser.add_argument_group('Search Query FiD Params')

# Search Query generator
group.add_argument(
'--search-query-generator-model-file',
type=str,
help='Path to a query generator model.',
)
group.add_argument(
'--search-query-generator-inference',
type=str,
default='greedy',
help='Generation algorithm for the search query generator model',
)
group.add_argument(
'--search-query-generator-beam-min-length',
type=int,
default=1,
help='The beam_min_length opt for the search query generator model',
)
group.add_argument(
'--search-query-generator-beam-size',
type=int,
default=1,
help='The beam_size opt for the search query generator model',
)
group.add_argument(
'--search-query-generator-text-truncate',
type=int,
default=512,
help='Truncates the input to the search query generator model',
)

# Creating chunks and spliting the documents
group.add_argument(
'--splitted-chunk-length',
type=int,
default=RETRIEVER_DOC_LEN_TOKENS,
help='The number of tokens in each document split',
)
group.add_argument(
'--doc-chunk-split-mode',
type=str,
choices=['word', 'token'],
default='word',
help='split the docs by white space (word) or dict tokens.',
)
group.add_argument(
'--n-ranked-doc-chunks',
type=int,
default=1,
help='Number of document chunks to keep if documents is too long and has to be splitted.',
)
group.add_argument(
'--doc-chunks-ranker',
type=str,
choices=['tfidf', 'head'],
default='head',
help='How to rank doc chunks.',
)

return parser


class SearchQuerySearchEngineFiDAgent(SearchQueryFiDAgent):
def __init__(self, opt: Opt, shared: TShared = None):
opt = deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.SEARCH_ENGINE.value
super().__init__(opt, shared=shared)

@classmethod
def add_cmdline_args(cls, parser, partial_opt=None):
super().add_cmdline_args(parser, partial_opt=partial_opt)
group = parser.add_argument_group('Search Engine FiD Params')
group.add_argument('--search-server', type=str, help='A search server addrees.')
return parser


class SearchQueryFAISSIndexFiDAgent(SearchQueryFiDAgent):
def __init__(self, opt: Opt, shared: TShared = None):
opt = deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.SEARCH_TERM_FAISS.value
super().__init__(opt, shared=shared)


def concat_enc_outs(
input: torch.LongTensor,
enc_out: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions parlai/agents/rag/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class RetrieverType(Enum):
TFIDF = 'tfidf'
DPR_THEN_POLY = 'dpr_then_poly'
POLY_FAISS = 'poly_faiss'
SEARCH_ENGINE = 'search_engine'
SEARCH_TERM_FAISS = 'search_term_faiss'


def setup_rag_args(parser: ParlaiParser) -> ParlaiParser:
Expand Down
132 changes: 132 additions & 0 deletions parlai/agents/rag/retrieve_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
APIs for retrieving a list of "Contents" using an "Search Query".

The term "Search Query" here refers to any abstract form of input string. The definition
of "Contents" is also loose and depends on the API.
"""

from abc import ABC, abstractmethod
import requests
from typing import Any, Dict, List

from parlai.core.opt import Opt
from parlai.utils import logging


CONTENT = 'content'
DEFAULT_NUM_TO_RETRIEVE = 5


class RetrieverAPI(ABC):
"""
Provides the common interfaces for retrievers.

Every retriever in this modules must implement the `retrieve` method.
"""

def __init__(self, opt: Opt):
self.skip_query_token = opt['skip_retrieval_token']

@abstractmethod
def retrieve(
self, queries: List[str], num_ret: int = DEFAULT_NUM_TO_RETRIEVE
) -> List[Dict[str, Any]]:
"""
Implements the underlying retrieval mechanism.
"""

def create_content_dict(self, content: list, **kwargs) -> Dict:
resp_content = {CONTENT: content}
resp_content.update(**kwargs)
return resp_content


class SearchEngineRetrieverMock(RetrieverAPI):
"""
For unit tests and debugging (does not need a running server).
"""

def retrieve(
self, queries: List[str], num_ret: int = DEFAULT_NUM_TO_RETRIEVE
) -> List[Dict[str, Any]]:
all_docs = []
for query in queries:
if query == self.skip_query_token:
docs = None
else:
docs = []
for idx in range(num_ret):
doc = self.create_content_dict(
f'content {idx} for query "{query}"',
url=f'url_{idx}',
title=f'title_{idx}',
)
docs.append(doc)
all_docs.append(docs)
return all_docs


class SearchEngineRetriever(RetrieverAPI):
"""
Queries a server (eg, search engine) for a set of documents.

This module relies on a running HTTP server. For each retrieval it sends the query
to this server and receieves a JSON; it parses the JSON to create the the response.
"""

def __init__(self, opt: Opt):
super().__init__(opt=opt)
self.server_address = self._validate_server(opt.get('search_server'))

def _query_search_server(self, query_term, n):
server = self.server_address
req = {'q': query_term, 'n': n}
logging.debug(f'sending search request to {server}')
server_response = requests.post(server, data=req)
resp_status = server_response.status_code
if resp_status == 200:
return server_response.json().get('response', None)
logging.error(
f'Failed to retrieve data from server! Search server returned status {resp_status}'
)

def _validate_server(self, address):
if not address:
raise ValueError('Must provide a valid server for search')
if address.startswith('http://') or address.startswith('https://'):
return address
PROTOCOL = 'http://'
logging.warning(f'No portocol provided, using "{PROTOCOL}"')
return f'{PROTOCOL}{address}'

def _retrieve_single(self, search_query: str, num_ret: int):
if search_query == self.skip_query_token:
return None

retrieved_docs = []
search_server_resp = self._query_search_server(search_query, num_ret)
if not search_server_resp:
logging.warning(
f'Server search did not produce any results for "{search_query}" query.'
' returning an empty set of results for this query.'
)
return retrieved_docs

for rd in search_server_resp:
url = rd.get('url', '')
title = rd.get('title', '')
sentences = [s.strip() for s in rd[CONTENT].split('\n') if s and s.strip()]
retrieved_docs.append(
self.create_content_dict(url=url, title=title, content=sentences)
)
return retrieved_docs

def retrieve(
self, queries: List[str], num_ret: int = DEFAULT_NUM_TO_RETRIEVE
) -> List[Dict[str, Any]]:
# TODO: update the server (and then this) for batch responses.
return [self._retrieve_single(q, num_ret) for q in queries]
Loading