Skip to content

Commit

Permalink
Merge pull request #555 from bghira/main
Browse files Browse the repository at this point in the history
cv2: error checking for image load when we hit grayscale images | LoRA: save/load refactor from Sayak | arguments: set --inference_noise_scheduler to None by default so that PixArt scheduler is uninterrupted
  • Loading branch information
bghira authored Jul 4, 2024
2 parents f37de18 + 0af27c0 commit 007409d
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 187 deletions.
4 changes: 2 additions & 2 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,12 +1204,12 @@ def parse_args(input_args=None):
"--validation_noise_scheduler",
type=str,
choices=["ddim", "ddpm", "euler", "euler-a", "unipc"],
default="euler",
default=None,
help=(
"When validating the model at inference time, a different scheduler may be chosen."
" UniPC can offer better speed, and Euler A can put up with instabilities a bit better."
" For zero-terminal SNR models, DDIM is the best choice. Choices: ['ddim', 'ddpm', 'euler', 'euler-a', 'unipc'],"
" Default: ddim"
" Default: None (use the model default)"
),
)
parser.add_argument(
Expand Down
7 changes: 4 additions & 3 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,16 @@ def read_image(self, filepath: str, delete_problematic_images: bool = False):
import traceback

logger.error(
f"Encountered error opening image: {e}, traceback: {traceback.format_exc()}"
f"Encountered error opening image {filepath}: {e}, traceback: {traceback.format_exc()}"
)
if delete_problematic_images:
logger.error(
f"Deleting image, because --delete_problematic_images is provided."
)
self.delete(filepath)
exit(1)
raise e
else:
exit(1)
raise e

def read_image_batch(
self, filepaths: list, delete_problematic_images: bool = False
Expand Down
20 changes: 12 additions & 8 deletions helpers/image_manipulation/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,28 @@ def decode_image_with_pil(img_data: bytes) -> Image.Image:

# For transparent images, add a white background as this is correct
# most of the time.
if img_pil.mode == 'RGBA':
if img_pil.mode == "RGBA":
canvas = Image.new("RGBA", img_pil.size, (255, 255, 255))
canvas.alpha_composite(img_pil)
img_pil = canvas.convert("RGB")
else:
img_pil = img_pil.convert('RGB')
img_pil = img_pil.convert("RGB")
except (OSError, Image.DecompressionBombError, ValueError) as e:
logger.warning(f'Error decoding image: {e}')
logger.warning(f"Error decoding image: {e}")
raise
return img_pil


def load_image(img_data: Union[bytes, IO[Any], str]) -> Image.Image:
'''
"""
Load an image using CV2. If that fails, fall back to PIL.
The image is returned as a PIL object.
'''
"""
if isinstance(img_data, str):
with open(img_data, 'rb') as file:
with open(img_data, "rb") as file:
img_data = file.read()
elif hasattr(img_data, 'read'):
elif hasattr(img_data, "read"):
# Check if it's file-like object.
img_data = img_data.read()

Expand All @@ -70,7 +70,11 @@ def load_image(img_data: Union[bytes, IO[Any], str]) -> Image.Image:
nparr = np.frombuffer(img_data, np.uint8)
image_preload = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
has_alpha = False
if image_preload is not None and image_preload.shape[2] == 4:
if (
image_preload is not None
and len(image_preload.shape) >= 3
and image_preload.shape[2] == 4
):
has_alpha = True
del image_preload

Expand Down
Loading

0 comments on commit 007409d

Please sign in to comment.