Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BB3] Fix Module Level Tasks #4798

Merged
merged 1 commit into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions projects/bb3/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import projects.bb3.tasks.mutators # type: ignore
107 changes: 104 additions & 3 deletions projects/bb3/tasks/module_level_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.teachers import MultiTaskTeacher
from parlai.tasks.convai2.agents import NormalizedTeacher
from parlai.tasks.fits.agents import FitsBaseTeacher
from parlai.tasks.light_dialog.agents import DefaultTeacher as LightTeacher
from parlai.tasks.light_dialog_wild.agents import DefaultTeacher as LightWildTeacher
from parlai.tasks.msc.agents import (
SessionBaseMscTeacher,
SessionBasePersonaSummaryTeacher,
)
from parlai.tasks.saferdialogues.agents import SaferDialoguesTeacher
from parlai.tasks.taskmaster2.agents import Taskmaster2Parser

#########
# Mixin #
Expand Down Expand Up @@ -110,6 +120,15 @@ def get_multitask_weights(self) -> Union[List[int], str]:


class MaybeSearchTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'WowSearchDecisionTeacher',
Expand All @@ -127,6 +146,15 @@ def get_task_type(self) -> str:


class MemoryDecisionTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'Convai2MemoryDecisionTeacher',
Expand All @@ -143,6 +171,14 @@ def get_task_type(self) -> str:


class SearchQueryGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
FitsBaseTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return ['WoiSearchQueryTeacher', 'FitsSearchQueryTeacher']

Expand All @@ -154,8 +190,16 @@ def get_task_type(self) -> str:


class MemoryGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
SessionBasePersonaSummaryTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return ['MSCMemoryGeneratorTeacher']
return ['MscMemoryGenerationTeacher']

def get_multitask_weights(self) -> Union[List[int], str]:
return [1]
Expand All @@ -165,6 +209,15 @@ def get_task_type(self) -> str:


class MemoryKnowledgeGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'BSTMemoryKnowledgePersOverlapTeacher',
Expand Down Expand Up @@ -214,6 +267,15 @@ def get_task_type(self) -> str:


class EntityKnowledgeGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'BSTEntityKnowledgeTeacher',
Expand All @@ -230,6 +292,15 @@ def get_task_type(self) -> str:


class SearchDialogueGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
Taskmaster2Parser.add_cmdline_args(parser, partial_opt)
FitsBaseTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'MsMarcoSearchDialogueTeacher',
Expand All @@ -240,7 +311,7 @@ def get_teachers(self) -> List[str]:
'Taskmaster2SearchDialogueTeacher',
'Taskmaster3SearchDialogueTeacher',
'FitsSearchDialogueTeacher',
'FunpediaWithStyleSearchDialogueTeacher',
'FunpediaSearchDialogueTeacher',
]

def get_multitask_weights(self) -> Union[List[int], str]:
Expand All @@ -251,6 +322,15 @@ def get_task_type(self) -> str:


class EntityDialogueGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'BSTEntityDialogueTeacher',
Expand All @@ -274,6 +354,15 @@ def get_task_type(self) -> str:


class MemoryDialogueGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'BSTMemoryDialogueFromPersOverlapTeacher',
Expand All @@ -299,6 +388,18 @@ def get_task_type(self) -> str:


class VanillaDialogueGenerationTeacher(BB3MultitaskTeacher):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
NormalizedTeacher.add_cmdline_args(parser, partial_opt)
SessionBaseMscTeacher.add_cmdline_args(parser, partial_opt)
SaferDialoguesTeacher.add_cmdline_args(parser, partial_opt)
LightTeacher.add_cmdline_args(parser, partial_opt)
LightWildTeacher.add_cmdline_args(parser, partial_opt)
return parser

def get_teachers(self) -> List[str]:
return [
'WowVanillaDialogueTeacher',
Expand All @@ -317,7 +418,7 @@ def get_multitask_weights(self) -> Union[List[int], str]:
"""
Justification:

Split up Convai2 into half with personas, half without
All equal
"""
return [1] * len(self.get_teachers())

Expand Down
2 changes: 1 addition & 1 deletion projects/bb3/tasks/mutators.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def build_summarizer(opt: Opt) -> Agent:
Build the Persona Summarizer.
"""
return create_agent_from_model_file(
modelzoo_path(opt['datapath', 'zoo:bb3/persona_summarizer/model']),
modelzoo_path(opt['datapath'], 'zoo:bb3/persona_summarizer/model'),
opt_overrides={
'skip_generation': False,
'inference': 'beam',
Expand Down