Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/743.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement events for logging model load events to detect common failures and improve the user experience
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
# Once Python 3.10 is the minimum version, this can be removed.
"eval-type-backport>=0.2.2",
"joblib>=1.2.0",
"tabpfn-common-utils[telemetry-interactive]>=0.2.13",
"tabpfn-common-utils[telemetry-interactive]>=0.2.15",
]
requires-python = ">=3.9"
authors = [
Expand Down
55 changes: 54 additions & 1 deletion src/tabpfn/model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,47 @@

# Copyright (c) Prior Labs GmbH 2025.

from __future__ import annotations

import contextlib
import inspect
import json
import logging
import os
import shutil
import sys
import tempfile
import urllib.request
import warnings
import zipfile
from collections.abc import Callable
from dataclasses import asdict, dataclass
from enum import Enum
from functools import wraps
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, overload
from urllib.error import URLError

import joblib
import torch
from tabpfn_common_utils.telemetry import set_model_config
from tabpfn_common_utils.telemetry import (
capture_event,
ModelLoadEvent,
set_model_config,
)
from torch import nn

from tabpfn.architectures import ARCHITECTURES
from tabpfn.architectures.base.bar_distribution import (
BarDistribution,
FullSupportBarDistribution,
)
from tabpfn.constants import ModelVersion
from tabpfn.errors import TabPFNHuggingFaceGatedRepoError
from tabpfn.inference import InferenceEngine
from tabpfn.inference_config import InferenceConfig
from tabpfn.settings import settings

Check failure on line 45 in src/tabpfn/model_loading.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (I001)

src/tabpfn/model_loading.py:5:1: I001 Import block is un-sorted or un-formatted

if TYPE_CHECKING:
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -160,6 +166,53 @@
)


def _log_huggingface_download_errors(func: Callable[..., None]) -> Callable[..., None]:
"""Decorator that catches exceptions and logs them with model information.

This is used for detecting and logging failures into our telemetry system,
to keep track of the most common failure reasons and improve the user experience
based on that information.

Args:
func: The function to decorate.
"""

@wraps(func)
def wrapper(
base_path: Path,
source: ModelSource,
model_name: str | None = None,
*,
suppress_warnings: bool = True,
) -> None:
# Extract model information
filename = model_name or source.default_filename
logged_model_name = Path(filename).parts[-1]

try:
r = func(base_path, source, model_name, suppress_warnings=suppress_warnings)

# Log success to the telemetry system
event = ModelLoadEvent(status="success", model_name=logged_model_name)
capture_event(event)

return r
except Exception as e:
# Detect failure reason and log it using the telemetry system
event = ModelLoadEvent(
status="failed",
failure_reason=e.__class__.__name__,
model_name=logged_model_name,
)
capture_event(event)

# Re-raise the original exception caught by the wrapper
raise

return wrapper


@_log_huggingface_download_errors
def _try_huggingface_downloads(
base_path: Path,
source: ModelSource,
Expand Down
Loading