Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CeecnetV1 Performance #3

Open
Farihaa opened this issue Nov 3, 2020 · 22 comments
Open

CeecnetV1 Performance #3

Farihaa opened this issue Nov 3, 2020 · 22 comments

Comments

@Farihaa
Copy link

Farihaa commented Nov 3, 2020

I have been training CeecnetV1 on LEVIRCD Dataset, even after 120 epochs I am getting an average f1-score between 0.35-0.45 for segments while average loss is 0.22. Around how many epochs should I see f1 score going up?

@feevos
Copy link
Owner

feevos commented Nov 3, 2020

Hi @Farihaa ,

in about 50 epochs for theLEVIRCD dataset you should be seeing first peak of performance (without learning rate reduction) - which is already excellent. You can see this from the training curves on Fig. 15 / page 16 of the manuscript (see docs/manuscript.pdf). You are probably searching for a bug. If you can please post your training code and setup (hardware resources/batch size etc), I should be able to help more.

Regards

@Farihaa
Copy link
Author

Farihaa commented Nov 3, 2020

I am using 4 Tesla T4 gpus and I could only get to a batch size of 2 per gpu for distributed training. All the other hyper parameters are same as you mentioned in the paper. However, I was not doing on-the-fly data augmentation because gpu memory could not handle it. Would that be the main cause of it?

def train(train_data_path, val_data_path):
    '''
    Parameters
    ----------
    epochs : int
        num iterations to be done over the dataset.
    train_data_path : str
        path to training rec/idx-pass filename.
    val_data_path : str
        path to validation rec/idx-pass filename.

    Returns
    -------
    None.

    '''
    gpu_count = mx.context.num_gpus()
    ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu()
    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    context = mx.cpu(hvd.local_rank()) if no_cuda else mx.gpu(hvd.local_rank())
    num_workers = hvd.size()
    CPU_COUNT = cpu_count()
    sw = SummaryWriter(logdir='./logs')

    epochs = 200
    lr= 0.0001
    lr_decay_count = 0
    log_interval = 1
   
    NClasses= 2 
    batch_size=2
    
    if load_ckpt:
      
        net = load_checkpoint("checkpoint",context)
       
    else :     
        # initializing network
        net = createModel()
       # net.cast('float16')
    net.hybridize()
    # initializing loss
    myMTSKL = mtsk_loss()

    #f1s= f1_score()
    # perform metric
    calc_f1 = f1_score()

   
    # Define our trainer for net

    # Set parameters
    adam_optimizer = optimizer.Adam(learning_rate=lr * hvd.size(), beta1=0.9, beta2=0.999, epsilon=1e-08)
    net.initialize(mx.init.Xavier(rnd_type='gaussian', factor_type="in",
                             magnitude=2), ctx=context)
    #net.collect_params().reset_ctx(context)
    # Horovod: fetch and broadcast parameters
    params = net.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)
    
    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params,adam_optimizer )
    #trainer = gluon.Trainer(net.collect_params(), adam_optimizer)
   # amp.init()
   # amp.init_trainer(trainer)


    train_history = TrainingHistory(['training-f1-score', 'validation-f1-score'])
    # load dataset
    tnorm = LVRCDNormal() # image normalization for each chip
    ttransform = None
    train_data=LVRCDDataset(root=r'/LEVIRCD/Files/', mode='train',transform=ttransform,norm=tnorm)
    ntrain = train_data.__len__()
    print("Total samples : %d" % ntrain)
    sampler=SplitSampler(ntrain,num_workers,hvd.rank())
    datagen=gluon.data.DataLoader(train_data,batch_size=batch_size,sampler=SplitSampler(ntrain,num_workers,hvd.rank()))
    
    val_prev=0
    # training loop
    for epoch in range(ep,epochs):
        tic = time.time()
        train_loss = 0
        j=0
        fout=0
        tlout=0
        
        for img1,img2,labels in datagen:

                    img1 = img1.as_in_context(context)
                    img2=img2.as_in_context(context)
                    labels = labels.as_in_context(context)

                    with autograd.record():
                          outputs = net( img1,  img2)
                          loss = myMTSKL.loss(outputs,labels)
                   # with amp.scale_loss(loss, trainer) as scaled_loss:
                    #    autograd.backward(scaled_loss)
                    loss.backward()
                    trainer.step(batch_size)
                    print("Loss for a batch : " , loss)
                    train_loss = sum([l.sum() for l in loss.asnumpy()])
                    tlout =tlout + train_loss
                    #print(train_loss)
                    label_segm=labels[:,:NClasses,:,:]
                    label_bound=labels[:,NClasses:2*NClasses,:,:]
                    label_dists=labels[:, 2*NClasses:,:,:]
                    ground = [label_segm ,label_bound,label_dists]
                    
                    f1= calc_f1.score(ground,outputs,NClasses)
                    
                    # f1_score.update(labels= ground, preds=outputs)
                    # name,f1= f1_score.get()
                    fout = fout + f1
                
                   # net.export("model-cdet", epoch)

                    # Update history and print metrics
                  #train_history.update([1-f1, 1-val_f1])
                    print('[Epoch %d] f1-score=%f train_loss=%f time: %f' %
                      (epoch, f1, train_loss/batch_size, time.time()-tic))
                    j=0
        
        sw.add_scalar(tag='train_f1score', value=fout/(ntrain/(batch_size *hvd.size())), global_step=epoch)
        # saving model parameters and weights
        print('[Epoch %d] Train-f1 = %f Train-loss = %f'% (epoch,
                     fout/(ntrain/(batch_size*hvd.size())), tlout/(ntrain/hvd.size())))
        if (epoch % 10==0):
            net.export("model-cdet",epoch)
            
       # if hvd.rank() == 0:
         #   logging.info('Epoch[%d]\tTrain-f1 = %f\tTrain-loss = %f', epoch,
          #           fout/ntrain/batch_size, train_loss/ntrain)
            
        if fout/(ntrain/batch_size*hvd.size()) > 0.7 and tlout/(ntrain/hvd.size()) < 0.1:
            net.export("model-cdet", epoch)
            print('performing validation')
            val_f1, val_loss= test(net, context, val_data_path,hvd.size(),num_workers,hvd)
            print('[Epoch %d] val_loss=%f f1-val=%f' %
              (epoch, val_loss, val_f1))
            if abs(val_loss - val_prev) < 0.0005:
                trainer.set_learning_rate(lr/10)
                myMTSKL.depth= myMTSKL.depth + 10
                val_prev = val_loss   

