Skip to content

Commit

Permalink
Reapply merge of PR #190 excluding unintentional revert of PRs #170 + #…
Browse files Browse the repository at this point in the history
  • Loading branch information
fazelehh committed Jan 13, 2025
1 parent 3453f65 commit 960d69a
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 521 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.9

- name: Install dependencies
- name: Install dependencies from pyproject.toml
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .[dev]
- name: Ruff Linting
run: |
Expand Down
38 changes: 0 additions & 38 deletions environment.yml

This file was deleted.

18 changes: 9 additions & 9 deletions examples/mia/tabular_mia/audit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ audit: # Configurations for auditing
gamma: 2.0
offline_a: 0.33 # parameter from which we compute p(x) from p_OUT(x) such that p_IN(x) = a p_OUT(x) + b.
offline_b: 0.66
qmia:
training_data_fraction: 1.0 # Fraction of the auxilary dataset (data without train and test indices) to use for training the quantile regressor
epochs: 5 # Number of training epochs for quantile regression
# qmia:
# training_data_fraction: 1.0 # Fraction of the auxilary dataset (data without train and test indices) to use for training the quantile regressor
# epochs: 5 # Number of training epochs for quantile regression
population:
attack_data_fraction: 1.0 # Fraction of the auxilary dataset to use for this attack
lira:
Expand All @@ -26,12 +26,12 @@ audit: # Configurations for auditing
number_of_traj: 10 # Number of epochs (number of points in the loss trajectory)
label_only: False # True or False
mia_classifier_epochs: 100
yoqo:
training_data_fraction: 0.5 # Fraction of the auxilary dataset to use for this attack (in each shadow model training)
num_shadow_models: 8 # Number of shadow models to train
online: True # perform online or offline attack
lr_xprime_optimization: .01
max_iterations: 35
# yoqo:
# training_data_fraction: 0.5 # Fraction of the auxilary dataset to use for this attack (in each shadow model training)
# num_shadow_models: 8 # Number of shadow models to train
# online: True # perform online or offline attack
# lr_xprime_optimization: .01
# max_iterations: 35

output_dir: "./leakpro_output"
attack_type: "mia" #mia, gia
Expand Down
653 changes: 278 additions & 375 deletions examples/mia/tabular_mia/main.ipynb

Large diffs are not rendered by default.

46 changes: 31 additions & 15 deletions leakpro/attacks/attack_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Module that contains the AttackScheduler class, which is responsible for creating and executing attacks."""

from leakpro.attacks.gia_attacks.attack_factory_gia import AttackFactoryGIA
from leakpro.attacks.mia_attacks.abstract_mia import AbstractMIA
from leakpro.attacks.mia_attacks.attack_factory_mia import AttackFactoryMIA
from leakpro.input_handler.abstract_input_handler import AbstractInputHandler
from leakpro.utils.import_helper import Any, Dict, Self
from leakpro.utils.logger import logger
Expand All @@ -11,8 +8,7 @@
class AttackScheduler:
"""Class responsible for creating and executing attacks."""

attack_type_to_factory = {"mia": AttackFactoryMIA,
"gia": AttackFactoryGIA}
attack_type_to_factory = {}

def __init__(
self:Self,
Expand All @@ -26,28 +22,48 @@ def __init__(
"""
configs = handler.configs
if configs["audit"]["attack_type"] not in list(self.attack_type_to_factory.keys()):
raise ValueError(
f"Unknown attack type: {configs['audit']['attack_type']}. "
f"Supported attack types: {self.attack_type_to_factory.keys()}"
)

# Prepare factory
factory = self.attack_type_to_factory[configs["audit"]["attack_type"]]
# Create factory
attack_type = configs["audit"]["attack_type"].lower()
self._initialize_factory(attack_type)

# Create the attacks
self.attack_list = list(configs["audit"]["attack_list"].keys())
self.attacks = []
for attack_name in self.attack_list:
try:
attack = factory.create_attack(attack_name, handler)
attack = self.attack_factory.create_attack(attack_name, handler)
self.add_attack(attack)
logger.info(f"Added attack: {attack_name}")
except ValueError as e:
logger.info(e)
logger.info(f"Failed to create attack: {attack_name}, supported attacks: {factory.attack_classes.keys()}")
logger.info(f"Failed to create attack: {attack_name}, supported attacks: {self.attack_factory.attack_classes.keys()}") # noqa: E501

