Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Diana Rico committed Jul 15, 2020
1 parent 02c3f07 commit 1af5b9e
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions parlai/scripts/self_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
Allows a model to self-chat on a given task.
"""
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
from parlai.core.agents import create_agent, create_agent_from_model_file
from parlai.core.worlds import create_task
from parlai.utils.world_logging import WorldLogger
from parlai.utils.misc import TimeLogger
from parlai.scripts.script import ParlaiScript
import parlai.utils.logging as logging

import math
import json
import random


Expand Down Expand Up @@ -60,6 +61,17 @@ def setup_args(parser=None):
choices=['conversations', 'parlai'],
help='Format to save logs in. conversations is a jsonl format, parlai is a text format.',
)
parser.add_argument(
'-pmf',
'--partner-model-file',
default=None,
help='Define a different partner for self chat',
)
parser.add_argument(
'--partner-opt-file',
default=None,
help='Path to file containing opts to override for partner',
)
parser.set_defaults(interactive_mode=True, task='self_chat')
WorldLogger.add_cmdline_args(parser)
return parser
Expand All @@ -86,15 +98,32 @@ def _run_self_chat_episode(opt, world, world_logger):

def self_chat(opt):
random.seed(opt['seed'])
partner = opt['partner_model_file']
partner_opt_file = opt['partner_opt_file']

# Create agents
agent1 = create_agent(opt, requireModelExists=True)
agent2 = agent1.clone()
if not partner:
# Self chat with same model
agent2 = agent1.clone()
else:
# Self chat with different models
if partner_opt_file:
print(f"WARNING: Loading override opts from: {partner_opt_file}")
with open(partner_opt_file) as f:
partner_opt = json.load(f)

print(
f"WARNING: Setting partner interactive mode to: {opt['interactive_mode']}"
)
partner_opt['interactive_mode'] = opt['interactive_mode']
agent2 = create_agent_from_model_file(partner, partner_opt)

# Set IDs
model_id = agent1.id
agent1.id = model_id + "_1"
agent2.id = model_id + "_2"
agent1.id = agent1.id + "_1"
agent2.id = agent2.id + "_2"

model_id = agent1.id + "_" + agent2.id

world = create_task(opt, user_agents=[agent1, agent2])

Expand Down

0 comments on commit 1af5b9e

Please sign in to comment.