Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions tiny_chat/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@
class DataLoader:
"""a class to load data from hugging face"""

def __init__(self, use_official: bool = True):
if use_official:
def __init__(self):
self.hf_set = False
self.agent_profiles = None
self.env_profiles = None
self.relationship_profiles = None


def set_official_hf(self):
if not self.hf_set:
self.hf_repo = "skyyyyks/tiny-chat"
self.agent_profiles_dataset = "agent_profiles.jsonl"
self.env_profiles_dataset = "environment_profiles.jsonl"
self.relationship_profiles_dataset = "relationship_profiles.jsonl"
self.agent_profiles = None
self.env_profiles = None
self.relationship_profiles = None
self.hf_set = True


def load_agent_profiles(self, use_local: bool = False, local_path: str = None):
if not use_local:
# Load the dataset from Hugging Face
self.set_official_hf()
self.agent_profiles = load_dataset(self.hf_repo, data_files=self.agent_profiles_dataset)
else:
# Load the dataset from local file
Expand Down Expand Up @@ -73,6 +79,7 @@ def get_all_agent_profiles(self, use_local: bool = False, local_path: str = None
def load_env_profiles(self, use_local: bool = False, local_path: str = None):
if not use_local:
# Load the dataset from Hugging Face
self.set_official_hf()
self.env_profiles = load_dataset(self.hf_repo, data_files=self.env_profiles_dataset)
else:
# Load the dataset from local file
Expand Down Expand Up @@ -117,6 +124,7 @@ def get_all_env_profiles(self, use_local: bool = False, local_path: str = None)
def load_relationship_profiles(self, use_local: bool = False, local_path: str = None):
if not use_local:
# Load the dataset from Hugging Face
self.set_official_hf()
self.relationship_profiles = load_dataset(self.hf_repo, data_files=self.relationship_profiles_dataset)
else:
# Load the dataset from local file
Expand All @@ -130,7 +138,7 @@ def load_relationship_profiles(self, use_local: bool = False, local_path: str =
except FileNotFoundError:
print(f"File not found: {local_path}")
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
print(f"Error decoding JSON: {e}")


def get_all_relationship_profiles(self, use_local: bool = False, local_path: str = None) -> list[BaseRelationshipProfile]:
Expand All @@ -157,17 +165,17 @@ def get_all_relationship_profiles(self, use_local: bool = False, local_path: str

return profiles


def sample_agent_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]:
all_profiles = self.get_all_agent_profiles(use_local, local_path);
all_profiles = self.get_all_agent_profiles(use_local, local_path)
return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles


def sample_env_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]:
all_profiles = self.get_all_env_profiles(use_local, local_path);
def sample_env_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseEnvironmentProfile]:
all_profiles = self.get_all_env_profiles(use_local, local_path)
return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles


def sample_relationship_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseAgentProfile]:
all_profiles = self.get_all_relationship_profiles(use_local, local_path);
def sample_relationship_random(self, n: int, use_local: bool = False, local_path: str = None) -> list[BaseRelationshipProfile]:
all_profiles = self.get_all_relationship_profiles(use_local, local_path)
return random.sample(all_profiles, n) if n <= len(all_profiles) else all_profiles
Empty file.
27 changes: 27 additions & 0 deletions tiny_chat/utils/sampler/base_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any, Generator, Generic, Sequence, Type, TypeVar
from tiny_chat.agents import BaseAgent
from tiny_chat.envs import TinyChatEnvironment
from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile

ObsType = TypeVar('ObsType')
ActType = TypeVar('ActType')
EnvAgentCombo = tuple[TinyChatEnvironment, Sequence[BaseAgent[ObsType, ActType]]]


class BaseSampler(Generic[ObsType, ActType]):
def __init__(self,
agent_list: Sequence[BaseAgentProfile] | None = None,
env_list: Sequence[BaseEnvironmentProfile] | None = None,):
self.agent_list = agent_list
self.env_list = env_list


def sample(self,
agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]],
agent_num: int = 2,
replacement: bool = True,
size: int = 1,
env_params: dict[str, Any] = {},
agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]:
raise NotImplementedError

126 changes: 126 additions & 0 deletions tiny_chat/utils/sampler/constraint_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import ast
import random
from typing import Any, Generator, Type, TypeVar
from ..data_loader import DataLoader
from tiny_chat.agents import BaseAgent
from tiny_chat.envs import TinyChatEnvironment
from tiny_chat.profiles import BaseAgentProfile, BaseEnvironmentProfile, BaseRelationshipProfile
from .base_sampler import BaseSampler, EnvAgentCombo

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")


class ConstraintSampler(BaseSampler[ObsType, ActType]):
"""A sampler that use the constraints in environment_profile"""
def sample(self,
agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]],
agent_num: int = 2,
replacement: bool = True,
size: int = 1,
env_params: dict[str, Any] = {},
agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]:
# check agent_classes
if not isinstance(agent_classes, list):
agent_classes = [agent_classes] * agent_num
elif len(agent_classes) != agent_num:
raise ValueError("Length of agent_classes must match agent_num")

