Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Flag to enable polarity categories #3132

Merged
merged 3 commits into from
Oct 1, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion parlai/tasks/image_chat/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"""
import json
import os
from typing import Tuple
import random
from typing import Tuple, Dict, List

from parlai.core.message import Message
from parlai.core.opt import Opt
Expand Down Expand Up @@ -210,11 +211,24 @@ def __init__(self, opt: Opt, shared: TShared = None):
self.idx_to_ep = shared['idx_to_ep']
self.prepend_personality = opt.get('prepend_personality', True)
self.include_dialogue_history = opt.get('include_dialogue_history', True)
self.category_frac = opt.get('category_frac', 0.0)
super().__init__(opt, shared)
self.num_eps = len(self.data) + len(
[d for d in self.data if len(d['dialog']) > 1]
)

# Replace personalities with polarity categories ("positive/neutral" or
# "negative"), with probability self.category_frac
if not shared:
category_map = get_category_map(self.personalities)
for i, d in enumerate(self.data):
use_category_rand = random.random()
if use_category_rand < self.category_frac:
self.data[i]['dialog'] = [
[category_map[personality], label]
for personality, label in d['dialog']
]

@staticmethod
def add_cmdline_args(argparser):
ImageChatTeacher.add_cmdline_args(argparser)
Expand All @@ -231,6 +245,12 @@ def add_cmdline_args(argparser):
default=True,
help='if false, remove the dialogue history',
)
agent.add_argument(
'--category-frac',
type=float,
default=0.0,
help='Fraction of the time to replace the personality with its polarity category ("positive/neutral" or "negative")',
)

def num_episodes(self) -> int:
return self.num_eps
Expand Down Expand Up @@ -358,3 +378,36 @@ def get(self, episode_idx, entry_idx=0):

class DefaultTeacher(ImageChatTeacher):
pass


def get_category_map(personalities: Dict[str, List[str]]) -> Dict[str, str]:
"""
Map personalities to polarity categories: "positive/neutral" and "negative".

Given a dictionary mapping Image-Chat categories (positive/neutral/negative) to
personalities, return a dictionary mapping each personality to its category.
Categories are merged into only two buckets: "positive/neutral", for personalities
that are more likely to be safe to use, and "negative". Add in rare personalities.
"""

category_map = {
personality: _get_final_category(category)
for category, personalities in personalities.items()
for personality in personalities
}
category_map['Crude'] = _get_final_category('negative')
category_map['Earnest'] = _get_final_category('positive')
# These personalities occasionally appear but are not in personalities
return category_map


def _get_final_category(category: str) -> str:
"""
Given the input raw category label, return the final one.
"""
if category in ['positive', 'neutral']:
return 'positive/neutral'
elif category == 'negative':
return 'negative'
else:
raise ValueError(f'Category "{category}" unrecognized!')