Skip to content
This repository has been archived by the owner on Jan 8, 2023. It is now read-only.

Commit

Permalink
Remove the keras requirement from the example
Browse files Browse the repository at this point in the history
  • Loading branch information
Bogdan Kulynych committed Jan 31, 2019
1 parent d7d3334 commit d389d30
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions examples/cifar10.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
"""
Example membership inference attack against a deep net classifier on the CIFAR10 dataset
"""

try:
import keras
except ImportError as e:
raise ImportError(
"You need to have keras installed for this example: pip install keras"
)


import numpy as np

from absl import app
from absl import flags

from keras import layers
from keras.datasets import cifar10
import tensorflow as tf
from tensorflow.keras import layers

from sklearn.model_selection import train_test_split

Expand All @@ -41,9 +32,9 @@

def get_data():
"""Prepare CIFAR10 data."""
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
y_train = y_train.astype("float32")
Expand All @@ -58,7 +49,7 @@ def target_model_fn():
The attack is white-box, hence the attacker is assumed to know this architecture too."""

model = keras.models.Sequential()
model = tf.keras.models.Sequential()

model.add(
layers.Conv2D(
Expand Down Expand Up @@ -95,7 +86,7 @@ def attack_model_fn():
Following the original paper, this attack model is specific to the class of the input.
AttachModelBundle creates multiple instances of this model for each class.
"""
model = keras.models.Sequential()
model = tf.keras.models.Sequential()

model.add(layers.Dense(128, activation="relu", input_shape=(NUM_CLASSES,)))

Expand Down

0 comments on commit d389d30

Please sign in to comment.