@feevos
Copy link
Owner

feevos commented Nov 3, 2020

Hi @Farihaa, comments/things I did differently and I do not know how they may affect your training

  1. You calculate F1 score from all segmentation, boundary, distance predictions, you should only be using segmentation: the boundary prediction usually is lower, and distance is a continuous regression variable, therefore F1 does not apply here. This will have an effect in getting low F1 score, therefore the F1 you calculate does not reflect the performance of the network (do some inference visualization sand compare with the manuscript). The F1 reported in the manuscript is only on the segmentation change class.
  2. The models are provided with default values, but these defaults are not the ones mentioned in the manuscript. For example, the mantis models use - by default - BatchNorm, while you need to use GroupNorm, especially with such a small batch size per GPU, as mentioned in the manuscript. I cannot understand from your code if you follow this practice? The default values used for training can be found in the demo notebook (cell 4), and these are:
# D6nf32 example 
depth=6
norm_type='GroupNorm'
norm_groups=4
ftdepth=5
NClasses=2
nfilters_init=32
psp_depth=4
nheads_start=4

net = mantis_dn_cmtsk(nfilters_init=nfilters_init, NClasses=NClasses,depth=depth, ftdepth=ftdepth, model='CEECNetV1',psp_depth=psp_depth,norm_type=norm_type,norm_groups=norm_groups,nheads_start=nheads_start)
net.initialize()
  1. The initialization we follow is the default net.initialize(), I don't know if the modifications you did affect the performance.
  2. I do not scale the learning rate with the number of workers as it is proposed in the horovod documentation. Instead, I follow the guidelines of Smith 2018, increasing in a single epoch the learning rate (starting from a very low value) and monitoring the training loss - see figure A.16 in our appendix of the resunet-a paper (https://arxiv.org/pdf/1904.00592.pdf). Therefore the learning rate used is lr=1.e-3 (without workers scaling) in our case, with 24 x 4 = 96 P100 GPUs, and batch_size = 4 per GPU (adding some nd.waitall() commands may allow you to increase the batch size).
with autograd.record():
    outputs = net( img1,  img2)
    nd.waitall() # mxnet < 2.0            
    loss = myMTSKL.loss(outputs,labels)
    nd.waitall() # mxnet < 2.0            

You should do the same process for the available number of gpus you have, it won't take you more than few minutes/1h max of compute time to find this. I found this much better than scaling linearly the learning rate in all my experiments. It should also be relative to your problem, because you are only using 4 GPUs (= total batch size =8, I think?) which is very low and training may be unstable. We used >200 - this affects the choice of learning rate, you probably need lower. You may also try the trick of gradient accumulation (it works with horovod as well) for increasing batch size and stabilize training.
5. It is not clear to me how you calculate the validation loss value. Usually a Tanimoto (with dual) loss value of ~0.1-0.2 for the segmentation task (only) gives nice results (not SOTA, but respectable). If you are also including boundary and distance loss value, this will be higher, as these are more difficult to estimate.
6. If you cannot overfit the training data, it is bound you have a bug somewhere. Looking at your code, make sure the ground truth labels are in the appropriate order and format (in 1hot encoding). I do not see an obvious bug (I have not tested amp with this model to know how it will perform, you should debug first on float32, as your code does I believe, by commenting out amp related commands).
7. Data augmentation cannot affect performance that much, it is not the source of error. In my code, this happens on CPU memory, prior loading to the gpu, so this does not affect your gpu memory load.

I propose to do some visualizations of ground truth labels (to validate they are in the correct order, segmentation/boundary/distance, in 1hot) and inference results, to judge by eye the performance of the network. Feel free to post here results.

Hope the above helps
Foivos

@Farihaa
Copy link
Author

Farihaa commented Nov 4, 2020

Thank you so much for such a detailed answer.
1-I am only calculating the f1-score for segmentation only, from code it may seem like I am using boundary and distance as well.
2- I am using the default values for training as mentioned in the demo notebook
3- I used net.initialize() as well but for me to export the model checkpoints I had to use net.hybridize() as well
4- Yes using 4 gpus its a batch size of 8, I will work on the learning rate part
5- Yes for loss I am using all three, I will change that.
I will work on your suggestions and see if it helps out.

Thanks once again.

@feevos
Copy link
Owner

feevos commented Nov 4, 2020

Pleasure @Farihaa - let me know how it works out for you and feel free to post results/ask questions.

Best of luck on your experiments.

@Farihaa
Copy link
Author

Farihaa commented Nov 11, 2020

I tried the gradient accumulation way and got to accumulate gradients after 32 samples. The training is still not stable and I tried with a really low learning rate as well. I change the low learning rate after every 4 epochs by multiplying it with a factor but there was no change.I also got the augmentation running as per paper and I created the .rec files as per Data_slicing.py and chopchop.py files. The 0 class is obviously good because it represents negative class but 1 class has bad f1 score. I mean for some samples it is good but overall its bad.

@feevos
Copy link
Owner

feevos commented Nov 11, 2020

Hi @Farihaa , please post your complete code to take a look. There is most probably a bug somewhere. I'll try my best to help.

@Farihaa
Copy link
Author

Farihaa commented Nov 11, 2020

Here, '#' before a line represents that it is commented.

from mxnet import nd
from models.changedetection.mantis.mantis_dn import *
from mxnet import autograd,optimizer,gluon
import mxnet as mx
import horovod.mxnet as hvd
from multiprocessing import cpu_count
from mxboard import SummaryWriter


from nn.loss.mtsk_loss import *
from src.LVRCDDataset import *
from src.LVRCDNormal import*
from src.semseg_aug_cv2 import*
from gluoncv.utils import makedirs, TrainingHistory, export_block
from pycm import *


import os
import cv2
from tqdm import tqdm
import time
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import logging
import random
import warnings
import argparse


from skimage.io import imread
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
from skimage.measure import label, regionprops
import matplotlib.patches as mpatches
from skimage.util import montage
from skimage.io import imread , imsave
import numpy as np
montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)

