Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed ignore_existing flag not working as expected. #224

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/taggers.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The following parameters are supported either via CLI (e.g. `dolma tag --paramet
|`taggers`|Yes| One or more taggers to run. |
|`tagger_modules`|No| List of one or more Python modules to load taggers from. See section [*"Using Custom Taggers"*](#using-custom-taggers) for more details. |
|`processes`|No| Number of processes to use for tagging. One process is used by default. |
|`ignore_existing`|No| If true, ignore existing outputs and re-run the taggers. |
|`skip_existing`|No| If true, ignore existing outputs and re-run the taggers. |
|`dryrun`|No| If true, only print the configuration and exit without running the taggers. |
|`debug`|No| If true, run in debug mode (i.e., disable parallelism). Useful when developing new taggers. |
|`profile.enable`|No| If true, enable profiling. Useful when benchmarking taggers during development. |
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ dev = [
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
"types-PyYAML",
"types-dateparser"
]
# extension to process code
code = ["detect-secrets==1.4.0", "beautifulsoup4>=4", "pygments", "regex"]
Expand Down Expand Up @@ -227,7 +229,6 @@ aggressive = 3
[tool.mypy]
python_version = "3.9"
ignore_missing_imports = true
no_site_packages = true
allow_redefinition = false
warn_unused_configs = true
warn_unused_ignores = true
Expand All @@ -238,5 +239,6 @@ show_error_codes = true
pretty = true
plugins = ["numpy.typing.mypy_plugin"]


[tool.mypy-tests]
strict_optional = false
4 changes: 2 additions & 2 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config:

untyped_config: DictConfig = om.merge(
om.create(config or {}), om.create(nested_config_dict)
) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig)
) # type: ignore # (pylance is confused because om.create might return a DictConfig or a ListConfig)

base_structured_config: DictConfig = om.structured(structured)
merged_config = om.merge(base_structured_config, untyped_config)
Expand All @@ -159,7 +159,7 @@ def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config:
except OmegaConfBaseException as ex:
raise DolmaConfigError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex

return merged_config # pyright: ignore
return merged_config # type: ignore # (pylance because same error as above)


def print_config(config: Any, console: Optional[Console] = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions python/dolma/cli/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TaggerConfig:
default=1,
help="Number of parallel processes to use.",
)
ignore_existing: bool = field(
skip_existing: bool = field(
default=False,
help="Whether to ignore existing outputs and re-run the taggers.",
)
Expand Down Expand Up @@ -132,7 +132,7 @@ def run(cls, parsed_config: TaggerConfig):
metadata=work_dirs.output,
taggers=taggers,
taggers_modules=parsed_config.tagger_modules,
ignore_existing=parsed_config.ignore_existing,
skip_existing=parsed_config.skip_existing,
num_processes=parsed_config.processes,
experiment=parsed_config.experiment,
debug=parsed_config.debug,
Expand Down
4 changes: 2 additions & 2 deletions python/dolma/cli/warc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class WarcExtractorConfig:
default=1,
help="Number of parallel processes to use.",
)
ignore_existing: bool = field(
skip_existing: bool = field(
default=False,
help="Whether to ignore existing outputs and re-run the taggers.",
)
Expand Down Expand Up @@ -107,7 +107,7 @@ def run(cls, parsed_config: WarcExtractorConfig):
destination=(destination[0] if len(destination) == 1 else destination),
metadata=work_dirs.output,
num_processes=parsed_config.processes,
ignore_existing=parsed_config.ignore_existing,
skip_existing=parsed_config.skip_existing,
debug=parsed_config.debug,
source_name=source_name,
pre_taggers=parsed_config.pre.taggers,
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def create_and_run_analyzer(
metadata_prefix=metadata_path,
debug=debug,
seed=seed,
ignore_existing=True,
skip_existing=True,
retries_on_error=0,
num_processes=num_processes,
)
Expand Down
6 changes: 4 additions & 2 deletions python/dolma/core/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, *args, metadata: Optional[Dict[str, Any]] = None, **kwargs) -
self.metadata = metadata or {}

@classmethod
def from_spec(cls, spec: InputSpecWithMetadata) -> "DocumentWithMetadata":
def from_spec(cls, spec: InputSpecWithMetadata) -> "DocumentWithMetadata": # type: ignore[override]
return DocumentWithMetadata(
source=spec.source,
version=spec.version,
Expand Down Expand Up @@ -125,7 +125,9 @@ def __init__(
self.attributes = attributes or {}

@classmethod
def from_spec(cls, spec: InputSpecWithMetadataAndAttributes) -> "DocumentWithMetadataAndAttributes":
def from_spec( # type: ignore[override]
cls, spec: InputSpecWithMetadataAndAttributes
) -> "DocumentWithMetadataAndAttributes":
return DocumentWithMetadataAndAttributes(
source=spec.source,
version=spec.version,
Expand Down
12 changes: 6 additions & 6 deletions python/dolma/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
debug: bool = False,
seed: int = 0,
pbar_timeout: float = 1e-3,
ignore_existing: bool = False,
skip_existing: bool = False,
include_paths: Optional[List[str]] = None,
exclude_paths: Optional[List[str]] = None,
files_regex_pattern: Optional[str] = None,
Expand All @@ -87,15 +87,15 @@ def __init__(
file names will also be the same.
metadata_prefix (str): The prefix of the metadata files to save. This can be a local path or an
S3 path. Metadata output will be created for each file after it is processed. Filenames are
checked to verify if a file has been processed and can be skipped unless `ignore_existing` is
checked to verify if a file has been processed and can be skipped unless `skip_existing` is
set to true.
num_processes (int, optional): The number of processes to use. Defaults to 1.
debug (bool, optional): Whether to run in debug mode; if true, no multiprocessing will be used.
Defaults to False.
seed (int, optional): The random seed to use when shuffling input files. Defaults to 0.
pbar_timeout (float, optional): How often to update progress bars in seconds.
Defaults to 0.01 seconds.
ignore_existing (bool, optional): Whether to ignore files that have been already processed and
skip_existing (bool, optional): Whether to ignore files that have been already processed and
re-run the processor on all files from scratch. Defaults to False.
include_paths (Optional[List[str]], optional): A list of paths to include. If provided, only files
that match one of the paths will be processed. Defaults to None.
Expand All @@ -118,7 +118,7 @@ def __init__(
self.debug = debug
self.seed = seed
self.pbar_timeout = pbar_timeout
self.ignore_existing = ignore_existing
self.skip_existing = skip_existing

self.include_paths = set(include_paths) if include_paths is not None else None
self.exclude_paths = set(exclude_paths) if exclude_paths is not None else None
Expand Down Expand Up @@ -354,7 +354,7 @@ def __add__(self: BPP, other: BPP) -> BPP:
debug=self.debug or other.debug,
seed=self.seed,
pbar_timeout=max(self.pbar_timeout, other.pbar_timeout),
ignore_existing=self.ignore_existing or other.ignore_existing,
skip_existing=self.skip_existing or other.skip_existing,
include_paths=include_paths,
exclude_paths=exclude_paths,
files_regex_pattern=regex_pattern,
Expand Down Expand Up @@ -484,7 +484,7 @@ def _get_all_paths(self) -> AllPathsTuple:
)

for path in rel_paths:
if not self.ignore_existing and path in existing_metadata_names:
if not self.skip_existing and path in existing_metadata_names:
continue

if not self._valid_path(path):
Expand Down
54 changes: 45 additions & 9 deletions python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,17 @@
TaggerOutputDictType,
)
from .errors import DolmaFatalError, DolmaRetryableFailure, DolmaShardError
from .loggers import get_logger
from .parallel import BaseParallelProcessor, QueueType
from .paths import delete_dir, join_path, make_relative, mkdir_p, split_glob, split_path
from .paths import (
delete_dir,
exists,
join_path,
make_relative,
mkdir_p,
split_glob,
split_path,
)
from .registry import TaggerRegistry
from .utils import import_modules, make_variable_name

Expand Down Expand Up @@ -178,10 +187,10 @@ def _make_output_streams(
mkdir_p(parent)

# open a new file and create a new encoder
io = stack.enter_context(smart_open.open(loc.path, **open_kwargs))
io_ = stack.enter_context(smart_open.open(loc.path, **open_kwargs))
encoder = msgspec.json.Encoder()
opened[loc.path] = TaggerOutputIO(
exp=loc.exp, taggers=set(), path=loc.path, io=io, encoder=encoder
exp=loc.exp, taggers=set(), path=loc.path, io=io_, encoder=encoder
)

# keep track of which taggers are writing to this paths
Expand Down Expand Up @@ -223,7 +232,7 @@ def _write_sample_to_streams(

class TaggerProcessor(BaseParallelProcessor):
@classmethod
def increment_progressbar( # type: ignore
def increment_progressbar( # type: ignore # pylint: disable=arguments-differ
cls,
queue: QueueType, # queue must be the first argument, and it should be a positional-only argument
/,
Expand All @@ -245,6 +254,10 @@ def process_single(
**kwargs,
):
"""Lets count run the taggers! We will use the destination path to save each tagger output."""

# get a logger
logger = get_logger(cls.__name__)

# import tagger modules
taggers_modules = kwargs.get("taggers_modules", None)
if taggers_modules is not None:
Expand All @@ -264,7 +277,9 @@ def process_single(

# this is the dictionary that will hold the output of each tagger
taggers_paths = _determine_output_paths_for_taggers(
experiment_name=experiment_name, destination=destination_path, taggers=taggers
experiment_name=experiment_name,
destination=destination_path,
taggers=taggers,
)

# skip on failure
Expand All @@ -283,6 +298,27 @@ def process_single(
# total number of documents processed
total_docs_cnt = 0

if kwargs.get("skip_existing", False):
# we group taggers by their path (this is for cases when two taggers are going to same file)
# and then remove all taggers if any of the paths exists and skip_existing is True
_taggers_by_path: Dict[str, list[str]] = {}
for tagger_name, tagger_location in taggers_paths.items():
_taggers_by_path.setdefault(tagger_location.path, []).append(tagger_name)

# actually take care of removal here
for tagger_path, tagger_names in _taggers_by_path.items():
if exists(tagger_path):
for tagger_name in tagger_names:
logger.info("Skipping %s because %s already exists.", tagger_name, tagger_path)
taggers.pop(tagger_name)
taggers_paths.pop(tagger_name)

if not taggers:
# if all taggers have been removed, we return early
cls.increment_progressbar(queue, files=1)
logger.info("All taggers for %s have been skipped.", source_path)
return

# creating dedicated decoder speeds up the process
# if any of the taggers require metadata, we use a decoder that can handle it
# otherwise, we use a decoder that does not parse metadata, which is faster
Expand Down Expand Up @@ -327,7 +363,7 @@ def process_single(
# double the update interval if the queue is full
update_interval *= 2

except Exception as exp:
except Exception as exp: # pylint: disable=broad-except
# handle any exception that might have occurred
msg = f"Failed to process {source_path} due to {exp.__class__.__name__}: {' '.join(exp.args)}"
if exp.__class__.__name__ == "IncompleteReadError":
Expand Down Expand Up @@ -383,7 +419,7 @@ def create_and_run_tagger(
metadata: Union[None, str, List[str]] = None,
debug: bool = False,
seed: int = 0,
ignore_existing: bool = False,
skip_existing: bool = False,
skip_on_failure: bool = False,
retries_on_error: int = 0,
num_processes: int = 1,
Expand Down Expand Up @@ -411,7 +447,7 @@ def create_and_run_tagger(
which documents have been processed. If `None`, the metadata will be saved in a temporary directory.
debug (bool, optional): Whether to run in debug mode. Defaults to False.
seed (int, optional): The seed to use for the random number generator. Defaults to 0.
ignore_existing (bool, optional): Whether to ignore existing outputs and re-run the taggers.
skip_existing (bool, optional): Whether to ignore existing outputs and re-run the taggers.
Defaults to False.
skip_on_failure (bool, optional): Whether to skip a document if it fails to process. Defaults to False.
retries_on_error (int, optional): Number of times to retry processing a document if it fails.
Expand Down Expand Up @@ -466,7 +502,7 @@ def create_and_run_tagger(
metadata_prefix=metadata,
debug=debug or profile_enable, # if profile is true, debug must be true
seed=seed,
ignore_existing=ignore_existing,
skip_existing=skip_existing,
retries_on_error=retries_on_error,
num_processes=num_processes,
)
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/core/taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BaseTaggerWithMetadata(BaseTagger):
def predict(self, doc: DocumentWithMetadata) -> DocResult: # type: ignore
raise NotImplementedError

def tag(self, row: InputSpecWithMetadata) -> TaggerOutputDictType:
def tag(self, row: InputSpecWithMetadata) -> TaggerOutputDictType: # type: ignore
"""Internal function that is used by the tagger to get data"""
doc = DocumentWithMetadata.from_spec(row)
doc_result = self.predict(doc)
Expand Down
5 changes: 3 additions & 2 deletions python/dolma/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,6 @@ def _handle_zstd(file_obj, mode):

register_compressor(".zstd", _handle_zstd)
else:
# add zstd compression
add_compression()
# add zstd compression; in case smart_open has zstd support already, this will error out
# with mypy, so we need the type: ignore[unreachable] comment
add_compression() # type: ignore[unreachable]
8 changes: 4 additions & 4 deletions python/dolma/taggers/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
from ..core.utils import split_paragraphs

with necessary.necessary("cld3", soft=True) as CLD3_AVAILABLE:
if CLD3_AVAILABLE or TYPE_CHECKING:
if CLD3_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import cld3 # pyright:ignore pylint:disable=import-error

with necessary.necessary("pycld2", soft=True) as CLD2_AVAILABLE:
if CLD2_AVAILABLE or TYPE_CHECKING:
if CLD2_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import pycld2 as cld2 # pyright:ignore pylint:disable=import-error


with necessary.necessary("langdetect", soft=True) as LANGDETECT_AVAILABLE:
if LANGDETECT_AVAILABLE or TYPE_CHECKING:
if LANGDETECT_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from langdetect import PROFILES_DIRECTORY, DetectorFactory, LangDetectException


with necessary.necessary("lingua", soft=True) as LINGUA_AVAILABLE:
if LINGUA_AVAILABLE or TYPE_CHECKING:
if LINGUA_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from lingua import Language, LanguageDetectorBuilder


Expand Down
20 changes: 14 additions & 6 deletions python/dolma/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from os import PathLike
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
from typing import ( # type: ignore[unreachable,unused-ignore]
TYPE_CHECKING,
Generator,
List,
Optional,
Tuple,
Union,
)

import msgspec
import numpy as np
Expand All @@ -25,8 +32,10 @@
from .data_types import InputSpec, TokenizerOutput

with necessary("transformers", soft=True) as TRANSFORMERS_AVAILABLE:
if TYPE_CHECKING or TRANSFORMERS_AVAILABLE:
from transformers import AutoTokenizer # pylint: disable=import-error
if TYPE_CHECKING or TRANSFORMERS_AVAILABLE: # type: ignore[unreachable,unused-ignore]
from transformers import ( # pyright: ignore # pylint: disable=import-error
AutoTokenizer,
)

PathOrStr = Union[str, PathLike]

Expand Down Expand Up @@ -365,7 +374,6 @@ def tokenize_file(
file, each containing a field named `text`.
"""
tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs)
dtype = deepcopy(tokenizer.dtype)
decoder = msgspec.json.Decoder(InputSpec)
with smart_open.open(path, mode="rt") as input_stream:
for i, line in enumerate(input_stream, start=1):
Expand All @@ -376,8 +384,8 @@ def tokenize_file(
tokens = tokenizer.encode(text, add_special_tokens=True)
if refresh_tokenizer_every:
# extra copy to prevent memory leaks
tokens = np.array(tokens, dtype=dtype)
yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore
tokens = deepcopy(tokens)
yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens)

if refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0:
# to prevent memory leaks, we refresh the tokenizer every so often
Expand Down
Loading
Loading