diff --git a/parlai/scripts/interactive_web.py b/parlai/scripts/interactive_web.py index 62ab838b372..c6088f2e197 100644 --- a/parlai/scripts/interactive_web.py +++ b/parlai/scripts/interactive_web.py @@ -23,6 +23,7 @@ import parlai.utils.logging as logging import json +import time HOST_NAME = 'localhost' PORT = 8080 @@ -250,12 +251,26 @@ def setup_interweb_args(shared): return parser +def shutdown(): + global SHARED + if 'server' in SHARED: + SHARED['server'].shutdown() + SHARED.clear() + + +def wait(): + global SHARED + while not SHARED.get('ready'): + time.sleep(0.01) + + def interactive_web(opt): + global SHARED - SHARED['opt']['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent' + opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent' # Create model and assign it to the specified task - agent = create_agent(SHARED.get('opt'), requireModelExists=True) + agent = create_agent(opt, requireModelExists=True) agent.opt.log() SHARED['opt'] = agent.opt SHARED['agent'] = agent @@ -263,9 +278,11 @@ def interactive_web(opt): MyHandler.protocol_version = 'HTTP/1.0' httpd = HTTPServer((opt['host'], opt['port']), MyHandler) + SHARED['server'] = httpd logging.info('http://{}:{}/'.format(opt['host'], opt['port'])) try: + SHARED['ready'] = True httpd.serve_forever() except KeyboardInterrupt: pass diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 98ee49b3e45..09093670ccf 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -84,6 +84,46 @@ def _run_test_repeat(self, tmpdir: str, fake_input: FakeInput): self.assertEqual(len(entry), 2 * fake_input.max_turns) +class TestInteractiveWeb(unittest.TestCase): + def test_iweb(self): + import threading + import random + import requests + import json + import parlai.scripts.interactive_web as iweb + + port = random.randint(30000, 40000) + thread = threading.Thread( + target=iweb.InteractiveWeb.main, + kwargs={'model': 'repeat_query', 'port': port}, + daemon=True, + ) + thread.start() + iweb.wait() + + r = requests.get(f'http://localhost:{port}/') + assert '' in r.text + + r = requests.post(f'http://localhost:{port}/interact', data='This is a test') + assert r.status_code == 200 + response = json.loads(r.text) + assert 'text' in response + assert response['text'] == 'This is a test' + + r = requests.post(f'http://localhost:{port}/reset') + assert r.status_code == 200 + response = json.loads(r.text) + assert response == {} + + r = requests.get(f'http://localhost:{port}/bad') + assert r.status_code == 500 + + r = requests.post(f'http://localhost:{port}/bad') + assert r.status_code == 500 + + iweb.shutdown() + + class TestProfileInteractive(unittest.TestCase): def test_profile_interactive(self): from parlai.scripts.profile_interactive import ProfileInteractive