Skip to content

Commit

Permalink
implied tag dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed May 28, 2024
1 parent f75e0a8 commit ad29e98
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,46 @@ def __init__(
# caching
self.caching_mode = None # None, 'latents', 'text'

# implication dropout
import pandas as pd
tag_stats = pd.read_csv(
"/home/hope/src/sd/pikaft-e621-posts-downloader/rr-e621-tags-2024-05-05.csv",
keep_default_na=False
)
assert len(tag_stats[tag_stats.duplicated(subset="tag", keep="first")]) == 0
self.tagdepth = dict(zip(tag_stats["tag"], tag_stats["avg_depth"]))
self.tag_implications = {}
for tag, impls in zip(tag_stats["tag"], tag_stats["implications"]):
if not impls or pd.isna(impls):
continue
impls = [impl.strip() for impl in impls.split(",")]
self.tag_implications[tag] = frozenset(impls)

def tag_dropout(self, tags: str, sep: str):
if not tags:
return ""
tags = [tag.strip() for tag in tags.split(sep)]
# drop artist tags
tags = [tag for tag in tags if not tag.startswith("by ")]

# separate implied
impl_tags = set()
impl_queue = set(tags)
while impl_queue:
tag = impl_queue.pop()
implied = self.tag_implications.get(tag, [])
impl_queue.update(implied)
impl_tags.update(implied)
impl_tags = list(impl_tags)

# dropout
avg_depths = np.array([self.tagdepth.get(tag, 0) for tag in impl_tags])
dropout_rate = 1 - 1 / (1 + avg_depths)
random_numbers = np.random.rand(len(impl_tags))
to_remove = set(np.array(impl_tags)[random_numbers < dropout_rate].tolist())
tags_filtered = [tag for tag in tags if tag not in to_remove]
return ", ".join(tags_filtered)

def set_seed(self, seed):
self.seed = seed

Expand Down Expand Up @@ -731,6 +771,8 @@ def replace_wildcard(match):
# if caption is multiline, use the first line
caption = caption.split("\n")[0]

caption = self.tag_dropout(caption, subset.caption_separator)

if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
flex_tokens = []
Expand Down

0 comments on commit ad29e98

Please sign in to comment.