from skimage.morphology import label
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '1'
os.environ['MXNET_GPU_MEM_POOL_RESERVE'] = '19'
os.environ['MXNET_BACKWARD_DO_MIRROR'] = '1'
os.environ['MXNET_GPU_COPY_NTHREADS'] = '1'

no_cuda=False
load_ckpt =True
ep=131

if not no_cuda:
    # Disable CUDA if there are no GPUs.
    if mx.context.num_gpus() == 0:
        no_cuda = True

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--batch_size', type=int, default=2,
                    help='training batch size per worker (default: 64)')
    parser.add_argument('--epochs', type=int, default=30,
                    help='number of training epochs (default: 5)')
    parser.add_argument('--lr', type=float, default=0.0001,
                    help='learning rate (default: 0.01)')

    parser.add_argument('--load_ckpt', action="store_true", default=False,
                    help='load checkpoints from directory')


    args = parser.parse_args()
    return args


class f1_score(object):
    """
    Here NClasses = 2 by default, for a binary segmentation problem in 1hot representation
    """

    def __init__(self,depth=0, NClasses=2):


        self.skip = NClasses

    def score(self,ground,outputs,NClasses=2):

        pred_segm  = np.argmax(outputs[0].asnumpy(), axis=0).flatten()
        pred_bound = np.argmax(outputs[1].asnumpy(), axis=0).flatten()
        pred_dists = np.argmax(outputs[2].asnumpy(), axis=0).flatten()

        # In our implementation of the labels, we stack together the [segmentation, boundary, distance] labels,
        # along the channel axis.
        # reshaping labels
        label_segm  = np.argmax(ground[0].asnumpy(), axis=0).flatten()
        label_bound = np.argmax(ground[1].asnumpy(), axis=0).flatten()
        label_dists = np.argmax(ground[2].asnumpy(), axis=0).flatten()


        # metric calculation
        cm_segm = ConfusionMatrix(label_segm,pred_segm,digit=3)
        cm_bound = ConfusionMatrix(label_bound,pred_bound,digit=3)
        cm_dists = ConfusionMatrix(label_dists,pred_dists,digit=3)

        print ("Segment F1 score : " , cm_segm.F1)
        segm_f1 =sum([cm_segm.F1[f] for f in range(NClasses)])/NClasses
       # bound_f1 =sum([cm_bound.F1[f] for f in range(NClasses)])/NClasses
       # dists_f1 = sum([cm_dists.F1[f] for f in range(NClasses)])/NClasses



        return segm_f1

def createModel():
    # D6nf32 example
    depth = 6
    norm_type = 'GroupNorm'
    norm_groups = 4
    ftdepth = 5
    NClasses = 2
    nfilters_init = 32
    psp_depth = 4
    nheads_start = 4

    #intialize net
    net = mantis_dn_cmtsk(nfilters_init=nfilters_init, NClasses=NClasses, depth=depth,
                          ftdepth=ftdepth, model='CEECNetV1', psp_depth=psp_depth, norm_type=norm_type,
                          norm_groups=norm_groups, nheads_start=nheads_start)

    return net



class SplitSampler(gluon.data.sampler.Sampler):
    """ Split the dataset into `num_parts` parts and sample from the part with
    index `part_index`
    Parameters
    ----------
    length: int
      Number of examples in the dataset
    num_parts: int
      Partition the data into multiple parts
    part_index: int
      The index of the part to read from
    """
    def __init__(self, length, num_parts=1, part_index=0):
        # Compute the length of each partition
        self.part_len = length // num_parts
        # Compute the start index for this partition
        self.start = self.part_len * part_index
        # Compute the end index for this partition
        self.end = self.start + self.part_len

    def __iter__(self):
        # Extract examples between `start` and `end`, shuffle and return them.
        indices = list(range(self.start, self.end))
        random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.part_len


def load_checkpoint(model_dir,ctx):
    '''


    Paramete
    ----------
    model_dir : TYPE
        DESCRIPTION.

    Returns
    -------
    None.

    '''

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        logging.info('loading checkpoint')
       # sm = mx.sym.load(model_dir +'/'+ 'model-cdet-symbol.json')
       # inputs = mx.sym.var('data0','data1')
       # net = mx.gluon.SymbolBlock(sm, inputs)
       # net.collect_params().load(model_dir + '/' + 'model-cdet-0004.params', ctx)
       # net = gluon.nn.SymbolBlock( outputs=mx.sym.load_json(model_dir +'/'+ 'model-cdet-symbol.json'), inputs=mx.sym.var('mantis_dn_cmtsk0_mantis_dn_features0_conv2dnormed0_conv0_weight'))
        #net.load_params(model_dir +'/'+ 'model-cdet-0004.params')

        net = gluon.nn.SymbolBlock.imports(model_dir +'/'+ 'model-cdet-symbol.json', ['data0','data1'], model_dir +'/'+ 'model-cdet-0148.params', ctx=ctx)

    return net

def test(model, ctx, val_data_path,size,num_workers,hvd):
    '''
    Parameters
    ----------
    model : initialized network
    ctx : int/list
        cpu/gpu id list
    val_data : str
        path to validation rec filename
    size : int
        number of devices available

    Returns
    -------
    float
        validation f1 score
    float
        validation loss

    '''
    batch_size = 2
    f1_val = f1_score()

    val_data = LVRCDDataset(root=r'/LEVIRCD/Files/', mode='val')
    nval=val_data.__len__()
    datagen=gluon.data.DataLoader(val_data,batch_size=batch_size,sampler=SplitSampler(nval,num_workers,hvd.rank()))

   # nval = val_data.__len__()
    myMTSKL = mtsk_loss()
    val_loss=0
    val_f1=0

    for img1,img2,labels in datagen:
        img1 = img1.as_in_context(ctx)
        img2=img2.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        outputs= model(img1,img2)
        loss = myMTSKL.loss(outputs,labels)
        val_loss += sum([l.sum() for l in loss.asnumpy()])
        label_segm=labels[:,:NClasses,:,:]
        label_bound=labels[:,NClasses:2*NClasses,:,:]
        label_dists=labels[:, 2*NClasses:,:,:]
        ground = [label_segm,label_bound,label_dists]

        f1= f1_val.score(ground, outputs)
        val_f1 += f1

    return  val_f1/(nval/(batch_size*size)), val_loss/(nval/size)


