Skip to content

Commit

Permalink
Use mock data and mock models in tests (#135)
Browse files Browse the repository at this point in the history
* Unit tests now use mock data and mock models

* Fixed typo

* Refactored to reduce redundancy

* Fixed test errors

* Fixed plotting error

* Fixed plotting error (2)

* Fixed plotting error (3)

* Fixed tests errors (2)

* Fixed tests errors (3)

* Fixed tests errors (4)
  • Loading branch information
NeelayS authored May 15, 2022
1 parent 9483299 commit ae72e3b
Show file tree
Hide file tree
Showing 40 changed files with 312 additions and 402 deletions.
31 changes: 15 additions & 16 deletions KD_Lib/KD/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from .common import BaseClass
from .text import BERT2LSTM, get_essentials
from .vision import (
VanillaKD,
VirtualTeacher,
SelfTraining,
TAKD,
RKDLoss,
RCO,
NoisyTeacher,
SoftRandom,
MessyCollab,
MeanTeacher,
LabelSmoothReg,
ProbShift,
DML,
BANN,
CSKD,
DML,
RCO,
TAKD,
Attention,
LabelSmoothReg,
MeanTeacher,
MessyCollab,
NoisyTeacher,
ProbShift,
RKDLoss,
SelfTraining,
SoftRandom,
VanillaKD,
VirtualTeacher,
)

from .text import BERT2LSTM, get_essentials
from .common import BaseClass
10 changes: 5 additions & 5 deletions KD_Lib/KD/common/base_class.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from copy import deepcopy
import os


class BaseClass:
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ def train_teacher(
loss.backward()
self.optimizer_teacher.step()

epoch_loss += loss
epoch_loss += loss.item()

epoch_acc = correct / length_of_dataset

Expand Down
13 changes: 6 additions & 7 deletions KD_Lib/KD/text/BERT2LSTM/bert2lstm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

from transformers import BertForSequenceClassification, AdamW, BertTokenizer
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AdamW, BertForSequenceClassification, BertTokenizer

from KD_Lib.KD.common import BaseClass
from KD_Lib.KD.text.utils import get_bert_dataloader
Expand Down
18 changes: 11 additions & 7 deletions KD_Lib/KD/text/BERT2LSTM/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# coding: utf-8

from __future__ import unicode_literals, print_function
from __future__ import print_function, unicode_literals

import random
from contextlib import closing
from multiprocessing import Pool

import numpy as np
import pandas as pd
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, random_split
from tqdm import tqdm
import numpy as np
import random

from torch.utils.data import (
DataLoader,
RandomSampler,
SequentialSampler,
TensorDataset,
random_split,
)
from torchtext.legacy import data
from tqdm import tqdm


class InputExample(object):
Expand Down
14 changes: 10 additions & 4 deletions KD_Lib/KD/text/utils/bert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import torch
import matplotlib.pyplot as plt
from copy import deepcopy
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, random_split
from torch.utils.data import TensorDataset
from torch.utils.data import (
DataLoader,
RandomSampler,
SequentialSampler,
TensorDataset,
random_split,
)

"""
DATALOADER UTILITIES
Expand Down
8 changes: 4 additions & 4 deletions KD_Lib/KD/vision/BANN/BANN.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import glob
import os
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import glob
from copy import deepcopy

from KD_Lib.KD.common import BaseClass


Expand Down
8 changes: 4 additions & 4 deletions KD_Lib/KD/vision/CSKD/cskd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
from copy import deepcopy

from KD_Lib.KD.common import BaseClass


Expand Down Expand Up @@ -77,7 +77,7 @@ def calculate_kd_loss(self, y_pred_pair_1, y_pred_pair_2):
q = torch.softmax(y_pred_pair_2 / self.temp, dim=1)
loss = (
nn.KLDivLoss(reduction="sum")(log_p, q)
* (self.temp ** 2)
* (self.temp**2)
/ y_pred_pair_1.size(0)
)

Expand Down
20 changes: 12 additions & 8 deletions KD_Lib/KD/vision/CSKD/sampler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import csv, torchvision, numpy as np, random, os
from PIL import Image
import csv
import os
import random
from collections import defaultdict

import numpy as np
import torchvision
from PIL import Image
from torch.utils.data import (
Sampler,
Dataset,
DataLoader,
BatchSampler,
SequentialSampler,
DataLoader,
Dataset,
RandomSampler,
Sampler,
SequentialSampler,
Subset,
)
from torchvision import transforms, datasets
from collections import defaultdict
from torchvision import datasets, transforms


class PairBatchSampler(Sampler):
Expand Down
10 changes: 5 additions & 5 deletions KD_Lib/KD/vision/DML/dml.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from copy import deepcopy
import os


class DML:
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def train_students(

correct += max(correct_preds)

epoch_loss += avg_student_loss
epoch_loss += avg_student_loss.item()

epoch_acc = correct / length_of_dataset

Expand Down
6 changes: 3 additions & 3 deletions KD_Lib/KD/vision/RCO/rco.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from copy import deepcopy

from KD_Lib.KD.common import BaseClass


Expand Down
2 changes: 1 addition & 1 deletion KD_Lib/KD/vision/RKD/loss_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn


def pairwaise_distance(output):
Expand Down
6 changes: 3 additions & 3 deletions KD_Lib/KD/vision/TAKD/takd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
from copy import deepcopy

from KD_Lib.KD.common import BaseClass


Expand Down
20 changes: 10 additions & 10 deletions KD_Lib/KD/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .vanilla import VanillaKD
from .teacher_free import VirtualTeacher, SelfTraining
from .TAKD import TAKD
from .attention import Attention
from .BANN import BANN
from .CSKD import CSKD
from .RKD import RKDLoss
from .RCO import RCO
from .noisy import NoisyTeacher, SoftRandom, MessyCollab
from .mean_teacher import MeanTeacher
from .KA import LabelSmoothReg, ProbShift
from .DML import DML
from .BANN import BANN
from .attention import Attention
from .KA import LabelSmoothReg, ProbShift
from .mean_teacher import MeanTeacher
from .noisy import MessyCollab, NoisyTeacher, SoftRandom
from .RCO import RCO
from .RKD import RKDLoss
from .TAKD import TAKD
from .teacher_free import SelfTraining, VirtualTeacher
from .vanilla import VanillaKD
2 changes: 1 addition & 1 deletion KD_Lib/KD/vision/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .loss_metric import ATLoss
from .attention import Attention
from .loss_metric import ATLoss
3 changes: 2 additions & 1 deletion KD_Lib/KD/vision/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn.functional as F

from KD_Lib.KD.common import BaseClass

from .loss_metric import ATLoss


Expand Down Expand Up @@ -66,7 +67,7 @@ def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
loss = (
(1.0 - self.distil_weight)
* self.temp
* F.cross_entropy(y_pred_student[0] / self.temp, y_true)
* F.cross_entropy(y_pred_student / self.temp, y_true)
)
loss += self.distil_weight * self.loss_fn(y_pred_teacher, y_pred_student)
return loss
3 changes: 2 additions & 1 deletion KD_Lib/KD/vision/mean_teacher/mean_teacher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn

from KD_Lib.KD.common import BaseClass


Expand Down
2 changes: 1 addition & 1 deletion KD_Lib/KD/vision/noisy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .messy_collab import MessyCollab
from .noisy_teacher import NoisyTeacher
from .soft_random import SoftRandom
from .messy_collab import MessyCollab
10 changes: 5 additions & 5 deletions KD_Lib/KD/vision/noisy/messy_collab.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import random
from copy import deepcopy

import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from KD_Lib.KD.common import BaseClass

Expand Down
9 changes: 5 additions & 4 deletions KD_Lib/KD/vision/noisy/noisy_teacher.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import random
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

import random
from copy import deepcopy
import matplotlib.pyplot as plt
from KD_Lib.KD.common import BaseClass

from .utils import add_noise
from KD_Lib.KD.common import BaseClass


class NoisyTeacher(BaseClass):
Expand Down
7 changes: 4 additions & 3 deletions KD_Lib/KD/vision/noisy/soft_random.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy
import matplotlib.pyplot as plt

from KD_Lib.KD.common import BaseClass

from .utils import add_noise


Expand Down
2 changes: 1 addition & 1 deletion KD_Lib/KD/vision/noisy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def add_noise(x, variance=0.1):
:param variance (float): Variance for adding noise
"""

return x * (1 + (variance ** 0.5) * torch.randn_like(x))
return x * (1 + (variance**0.5) * torch.randn_like(x))
2 changes: 1 addition & 1 deletion KD_Lib/KD/vision/teacher_free/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .virtual_teacher import VirtualTeacher
from .self_training import SelfTraining
from .virtual_teacher import VirtualTeacher
Loading

0 comments on commit ae72e3b

Please sign in to comment.