Skip to content

Commit

Permalink
Couple fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jondurbin committed Aug 12, 2023
1 parent 54a8c84 commit 883a3d1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
15 changes: 12 additions & 3 deletions airoboros/instructors/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,16 @@ def parse_response(response, current_name, user_name, names, action_delim):
if name == user_name:
name = "USER"
elif name not in names:
matches = get_close_matches(name, names)
matches = get_close_matches(name, list(set(names) | set(["USER"])))
if not matches:
name = random.choice(names)
else:
name = matches[0]
if name not in list(names) + ["USER"]:
if current_name.startswith("USER"):
name = random.choice(names)
else:
name = "USER"
return response, name


Expand Down Expand Up @@ -222,7 +229,7 @@ async def generate_first_message(
instructor, user_card, characters, topic, **api_params
):
"""Generate the first message for the chat."""
messages = {name: [] for name in list(characters) + ["USER"]}
messages = {name: [] for name in set(list(characters) + ["USER"])}
flesch = (
instructor.instructors.get("chat", {}).get("flesch")
or instructor.default_flesch
Expand Down Expand Up @@ -435,7 +442,9 @@ async def generate_chat(instructor, cards, topic, **api_params):
else:
training.append(
{
"role": "assistant" if current_name != "USER" else "user",
"role": "assistant"
if current_name not in ("USER", user_name)
else "user",
"content": f"{prefix}{response}"
if current_name != "USER"
else response,
Expand Down
13 changes: 9 additions & 4 deletions airoboros/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,17 @@ class SelfInstructor:
"default": "config.yaml",
"help": "path to the airobors configuration file",
},
"--debug": {
"action": "store_true",
"help": "enable debug logging",
},
}

def __init__(self, *, config_path: str = "config.yaml"):
def __init__(self, *, config_path: str = "config.yaml", debug: bool = False):
"""Constructor."""
if not debug:
logger.remove()
logger.add(sys.stdout, level="INFO")
self.used_tokens = 0
self.config_path = config_path
self.load_config()
Expand Down Expand Up @@ -144,7 +151,7 @@ def initialize_index(self):
for line in infile.readlines():
task = json.loads(line)
category = task.get("category", "general")
if category != "chat" or "chat" in category:
if category != "chat" or "chat" in task:
self.instructor_counts[category] += 1
if task["category"] != "chat":
docs.append(task["instruction"])
Expand Down Expand Up @@ -448,13 +455,11 @@ async def cull(self, input_paths: List[str], output_path: str) -> None:
:type output_path: str
"""
original = []
categories = defaultdict(list)
for path in input_paths:
with open(path) as infile:
for line in infile.readlines():
item = json.loads(line)
original.append(item)
category = item.get("category", "general")
if category == "reasoning_or_math":
category = "orca"
Expand Down

0 comments on commit 883a3d1

Please sign in to comment.