def train(train_data_path, val_data_path):
    '''
    Parameters
    ----------
    epochs : int
        num iterations to be done over the dataset.
    train_data_path : str
        path to training rec/idx-pass filename.
    val_data_path : str
        path to validation rec/idx-pass filename.

    Returns
    -------
    None.

    '''
    gpu_count = mx.context.num_gpus()
    ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu()
    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    context = mx.cpu(hvd.local_rank()) if no_cuda else mx.gpu(hvd.local_rank())
    num_workers = hvd.size()
    CPU_COUNT = cpu_count()

    sw = SummaryWriter(logdir='./logs')

    epochs = 200
    lr= 0.00000001
    lr_decay_count = 0
    log_interval = 1

    NClasses= 2
    batch_size=2
    NLarge_batch = 32

    if load_ckpt :
       # net = createModel()
       # net.load_parameters('checkpoint/model-cdet-0004.params')
        ep=153
        net = load_checkpoint("checkpoint",context)
      
    else :
        # initializing network
        ep=0
        net = createModel()
      
    net.hybridize()
    # initializing loss
    myMTSKL = mtsk_loss()
    schedule = LearningSchedule(lr=lr,epoch=ep)
    #f1s= f1_score()
    # perform metric
    calc_f1 = f1_score()


    # Define our trainer for net

    # Set parameters
    adam_optimizer = optimizer.Adam(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-08)
    net.initialize(ctx=context)
    #net.collect_params().reset_ctx(context)
    # Horovod: fetch and broadcast parameters
    params = net.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params,adam_optimizer )



    train_history = TrainingHistory(['training-f1-score', 'validation-f1-score'])
    # load dataset
    tnorm = LVRCDNormal() # image normalization for each chip
    ttransform = SemSegAugmentor_CV()
    train_data=LVRCDDataset(root=r'/LEVIRCD/Files/', mode='train',transform=ttransform,norm=tnorm)
    ntrain = train_data.__len__()
    print("Total samples : %d" % ntrain)

    datagen=gluon.data.DataLoader(train_data,batch_size=NLarge_batch,last_batch='discard',num_workers=CPU_COUNT,pin_memory=True,sampler=SplitSampler(ntrain,num_workers,hvd.rank()))

    val_prev=0
    # training loop
    for epoch in range(ep,epochs):
        tic = time.time()
        train_loss = 0
        j=0
        fout=0
        tlout=0



        for img1_b,img2_b,label_b in datagen:
                  

                tl_batch=0
                f1_batch=0
                img1_split = np.array_split(img1_b,16)
                img2_split = np.array_split(img2_b,16)
                label_split = np.array_split(label_b,16)


                for i in range(len(img1_split)):
                    img1 = mx.nd.array(img1_split[i],context)
                    img2= mx.nd.array(img2_split[i],context)
                    labels = mx.nd.array(label_split[i],context)
                    with autograd.record():
                          outputs=net(img1,  img2)
                          loss = myMTSKL.loss(outputs,labels)

                 

                    loss.backward()
                    train_loss = sum([l.sum() for l in loss.asnumpy()])
                    print("Loss for a batch : " , loss)

                    tlout = tlout + train_loss
                    tl_batch = tl_batch + train_loss
                    #print(train_loss)
                    label_segm=labels[:,:NClasses,:,:]
                    label_bound=labels[:,NClasses:2*NClasses,:,:]
                    label_dists=labels[:, 2*NClasses:,:,:]
                    ground = [label_segm ,label_bound,label_dists]

                    f1= calc_f1.score(ground,outputs,NClasses)

                    # f1_score.update(labels= ground, preds=outputs)
                    # name,f1= f1_score.get()
                    fout = fout + f1
                    f1_batch= f1_batch + f1

                  

                trainer.step(NLarge_batch)

                # net.export("model-cdet", epoch)

                # Update history and print metrics
                #train_history.update([1-f1, 1-val_f1])
                print('[Epoch %d] f1-score=%f train_loss=%f time: %f' %
                      (epoch, f1_batch/(NLarge_batch/batch_size), tl_batch/NLarge_batch, time.time()-tic))
                j=0
                # manually zero the gradients before# the next batch evaluation

                for weight_variable in net.collect_params().values():
                    weight_variable.zero_grad()

        #sw.add_scalar(tag='train_f1score', value=fout/(ntrain/(batch_size *hvd.size())), global_step=epoch)
        # saving model parameters and weights
        print('[Epoch %d] Train-f1 = %f Train-loss = %f'% (epoch,
                     fout/(ntrain/(batch_size * hvd.size())), tlout/(ntrain/hvd.size())))
        if (epoch % 4==0):
            net.export("model-cdet",epoch)
           # lr= trainer.learning_rate
           # trainer.set_learning_rate(lr * 2)
            trainer.save_states("states" + str(epoch))
       # if hvd.rank() == 0:
         #   logging.info('Epoch[%d]\tTrain-f1 = %f\tTrain-loss = %f', epoch,
          #           fout/ntrain/batch_size, train_loss/ntrain)

        if fout/(ntrain/batch_size*hvd.size()) > 0.7 and tlout/(ntrain/hvd.size()) < 0.1:
            net.export("model-cdet", epoch)
            print('performing validation')
            val_f1, val_loss= test(net, context, val_data_path,hvd.size(),num_workers,hvd)
            print('[Epoch %d] val_loss=%f f1-val=%f' %
              (epoch, val_loss, val_f1))
            if abs(val_loss - val_prev) < 0.0005:
                lr=trainer.learning_rate
                trainer.set_learning_rate(lr/10)
                myMTSKL.depth= myMTSKL.depth + 10
                val_prev = val_loss



