diff --git a/tiny_chat/utils/data_loader.py b/tiny_chat/utils/data_loader.py index 7ab56f5..100288b 100644 --- a/tiny_chat/utils/data_loader.py +++ b/tiny_chat/utils/data_loader.py @@ -9,20 +9,26 @@ class DataLoader: """a class to load data from hugging face""" - def __init__(self, use_official: bool = True): - if use_official: + def __init__(self): + self.hf_set = False + self.agent_profiles = None + self.env_profiles = None + self.relationship_profiles = None + + + def set_official_hf(self): + if not self.hf_set: self.hf_repo = "skyyyyks/tiny-chat" self.agent_profiles_dataset = "agent_profiles.jsonl" self.env_profiles_dataset = "environment_profiles.jsonl" self.relationship_profiles_dataset = "relationship_profiles.jsonl" - self.agent_profiles = None - self.env_profiles = None - self.relationship_profiles = None + self.hf_set = True def load_agent_profiles(self, use_local: bool = False, local_path: str = None): if not use_local: # Load the dataset from Hugging Face + self.set_official_hf() self.agent_profiles = load_dataset(self.hf_repo, data_files=self.agent_profiles_dataset) else: # Load the dataset from local file @@ -73,6 +79,7 @@ def get_all_agent_profiles(self, use_local: bool = False, local_path: str = None def load_env_profiles(self, use_local: bool = False, local_path: str = None): if not use_local: # Load the dataset from Hugging Face + self.set_official_hf() self.env_profiles = load_dataset(self.hf_repo, data_files=self.env_profiles_dataset) else: # Load the dataset from local file @@ -117,6 +124,7 @@ def get_all_env_profiles(self, use_local: bool = False, local_path: str = None) def load_relationship_profiles(self, use_local: bool = False, local_path: str = None): if not use_local: # Load the dataset from Hugging Face + self.set_official_hf() self.relationship_profiles = load_dataset(self.hf_repo, data_files=self.relationship_profiles_dataset) else: # Load the dataset from local file @@ -130,7 +138,7 @@ def load_relationship_profiles(self, use_local: bool = False, local_path: str = except FileNotFoundError: print(f"File not found: {local_path}") except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") + print(f"Error decoding JSON: {e}") def get_all_relationship_profiles(self, use_local: bool = False, local_path: str = None) -> list[BaseRelationshipProfile]: @@ -157,17 +165,17 @@ def get_all_relationship_profiles(self, use_local: bool = False, local_path: str return profiles - + def sample_agent_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]: - all_profiles = self.get_all_agent_profiles(use_local, local_path); + all_profiles = self.get_all_agent_profiles(use_local, local_path) return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles - def sample_env_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]: - all_profiles = self.get_all_env_profiles(use_local, local_path); + def sample_env_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseEnvironmentProfile]: + all_profiles = self.get_all_env_profiles(use_local, local_path) return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles - def sample_relationship_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]: - all_profiles = self.get_all_relationship_profiles(use_local, local_path); + def sample_relationship_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseRelationshipProfile]: + all_profiles = self.get_all_relationship_profiles(use_local, local_path) return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles diff --git a/tiny_chat/utils/sampler/__init__.py b/tiny_chat/utils/sampler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tiny_chat/utils/sampler/base_sampler.py b/tiny_chat/utils/sampler/base_sampler.py new file mode 100644 index 0000000..1bfccb0 --- /dev/null +++ b/tiny_chat/utils/sampler/base_sampler.py @@ -0,0 +1,27 @@ +from typing import Any, Generator, Generic, Sequence, Type, TypeVar +from tiny_chat.agents import BaseAgent +from tiny_chat.envs import TinyChatEnvironment +from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile + +ObsType = TypeVar('ObsType') +ActType = TypeVar('ActType') +EnvAgentCombo = tuple[TinyChatEnvironment, Sequence[BaseAgent[ObsType, ActType]]] + + +class BaseSampler(Generic[ObsType, ActType]): + def __init__(self, + agent_list: Sequence[BaseAgentProfile] | None = None, + env_list: Sequence[BaseEnvironmentProfile] | None = None,): + self.agent_list = agent_list + self.env_list = env_list + + + def sample(self, + agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]], + agent_num: int = 2, + replacement: bool = True, + size: int = 1, + env_params: dict[str, Any] = {}, + agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: + raise NotImplementedError + \ No newline at end of file diff --git a/tiny_chat/utils/sampler/constraint_sampler.py b/tiny_chat/utils/sampler/constraint_sampler.py new file mode 100644 index 0000000..fd0618e --- /dev/null +++ b/tiny_chat/utils/sampler/constraint_sampler.py @@ -0,0 +1,126 @@ +import ast +import random +from typing import Any, Generator, Type, TypeVar +from ..data_loader import DataLoader +from tiny_chat.agents import BaseAgent +from tiny_chat.envs import TinyChatEnvironment +from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile, BaseRelationshipProfile +from .base_sampler import BaseSampler, EnvAgentCombo + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + + +class ConstraintSampler(BaseSampler[ObsType, ActType]): + """A sampler that use the constraints in environment_profile""" + def sample(self, + agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]], + agent_num: int = 2, + replacement: bool = True, + size: int = 1, + env_params: dict[str, Any] = {}, + agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: + # check agent_classes + if not isinstance(agent_classes, list): + agent_classes = [agent_classes] * agent_num + elif len(agent_classes) != agent_num: + raise ValueError("Length of agent_classes must match agent_num") + + if len(agents_params) != agent_num: + raise ValueError("Length of agents_params must match agent_num") + + # only support 2 agents for now + if agent_num != 2: + raise NotImplementedError("Only support 2 agents for now") + + # Load profiles if not provided, use official dataset + data_loader = DataLoader() + if self.agent_list is None: + self.agent_list = data_loader.get_all_agent_profiles() + if self.env_list is None: + self.env_list = data_loader.get_all_env_profiles() + relationships = data_loader.get_all_relationship_profiles() + + if not replacement: + # pick one environment profile + env_profile = random.choice(self.env_list) + env = TinyChatEnvironment(**env_params) + + # find agents that fullfill the constraints in env_profile + sampled_agent_pairs = self.find_agent_pairs(env_profile, relationships) + random.shuffle(sampled_agent_pairs) + + for i in range(size): + if len(sampled_agent_pairs) < size: + raise ValueError("No agent pairs found that satisfy the constraints") + sampled_agent_profiles = sampled_agent_pairs[i] + + agents = [] + for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params): + agent = agent_class(agent_profile=agent_profile, **agent_params) + agents.append(agent) + + # set goal for each agent + for agent, goal in zip(agents, env_profile.agent_goals): + agent.goal = goal + + yield (env, agents) + + else: + for _ in range(size): + # pick one environment profile + env_profile = random.choice(self.env_list) + env = TinyChatEnvironment(**env_params) + + # find agents that fullfill the constraints in env_profile + sampled_agent_pairs = self.find_agent_pairs(env_profile, relationships) + if len(sampled_agent_pairs) == 0: + raise ValueError("No agent pairs found that satisfy the constraints") + sampled_agent_profiles = random.choice(sampled_agent_pairs) + + agents = [] + for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params): + agent = agent_class(agent_profile=agent_profile, **agent_params) + agents.append(agent) + + # set goal for each agent + for agent, goal in zip(agents, env_profile.agent_goals): + agent.goal = goal + + yield (env, agents) + + + def find_agent_pairs(self, + env: BaseEnvironmentProfile, + relationships: BaseRelationshipProfile) -> list[list[BaseAgentProfile]]: + sampled_agent_pairs = [] + for rel in relationships: + agent1 = self.get_agent(rel.agent_ids[0]) + agent2 = self.get_agent(rel.agent_ids[1]) + if agent1 is None or agent2 is None: + continue + # check the age constraint + if env.age_constraint and env.age_constraint != "[(18, 70), (18, 70)]": + age_contraint = ast.literal_eval(env.age_contraint) + if not (age_contraint[0][0] <= agent1.age <= age_contraint[0][1] and + age_contraint[1][0] <= agent2.age <= age_contraint[1][1]): + continue + # check the occupation constraint + if env.occupation_constraint and env.occupation_constraint != "nan" and env.occupation_constraint != "[[], []]": + occupation_constraint = ast.literal_eval(env.occupation_constraint) + if not (agent1.occupation.lower() in occupation_constraint[0] and + agent2.occupation.lower() in occupation_constraint[1]): + continue + # check agent constraint: not supported yet + + # add the pair + sampled_agent_pairs.append([agent1, agent2]) + + return sampled_agent_pairs + + + def get_agent(self, agent_id: str) -> BaseAgentProfile | None: + for agent in self.agent_list: + if agent.pk == agent_id: + return agent + return None \ No newline at end of file diff --git a/tiny_chat/utils/sampler/uniform_sampler.py b/tiny_chat/utils/sampler/uniform_sampler.py new file mode 100644 index 0000000..9c70416 --- /dev/null +++ b/tiny_chat/utils/sampler/uniform_sampler.py @@ -0,0 +1,55 @@ +import random +from typing import Any, Generator, Type, TypeVar +from ..data_loader import DataLoader +from tiny_chat.agents import BaseAgent +from tiny_chat.envs import TinyChatEnvironment +from .base_sampler import BaseSampler, EnvAgentCombo + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + + +class UniformSampler(BaseSampler[ObsType, ActType]): + def sample(self, + agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]], + agent_num: int = 2, + replacement: bool = True, + size: int = 1, + env_params: dict[str, Any] = {}, + agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]: + # check agent_classes + if not isinstance(agent_classes, list): + agent_classes = [agent_classes] * agent_num + elif len(agent_classes) != agent_num: + raise ValueError("Length of agent_classes must match agent_num") + + if len(agents_params) != agent_num: + raise ValueError("Length of agents_params must match agent_num") + + # Load profiles if not provided, use official dataset + data_loader = DataLoader() + if self.agent_list is None: + self.agent_list = data_loader.get_all_agent_profiles() + if self.env_list is None: + self.env_list = data_loader.get_all_env_profiles() + + # set up environment + for _ in range(size): + env_profile = random.choice(self.env_list) + env = TinyChatEnvironment(**env_params) + + # Sample agent profiles + if len(self.agent_list) < agent_num: + raise ValueError("Not enough agent profiles") + sampled_agent_profiles = random.sample(self.agent_list, agent_num) + + agents = [] + for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params): + agent = agent_class(agent_profile=agent_profile, **agent_params) + agents.append(agent) + + # set goal for each agent + for agent, goal in zip(agents, env_profile.agent_goals): + agent.goal = goal + + yield (env, agents) \ No newline at end of file