Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #1039

Merged
merged 23 commits into from
Oct 9, 2024
Merged

merge #1039

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ python3.11 -m venv .venv
source .venv/bin/activate

pip install -U poetry pip

# Necessary on some systems to prevent it from deciding it knows better than us.
poetry config virtualenvs.create false
```

> ℹ️ You can use your own custom venv path by setting `export VENV_PATH=/path/to/.venv` in your `config/config.env` file.
Expand All @@ -36,6 +39,23 @@ poetry install
poetry install -C install/rocm
```

#### NVIDIA Hopper / Blackwell follow-up steps

Optionally, Hopper (or newer) equipment can make use of FlashAttention3 for improved inference and training performance when making use of `torch.compile`

You'll need to run the following sequence of commands from your SimpleTuner directory, with your venv active:

```bash
git clone https://github.com/Dao-AILab/flash-attention
pushd flash-attention
pushd hopper
python setup.py install
popd
popd
```

> ⚠️ Managing the flash_attn build is poorly-supported in SimpleTuner, currently. This can break on updates, requiring you to re-run this build procedure manually from time-to-time.

#### AMD ROCm follow-up steps

The following must be executed for an AMD MI300X to be useable:
Expand Down
8 changes: 7 additions & 1 deletion OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ A newer library from Pytorch, AO allows us to replace the linears and 2D convolu
- at the time of writing, runs slightly slower (11s/iter) than Quanto does (9s/iter) on Apple MPS
- When not using `torch.compile`, same speed and memory use as `int8-quanto` on CUDA devices, unknown speed profile on ROCm
- When using `torch.compile`, slower than `int8-quanto`
- `fp8-torchao` is not enabled due to bugs in the implementation.
- `fp8-torchao` is only available for Hopper (H100, H200) or newer (Blackwell B200) accelerators

##### Optimisers

TorchAO includes generally-available 4bit and 8bit optimisers: `ao-adamw8bit`, `ao-adamw4bit`

It also provides two optimisers that are directed toward Hopper (H100 or better) users: `ao-adamfp8`, and `ao-adamwfp8`

#### Torch Dynamo

Expand Down
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def configure_env():
if quantization_type:
print(f"Invalid quantization type: {quantization_type}")
quantization_type = prompt_user(
f"Choose quantization type. int4 may only work on A100, H100, or Apple systems. (Options: {'/'.join(quantised_precision_levels)})",
f"Choose quantization type. (Options: {'/'.join(quantised_precision_levels)})",
"int8-quanto",
)
env_contents["--base_model_precision"] = quantization_type
Expand Down
5 changes: 4 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ python3.11 -m venv .venv
source .venv/bin/activate

pip install -U poetry pip

# Necessary on some systems to prevent it from deciding it knows better than us.
poetry config virtualenvs.create false
```

**Note:** We're currently installing the `release` branch here; the `main` branch may contain experimental features that might have better results or lower memory use.
Expand Down Expand Up @@ -438,7 +441,7 @@ We can partially reintroduce distillation to a de-distilled model by continuing
- **int8** has hardware acceleration and `torch.compile()` support on newer NVIDIA hardware (3090 or better)
- **nf4-bnb** brings VRAM requirements down to 9GB, fitting on a 10G card (with bfloat16 support)
- When loading the LoRA in ComfyUI later, you **must** use the same base model precision as you trained your LoRA on.
- **int4** is weird and really only works on A100 and H100 cards due to a reliance on custom bf16 kernels
- **int4** is relies on custom bf16 kernels, and will not work if your card does not support bfloat16

### Crashing
- If you get SIGKILL after the text encoders are unloaded, this means you do not have enough system memory to quantise Flux.
Expand Down
3 changes: 3 additions & 0 deletions documentation/quickstart/KOLORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ python -m venv .venv
source .venv/bin/activate

pip install -U poetry pip

# Necessary on some systems to prevent it from deciding it knows better than us.
poetry config virtualenvs.create false
```

Depending on your system, you will run one of 3 commands:
Expand Down
3 changes: 3 additions & 0 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ python -m venv .venv
source .venv/bin/activate

pip install -U poetry pip

# Necessary on some systems to prevent it from deciding it knows better than us.
poetry config virtualenvs.create false
```

Depending on your system, you will run one of 3 commands:
Expand Down
3 changes: 3 additions & 0 deletions documentation/quickstart/SIGMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ python -m venv .venv
source .venv/bin/activate

pip install -U poetry pip

# Necessary on some systems to prevent it from deciding it knows better than us.
poetry config virtualenvs.create false
```

