Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log Context source as source in wandb #539

Merged
merged 5 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion prompting/datasets/huggingface_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class HuggingFaceGithubDatasetEntry(DatasetEntry):
github_url: str
file_path: str
file_content: str
source: str | None = None


class HuggingFaceGithubDataset(BaseDataset):
Expand All @@ -46,8 +47,9 @@ def _filter_function(self, example):

def _process_entry(self, entry: dict) -> HuggingFaceGithubDatasetEntry:
file_content = "\n".join(entry["content"].split("\n")[:MAX_LINES])
url = f"https://github.com/{entry['repo_name']}"
return HuggingFaceGithubDatasetEntry(
github_url=f"https://github.com/{entry['repo_name']}", file_path=entry["path"], file_content=file_content
github_url=url, file_path=entry["path"], file_content=file_content, source=url
)

def get(self) -> HuggingFaceGithubDatasetEntry:
Expand Down
5 changes: 4 additions & 1 deletion prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DDGDatasetEntry(DatasetEntry):
website_url: str = None
website_content: str = None
query: str | None = None
source: str | None = None


class DDGDataset(BaseDataset):
Expand Down Expand Up @@ -55,7 +56,9 @@ def next(self) -> Optional[DDGDatasetEntry]:
logger.debug(f"Failed to extract content from website {website_url}")
return None

return DDGDatasetEntry(search_term=search_term, website_url=website_url, website_content=website_content)
return DDGDatasetEntry(
search_term=search_term, website_url=website_url, website_content=website_content, source=website_url
)

def get(self) -> Optional[DDGDatasetEntry]:
return self.next()
Expand Down
121 changes: 2 additions & 119 deletions prompting/datasets/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
import re
import sys
from functools import lru_cache
from queue import Empty, Full, Queue
from typing import ClassVar, Optional
from typing import ClassVar

import requests
import wikipedia
from bs4 import BeautifulSoup
from loguru import logger
from pydantic import ConfigDict, model_validator

from shared.base import BaseDataset, Context

# Create a queue called CACHED_ARTICLES to store wikipedia articles that have been fetched
CACHED_ARTICLES: Queue[Context] = Queue(maxsize=300)


# speed up page loading
@lru_cache(maxsize=1000)
Expand Down Expand Up @@ -183,17 +178,13 @@ def get(
internal_links=list(filter(lambda x: x not in exclude, page.sections)),
external_links=most_relevant_links(page, num_links=self.max_links),
tags=filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES),
source="Wikipedia",
source=page.url,
extra={
"url": page.url,
"page_length": len(page.content.split()),
"section_length": section_length,
},
)
try:
CACHED_ARTICLES.put(context, block=False)
except Full:
logger.debug("Cache is full. Skipping article until cache is emptied.")
return context

def search(self, name, results=3) -> Context:
Expand All @@ -207,111 +198,3 @@ def random(self, pages=10) -> dict:
if context := self.get(title):
return context
return None


class DateContext(Context):
date: str = None

@classmethod
def from_context(cls, context: Context, date: str) -> "DateContext":
return cls(
**context.model_dump(),
date=date,
)


class WikiDateDataset(BaseDataset):
name: ClassVar[str] = "wikipedia_date"
INCLUDE_HEADERS: tuple = ("Events", "Births", "Deaths")
MONTHS: tuple = (
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
)
EXCLUDE_CATEGORIES: tuple = ("articles", "wikipedia", "pages", "cs1")
seed: int | None = None
rng: Optional[random.Random] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="after")
def create_rng(self) -> "WikiDateDataset":
self.rng = random.Random(self.seed)
return self

def _extract_dates_and_sentences(self, text: str) -> tuple[str, str]:
# Regular expression to find dates in various formats
date_pattern = r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b"

# Compile the regex pattern
date_regex = re.compile(date_pattern)

# Split text into sentences
sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)

# Iterate through sentences and find dates
for sentence in sentences:
# Find all dates in the sentence
dates = date_regex.findall(sentence)
# If dates are found, add them to the result dictionary with the corresponding sentence
if dates:
for date in dates:
# Return the first date found
return (str(date), sentence.replace(str(date), "<date>").strip())

# If no dates are found, search for dates in the form of "Month DD"
secondary_date_pattern = r"\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?\b"
secondary_date_regex = re.compile(secondary_date_pattern)

for sentence in sentences:
# Find all dates in the sentence
dates = secondary_date_regex.findall(sentence)
# If dates are found, add them to the result dictionary with the corresponding sentence
if dates:
for date in dates:
# Return the first date found
return (str(date), sentence.replace(str(date), "<date>").strip())

return None

def _random_date(self) -> DateContext:
for i in range(self.max_tries):
try:
context = CACHED_ARTICLES.get(block=False)
if not context:
continue

date_sentence = self._extract_dates_and_sentences(context.content)

if date_sentence and all(date_sentence):
content, date = date_sentence
date_context = DateContext.from_context(context, date=date)
date_context.content = content
return date_context

except Empty:
logger.debug(f"Retry {i} Cache is empty. Skipping date until cache is filled.")
return None

except Exception as e:
logger.exception(f"Error fetching date: {e}")
continue

def get(
self,
) -> dict:
raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}")

def search(self, name: str, results: int = 5) -> dict:
raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}")

def random(self) -> DateContext:
return self._random_date()
2 changes: 1 addition & 1 deletion prompting/miner_availability/miner_availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def run_step(self):
llm_model_availabilities=response["llm_model_availabilities"],
)
except Exception:
logger.debug("Availability Response Invalid")
# logger.debug("Availability Response Invalid")
miner_availabilities.miners[uid] = MinerAvailability(
task_availabilities={task: True for task in task_config},
llm_model_availabilities={model: False for model in model_config},
Expand Down
1 change: 1 addition & 0 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def run_step(self) -> RewardLoggingEvent:
step=scoring_config.step,
task_id=scoring_config.task_id,
task_dict=scoring_config.task.model_dump(),
source=scoring_config.dataset_entry.source,
)
)
logger.info("Adding scores to rewards_and_uids")
Expand Down
1 change: 0 additions & 1 deletion shared/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def next(self, method: Literal["random", "search", "get"] = "random", **kwargs)
logger.error(f"Failed to fetch context after {RETRIES} tries.")
return None

context.source = self.__class__.__name__
context.stats = {
"fetch_time": timer.final_time,
"num_tries": tries,
Expand Down
1 change: 1 addition & 0 deletions shared/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class RewardLoggingEvent(BaseEvent):
challenge: str | list[dict]
task: str
task_dict: dict
source: str | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
Loading