Skip to content

Commit

Permalink
adding upstream multi token textual inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
brandontrabucco committed Aug 8, 2023
1 parent 4ed925b commit 1214d27
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 10 deletions.
160 changes: 160 additions & 0 deletions semantic_aug/augmentations/textual_inversion_upstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from semantic_aug.generative_augmentation import GenerativeAugmentation
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer
)
from diffusers.utils import logging
from PIL import Image, ImageOps

from typing import Any, Tuple, Callable
from torch import autocast
from scipy.ndimage import maximum_filter

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from glob import glob

ERROR_MESSAGE = "Tokenizer already contains the token {token}. \
Please pass a different `token` that is not already in the tokenizer."

def format_name(name, num_tokens: int = 1):

special_token = f"<{name.replace(' ', '_')}>"

return " ".join([
special_token
if token_idx == 0 else
f"{special_token}_{token_idx}"
for token_idx in range(num_tokens)
])

class TextualInversion(GenerativeAugmentation):

pipe = None # global sharing is a hack to avoid OOM

def __init__(self, embed_path: str,
model_path: str = "CompVis/stable-diffusion-v1-4",
prompt: str = "a photo of a {name}",
format_name: Callable = format_name,
strength: float = 0.5,
guidance_scale: float = 7.5,
mask: bool = False,
inverted: bool = False,
mask_grow_radius: int = 16,
erasure_ckpt_path: str = None,
disable_safety_checker: bool = True,
tokens_per_class: int = 1,
**kwargs):

super(TextualInversion, self).__init__()

if TextualInversion.pipe is None:

PipelineClass = (StableDiffusionInpaintPipeline
if mask else
StableDiffusionImg2ImgPipeline)

TextualInversion.pipe = PipelineClass.from_pretrained(
model_path, use_auth_token=True,
revision="fp16",
torch_dtype=torch.float16
).to('cuda')

logging.disable_progress_bar()
self.pipe.set_progress_bar_config(disable=True)

if disable_safety_checker:
self.pipe.safety_checker = None

embeds_list = glob(embed_path + '/**/learned_embeds.bin')

for e in embeds_list:
self.pipe.load_textual_inversion(e)

self.prompt = prompt
self.strength = strength
self.guidance_scale = guidance_scale
self.format_name = format_name
self.tokens_per_class = tokens_per_class

self.mask = mask
self.inverted = inverted
self.mask_grow_radius = mask_grow_radius

self.erasure_ckpt_path = erasure_ckpt_path
self.erasure_word_name = None

def forward(self, image: Image.Image, label: int,
metadata: dict) -> Tuple[Image.Image, int]:

canvas = image.resize((512, 512), Image.BILINEAR)
name = self.format_name(
metadata.get("name", ""),
num_tokens=self.tokens_per_class)
prompt = self.prompt.format(name=name)

if self.mask: assert "mask" in metadata, \
"mask=True but no mask present in metadata"

word_name = metadata.get("name", "").replace(" ", "")

if self.erasure_ckpt_path is not None and (
self.erasure_word_name is None
or self.erasure_word_name != word_name):

self.erasure_word_name = word_name
ckpt_name = "method_full-sg_3-ng_1-iter_1000-lr_1e-05"

ckpt_path = os.path.join(
self.erasure_ckpt_path,
f"compvis-word_{word_name}-{ckpt_name}",
f"diffusers-word_{word_name}-{ckpt_name}.pt")

self.pipe.unet.load_state_dict(torch.load(
ckpt_path, map_location='cuda'))

kwargs = dict(
image=canvas,
prompt=[prompt],
strength=self.strength,
guidance_scale=self.guidance_scale
)

if self.mask: # use focal object mask

mask_image = Image.fromarray((
np.where(metadata["mask"], 255, 0)
).astype(np.uint8)).resize((512, 512), Image.NEAREST)

mask_image = Image.fromarray(
maximum_filter(np.array(mask_image),
size=self.mask_grow_radius))

if self.inverted:

mask_image = ImageOps.invert(
mask_image.convert('L')).convert('1')

kwargs["mask_image"] = mask_image

has_nsfw_concept = True
while has_nsfw_concept:
with autocast("cuda"):
outputs = self.pipe(**kwargs)

has_nsfw_concept = (
self.pipe.safety_checker is not None
and outputs.nsfw_content_detected[0]
)

canvas = outputs.images[0].resize(
image.size, Image.BILINEAR)

return canvas, label
30 changes: 23 additions & 7 deletions semantic_aug/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@

class ImageNetDataset(FewShotDataset):

