From a2a8547c33db74be08c8c769f1a144246b0f2b10 Mon Sep 17 00:00:00 2001 From: Dibya Ghosh Date: Tue, 10 Dec 2024 16:10:50 -0800 Subject: [PATCH] Update beta sampling code in augment.py --- official/vision/ops/augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/official/vision/ops/augment.py b/official/vision/ops/augment.py index 90c7266d4de..dc45324a034 100644 --- a/official/vision/ops/augment.py +++ b/official/vision/ops/augment.py @@ -2709,8 +2709,8 @@ def distort(self, images: tf.Tensor, @staticmethod def _sample_from_beta(alpha, beta, shape): - sample_alpha = tf.random.gamma(shape, 1., beta=alpha) - sample_beta = tf.random.gamma(shape, 1., beta=beta) + sample_alpha = tf.random.gamma(shape, alpha, beta=1.0) + sample_beta = tf.random.gamma(shape, beta, beta=1.0) return sample_alpha / (sample_alpha + sample_beta) def _cutmix(self, images: tf.Tensor,