Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cv2: error checking for image load when we hit grayscale images | LoRA: save/load refactor from Sayak | arguments: set --inference_noise_scheduler to None by default so that PixArt scheduler is uninterrupted #555

Merged
merged 6 commits into from
Jul 4, 2024
Merged
4 changes: 2 additions & 2 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,12 +1204,12 @@ def parse_args(input_args=None):
"--validation_noise_scheduler",
type=str,
choices=["ddim", "ddpm", "euler", "euler-a", "unipc"],
default="euler",
default=None,
help=(
"When validating the model at inference time, a different scheduler may be chosen."
" UniPC can offer better speed, and Euler A can put up with instabilities a bit better."
" For zero-terminal SNR models, DDIM is the best choice. Choices: ['ddim', 'ddpm', 'euler', 'euler-a', 'unipc'],"
" Default: ddim"
" Default: None (use the model default)"
),
)
parser.add_argument(
Expand Down
7 changes: 4 additions & 3 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,16 @@ def read_image(self, filepath: str, delete_problematic_images: bool = False):
import traceback

logger.error(
f"Encountered error opening image: {e}, traceback: {traceback.format_exc()}"
f"Encountered error opening image {filepath}: {e}, traceback: {traceback.format_exc()}"
)
if delete_problematic_images:
logger.error(
f"Deleting image, because --delete_problematic_images is provided."
)
self.delete(filepath)
exit(1)
raise e
else:
exit(1)
raise e

def read_image_batch(
self, filepaths: list, delete_problematic_images: bool = False
Expand Down
20 changes: 12 additions & 8 deletions helpers/image_manipulation/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,28 @@ def decode_image_with_pil(img_data: bytes) -> Image.Image:

# For transparent images, add a white background as this is correct
# most of the time.
if img_pil.mode == 'RGBA':
if img_pil.mode == "RGBA":
canvas = Image.new("RGBA", img_pil.size, (255, 255, 255))
canvas.alpha_composite(img_pil)
img_pil = canvas.convert("RGB")
else:
img_pil = img_pil.convert('RGB')
img_pil = img_pil.convert("RGB")
except (OSError, Image.DecompressionBombError, ValueError) as e:
logger.warning(f'Error decoding image: {e}')
logger.warning(f"Error decoding image: {e}")
raise
return img_pil


def load_image(img_data: Union[bytes, IO[Any], str]) -> Image.Image:
'''
"""
Load an image using CV2. If that fails, fall back to PIL.

The image is returned as a PIL object.
'''
"""
if isinstance(img_data, str):
with open(img_data, 'rb') as file:
with open(img_data, "rb") as file:
img_data = file.read()
elif hasattr(img_data, 'read'):
elif hasattr(img_data, "read"):
# Check if it's file-like object.
img_data = img_data.read()

Expand All @@ -70,7 +70,11 @@ def load_image(img_data: Union[bytes, IO[Any], str]) -> Image.Image:
nparr = np.frombuffer(img_data, np.uint8)
image_preload = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
has_alpha = False
if image_preload is not None and image_preload.shape[2] == 4:
if (
image_preload is not None
and len(image_preload.shape) >= 3
and image_preload.shape[2] == 4
):
has_alpha = True
del image_preload

Expand Down
Loading
Loading