-
Notifications
You must be signed in to change notification settings - Fork 17
/
detect_fingerprints.py
133 lines (104 loc) · 3.79 KB
/
detect_fingerprints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import glob
import PIL
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, help="Directory with images.")
parser.add_argument(
"--output_dir", type=str, help="Path to save watermarked images to."
)
parser.add_argument(
"--image_resolution",
type=int,
required=True,
help="Height and width of square images.",
)
parser.add_argument(
"--decoder_path",
type=str,
required=True,
help="Path to trained StegaStamp decoder.",
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
parser.add_argument("--cuda", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
from time import time
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder
from torchvision import transforms
if args.cuda != -1:
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
class CustomImageFolder(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.filenames = glob.glob(os.path.join(data_dir, "*.png"))
self.filenames.extend(glob.glob(os.path.join(data_dir, "*.jpeg")))
self.filenames.extend(glob.glob(os.path.join(data_dir, "*.jpg")))
self.filenames = sorted(self.filenames)
self.transform = transform
def __getitem__(self, idx):
filename = self.filenames[idx]
image = PIL.Image.open(filename)
if self.transform:
image = self.transform(image)
return image, 0
def __len__(self):
return len(self.filenames)
def load_decoder():
global RevealNet
global FINGERPRINT_SIZE
from models import StegaStampDecoder
state_dict = torch.load(args.decoder_path)
FINGERPRINT_SIZE = state_dict["dense.2.weight"].shape[0]
RevealNet = StegaStampDecoder(args.image_resolution, 3, FINGERPRINT_SIZE)
kwargs = {"map_location": "cpu"} if args.cuda == -1 else {}
RevealNet.load_state_dict(torch.load(args.decoder_path, **kwargs))
RevealNet = RevealNet.to(device)
def load_data():
global dataset, dataloader
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
s = time()
print(f"Loading image folder {args.data_dir} ...")
dataset = CustomImageFolder(args.data_dir, transform=transform)
print(f"Finished. Loading took {time() - s:.2f}s")
def extract_fingerprints():
all_fingerprinted_images = []
all_fingerprints = []
BATCH_SIZE = args.batch_size
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
for images, _ in tqdm(dataloader):
images = images.to(device)
fingerprints = RevealNet(images)
fingerprints = (fingerprints > 0).long()
all_fingerprinted_images.append(images.detach().cpu())
all_fingerprints.append(fingerprints.detach().cpu())
dirname = args.output_dir
if not os.path.exists(dirname):
os.makedirs(dirname)
all_fingerprints = torch.cat(all_fingerprints, dim=0).cpu()
f = open(os.path.join(args.output_dir, "detected_fingerprints.txt"), "w")
for idx in range(len(all_fingerprints)):
fingerprint = all_fingerprints[idx]
fingerprint_str = "".join(map(str, fingerprint.cpu().long().numpy().tolist()))
_, filename = os.path.split(dataset.filenames[idx])
filename = filename.split('.')[0] + ".png"
f.write(f"{filename} {fingerprint_str}\n")
f.close()
if __name__ == "__main__":
load_decoder()
load_data()
extract_fingerprints()