diff --git a/sgm/util.py b/sgm/util.py index 06f48a882..96181ed74 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -64,37 +64,36 @@ def do_autocast(*args, **kwargs): def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) - def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) - txts = list() + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + txts = [] + for bi in range(b): - txt = Image.new("RGB", wh, color="white") - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) - nc = int(40 * (wh[0] / 256)) if isinstance(xc[bi], list): text_seq = xc[bi][0] else: text_seq = xc[bi] lines = "\n".join( - text_seq[start : start + nc] for start in range(0, len(text_seq), nc) + text_seq[start: start + nc] for start in range(0, len(text_seq), nc) ) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") + print("Can't encode string for logging. Skipping.") + + txt_arr = np.array(txt) / 127.5 - 1.0 + txts.append(txt_arr) - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) txts = np.stack(txts) - txts = torch.tensor(txts) return txts - def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)