Skip to content

Commit

Permalink
revert test and demo_agent loggers
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed May 27, 2024
1 parent e768dcd commit 88ce1af
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 20 deletions.
5 changes: 1 addition & 4 deletions core/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import pytest


logger = logging.getLogger(__name__)


# setup code, executed ahead of first test
@pytest.fixture(scope="session", autouse=True)
def setup_playwright(playwright: playwright.sync_api.Playwright):
# bugfix: re-use pytest-playwright's playwright instance in browsergym
# https://github.com/microsoft/playwright-python/issues/2053
browsergym.core._set_global_playwright(playwright)
logger.info("Browsergym is using the playwright instance provided by pytest-playwright.")
logging.info("Browsergym is using the playwright instance provided by pytest-playwright.")
7 changes: 2 additions & 5 deletions demo_agent/agents/legacy/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
)


logger = logging.getLogger(__name__)


@dataclass
class Flags:
use_html: bool = True
Expand Down Expand Up @@ -216,7 +213,7 @@ def fit_tokens(
return prompt
shrinkable.shrink()

logger.info(
logging.info(
dedent(
f"""\
After {max_iterations} shrink iterations, the prompt is still
Expand Down Expand Up @@ -377,7 +374,7 @@ def __init__(
self.instructions = ChatInstructions(obs_history[-1]["chat_messages"])
else:
if sum([msg["role"] == "user" for msg in obs_history[-1]["chat_messages"]]) > 1:
logger.warning(
logging.warning(
"Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
)
self.instructions = GoalInstructions(obs_history[-1]["goal"])
Expand Down
19 changes: 8 additions & 11 deletions demo_agent/agents/legacy/utils/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from transformers import GPT2TokenizerFast


logger = logging.getLogger(__name__)


@dataclass
class ChatModelArgs:
"""Serializable object for instantiating a generic chat model.
Expand Down Expand Up @@ -166,21 +163,21 @@ def __init__(

if max_new_tokens is None:
max_new_tokens = max_total_tokens - max_input_tokens
logger.warning(
logging.warning(
f"max_new_tokens is not specified. Setting it to {max_new_tokens} (max_total_tokens - max_input_tokens)."
)

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if isinstance(self.tokenizer, GPT2TokenizerFast):
# TODO: make this less hacky once tokenizer.apply_chat_template is more mature
logger.warning(
logging.warning(
f"No chat template is defined for {model_name}. Resolving to the hard-coded templates."
)
self.tokenizer = None
self.prompt_template = get_prompt_template(model_name)

if temperature < 1e-3:
logger.warning(
logging.warning(
"some weird things might happen when temperature is too low for some models."
)

Expand All @@ -189,17 +186,17 @@ def __init__(
}

if model_url is not None:
logger.info("Loading the LLM from a URL")
logging.info("Loading the LLM from a URL")
client = InferenceClient(model=model_url, token=eai_token)
self.llm = partial(
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
)
elif hf_hosted:
logger.info("Serving the LLM on HuggingFace Hub")
logging.info("Serving the LLM on HuggingFace Hub")
model_kwargs["max_length"] = max_new_tokens
self.llm = HuggingFaceHub(repo_id=model_name, model_kwargs=model_kwargs)
else:
logger.info("Loading the LLM locally")
logging.info("Loading the LLM locally")
pipe = pipeline(
task="text-generation",
model=model_name,
Expand All @@ -217,7 +214,7 @@ def _call(
**kwargs: Any,
) -> str:
if stop is not None or run_manager is not None or kwargs:
logger.warning(
logging.warning(
"The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation."
)

Expand All @@ -236,7 +233,7 @@ def _call(
except Exception as e:
if itr == self.n_retry_server - 1:
raise e
logger.warning(
logging.warning(
f"Failed to get a response from the server: \n{e}\n"
f"Retrying... ({itr+1}/{self.n_retry_server})"
)
Expand Down

0 comments on commit 88ce1af

Please sign in to comment.