Skip to content

Commit

Permalink
Merge pull request #1165 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Nov 16, 2024
2 parents 8ff4e97 + 1e0b8d5 commit 04a5a74
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 46 deletions.
20 changes: 13 additions & 7 deletions documentation/LYCORIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

## Using LyCORIS

To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file:
To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file.

```bash
MODEL_TYPE=lora
# We use trainer_extra_args for now, as Lycoris support is so new.
TRAINER_EXTRA_ARGS+=" --lora_type=lycoris --lycoris_config=config/lycoris_config.json"
The following will go into your `config.json`:
```json
{
"model_type": "lora",
"lora_type": "lycoris",
"lycoris_config": "config/lycoris_config.json",
"validation_lycoris_strength": 1.0,
...the rest of your settings...
}
```


Expand Down Expand Up @@ -48,7 +53,7 @@ Optional fields:
- any keyword arguments specific to the selected algorithm, at the end.

Mandatory fields:
- multiplier
- multiplier, which should be set to 1.0 only unless you know what to expect
- linear_dim
- linear_alpha

Expand Down Expand Up @@ -81,7 +86,8 @@ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer")

lycoris_safetensors_path = 'pytorch_lora_weights.safetensors'
wrapper, _ = create_lycoris_from_weights(1.0, lycoris_safetensors_path, transformer)
lycoris_strength = 1.0
wrapper, _ = create_lycoris_from_weights(lycoris_strength, lycoris_safetensors_path, transformer)
wrapper.merge_to() # using apply_to() will be slower.

transformer.to(device, dtype=dtype)
Expand Down
9 changes: 9 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,15 @@ def get_argument_parser():
" and submit debug.log to a new Github issue report."
),
)
parser.add_argument(
"--validation_lycoris_strength",
type=float,
default=1.0,
help=(
"When inferencing for validations, the Lycoris model will by default be run at its training strength, 1.0."
" However, this value can be increased to a value of around 1.3 or 1.5 to get a stronger effect from the model."
),
)
parser.add_argument(
"--validation_torch_compile",
action="store_true",
Expand Down
7 changes: 7 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,13 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
raise ValueError(
f"VAE image embed cache directory {backend.get('cache_dir_vae')} is the same as the text embed cache directory. This is not allowed, the trainer will get confused."
)

if backend["type"] == "local" and (
vae_cache_dir is None or vae_cache_dir == ""
):
raise ValueError(
f"VAE image embed cache directory {backend.get('cache_dir_vae')} is not set. This is required for the VAE image embed cache."
)
init_backend["vaecache"] = VAECache(
id=init_backend["id"],
vae=StateTracker.get_vae(),
Expand Down
103 changes: 77 additions & 26 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch
from typing import Any
from regex import regex
import fcntl
import tempfile
import shutil

logger = logging.getLogger("LocalDataBackend")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
Expand All @@ -21,29 +24,50 @@ def __init__(self, accelerator, id: str, compress_cache: bool = False):

def read(self, filepath, as_byteIO: bool = False):
"""Read and return the content of the file."""
# Openfilepath as BytesIO:
with open(filepath, "rb") as file:
data = file.read()
if not as_byteIO:
return data
return BytesIO(data)
# Acquire a shared lock
fcntl.flock(file, fcntl.LOCK_SH)
try:
data = file.read()
if not as_byteIO:
return data
return BytesIO(data)
finally:
# Release the lock
fcntl.flock(file, fcntl.LOCK_UN)

def write(self, filepath: str, data: Any) -> None:
"""Write the provided data to the specified filepath."""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "wb") as file:
# Check if data is a Tensor, and if so, save it appropriately
if isinstance(data, torch.Tensor):
# logger.debug(f"Writing a torch file to disk.")
return self.torch_save(data, file)
elif isinstance(data, str):
# logger.debug(f"Writing a string to disk as {filepath}: {data}")
data = data.encode("utf-8")
else:
logger.debug(
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
file.write(data)
temp_dir = os.path.dirname(filepath)
temp_file_path = os.path.join(temp_dir, f".{os.path.basename(filepath)}.tmp")

# Open the temporary file for writing
with open(temp_file_path, "wb") as temp_file:
# Acquire an exclusive lock on the temporary file
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
# Write data to the temporary file
if isinstance(data, torch.Tensor):
# Use the torch_save method, passing the temp file
self.torch_save(data, temp_file)
return # torch_save handles closing the file
elif isinstance(data, str):
data = data.encode("utf-8")
else:
logger.debug(
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
temp_file.write(data)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)

# Atomically replace the target file with the temporary file
os.rename(temp_file_path, filepath)


def delete(self, filepath):
"""Delete the specified file."""
Expand Down Expand Up @@ -212,16 +236,43 @@ def torch_save(self, data, original_location):
Save a torch tensor to a file.
"""
if isinstance(original_location, str):
location = self.open_file(original_location, "wb")
else:
location = original_location
filepath = original_location
os.makedirs(os.path.dirname(filepath), exist_ok=True)
temp_dir = os.path.dirname(filepath)
temp_file_path = os.path.join(temp_dir, f".{os.path.basename(filepath)}.tmp")

if self.compress_cache:
compressed_data = self._compress_torch(data)
location.write(compressed_data)
with open(temp_file_path, "wb") as temp_file:
# Acquire an exclusive lock on the temporary file
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
if self.compress_cache:
compressed_data = self._compress_torch(data)
temp_file.write(compressed_data)
else:
torch.save(data, temp_file)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)
# Atomically replace the target file with the temporary file
os.rename(temp_file_path, filepath)
else:
torch.save(data, location)
location.close()
# Handle the case where original_location is a file object
temp_file = original_location
# Acquire an exclusive lock on the file object
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
if self.compress_cache:
compressed_data = self._compress_torch(data)
temp_file.write(compressed_data)
else:
torch.save(data, temp_file)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)

def write_batch(self, filepaths: list, data_list: list) -> None:
"""Write a batch of data to the specified filepaths."""
Expand Down
61 changes: 52 additions & 9 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def code_example(args, repo_id: str = None):
image = pipeline(
prompt=prompt,{_negative_prompt(args, in_call=True) if args.model_family.lower() != 'flux' else ''}
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
generator=torch.Generator(device={_torch_device()}).manual_seed({args.validation_seed or args.seed or 42}),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}{_skip_layers(args)}
).images[0]
Expand Down Expand Up @@ -293,13 +293,15 @@ def lora_info(args):
lycoris_config = json.load(file)
except:
lycoris_config = {"error": "could not locate or load LyCORIS config."}
return f"""- LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""
return f"""### LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""


def model_card_note(args):
"""Return a string with the model card note."""
note_contents = args.model_card_note if args.model_card_note else ""
return f"\n{note_contents}\n"
if note_contents is None or note_contents == "":
return ""
return f"\n**Note:** {note_contents}\n"


def flux_schedule_info(args):
Expand All @@ -312,6 +314,7 @@ def flux_schedule_info(args):
output_args.append("flux_schedule_auto_shift")
if args.flux_schedule_shift is not None:
output_args.append(f"shift={args.flux_schedule_shift}")
output_args.append(f"flux_guidance_mode={args.flux_guidance_mode}")
if args.flux_guidance_value:
output_args.append(f"flux_guidance_value={args.flux_guidance_value}")
if args.flux_guidance_min:
Expand All @@ -324,6 +327,9 @@ def flux_schedule_info(args):
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_attention_masked_training:
output_args.append("flux_attention_masked_training")
if args.t5_padding != "unmodified":
output_args.append(f"t5_padding={args.t5_padding}")
output_args.append(f"flow_matching_loss={args.flow_matching_loss}")
if (
args.model_type == "lora"
and args.lora_type == "standard"
Expand Down Expand Up @@ -363,11 +369,45 @@ def sd3_schedule_info(args):
return output_str


def ddpm_schedule_info(args):
"""Information about DDPM schedules, eg. rescaled betas or offset noise"""
output_args = []
if args.snr_gamma:
output_args.append(f"snr_gamma={args.snr_gamma}")
if args.use_soft_min_snr:
output_args.append(f"use_soft_min_snr")
if args.soft_min_snr_sigma_data:
output_args.append(
f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}"
)
if args.rescale_betas_zero_snr:
output_args.append(f"rescale_betas_zero_snr")
if args.offset_noise:
output_args.append(f"offset_noise")
output_args.append(f"noise_offset={args.noise_offset}")
output_args.append(f"noise_offset_probability={args.noise_offset_probability}")
output_args.append(
f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}"
)
output_args.append(
f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}"
)
output_str = (
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str


def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
if args.model_family == "sd3":
return sd3_schedule_info(args)
else:
return ddpm_schedule_info(args)


def save_model_card(
Expand Down Expand Up @@ -461,18 +501,19 @@ def save_model_card(
{'This is a **diffusion** model trained using DDPM objective instead of Flow matching. **Be sure to set the appropriate scheduler configuration.**' if args.model_family == "sd3" and args.flow_matching_loss == "diffusion" else ''}
{'The main validation prompt used during training was:' if prompt else 'Validation used ground-truth images as an input for partial denoising (img2img).' if args.validation_using_datasets else 'No validation prompt was used during training.'}
{model_card_note(args)}
{'```' if prompt else ''}
{prompt}
{'```' if prompt else ''}
{model_card_note(args)}
## Validation settings
- CFG: `{StateTracker.get_args().validation_guidance}`
- CFG Rescale: `{StateTracker.get_args().validation_guidance_rescale}`
- Steps: `{StateTracker.get_args().validation_num_inference_steps}`
- Sampler: `{StateTracker.get_args().validation_noise_scheduler}`
- Sampler: `{'FlowMatchEulerDiscreteScheduler' if args.model_family in ['sd3', 'flux'] else StateTracker.get_args().validation_noise_scheduler}`
- Seed: `{StateTracker.get_args().validation_seed}`
- Resolution{'s' if ',' in StateTracker.get_args().validation_resolution else ''}: `{StateTracker.get_args().validation_resolution}`
{f"- Skip-layer guidance: {_skip_layers(args)}" if args.model_family in ['sd3', 'flux'] else ''}
Note: The validation settings are not necessarily the same as the [training settings](#training-settings).
Expand All @@ -489,17 +530,19 @@ def save_model_card(
- Training epochs: {StateTracker.get_epoch() - 1}
- Training steps: {StateTracker.get_global_step()}
- Learning rate: {StateTracker.get_args().learning_rate}
- Learning rate schedule: {StateTracker.get_args().lr_scheduler}
- Warmup steps: {StateTracker.get_args().lr_warmup_steps}
- Max grad norm: {StateTracker.get_args().max_grad_norm}
- Effective batch size: {StateTracker.get_args().train_batch_size * StateTracker.get_args().gradient_accumulation_steps * StateTracker.get_accelerator().num_processes}
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Gradient checkpointing: {StateTracker.get_args().gradient_checkpointing}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'}
- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}%
{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''}
{lora_info(args=StateTracker.get_args())}
## Datasets
Expand Down
2 changes: 1 addition & 1 deletion helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _save_full_model(self, models, weights, output_dir):
shutil.copy2(s, d)

# Remove the temporary directory
shutil.rmtree(temporary_dir)
shutil.rmtree(temporary_dir, ignore_errors=True)

def save_model_hook(self, models, weights, output_dir):
# Write "training_state.json" to the output directory containing the training state
Expand Down
2 changes: 1 addition & 1 deletion helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,7 +2717,7 @@ def train(self):
self.config.output_dir, removing_checkpoint
)
try:
shutil.rmtree(removing_checkpoint)
shutil.rmtree(removing_checkpoint, ignore_errors=True)
except Exception as e:
logger.error(
f"Failed to remove directory: {removing_checkpoint}"
Expand Down
6 changes: 6 additions & 0 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,10 @@ def setup_scheduler(self):
return scheduler

def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
if hasattr(self.accelerator, "_lycoris_wrapped_network"):
self.accelerator._lycoris_wrapped_network.set_multiplier(float(getattr(
self.args, "validation_lycoris_strength", 1.0
)))
if validation_type == "intermediary" and self.args.use_ema:
if enable_ema_model:
if self.unet is not None:
Expand Down Expand Up @@ -1120,6 +1124,8 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):

def clean_pipeline(self):
"""Remove the pipeline."""
if hasattr(self.accelerator, "_lycoris_wrapped_network"):
self.accelerator._lycoris_wrapped_network.set_multiplier(1.0)
if self.pipeline is not None:
del self.pipeline
self.pipeline = None
Expand Down
Loading

0 comments on commit 04a5a74

Please sign in to comment.