From 7d79a5c2bdc8427a292022a4b1c8efb983a89d25 Mon Sep 17 00:00:00 2001 From: X <77029666+William-f-12@users.noreply.github.com> Date: Wed, 27 Aug 2025 19:22:08 -0500 Subject: [PATCH 1/5] update DataLoader to set huggingface dynamically --- tiny_chat/utils/data_loader.py | 20 ++++++++++++++------ tiny_chat/utils/sampler.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) create mode 100644 tiny_chat/utils/sampler.py diff --git a/tiny_chat/utils/data_loader.py b/tiny_chat/utils/data_loader.py index 7ab56f5..930e168 100644 --- a/tiny_chat/utils/data_loader.py +++ b/tiny_chat/utils/data_loader.py @@ -9,20 +9,26 @@ class DataLoader: """a class to load data from hugging face""" - def __init__(self, use_official: bool = True): - if use_official: + def __init__(self): + self.hf_set = False + self.agent_profiles = None + self.env_profiles = None + self.relationship_profiles = None + + + def set_official_hf(self): + if not self.hf_set: self.hf_repo = "skyyyyks/tiny-chat" self.agent_profiles_dataset = "agent_profiles.jsonl" self.env_profiles_dataset = "environment_profiles.jsonl" self.relationship_profiles_dataset = "relationship_profiles.jsonl" - self.agent_profiles = None - self.env_profiles = None - self.relationship_profiles = None + self.hf_set = True def load_agent_profiles(self, use_local: bool = False, local_path: str = None): if not use_local: # Load the dataset from Hugging Face + self.set_official_hf() self.agent_profiles = load_dataset(self.hf_repo, data_files=self.agent_profiles_dataset) else: # Load the dataset from local file @@ -73,6 +79,7 @@ def get_all_agent_profiles(self, use_local: bool = False, local_path: str = None def load_env_profiles(self, use_local: bool = False, local_path: str = None): if not use_local: # Load the dataset from Hugging Face + self.set_official_hf() self.env_profiles = load_dataset(self.hf_repo, data_files=self.env_profiles_dataset) else: # Load the dataset from local file @@ -117,6 +124,7 @@ def get_all_env_profiles(self, use_local: bool = False, local_path: str = None) def load_relationship_profiles(self, use_local: bool = False, local_path: str = None): if not use_local: # Load the dataset from Hugging Face + self.set_official_hf() self.relationship_profiles = load_dataset(self.hf_repo, data_files=self.relationship_profiles_dataset) else: # Load the dataset from local file @@ -130,7 +138,7 @@ def load_relationship_profiles(self, use_local: bool = False, local_path: str = except FileNotFoundError: print(f"File not found: {local_path}") except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") + print(f"Error decoding JSON: {e}") def get_all_relationship_profiles(self, use_local: bool = False, local_path: str = None) -> list[BaseRelationshipProfile]: diff --git a/tiny_chat/utils/sampler.py b/tiny_chat/utils/sampler.py new file mode 100644 index 0000000..d9418b7 --- /dev/null +++ b/tiny_chat/utils/sampler.py @@ -0,0 +1,14 @@ +import random +from .data_loader import DataLoader +from tiny_chat.profiles.agent_profile import BaseAgentProfile +from tiny_chat.profiles.enviroment_profile import BaseEnvironmentProfile +from tiny_chat.profiles.relationship_profile import BaseRelationshipProfile + + +class Sampler: + def __init__(self): + pass + + + def sample_agent_random(self): + pass From 1dc80c5eab1d1078e9b202e216ad35802e16d8e2 Mon Sep 17 00:00:00 2001 From: X <77029666+William-f-12@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:41:47 -0500 Subject: [PATCH 2/5] finish base sampler --- tiny_chat/utils/data_loader.py | 12 +++++------ tiny_chat/utils/sampler.py | 14 ------------- tiny_chat/utils/sampler/__init__.py | 0 tiny_chat/utils/sampler/base_sampler.py | 28 +++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 20 deletions(-) delete mode 100644 tiny_chat/utils/sampler.py create mode 100644 tiny_chat/utils/sampler/__init__.py create mode 100644 tiny_chat/utils/sampler/base_sampler.py diff --git a/tiny_chat/utils/data_loader.py b/tiny_chat/utils/data_loader.py index 930e168..100288b 100644 --- a/tiny_chat/utils/data_loader.py +++ b/tiny_chat/utils/data_loader.py @@ -165,17 +165,17 @@ def get_all_relationship_profiles(self, use_local: bool = False, local_path: str return profiles - + def sample_agent_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]: - all_profiles = self.get_all_agent_profiles(use_local, local_path); + all_profiles = self.get_all_agent_profiles(use_local, local_path) return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles - def sample_env_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]: - all_profiles = self.get_all_env_profiles(use_local, local_path); + def sample_env_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseEnvironmentProfile]: + all_profiles = self.get_all_env_profiles(use_local, local_path) return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles - def sample_relationship_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]: - all_profiles = self.get_all_relationship_profiles(use_local, local_path); + def sample_relationship_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseRelationshipProfile]: + all_profiles = self.get_all_relationship_profiles(use_local, local_path) return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles diff --git a/tiny_chat/utils/sampler.py b/tiny_chat/utils/sampler.py deleted file mode 100644 index d9418b7..0000000 --- a/tiny_chat/utils/sampler.py +++ /dev/null @@ -1,14 +0,0 @@ -import random -from .data_loader import DataLoader -from tiny_chat.profiles.agent_profile import BaseAgentProfile -from tiny_chat.profiles.enviroment_profile import BaseEnvironmentProfile -from tiny_chat.profiles.relationship_profile import BaseRelationshipProfile - - -class Sampler: - def __init__(self): - pass - - - def sample_agent_random(self): - pass diff --git a/tiny_chat/utils/sampler/__init__.py b/tiny_chat/utils/sampler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tiny_chat/utils/sampler/base_sampler.py b/tiny_chat/utils/sampler/base_sampler.py new file mode 100644 index 0000000..e98eaf8 --- /dev/null +++ b/tiny_chat/utils/sampler/base_sampler.py @@ -0,0 +1,28 @@ +import random +from typing import Any, Generator, Generic, Sequence, Type, TypeVar +from ..data_loader import DataLoader +from tiny_chat.agents import BaseAgent +from tiny_chat.envs import TinyChatEnvironment +from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile, BaseRelationshipProfile + +ObsType = TypeVar('ObsType') +ActType = TypeVar('ActType') +EnvAgentCombo = tuple[TinyChatEnvironment, Sequence[BaseAgent[ObsType, ActType]]] + + +class BaseSampler: + def __init__(self, + agent_list: Sequence[BaseAgentProfile | str] | None = None, + env_list: Sequence[BaseEnvironmentProfile | str] | None = None,): + self.agent_list = agent_list + self.env_list = env_list + + + def sample(self, + agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]], + agent_num: int = 2, + replacement: bool = True, + size: int = 1, + env_params: dict[str, Any] = {}, + agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: + raise NotImplementedError \ No newline at end of file From cd3e52d9cdf6614b0016bd7ec31f3b7a753b032a Mon Sep 17 00:00:00 2001 From: X <77029666+William-f-12@users.noreply.github.com> Date: Thu, 28 Aug 2025 22:12:19 -0500 Subject: [PATCH 3/5] finish uniform_sampler? --- tiny_chat/utils/sampler/base_sampler.py | 13 ++--- tiny_chat/utils/sampler/constraint_sampler.py | 0 tiny_chat/utils/sampler/uniform_sampler.py | 56 +++++++++++++++++++ 3 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 tiny_chat/utils/sampler/constraint_sampler.py create mode 100644 tiny_chat/utils/sampler/uniform_sampler.py diff --git a/tiny_chat/utils/sampler/base_sampler.py b/tiny_chat/utils/sampler/base_sampler.py index e98eaf8..1bfccb0 100644 --- a/tiny_chat/utils/sampler/base_sampler.py +++ b/tiny_chat/utils/sampler/base_sampler.py @@ -1,19 +1,17 @@ -import random from typing import Any, Generator, Generic, Sequence, Type, TypeVar -from ..data_loader import DataLoader from tiny_chat.agents import BaseAgent from tiny_chat.envs import TinyChatEnvironment -from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile, BaseRelationshipProfile +from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile ObsType = TypeVar('ObsType') ActType = TypeVar('ActType') EnvAgentCombo = tuple[TinyChatEnvironment, Sequence[BaseAgent[ObsType, ActType]]] -class BaseSampler: +class BaseSampler(Generic[ObsType, ActType]): def __init__(self, - agent_list: Sequence[BaseAgentProfile | str] | None = None, - env_list: Sequence[BaseEnvironmentProfile | str] | None = None,): + agent_list: Sequence[BaseAgentProfile] | None = None, + env_list: Sequence[BaseEnvironmentProfile] | None = None,): self.agent_list = agent_list self.env_list = env_list @@ -25,4 +23,5 @@ def sample(self, size: int = 1, env_params: dict[str, Any] = {}, agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + \ No newline at end of file diff --git a/tiny_chat/utils/sampler/constraint_sampler.py b/tiny_chat/utils/sampler/constraint_sampler.py new file mode 100644 index 0000000..e69de29 diff --git a/tiny_chat/utils/sampler/uniform_sampler.py b/tiny_chat/utils/sampler/uniform_sampler.py new file mode 100644 index 0000000..abc2abd --- /dev/null +++ b/tiny_chat/utils/sampler/uniform_sampler.py @@ -0,0 +1,56 @@ +import random +from typing import Any, Generator, Type, TypeVar +from ..data_loader import DataLoader +from tiny_chat.agents import BaseAgent +from tiny_chat.envs import TinyChatEnvironment +from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile +from .base_sampler import BaseSampler, EnvAgentCombo + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + + +class UniformSampler(BaseSampler[ObsType, ActType]): + def sample(self, + agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]], + agent_num: int = 2, + replacement: bool = True, + size: int = 1, + env_params: dict[str, Any] = {}, + agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: + # check agent_classes + if not isinstance(agent_classes, list): + agent_classes = [agent_classes] * agent_num + elif len(agent_classes) != agent_num: + raise ValueError("Length of agent_classes must match agent_num") + + if len(agents_params) != agent_num: + raise ValueError("Length of agents_params must match agent_num") + + # Load profiles if not provided, use official dataset + data_loader = DataLoader() + if self.agent_list is None: + self.agent_list = data_loader.get_all_agent_profiles() + if self.env_list is None: + self.env_list = data_loader.get_all_env_profiles() + + # set up environment + for _ in range(size): + env_profile = random.choice(self.env_list) + env = TinyChatEnvironment(**env_params) + + # Sample agent profiles + if len(self.agent_list) < agent_num: + raise ValueError("Not enough agent profiles") + sampled_agent_profiles = random.sample(self.agent_list, agent_num) + + agents = [] + for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params): + agent = agent_class(agent_profile=agent_profile, **agent_params) + agents.append(agent) + + # set goal for each agent + for agent, goal in zip(agents, env_profile.agent_goals): + agent.goal = goal + + yield (env, agents) \ No newline at end of file From a8067dcb948c4511a7c77f8227cda3dc88cf0987 Mon Sep 17 00:00:00 2001 From: X <77029666+William-f-12@users.noreply.github.com> Date: Thu, 28 Aug 2025 22:50:13 -0500 Subject: [PATCH 4/5] update uniform_sampler --- tiny_chat/utils/sampler/uniform_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tiny_chat/utils/sampler/uniform_sampler.py b/tiny_chat/utils/sampler/uniform_sampler.py index abc2abd..9c70416 100644 --- a/tiny_chat/utils/sampler/uniform_sampler.py +++ b/tiny_chat/utils/sampler/uniform_sampler.py @@ -3,7 +3,6 @@ from ..data_loader import DataLoader from tiny_chat.agents import BaseAgent from tiny_chat.envs import TinyChatEnvironment -from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile from .base_sampler import BaseSampler, EnvAgentCombo ObsType = TypeVar("ObsType") From 1ab7f26201e1e21ade97082a9db384b828b4245f Mon Sep 17 00:00:00 2001 From: X <77029666+William-f-12@users.noreply.github.com> Date: Sat, 30 Aug 2025 23:50:01 -0500 Subject: [PATCH 5/5] 1 draft of constraint sampler --- tiny_chat/utils/sampler/constraint_sampler.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/tiny_chat/utils/sampler/constraint_sampler.py b/tiny_chat/utils/sampler/constraint_sampler.py index e69de29..fd0618e 100644 --- a/tiny_chat/utils/sampler/constraint_sampler.py +++ b/tiny_chat/utils/sampler/constraint_sampler.py @@ -0,0 +1,126 @@ +import ast +import random +from typing import Any, Generator, Type, TypeVar +from ..data_loader import DataLoader +from tiny_chat.agents import BaseAgent +from tiny_chat.envs import TinyChatEnvironment +from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile, BaseRelationshipProfile +from .base_sampler import BaseSampler, EnvAgentCombo + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + + +class ConstraintSampler(BaseSampler[ObsType, ActType]): + """A sampler that use the constraints in environment_profile""" + def sample(self, + agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]], + agent_num: int = 2, + replacement: bool = True, + size: int = 1, + env_params: dict[str, Any] = {}, + agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: + # check agent_classes + if not isinstance(agent_classes, list): + agent_classes = [agent_classes] * agent_num + elif len(agent_classes) != agent_num: + raise ValueError("Length of agent_classes must match agent_num") + + if len(agents_params) != agent_num: + raise ValueError("Length of agents_params must match agent_num") + + # only support 2 agents for now + if agent_num != 2: + raise NotImplementedError("Only support 2 agents for now") + + # Load profiles if not provided, use official dataset + data_loader = DataLoader() + if self.agent_list is None: + self.agent_list = data_loader.get_all_agent_profiles() + if self.env_list is None: + self.env_list = data_loader.get_all_env_profiles() + relationships = data_loader.get_all_relationship_profiles() + + if not replacement: + # pick one environment profile + env_profile = random.choice(self.env_list) + env = TinyChatEnvironment(**env_params) + + # find agents that fullfill the constraints in env_profile + sampled_agent_pairs = self.find_agent_pairs(env_profile, relationships) + random.shuffle(sampled_agent_pairs) + + for i in range(size): + if len(sampled_agent_pairs) < size: + raise ValueError("No agent pairs found that satisfy the constraints") + sampled_agent_profiles = sampled_agent_pairs[i] + + agents = [] + for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params): + agent = agent_class(agent_profile=agent_profile, **agent_params) + agents.append(agent) + + # set goal for each agent + for agent, goal in zip(agents, env_profile.agent_goals): + agent.goal = goal + + yield (env, agents) + + else: + for _ in range(size): + # pick one environment profile + env_profile = random.choice(self.env_list) + env = TinyChatEnvironment(**env_params) + + # find agents that fullfill the constraints in env_profile + sampled_agent_pairs = self.find_agent_pairs(env_profile, relationships) + if len(sampled_agent_pairs) == 0: + raise ValueError("No agent pairs found that satisfy the constraints") + sampled_agent_profiles = random.choice(sampled_agent_pairs) + + agents = [] + for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params): + agent = agent_class(agent_profile=agent_profile, **agent_params) + agents.append(agent) + + # set goal for each agent + for agent, goal in zip(agents, env_profile.agent_goals): + agent.goal = goal + + yield (env, agents) + + + def find_agent_pairs(self, + env: BaseEnvironmentProfile, + relationships: BaseRelationshipProfile) -> list[list[BaseAgentProfile]]: + sampled_agent_pairs = [] + for rel in relationships: + agent1 = self.get_agent(rel.agent_ids[0]) + agent2 = self.get_agent(rel.agent_ids[1]) + if agent1 is None or agent2 is None: + continue + # check the age constraint + if env.age_constraint and env.age_constraint != "[(18, 70), (18, 70)]": + age_contraint = ast.literal_eval(env.age_contraint) + if not (age_contraint[0][0] <= agent1.age <= age_contraint[0][1] and + age_contraint[1][0] <= agent2.age <= age_contraint[1][1]): + continue + # check the occupation constraint + if env.occupation_constraint and env.occupation_constraint != "nan" and env.occupation_constraint != "[[], []]": + occupation_constraint = ast.literal_eval(env.occupation_constraint) + if not (agent1.occupation.lower() in occupation_constraint[0] and + agent2.occupation.lower() in occupation_constraint[1]): + continue + # check agent constraint: not supported yet + + # add the pair + sampled_agent_pairs.append([agent1, agent2]) + + return sampled_agent_pairs + + + def get_agent(self, agent_id: str) -> BaseAgentProfile | None: + for agent in self.agent_list: + if agent.pk == agent_id: + return agent + return None \ No newline at end of file