if __name__ == "__main__":
    train_data_path = "training_LVRCD_F256"
    val_data_path = "validation_LVRCD_F256"
    # Training settings
    #args = parse_args()
    train( train_data_path, val_data_path)
    #main()

@feevos
Copy link
Owner

feevos commented Nov 11, 2020

Hi @Farihaa , you are doing several things differently, some I cannot know how they affect performance.

One of the bugs in the code is that in your training you do not implement the delayed gradients routine. For this you need to set explicitly the grad_req='add' for all network parameters. The default value is write, which means you are still using batch size of 2 (per iteration), and every time you call the loss value (and backward) the new gradient values overwrite the previous evaluation. You need to add the following after the network definition

net = CreateModel()
# Increase batch size under memory limitations
for weight_variable in net.collect_params().values():
    weight_variable.grad_req = 'add'

Example of training routine (but for mxnet 2.0, there are differences on how you manually zero the gradients). Here opt.batch_size is the batch size per gpu and update delay rate how many iterations one should wait until updating the weight variables.

def train(epochs,ctx,flname_write):
    counter=1
    with open(flname_write,"w") as f:
        print('epoch','train_mse','val_mse','train_loss',file=f,flush=True)

        ref_metric = 1000
        for epoch in range(epoch_start,epochs):
            tic = time.time()
            train_metric.reset()
            train_loss = 0

            # Loop through each batch of training data
            for i, (data,label) in enumerate(datagen_train):
                print("\rWithin epoch completion:: {}/{}".format(i+1,len(datagen_train)),end='',flush=True)
                # Extract data and label
                data = mx.np.array(data,ctx=ctx)
                label = mx.np.array(label,ctx=ctx)

                # AutoGrad
                with autograd.record():

                    outputs  = net(data)
                    mx.npx.waitall()
                    loss = loss_fn(outputs,label)
                loss.backward()
                mx.npx.waitall()


                # Optimize
                increase_counter=True
                if (counter % opt.update_delay_rate==0):
                    trainer.step(opt.batch_size *opt.update_delay_rate)
                    net.zero_grad()
                    # reset internal counter 
                    counter = 1
                    increase_counter = False

                if increase_counter:
                    counter += 1



                # Update metrics
                train_loss += loss.sum() #sum(losses)
                train_metric.update(label, outputs)
            train_loss = train_loss / len(datagen_train)
            name, train_mse = train_metric.get()
            # Evaluate on Validation data
            nd.waitall() # necessary to avoid cuda malloc
            name, val_mse = test(ctx, net, datagen_dev)

            # Print metrics
            # print both on screen and in file 
            print("\n")
            print('epoch={} train_mse={} val_mse={} train_loss={} time={}'.format(epoch, train_mse, val_mse, train_loss, time.time()-tic))
            print(epoch, train_mse, val_mse, train_loss, file=f,flush=True)

            net.save_parameters(flname_save_weights.replace('best_model','epoch-{}'.format(epoch)))
            if val_mse < ref_metric:
                # Save best model parameters, according to minimum val_mse
                net.save_parameters(flname_save_weights)
                ref_metric = val_mse

I do not know how these variables affect training (I use default values):

os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '1'
os.environ['MXNET_GPU_MEM_POOL_RESERVE'] = '19'
os.environ['MXNET_BACKWARD_DO_MIRROR'] = '1'
os.environ['MXNET_GPU_COPY_NTHREADS'] = '1'

In the evaluation of the score you perform argmax along the batch axis, not the channel (classes) axes, this should give erroneous score (however this does not mean the model is training, it is the visualizations that make the difference in performance).

def score(self,ground,outputs,NClasses=2):

    pred_segm  = np.argmax(outputs[0].asnumpy(), axis=0).flatten() # <==== axis 0 is batch axis, not classes axes. 
    pred_bound = np.argmax(outputs[1].asnumpy(), axis=0).flatten() # <==== axis 0 is batch axis, not classes axes. 
    pred_dists = np.argmax(outputs[2].asnumpy(), axis=0).flatten() # <==== axis 0 is batch axis, not classes axes. 

    # In our implementation of the labels, we stack together the [segmentation, boundary, distance] labels,
    # along the channel axis.
    # reshaping labels
    label_segm  = np.argmax(ground[0].asnumpy(), axis=0).flatten() # <==== axis 0 is batch axis, not classes axes. 
    label_bound = np.argmax(ground[1].asnumpy(), axis=0).flatten() # <==== axis 0 is batch axis, not classes axes. 
    label_dists = np.argmax(ground[2].asnumpy(), axis=0).flatten() # <==== axis 0 is batch axis, not classes axes. 


    # metric calculation
    cm_segm = ConfusionMatrix(label_segm,pred_segm,digit=3)
    cm_bound = ConfusionMatrix(label_bound,pred_bound,digit=3)
    cm_dists = ConfusionMatrix(label_dists,pred_dists,digit=3) # as explained in previous comment, these are regression variables. 

    print ("Segment F1 score : " , cm_segm.F1)
    segm_f1 =sum([cm_segm.F1[f] for f in range(NClasses)])/NClasses
    # bound_f1 =sum([cm_bound.F1[f] for f in range(NClasses)])/NClasses
    # dists_f1 = sum([cm_dists.F1[f] for f in range(NClasses)])/NClasses

    return segm_f1

I do not know what learning schedule you are using, but you don't need one to get SOTA - babysitting is better (we did not use lr schedule).

I will look more carefully tomorrow and get back to you. It will help if you provide visualizations of the data.

Regards,
Foivos

@Farihaa
Copy link
Author

Farihaa commented Nov 12, 2020

Thank you once again for such a detailed reply.
Okay so when I calculate f1 score against class axis, it is actually good.
So for visualising results and data, I did visualise but since patches are extracted from each image so it is not a complete image I see. I was trying to visualise inference results, so even for that I would have to extract patches and then pass through the model, right? How do I create complete image like I see in the paper?

@feevos
Copy link
Owner

feevos commented Nov 12, 2020

