diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 3dd331650412..6ddf90164ba7 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -687,7 +687,6 @@ async def _inner_run(self): model = self.args.model_name_or_path + "@" + self.args.model_revision host = "http://localhost" if self.args.host == "localhost" else self.args.host - client = AsyncInferenceClient(f"{host}:{self.args.port}") args = self.args if args.examples_path is None: @@ -710,48 +709,47 @@ async def _inner_run(self): # Starts the session with a minimal help message at the top, so that a user doesn't get stuck interface.print_help(minimal=True) - while True: - try: - user_input = interface.input() - - # User commands - if user_input.startswith("!"): - # `!exit` is special, it breaks the loop - if user_input == "!exit": - break - else: - chat, valid_command, generation_config, model_kwargs = self.handle_non_exit_user_commands( - user_input=user_input, - args=args, - interface=interface, - examples=examples, - generation_config=generation_config, - model_kwargs=model_kwargs, - chat=chat, - ) - # `!example` sends a user message to the model - if not valid_command or not user_input.startswith("!example"): - continue - else: - chat.append({"role": "user", "content": user_input}) - - stream = client.chat_completion( - chat, - stream=True, - extra_body={ - "generation_config": generation_config.to_json_string(), - "model": model, - }, - ) - model_output = await interface.stream_output(stream) + async with AsyncInferenceClient(f"{host}:{self.args.port}") as client: + while True: + try: + user_input = interface.input() + + # User commands + if user_input.startswith("!"): + # `!exit` is special, it breaks the loop + if user_input == "!exit": + break + else: + chat, valid_command, generation_config, model_kwargs = self.handle_non_exit_user_commands( + user_input=user_input, + args=args, + interface=interface, + examples=examples, + generation_config=generation_config, + model_kwargs=model_kwargs, + chat=chat, + ) + # `!example` sends a user message to the model + if not valid_command or not user_input.startswith("!example"): + continue + else: + chat.append({"role": "user", "content": user_input}) + + stream = client.chat_completion( + chat, + stream=True, + extra_body={ + "generation_config": generation_config.to_json_string(), + "model": model, + }, + ) - chat.append({"role": "assistant", "content": model_output}) + model_output = await interface.stream_output(stream) - except KeyboardInterrupt: - break - finally: - await client.close() + chat.append({"role": "assistant", "content": model_output}) + except KeyboardInterrupt: + break if __name__ == "__main__":