def add_attack(self:Self, attack: AbstractMIA) -> None:
def _initialize_factory(self:Self, attack_type:str) -> None:
"""Conditionally import attack factories based on attack."""
if attack_type == "mia":
try:
from leakpro.attacks.mia_attacks.attack_factory_mia import AttackFactoryMIA
self.attack_factory = AttackFactoryMIA
logger.info("MIA attack factory loaded.")
except ImportError as e:
logger.error("Failed to import MIA attack module.")
raise ImportError("MIA attack module is not available.") from e

elif attack_type == "gia":
try:
from leakpro.attacks.gia_attacks.attack_factory_gia import AttackFactoryGIA
self.attack_factory = AttackFactoryGIA
logger.info("GIA attack factory loaded.")
except ImportError as e:
logger.error("Failed to import GIA attack module.")
raise ImportError("GIA attack module is not available.") from e

else:
logger.error(f"Unsupported attack type: {self.attack_type}")
raise ValueError(f"Unsupported attack type: {self.attack_type}. Must be 'mia' or 'gia'.")

def add_attack(self:Self, attack: Any) -> None:
"""Add an attack to the list of attacks."""
self.attacks.append(attack)

Expand Down
84 changes: 23 additions & 61 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,56 +1,42 @@
lint = ["ruff>=0.0.220"] # MIT License (MIT)

[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "leakpro"
version = "0.1"
version = "0.1.0"
description = "A package for privacy risk analysis"
authors = [{ name = "LeakPro team", email = "johan.ostman@ai.se" }]
authors = [{name = "LeakPro team", email = "johan.ostman@ai.se"}]
readme = "README.md"
license = {file="LICENSE"}
keywords = [
"Privacy",
"Risk analysis",
"Centralized learning",
"Federated learning",
"Synthethic data",
"Machine learning",
]
classifiers = [
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]

requires-python = '>=3.8,<3.13'

dependencies = [
"numpy",
"pandas",
"scipy",
"scikit-learn",
"matplotlib",
"seaborn",
"pillow",
"torch",
"torchvision",
"torchmetrics",
"dotmap",
"loguru",
"jinja2",
"tqdm",
"pyyaml",
"numba",
"pydantic",
"joblib",
"pyyaml",
"scikit-learn",
]

[project.optional-dependencies]
mia = ["torch", "torchvision"]
synthetic = ["numba", "pydantic"]
federated = ["torch", "torchvision", "torchmetrics",]
dev = [
"pytest",
"gdown",
"pytest-cov",
"pytest-mock",
"coverage-badge",
"ruff",
"torch", # from mia & federated
"torchvision", # from mia & federated
"torchmetrics", # from federated
"numba", # from synthetic
"pydantic", # from synthetic
]

[tool.setuptools]
Expand All @@ -64,33 +50,9 @@ include = ["leakpro*"]
line-length = 130
target-version = "py39"

lint.select = [
"ANN", # flake8-annotations
"ARG", # flake8-unused-arguments
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"C90", # mccabe
"D", # pydocstyle
"DTZ", # flake8-datetimez
"E", # pycodestyle
"ERA", # eradicate
"F", # Pyflakes
"I", # isort
"N", # pep8-naming
"PD", # pandas-vet
"PGH", # pygrep-hooks
"PLC", # Pylint
"PLE", # Pylint
"PLR", # Pylint
"PLW", # Pylint
"PT", # flake8-pytest-style
"Q", # flake8-quotes
"RET", # flake8-return
"S", # flake8-bandit
"SIM", # flake8-simplify
"T20", # flake8-print
"TID", # flake8-tidy-imports
"W", # pycodestyle
select = [
"ANN", "ARG", "B", "C4", "C90", "D", "DTZ", "E", "ERA", "F", "I", "N", "PD", "PGH", "PLC",
"PLE", "PLR", "PLW", "PT", "Q", "RET", "S", "SIM", "T20", "TID", "W",
]

exclude = [
Expand Down
20 changes: 0 additions & 20 deletions requirements.txt

This file was deleted.

0 comments on commit 960d69a

Please sign in to comment.