Depending on your system, you will run one of 3 commands:
Expand Down
8 changes: 4 additions & 4 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,6 @@ def generate_vae_cache_filename(self, filepath: str) -> tuple:
def _image_filename_from_vaecache_filename(self, filepath: str) -> tuple[str, str]:
test_filepath, _ = self.generate_vae_cache_filename(filepath)
result = self.vae_path_to_image_path.get(test_filepath, None)
if result is None:
raise ValueError(
f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). This occurs when you toggle the value for hashed_filenames without clearing your VAE cache. If it still occurs after clearing the cache, please open an issue: https://github.com/bghira/simpletuner/issues"
)

return result

Expand Down Expand Up @@ -369,6 +365,8 @@ def discover_unprocessed_files(self, directory: str = None):
for cache_file in existing_cache_files:
try:
n = self._image_filename_from_vaecache_filename(cache_file)
if n is None:
continue
already_cached_images.append(n)
except Exception as e:
logger.error(
Expand Down Expand Up @@ -955,6 +953,8 @@ def process_buckets(self):
test_filepath = self._image_filename_from_vaecache_filename(
filepath
)
if test_filepath is None:
continue
if test_filepath not in self.local_unprocessed_files:
statistics["not_local"] += 1
continue
Expand Down
12 changes: 5 additions & 7 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,13 @@ def get_argument_parser():
"--flux_beta_schedule_alpha",
type=float,
default=2.0,
help=(
"The alpha value of the flux beta schedule. Default is 2.0"
),
help=("The alpha value of the flux beta schedule. Default is 2.0"),
)
parser.add_argument(
"--flux_beta_schedule_beta",
type=float,
default=2.0,
help=(
"The beta value of the flux beta schedule. Default is 2.0"
),
help=("The beta value of the flux beta schedule. Default is 2.0"),
)
parser.add_argument(
"--flux_schedule_shift",
Expand Down Expand Up @@ -2246,7 +2242,9 @@ def parse_cmdline_args(input_args=None):
f"{'PixArt Sigma' if args.model_family == 'pixart_sigma' else 'Stable Diffusion 3'} requires --max_grad_norm=0.01 to prevent model collapse. Overriding value. Set this value manually to disable this warning."
)
args.max_grad_norm = 0.01

if args.gradient_checkpointing:
# enable torch compile w/ activation checkpointing :[ slows us down.
torch._dynamo.config.optimize_ddp = False
if args.gradient_accumulation_steps > 1:
if args.gradient_precision == "unmodified" or args.gradient_precision is None:
warning_log(
Expand Down
87 changes: 59 additions & 28 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
import fnmatch
import logging
import torch
from torch import Tensor
import concurrent.futures
from botocore.config import Config
Expand Down Expand Up @@ -289,46 +290,73 @@ def create_directory(self, directory_path):
# Since S3 doesn't have a traditional directory structure, this is just a pass-through
pass

def torch_load(self, s3_key):
import torch
from io import BytesIO
def _detect_file_format(self, fileobj):
fileobj.seek(0)
magic_number = fileobj.read(4)
fileobj.seek(0)
logger.debug(f"Magic number: {magic_number}")
if magic_number[:2] == b"\x80\x04" or b"PK" in magic_number:
# This is likely a torch-saved object (Pickle protocol 4)
# Need to check whether it's the incorrectly saved compressed data
try:
obj = torch.load(fileobj, map_location="cpu")
if isinstance(obj, bytes):
# If obj is bytes, it means compressed data was saved incorrectly
return "incorrect"
else:
return "correct_uncompressed"
except Exception as e:
# If torch.load fails, it's possibly compressed correctly
return "correct_compressed"
elif magic_number[:2] == b"\x1f\x8b":
# GZIP magic number, compressed data saved correctly
return "correct_compressed"
else:
# Unrecognized format
return "unknown"

# Retry the torch load within the retry limit
def torch_load(self, s3_key):
for i in range(self.read_retry_limit):
try:
stored_tensor = BytesIO(self.read(s3_key))
if self.compress_cache:
try:
stored_tensor = self._decompress_torch(stored_tensor)
except Exception as e:
pass
# Read data from S3
data = self.read(s3_key)
stored_data = BytesIO(data)
stored_data.seek(0)

# Determine if the file was saved incorrectly
file_format = self._detect_file_format(stored_data)
logger.debug(f"File format: {file_format}")
if file_format == "incorrect":
# Load the compressed bytes object serialized by torch.save
stored_data.seek(0)
compressed_data = BytesIO(
torch.load(stored_data, map_location="cpu")
)
# Decompress the data
stored_tensor = self._decompress_torch(compressed_data)
elif file_format == "correct_compressed":
# Data is compressed but saved correctly
decompressed_data = self._decompress_torch(data)
else:
# Data is uncompressed and saved correctly
stored_tensor = stored_data

if hasattr(stored_tensor, "seek"):
stored_tensor.seek(0)

obj = torch.load(stored_tensor, map_location="cpu")
# logger.debug(f"torch.load found: {obj}")
if type(obj) is tuple:

if isinstance(obj, tuple):
obj = tuple(o.to(torch.float32) for o in obj)
elif type(obj) is Tensor:
elif isinstance(obj, torch.Tensor):
obj = obj.to(torch.float32)

return obj
except Exception as e:
if not self.exists(s3_key):
logger.debug(f"File {s3_key} does not exist in S3 bucket.")
raise FileNotFoundError(f"{s3_key} not found.")
logger.error(f"Error loading torch file (path: {s3_key}): {e}")
if str(e) == "Ran out of input":
logger.error(f"File {s3_key} is empty. Deleting it from S3.")
self.delete(s3_key)
raise FileNotFoundError(f"{s3_key} not found.")
logging.error(f"Failed to load tensor from {s3_key}: {e}")
if i == self.read_retry_limit - 1:
# We have reached our maximum retry count.
raise e
raise
else:
# Sleep for a bit before retrying.
time.sleep(self.read_retry_interval)
logging.info(f"Retrying... ({i+1}/{self.read_retry_limit})")

def torch_save(self, data, s3_key):
import torch
Expand All @@ -339,8 +367,11 @@ def torch_save(self, data, s3_key):
try:
buffer = BytesIO()
if self.compress_cache:
data = self._compress_torch(data)
torch.save(data, buffer)
compressed_data = self._compress_torch(data)
buffer.write(compressed_data)
else:
torch.save(data, buffer)
buffer.seek(0) # Reset buffer position to the beginning
logger.debug(f"Writing torch file: {s3_key}")
result = self.write(s3_key, buffer.getvalue())
logger.debug(f"Write completed: {s3_key}")
Expand Down
24 changes: 17 additions & 7 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,22 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
if (
output["config"]["crop_aspect"] == "random"
or output["config"]["crop_aspect"] == "closest"
) and "crop_aspect_buckets" not in backend:
raise ValueError(
f"(id={backend['id']}) crop_aspect_buckets must be provided when crop_aspect is set to 'random'."
" This should be a list of float values or a list of dictionaries following the format: {'aspect_bucket': float, 'weight': float}."
" The weight represents how likely this bucket is to be chosen, and all weights should add up to 1.0 collectively."
)
):
if "crop_aspect_buckets" not in backend or not isinstance(
backend["crop_aspect_buckets"], list
):
raise ValueError(
f"(id={backend['id']}) crop_aspect_buckets must be provided when crop_aspect is set to 'random'."
" This should be a list of float values or a list of dictionaries following the format: {'aspect_bucket': float, 'weight': float}."
" The weight represents how likely this bucket is to be chosen, and all weights should add up to 1.0 collectively."
)
for bucket in backend.get("crop_aspect_buckets"):
if type(bucket) not in [float, int, dict]:
raise ValueError(
f"(id={backend['id']}) crop_aspect_buckets must be a list of float values or a list of dictionaries following the format: {'aspect_bucket': float, 'weight': float}."
" The weight represents how likely this bucket is to be chosen, and all weights should add up to 1.0 collectively."
)

output["config"]["crop_aspect_buckets"] = backend.get("crop_aspect_buckets")
else:
output["config"]["crop_aspect"] = "square"
Expand Down Expand Up @@ -335,7 +345,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
f"Data backend config file {args.data_backend_config} not found."
)
info_log(f"Loading data backend config from {args.data_backend_config}")
with open(args.data_backend_config, "r") as f:
with open(args.data_backend_config, "r", encoding="utf-8") as f:
data_backend_config = json.load(f)
if len(data_backend_config) == 0:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _trim_aspect_bucket_list(self):
# If any of the aspect buckets will result in that, we'll ignore it.
if type(bucket) is dict:
aspect = bucket["aspect_ratio"]
elif type(bucket) is float:
elif type(bucket) is float or type(bucket) is int:
aspect = bucket
else:
raise ValueError(
Expand Down
Loading
Loading