Hi @Farihaa , happy you got it working (hope you sorted out also the delayed gradients issue). With regards to inference, the process is the one described in the change detection paper (this repository), and in a bit more detail in the resunet-a paper. The idea is that you do inference on overlapping windows and then you average the prediction "probability" on the parts of image that overlaps. You will need to write custom code for that, as we do not provide code for inference over large rasters.

Best of luck and post some of your results :)

@Farihaa
Copy link
Author

Farihaa commented Nov 17, 2020

img42
img43

These are the results I got on a test image. I used the sliding window method to perform inferences. The third and fourth images represent ground truth and inference result respectively.

However, I am getting a 31.5 second inference time for a single image on gpu. What was your inference time?
Are there any optimisations to lower the inference time?

@feevos
Copy link
Owner

feevos commented Nov 17, 2020

Hi Fariha, the results show that your model is training properly. You need to train (probably much) longer in order to get optimal performance, but they look great for a starting point, my compliments! I do not remember inference time, but it took me ~1-2 hours on CSIRO HPC facilities.

@Farihaa
Copy link
Author

Farihaa commented Nov 17, 2020

Thank you, you helped a lot. Yes that is true, it needs more training. I wanted to ask if I lower the image size and then apply the sliding window method on the resized image, do you think the model performance would degrade? I wanted to experiment with inference time to lower to it.

@feevos
Copy link
Owner

feevos commented Nov 17, 2020

In principle - especially given you are doing zoom in/out operations as data augmentation during training - image resolution does not affect performance that much (assuming you are within the training scaling range), but you should anticipate performance drop if you go in size other than the original resolution. I cannot quantify how much, and it really is a question of cost vs performance issue.

@feevos
Copy link
Owner

feevos commented Feb 17, 2021

Hi @Farihaa, please see issue #6 - I updated the repository fixing this. You should see improved performance after this bug was fixed.

My apologies, it was introduced when I was writing a clean version of the code.

@humayunah
Copy link

Here, '#' before a line represents that it is commented.

from mxnet import nd
from models.changedetection.mantis.mantis_dn import *
from mxnet import autograd,optimizer,gluon
import mxnet as mx
import horovod.mxnet as hvd
from multiprocessing import cpu_count
from mxboard import SummaryWriter


from nn.loss.mtsk_loss import *
from src.LVRCDDataset import *
from src.LVRCDNormal import*
from src.semseg_aug_cv2 import*
from gluoncv.utils import makedirs, TrainingHistory, export_block
from pycm import *


import os
import cv2
from tqdm import tqdm
import time
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import logging
import random
import warnings
import argparse


from skimage.io import imread
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
from skimage.measure import label, regionprops
import matplotlib.patches as mpatches
from skimage.util import montage
from skimage.io import imread , imsave
import numpy as np
montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)

from skimage.morphology import label
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '1'
os.environ['MXNET_GPU_MEM_POOL_RESERVE'] = '19'
os.environ['MXNET_BACKWARD_DO_MIRROR'] = '1'
os.environ['MXNET_GPU_COPY_NTHREADS'] = '1'

no_cuda=False
load_ckpt =True
ep=131

if not no_cuda:
    # Disable CUDA if there are no GPUs.
    if mx.context.num_gpus() == 0:
        no_cuda = True

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--batch_size', type=int, default=2,
                    help='training batch size per worker (default: 64)')
    parser.add_argument('--epochs', type=int, default=30,
                    help='number of training epochs (default: 5)')
    parser.add_argument('--lr', type=float, default=0.0001,
                    help='learning rate (default: 0.01)')

    parser.add_argument('--load_ckpt', action="store_true", default=False,
                    help='load checkpoints from directory')


    args = parser.parse_args()
    return args


class f1_score(object):
    """
    Here NClasses = 2 by default, for a binary segmentation problem in 1hot representation
    """

    def __init__(self,depth=0, NClasses=2):


        self.skip = NClasses

    def score(self,ground,outputs,NClasses=2):

        pred_segm  = np.argmax(outputs[0].asnumpy(), axis=0).flatten()
        pred_bound = np.argmax(outputs[1].asnumpy(), axis=0).flatten()
        pred_dists = np.argmax(outputs[2].asnumpy(), axis=0).flatten()

        # In our implementation of the labels, we stack together the [segmentation, boundary, distance] labels,
        # along the channel axis.
        # reshaping labels
        label_segm  = np.argmax(ground[0].asnumpy(), axis=0).flatten()
        label_bound = np.argmax(ground[1].asnumpy(), axis=0).flatten()
        label_dists = np.argmax(ground[2].asnumpy(), axis=0).flatten()


        # metric calculation
        cm_segm = ConfusionMatrix(label_segm,pred_segm,digit=3)
        cm_bound = ConfusionMatrix(label_bound,pred_bound,digit=3)
        cm_dists = ConfusionMatrix(label_dists,pred_dists,digit=3)

        print ("Segment F1 score : " , cm_segm.F1)
        segm_f1 =sum([cm_segm.F1[f] for f in range(NClasses)])/NClasses
       # bound_f1 =sum([cm_bound.F1[f] for f in range(NClasses)])/NClasses
       # dists_f1 = sum([cm_dists.F1[f] for f in range(NClasses)])/NClasses



        return segm_f1

def createModel():
    # D6nf32 example
    depth = 6
    norm_type = 'GroupNorm'
    norm_groups = 4
    ftdepth = 5
    NClasses = 2
    nfilters_init = 32
    psp_depth = 4
    nheads_start = 4

    #intialize net
    net = mantis_dn_cmtsk(nfilters_init=nfilters_init, NClasses=NClasses, depth=depth,
                          ftdepth=ftdepth, model='CEECNetV1', psp_depth=psp_depth, norm_type=norm_type,
                          norm_groups=norm_groups, nheads_start=nheads_start)

    return net



class SplitSampler(gluon.data.sampler.Sampler):
    """ Split the dataset into `num_parts` parts and sample from the part with
    index `part_index`
    Parameters
    ----------
    length: int
      Number of examples in the dataset
    num_parts: int
      Partition the data into multiple parts
    part_index: int
      The index of the part to read from
    """
    def __init__(self, length, num_parts=1, part_index=0):
        # Compute the length of each partition
        self.part_len = length // num_parts
        # Compute the start index for this partition
        self.start = self.part_len * part_index
        # Compute the end index for this partition
        self.end = self.start + self.part_len

    def __iter__(self):
        # Extract examples between `start` and `end`, shuffle and return them.
        indices = list(range(self.start, self.end))
        random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.part_len


def load_checkpoint(model_dir,ctx):
    '''


    Paramete
    ----------
    model_dir : TYPE
        DESCRIPTION.

    Returns
    -------
    None.

    '''

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        logging.info('loading checkpoint')
       # sm = mx.sym.load(model_dir +'/'+ 'model-cdet-symbol.json')
       # inputs = mx.sym.var('data0','data1')
       # net = mx.gluon.SymbolBlock(sm, inputs)
       # net.collect_params().load(model_dir + '/' + 'model-cdet-0004.params', ctx)
       # net = gluon.nn.SymbolBlock( outputs=mx.sym.load_json(model_dir +'/'+ 'model-cdet-symbol.json'), inputs=mx.sym.var('mantis_dn_cmtsk0_mantis_dn_features0_conv2dnormed0_conv0_weight'))
        #net.load_params(model_dir +'/'+ 'model-cdet-0004.params')

        net = gluon.nn.SymbolBlock.imports(model_dir +'/'+ 'model-cdet-symbol.json', ['data0','data1'], model_dir +'/'+ 'model-cdet-0148.params', ctx=ctx)

    return net

def test(model, ctx, val_data_path,size,num_workers,hvd):
    '''
    Parameters
    ----------
    model : initialized network
    ctx : int/list
        cpu/gpu id list
    val_data : str
        path to validation rec filename
    size : int
        number of devices available

    Returns
    -------
    float
        validation f1 score
    float
        validation loss

    '''
    batch_size = 2
    f1_val = f1_score()

    val_data = LVRCDDataset(root=r'/LEVIRCD/Files/', mode='val')
    nval=val_data.__len__()
    datagen=gluon.data.DataLoader(val_data,batch_size=batch_size,sampler=SplitSampler(nval,num_workers,hvd.rank()))

   # nval = val_data.__len__()
    myMTSKL = mtsk_loss()
    val_loss=0
    val_f1=0

    for img1,img2,labels in datagen:
        img1 = img1.as_in_context(ctx)
        img2=img2.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        outputs= model(img1,img2)
        loss = myMTSKL.loss(outputs,labels)
        val_loss += sum([l.sum() for l in loss.asnumpy()])
        label_segm=labels[:,:NClasses,:,:]
        label_bound=labels[:,NClasses:2*NClasses,:,:]
        label_dists=labels[:, 2*NClasses:,:,:]
        ground = [label_segm,label_bound,label_dists]

        f1= f1_val.score(ground, outputs)
        val_f1 += f1

    return  val_f1/(nval/(batch_size*size)), val_loss/(nval/size)


def train(train_data_path, val_data_path):
    '''
    Parameters
    ----------
    epochs : int
        num iterations to be done over the dataset.
    train_data_path : str
        path to training rec/idx-pass filename.
    val_data_path : str
        path to validation rec/idx-pass filename.

    Returns
    -------
    None.

    '''
    gpu_count = mx.context.num_gpus()
    ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu()
    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    context = mx.cpu(hvd.local_rank()) if no_cuda else mx.gpu(hvd.local_rank())
    num_workers = hvd.size()
    CPU_COUNT = cpu_count()

    sw = SummaryWriter(logdir='./logs')

    epochs = 200
    lr= 0.00000001
    lr_decay_count = 0
    log_interval = 1

    NClasses= 2
    batch_size=2
    NLarge_batch = 32

    if load_ckpt :
       # net = createModel()
       # net.load_parameters('checkpoint/model-cdet-0004.params')
        ep=153
        net = load_checkpoint("checkpoint",context)
      
    else :
        # initializing network
        ep=0
        net = createModel()
      
    net.hybridize()
    # initializing loss
    myMTSKL = mtsk_loss()
    schedule = LearningSchedule(lr=lr,epoch=ep)
    #f1s= f1_score()
    # perform metric
    calc_f1 = f1_score()


    # Define our trainer for net

    # Set parameters
    adam_optimizer = optimizer.Adam(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-08)
    net.initialize(ctx=context)
    #net.collect_params().reset_ctx(context)
    # Horovod: fetch and broadcast parameters
    params = net.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params,adam_optimizer )



    train_history = TrainingHistory(['training-f1-score', 'validation-f1-score'])
    # load dataset
    tnorm = LVRCDNormal() # image normalization for each chip
    ttransform = SemSegAugmentor_CV()
    train_data=LVRCDDataset(root=r'/LEVIRCD/Files/', mode='train',transform=ttransform,norm=tnorm)
    ntrain = train_data.__len__()
    print("Total samples : %d" % ntrain)

    datagen=gluon.data.DataLoader(train_data,batch_size=NLarge_batch,last_batch='discard',num_workers=CPU_COUNT,pin_memory=True,sampler=SplitSampler(ntrain,num_workers,hvd.rank()))

    val_prev=0
    # training loop
    for epoch in range(ep,epochs):
        tic = time.time()
        train_loss = 0
        j=0
        fout=0
        tlout=0



        for img1_b,img2_b,label_b in datagen:
                  

                tl_batch=0
                f1_batch=0
                img1_split = np.array_split(img1_b,16)
                img2_split = np.array_split(img2_b,16)
                label_split = np.array_split(label_b,16)


                for i in range(len(img1_split)):
                    img1 = mx.nd.array(img1_split[i],context)
                    img2= mx.nd.array(img2_split[i],context)
                    labels = mx.nd.array(label_split[i],context)
                    with autograd.record():
                          outputs=net(img1,  img2)
                          loss = myMTSKL.loss(outputs,labels)

                 

                    loss.backward()
                    train_loss = sum([l.sum() for l in loss.asnumpy()])
                    print("Loss for a batch : " , loss)

                    tlout = tlout + train_loss
                    tl_batch = tl_batch + train_loss
                    #print(train_loss)
                    label_segm=labels[:,:NClasses,:,:]
                    label_bound=labels[:,NClasses:2*NClasses,:,:]
                    label_dists=labels[:, 2*NClasses:,:,:]
                    ground = [label_segm ,label_bound,label_dists]

                    f1= calc_f1.score(ground,outputs,NClasses)

                    # f1_score.update(labels= ground, preds=outputs)
                    # name,f1= f1_score.get()
                    fout = fout + f1
                    f1_batch= f1_batch + f1

                  

                trainer.step(NLarge_batch)

                # net.export("model-cdet", epoch)

                # Update history and print metrics
                #train_history.update([1-f1, 1-val_f1])
                print('[Epoch %d] f1-score=%f train_loss=%f time: %f' %
                      (epoch, f1_batch/(NLarge_batch/batch_size), tl_batch/NLarge_batch, time.time()-tic))
                j=0
                # manually zero the gradients before# the next batch evaluation

                for weight_variable in net.collect_params().values():
                    weight_variable.zero_grad()

        #sw.add_scalar(tag='train_f1score', value=fout/(ntrain/(batch_size *hvd.size())), global_step=epoch)
        # saving model parameters and weights
        print('[Epoch %d] Train-f1 = %f Train-loss = %f'% (epoch,
                     fout/(ntrain/(batch_size * hvd.size())), tlout/(ntrain/hvd.size())))
        if (epoch % 4==0):
            net.export("model-cdet",epoch)
           # lr= trainer.learning_rate
           # trainer.set_learning_rate(lr * 2)
            trainer.save_states("states" + str(epoch))
       # if hvd.rank() == 0:
         #   logging.info('Epoch[%d]\tTrain-f1 = %f\tTrain-loss = %f', epoch,
          #           fout/ntrain/batch_size, train_loss/ntrain)

        if fout/(ntrain/batch_size*hvd.size()) > 0.7 and tlout/(ntrain/hvd.size()) < 0.1:
            net.export("model-cdet", epoch)
            print('performing validation')
            val_f1, val_loss= test(net, context, val_data_path,hvd.size(),num_workers,hvd)
            print('[Epoch %d] val_loss=%f f1-val=%f' %
              (epoch, val_loss, val_f1))
            if abs(val_loss - val_prev) < 0.0005:
                lr=trainer.learning_rate
                trainer.set_learning_rate(lr/10)
                myMTSKL.depth= myMTSKL.depth + 10
                val_prev = val_loss



if __name__ == "__main__":
    train_data_path = "training_LVRCD_F256"
    val_data_path = "validation_LVRCD_F256"
    # Training settings
    #args = parse_args()
    train( train_data_path, val_data_path)
    #main()

depth:= 0, nfilters: 32, nheads::4, widths::1
depth:= 1, nfilters: 64, nheads::8, widths::1
depth:= 2, nfilters: 128, nheads::16, widths::1
depth:= 3, nfilters: 256, nheads::32, widths::1
depth:= 4, nfilters: 512, nheads::64, widths::1
depth:= 5, nfilters: 1024, nheads::128, widths::1
depth:= 6, nfilters: 512, nheads::128, widths::1
depth:= 7, nfilters: 256, nheads::64, widths::1
depth:= 8, nfilters: 128, nheads::32, widths::1
depth:= 9, nfilters: 64, nheads::16, widths::1
depth:= 10, nfilters: 32, nheads::8, widths::1
Total samples : 16288
[12:12:27] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
[12:12:30] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
[12:12:31] src/operator/nn/./cudnn/cudnn_pooling-inl.h:375: 0D pooling is not supported by cudnn, MXNet 0D pooling is applied.
[12:12:32] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Loss for a batch :
[0.5339367]
<NDArray 1 @gpu(0)>
Segment F1 score : {'0': 1.0, 'other': 'None'}
Traceback (most recent call last):
File "issues-code.py", line 426, in
train( train_data_path, val_data_path)
File "issues-code.py", line 371, in train
f1= calc_f1.score(ground,outputs,NClasses)
File "issues-code.py", line 98, in score
segm_f1 =sum([cm_segm.F1[f] for f in range(NClasses)])/NClasses
File "issues-code.py", line 98, in
segm_f1 =sum([cm_segm.F1[f] for f in range(NClasses)])/NClasses
KeyError: 0

Hi @Farihaa @feevos I've been trying to train on LEVIRCD Dataset using miss Farihaa's given code and I'm stuck at this error. I would be grateful if u could help me figure out the solution to this error. Thankyou.

@feevos
Copy link
Owner

feevos commented Feb 25, 2021

Hi @humayunah , the error suggests the dictionary cm_segm.F1 does not have key 0, so there is something wrong there.

@humayunah
Copy link

Thanks for the reply @feevos . But the problem is in the output of cm_segm.F1 = {'0': 1.0, 'other': 'None'}
segm_f1 =sum([cm_segm.F1[f] for f in range(NClasses)])/NClasses

and even if I loop through the dictionary using 'dict'.items(), the sum of the two values is '1.0 + None' which evaluates to the TypeError operation. Could you evaluate why the second value of the cm_segm.F1 is 'None'. Thanks

@feevos
Copy link
Owner

feevos commented Feb 25, 2021

I have no idea @humayunah and - since I haven't written this training code, I am afraid I cannot help here. But you can easily remove this monitoring function, replace it with one of yours and take it from there? I will add a simple training code in the repository in the revision of the paper (almost there).

@Farihaa
Copy link
Author

Farihaa commented Mar 3, 2021

@humayunah The 'none' value comes out when score for that label does not exist because that class does not exit in the data point. So you will have to decide how do you want to proceed with the calculation of overall F1 score if a certain class does not exist in a sample

@coocooon
Copy link

coocooon commented Jul 5, 2021

Thank you, you helped a lot. Yes that is true, it needs more training. I wanted to ask if I lower the image size and then apply the sliding window method on the resized image, do you think the model performance would degrade? I wanted to experiment with inference time to lower to it.

Dear Farihaa, I recently saw this excellent document and want to run it, but I found that there is no train.py. I saw your communication here. Can you share the final version of the train you are running? Many thanks
or my email:196715133@qq.com Thank you again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants