-
Notifications
You must be signed in to change notification settings - Fork 4
/
__init__.py
82 lines (71 loc) · 1.62 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
These imports are shared across all files.
Created by Basile Van Hoorick for TCOW.
'''
# Library imports.
import argparse
import collections
import collections.abc
import copy
import cv2
import imageio
import itertools
import joblib
import json
import lovely_numpy
import lovely_tensors
import matplotlib.colors
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import os
import pandas as pd
import pathlib
import pickle
import platform
import random
import rich
import rich.console
import rich.logging
import rich.progress
import scipy
import seaborn as sns
import shutil
import sklearn
import sklearn.decomposition
import sys
import time
import torch
import torch.nn
import torch.nn.functional
import torch.optim
import torch.utils
import torch.utils.data
import torchvision
import torchvision.datasets
import torchvision.io
import torchvision.models
import torchvision.transforms
import torchvision.utils
import tqdm
import tqdm.rich
import warnings
from collections import defaultdict
from einops import rearrange, repeat
from lovely_numpy import lo
from rich import print
PROJECT_NAME = 'tcow'
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'data/'))
sys.path.append(os.path.join(os.getcwd(), 'eval/'))
sys.path.append(os.path.join(os.getcwd(), 'model/'))
sys.path.append(os.path.join(os.getcwd(), 'third_party/'))
sys.path.append(os.path.join(os.getcwd(), 'utils/'))
lovely_tensors.monkey_patch()
# Quick functions for usage during debugging:
def mmm(x):
return (x.min(), x.mean(), x.max())
def st(x):
return (x.dtype, x.shape)
def stmmm(x):
return (*st(x), *mmm(x))