if len(agents_params) != agent_num:
raise ValueError("Length of agents_params must match agent_num")

# only support 2 agents for now
if agent_num != 2:
raise NotImplementedError("Only support 2 agents for now")

# Load profiles if not provided, use official dataset
data_loader = DataLoader()
if self.agent_list is None:
self.agent_list = data_loader.get_all_agent_profiles()
if self.env_list is None:
self.env_list = data_loader.get_all_env_profiles()
relationships = data_loader.get_all_relationship_profiles()

if not replacement:
# pick one environment profile
env_profile = random.choice(self.env_list)
env = TinyChatEnvironment(**env_params)

# find agents that fullfill the constraints in env_profile
sampled_agent_pairs = self.find_agent_pairs(env_profile, relationships)
random.shuffle(sampled_agent_pairs)

for i in range(size):
if len(sampled_agent_pairs) < size:
raise ValueError("No agent pairs found that satisfy the constraints")
sampled_agent_profiles = sampled_agent_pairs[i]

agents = []
for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params):
agent = agent_class(agent_profile=agent_profile, **agent_params)
agents.append(agent)

# set goal for each agent
for agent, goal in zip(agents, env_profile.agent_goals):
agent.goal = goal

yield (env, agents)

else:
for _ in range(size):
# pick one environment profile
env_profile = random.choice(self.env_list)
env = TinyChatEnvironment(**env_params)

# find agents that fullfill the constraints in env_profile
sampled_agent_pairs = self.find_agent_pairs(env_profile, relationships)
if len(sampled_agent_pairs) == 0:
raise ValueError("No agent pairs found that satisfy the constraints")
sampled_agent_profiles = random.choice(sampled_agent_pairs)

agents = []
for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params):
agent = agent_class(agent_profile=agent_profile, **agent_params)
agents.append(agent)

# set goal for each agent
for agent, goal in zip(agents, env_profile.agent_goals):
agent.goal = goal

yield (env, agents)


def find_agent_pairs(self,
env: BaseEnvironmentProfile,
relationships: BaseRelationshipProfile) -> list[list[BaseAgentProfile]]:
sampled_agent_pairs = []
for rel in relationships:
agent1 = self.get_agent(rel.agent_ids[0])
agent2 = self.get_agent(rel.agent_ids[1])
if agent1 is None or agent2 is None:
continue
# check the age constraint
if env.age_constraint and env.age_constraint != "[(18, 70), (18, 70)]":
age_contraint = ast.literal_eval(env.age_contraint)
if not (age_contraint[0][0] <= agent1.age <= age_contraint[0][1] and
age_contraint[1][0] <= agent2.age <= age_contraint[1][1]):
continue
# check the occupation constraint
if env.occupation_constraint and env.occupation_constraint != "nan" and env.occupation_constraint != "[[], []]":
occupation_constraint = ast.literal_eval(env.occupation_constraint)
if not (agent1.occupation.lower() in occupation_constraint[0] and
agent2.occupation.lower() in occupation_constraint[1]):
continue
# check agent constraint: not supported yet

# add the pair
sampled_agent_pairs.append([agent1, agent2])

return sampled_agent_pairs


def get_agent(self, agent_id: str) -> BaseAgentProfile | None:
for agent in self.agent_list:
if agent.pk == agent_id:
return agent
return None
55 changes: 55 additions & 0 deletions tiny_chat/utils/sampler/uniform_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import random
from typing import Any, Generator, Type, TypeVar
from ..data_loader import DataLoader
from tiny_chat.agents import BaseAgent
from tiny_chat.envs import TinyChatEnvironment
from .base_sampler import BaseSampler, EnvAgentCombo

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")


class UniformSampler(BaseSampler[ObsType, ActType]):
def sample(self,
agent_classes: Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]],
agent_num: int = 2,
replacement: bool = True,
size: int = 1,
env_params: dict[str, Any] = {},
agents_params: list[dict[str, Any]] = [{}, {}]) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]:
# check agent_classes
if not isinstance(agent_classes, list):
agent_classes = [agent_classes] * agent_num
elif len(agent_classes) != agent_num:
raise ValueError("Length of agent_classes must match agent_num")

if len(agents_params) != agent_num:
raise ValueError("Length of agents_params must match agent_num")

# Load profiles if not provided, use official dataset
data_loader = DataLoader()
if self.agent_list is None:
self.agent_list = data_loader.get_all_agent_profiles()
if self.env_list is None:
self.env_list = data_loader.get_all_env_profiles()

# set up environment
for _ in range(size):
env_profile = random.choice(self.env_list)
env = TinyChatEnvironment(**env_params)

# Sample agent profiles
if len(self.agent_list) < agent_num:
raise ValueError("Not enough agent profiles")
sampled_agent_profiles = random.sample(self.agent_list, agent_num)

agents = []
for agent_class, agent_profile, agent_params in zip(agent_classes, sampled_agent_profiles, agents_params):
agent = agent_class(agent_profile=agent_profile, **agent_params)
agents.append(agent)

# set goal for each agent
for agent, goal in zip(agents, env_profile.agent_goals):
agent.goal = goal

yield (env, agents)