class_names = ['steel arch bridge', 'ram', 'great white shark', 'sombrero',
'hamster', 'racket', 'chain mail', 'ski mask', 'potpie', 'cocktail shaker',
'Indian cobra', 'green snake', 'orange', 'Great Pyrenees', 'minibus', 'wall clock',
"yellow lady's slipper", 'vacuum', 'guillotine', 'redshank', 'pajama',
'tile roof', 'hen of the woods', 'oboe', 'overskirt', 'slug', 'running shoe',
'harp', 'strawberry', 'sturgeon', 'leatherback turtle', 'malamute', 'ladybug',
'mink', 'bulletproof vest', 'walking stick', 'can opener', 'pelican',
'projectile', 'gorilla', 'green mamba', 'drilling platform',
'black and gold garden spider', 'suit', 'volcano', 'hoopskirt',
'meat loaf', 'scuba diver', 'armadillo', 'crane', 'throne', 'barrel',
'golfcart', 'Border collie', 'fire engine', 'Indian elephant',
"carpenter's kit", 'black-and-tan coonhound', 'ballplayer', 'earthstar',
'Italian greyhound', 'confectionery', 'warthog', 'dishwasher', 'American egret',
'bald eagle', 'beagle', 'pinwheel', 'wombat', 'disk brake', 'pole', 'sandbar', 'drake',
'cheeseburger', 'sea anemone', 'computer keyboard', 'suspension bridge', 'ibex',
'toilet seat', 'vulture', 'coffee mug', 'Bouvier des Flandres',
'honeycomb', 'African chameleon', 'barn spider', 'ladle', 'Airedale',
'maze', 'scoreboard', 'fly', 'Bedlington terrier',
'yawl', 'revolver', 'racer', 'croquet ball', 'obelisk', 'mosque',
'dowitcher', 'shovel', 'sleeping bag']

num_classes: int = len(class_names)

def __init__(self, *args, split: str = "train", seed: int = 0,
train_image_dir: str = TRAIN_IMAGE_DIR,
val_image_dir: str = VAL_IMAGE_DIR,
Expand All @@ -35,7 +58,6 @@ def __init__(self, *args, split: str = "train", seed: int = 0,
examples_per_class: int = None,
generative_aug: GenerativeAugmentation = None,
synthetic_probability: float = 0.5,
max_classes: int = 100,
use_randaugment: bool = False,
image_size: Tuple[int] = (256, 256), **kwargs):

Expand All @@ -50,15 +72,13 @@ def __init__(self, *args, split: str = "train", seed: int = 0,
with open(label_synset, "r") as f:
label_synset_lines = f.readlines()

self.class_names = []
self.dir_to_class_names = dict()

for synset in label_synset_lines:

dir_name, synset = synset.split(" ", maxsplit=1)
class_name = synset.split(",")[0].strip()

self.class_names.append(class_name)
self.dir_to_class_names[dir_name] = class_name

class_to_images = defaultdict(list)
Expand All @@ -75,10 +95,6 @@ def __init__(self, *args, split: str = "train", seed: int = 0,
os.path.join(image_dir, path + ".JPEG"))

rng = np.random.default_rng(seed)
self.class_names = [self.class_names[i] for i in (
rng.permutation(len(self.class_names))[:max_classes])]

self.num_classes = len(self.class_names)
class_to_ids = {key: rng.permutation(
len(class_to_images[key])) for key in self.class_names}

Expand Down
15 changes: 12 additions & 3 deletions train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from semantic_aug.augmentations.compose import ComposeSequential
from semantic_aug.augmentations.real_guidance import RealGuidance
from semantic_aug.augmentations.textual_inversion import TextualInversion
from semantic_aug.augmentations.textual_inversion_upstream \
import TextualInversion as MultiTokenTextualInversion
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from transformers import AutoImageProcessor, DeiTModel
Expand Down Expand Up @@ -57,7 +59,8 @@

AUGMENTATIONS = {
"real-guidance": RealGuidance,
"textual-inversion": TextualInversion
"textual-inversion": TextualInversion,
"multi-token-inversion": MultiTokenTextualInversion
}


Expand All @@ -80,6 +83,7 @@ def run_experiment(examples_per_class: int = 0,
embed_path: str = DEFAULT_EMBED_PATH,
model_path: str = DEFAULT_MODEL_PATH,
prompt: str = DEFAULT_PROMPT,
tokens_per_class: int = 4,
use_randaugment: bool = False,
use_cutmix: bool = False,
erasure_ckpt_path: str = None,
Expand All @@ -102,7 +106,8 @@ def run_experiment(examples_per_class: int = 0,
guidance_scale=guidance_scale,
mask=mask,
inverted=inverted,
erasure_ckpt_path=erasure_ckpt_path
erasure_ckpt_path=erasure_ckpt_path,
tokens_per_class=tokens_per_class
)

for (aug, guidance_scale,
Expand Down Expand Up @@ -392,7 +397,8 @@ def forward(self, image):
choices=["spurge", "imagenet", "coco", "pascal", "flowers", "caltech"])

parser.add_argument("--aug", nargs="+", type=str, default=None,
choices=["real-guidance", "textual-inversion"])
choices=["real-guidance", "textual-inversion",
"multi-token-inversion"])

parser.add_argument("--strength", nargs="+", type=float, default=None)
parser.add_argument("--guidance-scale", nargs="+", type=float, default=None)
Expand All @@ -409,6 +415,8 @@ def forward(self, image):

parser.add_argument("--use-randaugment", action="store_true")
parser.add_argument("--use-cutmix", action="store_true")

parser.add_argument("--tokens-per-class", type=int, default=4)

args = parser.parse_args()

Expand Down Expand Up @@ -443,6 +451,7 @@ def forward(self, image):
synthetic_probability=args.synthetic_probability,
num_synthetic=args.num_synthetic,
prompt=args.prompt,
tokens_per_class=args.tokens_per_class,
aug=args.aug,
strength=args.strength,
guidance_scale=args.guidance_scale,
Expand Down

0 comments on commit 1214d27

Please sign in to comment.