You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello. I am unable to make work hard triplet mining and balanced batches. I think that we had a discussion about it here, but so far I think embedding are always collapsing into one point. I tried many combinations of "margins", "num_classes_per_batch" , "num_images_per_class". But nothing seems to work. Could you please take a look at the code if there is some obvious problem? Noting that with batch_all strategy, it works well.
Thanks,
Tom
def train_input_fn(data_dir, params):
data_root = pathlib.Path(data_dir)
all_image_paths = list(data_root.glob('**/*.jpg'))
all_directories = {'/'.join(str(i).split("/")[:-1]) for i in all_image_paths}
print("-----")
print("num of labels: ")
print(len(all_directories))
print("-----")
labels_index = list(i.split("/")[-1] for i in all_directories)
# Create the list of datasets creating filenames
datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir), shuffle=False) for image_dir in all_directories]
num_labels = len(all_directories)
print(datasets)
num_classes_per_batch = params.num_classes_per_batch
num_images_per_class = params.num_images_per_class
def get_label_index(s):
return labels_index.index(s.numpy().decode("utf-8").split("/")[-2])
def preprocess_image(image):
image = tf.cast(image, tf.float32)
image = tf.math.divide(image, 255.0)
return image
def load_and_preprocess_image(path):
image = tf.read_file(path)
return tf.py_function(preprocess_image, [image], tf.float32), tf.py_function(get_label_index, [path], tf.int64)
def generator():
while True:
# Sample the labels that will compose the batch
labels = np.random.choice(range(num_labels),
num_classes_per_batch,
replace=False)
for label in labels:
for _ in range(num_images_per_class):
yield label
choice_dataset = tf.data.Dataset.from_generator(generator, tf.int64)
dataset = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
batch_size = num_classes_per_batch * num_images_per_class
print("----------------------")
print(batch_size)
print("----------------------")
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(params.num_epochs)
dataset = dataset.prefetch(1)
print(dataset)
return dataset
The text was updated successfully, but these errors were encountered:
batrlatom
changed the title
Batch hard and balabced batch
Batch hard and balanced batch
Aug 5, 2019
If the batch all loss works, and the batch hard triplet loss does not, this might indicate that your dataset is a bit noisy so hard triplets are mislabeled.
You can also train first with batch all, then finetune at the end with batch hard.
Hello. I am unable to make work hard triplet mining and balanced batches. I think that we had a discussion about it here, but so far I think embedding are always collapsing into one point. I tried many combinations of "margins", "num_classes_per_batch" , "num_images_per_class". But nothing seems to work. Could you please take a look at the code if there is some obvious problem? Noting that with batch_all strategy, it works well.
Thanks,
Tom
The text was updated successfully, but these errors were encountered: