diff --git a/bergson/__init__.py b/bergson/__init__.py index 5a3a568..5ec4228 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -1,6 +1,7 @@ -from .attributor import Attributor, FaissConfig +from .attributor import Attributor from .collection import collect_gradients from .data import IndexConfig, load_gradients +from .faiss_index import FaissConfig from .gradcheck import FiniteDiff from .gradients import ( GradientCollector, diff --git a/bergson/attributor.py b/bergson/attributor.py index ddc753c..e8d197e 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -1,20 +1,12 @@ -import json -import math -import os +from collections import defaultdict from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from time import time -from typing import Generator, Protocol +from typing import Generator -import numpy as np import torch -from numpy.lib.recfunctions import structured_to_unstructured -from numpy.typing import NDArray from torch import Tensor, nn -from tqdm import tqdm -from .data import load_gradients, load_unstructured_gradients +from .data import load_gradients +from .faiss_index import FaissConfig, FaissIndex from .gradients import GradientCollector, GradientProcessor @@ -43,202 +35,6 @@ def scores(self) -> Tensor: return self._scores -class Index(Protocol): - """Protocol for any FAISS index that supports search operations.""" - - def search(self, x: NDArray, k: int) -> tuple[NDArray, NDArray]: ... - @property - def ntotal(self) -> int: ... - @property - def nprobe(self) -> int: ... - @nprobe.setter - def nprobe(self, value: int) -> None: ... - def train(self, x: NDArray) -> None: ... - def add(self, x: NDArray) -> None: ... - - -@dataclass -class FaissConfig: - """Configuration for FAISS index.""" - - index_factory: str = "IVF1,SQfp16" - """ - The [FAISS index factory string](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index). - - Common FAISS factory strings: - - "IVF1,SQfp16": exact nearest neighbors with brute force search and fp16. - - "IVF1024,SQfp16": approximate nearest neighbors with 1024 cluster centers - and fp16. Fast approximate queries are produced at the cost of a slower - initial index build. - - "PQ6720": nearest neighbors with vector product quantization to 6720 elements. - Reduces memory usage at the cost of accuracy. - """ - mmap_index: bool = False - """Whether to query the gradients on-disk.""" - max_train_examples: int | None = None - """The maximum number of examples to train the index on. - If `None`, all examples will be used.""" - batch_size: int = 1024 - """The batch size for pre-processing gradients.""" - num_shards: int = 1 - """The number of shards to build for an index. - Using more shards reduces peak RAM usage.""" - nprobe: int = 10 - """The number of FAISS vector clusters to search if using ANN.""" - - -def normalize_grads( - grads: NDArray, - device: str, - batch_size: int, -) -> NDArray: - normalized_grads = np.zeros_like(grads).astype(grads.dtype) - - for i in range(0, grads.shape[0], batch_size): - batch = torch.from_numpy(grads[i : i + batch_size]).to(device) - normalized_grads[i : i + batch_size] = ( - (batch / batch.norm(dim=1, keepdim=True)).cpu().numpy() - ) - - return normalized_grads - - -def gradients_loader(root_dir: str): - def load_shard(shard_dir: str) -> np.memmap: - with open(os.path.join(shard_dir, "info.json")) as f: - info = json.load(f) - - if "grad_size" in info: - return load_unstructured_gradients(shard_dir) - - dtype = info["dtype"] - num_grads = info["num_grads"] - - return np.memmap( - os.path.join(shard_dir, "gradients.bin"), - dtype=dtype, - mode="r", - shape=(num_grads,), - ) - - root_path = Path(root_dir) - if (root_path / "info.json").exists(): - yield load_shard(root_dir) - else: - for shard_path in sorted(root_path.iterdir()): - if shard_path.is_dir(): - yield load_shard(str(shard_path)) - - -def index_to_device(index: Index, device: str) -> Index: - assert faiss is not None, "Faiss not found, run `pip install faiss-gpu-cu12`..." - - if device != "cpu": - gpus = ( - list(range(torch.cuda.device_count())) - if device == "cuda" - else [int(device.split(":")[1])] - ) - - options = faiss.GpuMultipleClonerOptions() - options.shard = True - return faiss.index_cpu_to_gpus_list(index, options, gpus=gpus) - - return faiss.index_gpu_to_cpu(index) - - -def load_faiss_index( - index_path: str, - device: str, - unit_norm: bool, - faiss_cfg: FaissConfig, -) -> list[Index]: - assert faiss is not None, "Faiss not found, run `pip install faiss-gpu-cu12`..." - - faiss_path = ( - Path("runs/faiss") - / Path(index_path).stem - / ( - f"{faiss_cfg.index_factory.replace(',', '_')}" - f"{'_unit_norm' if unit_norm else ''}" - ) - ) - - if not faiss_path.exists(): - print("Building FAISS index...") - start = time() - - faiss_path.mkdir(exist_ok=True, parents=True) - - num_dataset_shards = len(list(Path(index_path).iterdir())) - shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards) - - dl = gradients_loader(index_path) - buffer = [] - index_idx = 0 - - for grads in tqdm(dl, desc="Loading gradients"): - if grads.dtype.names is not None: - grads = structured_to_unstructured(grads) - - if unit_norm: - grads = normalize_grads(grads, device, faiss_cfg.batch_size) - - buffer.append(grads) - - if len(buffer) == shards_per_index: - # Build index shard - print(f"Building shard {index_idx}...") - - grads = np.concatenate(buffer, axis=0) - buffer = [] - - index = faiss.index_factory( - grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT - ) - index = index_to_device(index, device) - train_examples = faiss_cfg.max_train_examples or grads.shape[0] - index.train(grads[:train_examples]) - index.add(grads) - - # Write index to disk - del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) - - index_idx += 1 - - if buffer: - grads = np.concatenate(buffer, axis=0) - buffer = [] - index = faiss.index_factory( - grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT - ) - index = index_to_device(index, device) - index.train(grads) - index.add(grads) - - # Write index to disk - del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) - - print(f"Built index in {(time() - start) / 60:.2f} minutes.") - del buffer, index - - shards = [] - for i in range(faiss_cfg.num_shards): - shard = faiss.read_index( - str(faiss_path / f"{i}.faiss"), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY - ) - if not faiss_cfg.mmap_index: - shard = index_to_device(shard, device) - - shards.append(shard) - - return shards - - class Attributor: def __init__( self, @@ -248,95 +44,91 @@ def __init__( unit_norm: bool = False, faiss_cfg: FaissConfig | None = None, ): - if faiss_cfg: - self.faiss_shards = load_faiss_index( - index_path, device, unit_norm, faiss_cfg - ) - else: - mmap = load_gradients(index_path) - if mmap.dtype.names is not None: - mmap = structured_to_unstructured(mmap) - - self.grads = torch.tensor(mmap, device=device, dtype=dtype) - - # In-place normalize for numerical stability - if unit_norm: - self.grads /= self.grads.norm(dim=1, keepdim=True) - self.device = device self.dtype = dtype self.unit_norm = unit_norm - self.faiss_cfg = faiss_cfg + self.faiss_index = None # Load the gradient processor self.processor = GradientProcessor.load(index_path, map_location=device) - def search(self, queries: Tensor, k: int) -> tuple[Tensor, Tensor]: + # Load the gradient index + if faiss_cfg: + self.faiss_index = FaissIndex(index_path, faiss_cfg, device, unit_norm) + self.N = self.faiss_index.ntotal + else: + mmap = load_gradients(index_path) + + # Copy gradients into device memory + self.grads = { + name: torch.tensor(mmap[name], device=device, dtype=dtype) + for name in mmap.dtype.names + } + self.N = mmap[mmap.dtype.names[0]].shape[0] + + if unit_norm: + norm = torch.cat([grad for grad in self.grads.values()], dim=1).norm( + dim=1, keepdim=True + ) + for name in self.grads: + self.grads[name] /= norm + + def search( + self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None + ) -> tuple[Tensor, Tensor]: """ Search for the `k` nearest examples in the index based on the query or queries. - If fewer than `k` examples are found FAISS will return items with the index -1 - and the maximum negative distance. Args: queries: The query tensor of shape [..., d]. k: The number of nearest examples to return for each query. - nprobe: The number of FAISS vector clusters to search if using ANN. + module: The name of the module to search for. If `None`, + all modules will be searched. Returns: A namedtuple containing the top `k` indices and inner products for each query. Both have shape [..., k]. """ - q = queries + q = {name: item.to(self.device, self.dtype) for name, item in queries.items()} if self.unit_norm: - q /= q.norm(dim=1, keepdim=True) - - if not self.faiss_cfg: - return torch.topk(q.to(self.device) @ self.grads.mT, k) + norm = torch.cat(list(q.values()), dim=1).norm(dim=1, keepdim=True) - q = q.cpu().numpy() + for name in q: + q[name] /= norm + 1e-8 - shard_distances = [] - shard_indices = [] - offset = 0 + if self.faiss_index: + if modules: + raise NotImplementedError( + "FAISS index does not implement module-specific search." + ) - for index in self.faiss_shards: - index.nprobe = self.faiss_cfg.nprobe - distances, indices = index.search(q, k) + q = torch.cat([q[name] for name in q], dim=1).cpu().numpy() - indices += offset - offset += index.ntotal + distances, indices = self.faiss_index.search(q, k) - shard_distances.append(distances) - shard_indices.append(indices) + return torch.from_numpy(distances.squeeze()), torch.from_numpy( + indices.squeeze() + ) - distances = np.concatenate(shard_distances, axis=1) - indices = np.concatenate(shard_indices, axis=1) + modules = modules or list(q.keys()) + k = min(k, self.N) - # Rerank results overfetched from multiple shards - if len(self.faiss_shards) > 1: - topk_indices = np.argsort(distances, axis=1)[:, :k] - indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] - distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] + scores = torch.stack( + [q[name] @ self.grads[name].mT for name in modules], dim=-1 + ).sum(-1) - return torch.from_numpy(distances.squeeze()), torch.from_numpy( - indices.squeeze() - ) + return torch.topk(scores, k) @contextmanager def trace( - self, - module: nn.Module, - k: int, - *, - precondition: bool = False, - unit_norm: bool = True, + self, module: nn.Module, k: int, *, precondition: bool = False ) -> Generator[TraceResult, None, None]: """ Context manager to trace the gradients of a module and return the corresponding Attributor instance. """ - mod_grads: list[Tensor] = [] + mod_grads = defaultdict(list) result = TraceResult() def callback(name: str, g: Tensor): @@ -349,7 +141,7 @@ def callback(name: str, g: Tensor): g = g.flatten(1) # Store the gradient for later use - mod_grads.append(g.to(self.device, self.dtype, non_blocking=True)) + mod_grads[name].append(g.to(self.device, self.dtype, non_blocking=True)) with GradientCollector(module, callback, self.processor): yield result @@ -357,12 +149,9 @@ def callback(name: str, g: Tensor): if not mod_grads: raise ValueError("No grads collected. Did you forget to call backward?") - queries = torch.cat(mod_grads, dim=1) + queries = {name: torch.cat(g, dim=1) for name, g in mod_grads.items()} - if queries.isnan().any(): + if any(q.isnan().any() for q in queries.values()): raise ValueError("NaN found in queries.") - if unit_norm: - queries /= queries.norm(dim=1, keepdim=True) - result._scores, result._indices = self.search(queries, k) diff --git a/bergson/build.py b/bergson/build.py index ca29d8f..2ba5420 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -154,20 +154,20 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD target_modules=target_modules, ) else: - # Convert each chunk to Dataset then collect their gradients - buf, chunk_id = [], 0 + # Convert each shard to a Dataset then collect its gradients + buf, shard_id = [], 0 def flush(): - nonlocal buf, chunk_id + nonlocal buf, shard_id if not buf: return - sub_ds = assert_type(Dataset, Dataset.from_list(buf)) - batches = allocate_batches(sub_ds["length"], cfg.token_batch_size) + ds_shard = assert_type(Dataset, Dataset.from_list(buf)) + batches = allocate_batches(ds_shard["length"][:], cfg.token_batch_size) collect_gradients( model, - sub_ds, + ds_shard, processor, - os.path.join(cfg.run_path, f"chunk-{chunk_id:05d}"), + os.path.join(cfg.run_path, f"shard-{shard_id:05d}"), batches=batches, kl_divergence=cfg.loss_fn == "kl", loss_reduction=cfg.loss_reduction, @@ -175,11 +175,11 @@ def flush(): target_modules=target_modules, ) buf.clear() - chunk_id += 1 + shard_id += 1 for ex in tqdm(ds, desc="Collecting gradients"): buf.append(ex) - if len(buf) == cfg.streaming_chunk_size: + if len(buf) == cfg.stream_shard_size: flush() flush() diff --git a/bergson/data.py b/bergson/data.py index a310895..4475585 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -102,8 +102,8 @@ class IndexConfig: streaming: bool = False """Whether to use streaming mode for the dataset.""" - streaming_chunk_size: int = 100_000 - """Chunk size for streaming the dataset into Dataset objects.""" + stream_shard_size: int = 100_000 + """Shard size for streaming the dataset into Dataset objects.""" revision: str | None = None """Revision of the model.""" @@ -305,32 +305,12 @@ def load_data_string( return ds -def load_unstructured_gradients(root_dir: str) -> np.memmap: - """Map the gradients stored in `root_dir` into memory.""" - with open(os.path.join(root_dir, "info.json")) as f: - info = json.load(f) - grad_size = info["grad_size"] - num_grads = info["num_grads"] - - mmap = np.memmap( - root_dir + "/gradients.bin", - dtype=np.float16, - mode="r", - shape=(num_grads, grad_size), - ) - return mmap - - def load_gradients(root_dir: str) -> np.memmap: """Map the structured gradients stored in `root_dir` into memory.""" with open(os.path.join(root_dir, "info.json")) as f: info = json.load(f) - # TODO 2025-08-01 Remove legacy loading - if "grad_size" in info: - return load_unstructured_gradients(root_dir) - dtype = info["dtype"] num_grads = info["num_grads"] diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py new file mode 100644 index 0000000..970fa62 --- /dev/null +++ b/bergson/faiss_index.py @@ -0,0 +1,265 @@ +import json +import math +import os +from dataclasses import dataclass +from pathlib import Path +from time import time +from typing import Protocol + +import numpy as np +import torch +from numpy.lib.recfunctions import structured_to_unstructured +from numpy.typing import NDArray +from tqdm import tqdm + + +@dataclass +class FaissConfig: + """Configuration for FAISS index.""" + + index_factory: str = "Flat" + """ + The [FAISS index factory string](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index). + + Common FAISS factory strings: + - "IVF1,SQfp16": exact nearest neighbors with brute force search and fp16. + Valid for CPU or memmapped indices. + - "IVF1024,SQfp16": approximate nearest neighbors with 1024 cluster centers + and fp16. Fast approximate queries are produced at the cost of a slower + initial index build. + - "PQ6720": nearest neighbors with vector product quantization to 6720 elements. + Reduces memory usage at the cost of accuracy. + """ + mmap_index: bool = False + """Whether to query the gradients on-disk.""" + max_train_examples: int | None = None + """The maximum number of examples to train the index on. + If `None`, all examples will be used.""" + batch_size: int = 1024 + """The batch size for pre-processing gradients.""" + num_shards: int = 1 + """The number of shards to build for an index. + Using more shards reduces peak RAM usage.""" + nprobe: int = 10 + """The number of FAISS vector clusters to search if using ANN.""" + + +class Index(Protocol): + """Protocol for searchable FAISS index.""" + + def search(self, x: NDArray, k: int) -> tuple[NDArray, NDArray]: ... + @property + def ntotal(self) -> int: ... + @property + def nprobe(self) -> int: ... + @nprobe.setter + def nprobe(self, value: int) -> None: ... + def train(self, x: NDArray) -> None: ... + def add(self, x: NDArray) -> None: ... + + +def normalize_grads( + grads: NDArray, + device: str, + batch_size: int, +) -> NDArray: + normalized_grads = np.zeros_like(grads).astype(grads.dtype) + + for i in range(0, grads.shape[0], batch_size): + batch = torch.from_numpy(grads[i : i + batch_size]).to(device) + normalized_grads[i : i + batch_size] = ( + (batch / batch.norm(dim=1, keepdim=True)).cpu().numpy() + ) + + return normalized_grads + + +def gradients_loader(root_dir: str): + def load_shard(shard_dir: str) -> np.memmap: + with open(os.path.join(shard_dir, "info.json")) as f: + info = json.load(f) + + return np.memmap( + os.path.join(shard_dir, "gradients.bin"), + dtype=info["dtype"], + mode="r", + shape=(info["num_grads"],), + ) + + root_path = Path(root_dir) + if (root_path / "info.json").exists(): + yield load_shard(root_dir) + else: + for shard_path in sorted(root_path.iterdir()): + if shard_path.is_dir(): + yield load_shard(str(shard_path)) + + +def index_to_device(index: Index, device: str) -> Index: + try: + import faiss + except ImportError: + raise ImportError("Faiss not found, run `pip install faiss-gpu-cu12`...") + import faiss + + if device != "cpu": + gpus = ( + list(range(torch.cuda.device_count())) + if device == "cuda" + else [int(device.split(":")[1])] + ) + + try: + options = faiss.GpuMultipleClonerOptions() + except AttributeError as e: + raise ImportError( + "Faiss not found, you may have faiss-cpu installed instead " + "of faiss-gpu with `pip install faiss-gpu-cu12`..." + ) from e + + options.shard = True + return faiss.index_cpu_to_gpus_list(index, options, gpus=gpus) + + return faiss.index_gpu_to_cpu(index) + + +class FaissIndex: + """FAISS index.""" + + shards: list[Index] + + def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bool): + try: + import faiss + except ImportError: + raise ImportError("Faiss not found, run `pip install faiss-gpu-cu12`") + import faiss + + self.faiss_cfg = faiss_cfg + + faiss_path = ( + Path("runs/faiss") + / Path(path).stem + / ( + f"{faiss_cfg.index_factory.replace(',', '_')}" + f"{'_unit_norm' if unit_norm else ''}" + ) + ) + + if not (faiss_path.exists() and any(faiss_path.iterdir())): + print("Building FAISS index...") + start = time() + + faiss_path.mkdir(exist_ok=True, parents=True) + + num_dataset_shards = len(list(Path(path).iterdir())) + shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards) + + dl = gradients_loader(path) + buffer = [] + index_idx = 0 + + for grads in tqdm(dl, desc="Loading gradients"): + if grads.dtype.names is not None: + grads = structured_to_unstructured(grads) + + if unit_norm: + grads = normalize_grads(grads, device, faiss_cfg.batch_size) + + buffer.append(grads) + + if len(buffer) == shards_per_index: + # Build index shard + print(f"Building shard {index_idx}...") + + grads = np.concatenate(buffer, axis=0) + buffer = [] + + index = faiss.index_factory( + grads.shape[1], + faiss_cfg.index_factory, + faiss.METRIC_INNER_PRODUCT, + ) + index = index_to_device(index, device) + train_examples = faiss_cfg.max_train_examples or grads.shape[0] + index.train(grads[:train_examples]) + index.add(grads) + + # Write index to disk + del grads + index = index_to_device(index, "cpu") + faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) + + index_idx += 1 + + if buffer: + grads = np.concatenate(buffer, axis=0) + buffer = [] + index = faiss.index_factory( + grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT + ) + index = index_to_device(index, device) + index.train(grads) + index.add(grads) + + # Write index to disk + del grads + index = index_to_device(index, "cpu") + faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) + + print(f"Built index in {(time() - start) / 60:.2f} minutes.") + del buffer, index + + shards = [] + for i in range(faiss_cfg.num_shards): + shard = faiss.read_index( + str(faiss_path / f"{i}.faiss"), + faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, + ) + if not faiss_cfg.mmap_index: + shard = index_to_device(shard, device) + + shards.append(shard) + + self.shards = shards + + def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: + """Note: if fewer than `k` examples are found FAISS will return items + with the index -1 and the maximum negative distance.""" + shard_distances = [] + shard_indices = [] + offset = 0 + + for index in self.shards: + index.nprobe = self.faiss_cfg.nprobe + distances, indices = index.search(q, k) + + indices += offset + offset += index.ntotal + + shard_distances.append(distances) + shard_indices.append(indices) + + distances = np.concatenate(shard_distances, axis=1) + indices = np.concatenate(shard_indices, axis=1) + + # Rerank results overfetched from multiple shards + if len(self.shards) > 1: + topk_indices = np.argsort(distances, axis=1)[:, :k] + indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] + distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] + + return distances, indices + + @property + def ntotal(self) -> int: + return sum(shard.ntotal for shard in self.shards) + + @property + def nprobe(self) -> int: + return self.shards[0].nprobe + + @nprobe.setter + def nprobe(self, value: int) -> None: + for shard in self.shards: + shard.nprobe = value diff --git a/examples/query_index.py b/examples/query_index.py index b2cc448..d22dcea 100644 --- a/examples/query_index.py +++ b/examples/query_index.py @@ -3,7 +3,7 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from bergson import Attributor +from bergson import Attributor, FaissConfig def main(): @@ -12,16 +12,18 @@ def main(): parser.add_argument( "--model", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct" ) - parser.add_argument("--dataset", type=str, default="EleutherAI/SmolLM2-135M-10B") + parser.add_argument("--dataset", type=str, default="RonenEldan/TinyStories") parser.add_argument("--text_field", type=str, default="text") parser.add_argument("--unit_norm", action="store_true") + parser.add_argument("--faiss", action="store_true") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModelForCausalLM.from_pretrained(args.model, device_map={"": "cuda:0"}) dataset = load_dataset(args.dataset, split="train") - attr = Attributor(args.index, device="cuda:0") + faiss_cfg = FaissConfig() if args.faiss else None + attr = Attributor(args.index, device="cuda", faiss_cfg=faiss_cfg) # Query loop while True: