You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
학습과정에서 KD를 사용하는 것이 성능 향상과, 학습 안정성에 큰 도움이 된다는 것을 밝히 논문입니다.
대부분의 모델에서 동일한 하이퍼파라미터를 사용해서 SOTA에 근접한 성능을 달성하였습니다.
한줄 평을 하자면 'KD는 학습 안정성에 도움을 주고, Augmentation이 부정확한 라벨을 정제해서 성능 향상에 도움이 됬다.'라고 볼 수 있을거 같습니다.
Training scheme은 매우 중요한데, 모델별로 3.3%에서 최대 6%까지도 성능 차이가 발생할 수 있습니다.
예를 들어서 ResNet50의 하이퍼파라미터로 EfficientNetV2를 학습하면 3.3%의 성능하락을 발견할 수 있고,
ViT-L의 경우에는 사람에따라 6%까지도 차이가 날 수 있습니다.
모델에 따른 학습 방법들은 backbone에 따라 다음과 같이 4가지의 분류로 나눌 수 있습니다.
ResNet like Model : TResNet이나, SEResNet, Renet-D등 다양한 ResNet 기반의 모델들은, 대부분의 traning scheme에서 좋은 성능을 보여주는 것으로 알려져 있습니다.
Mobile oriented Model : 해당 모델들의 경우에는 depth-wise convolution과, 효율적인 cpu 지향적 구조에 전적으로 의존하고 있습니다. 그리고 이러한 방법들은 보통 RMSProp Optimizer를 사용하거나, waterfall learning rate schedular를 사용하기도 하고, EMA를 사용하기도 합니다.
Transformer based Model : Transformer의 경우에는 학습이 불안정한 편에 속하고, 기본적으로 많은 에폭(1000 epochs)을 필요로 하며, 많은 데이터가 필요합니다.
MLP only Model : Transformer와 마찬가지입니다.
최근 대부분의 학습 방법들은 ImageNet-21K에서 ImageNet-1K로 transfer learning을 진행했습니다.
하지만 이는 사전에 Imagenet-21K에서 학습된 모델이 있어야 하고, 모델의 구조가 동일해야 하며, 학습에 오랜 시간이 걸리는 문제들이 있었습니다.
이 논문에서는 USI(Unified Scheme for Imagenet)이라는 모델에 관계없이 좋은 성능을 얻을 수 있는, 학습 방법을 제안합니다.
아래 그림에서 볼 수 있듯, USI는 대부분의 모델에서 높은 성능 향상을 보여줍니다.
처음에 저자들이 이 논문을 작성할 때, 주목한 점은 KD가 backbone에 상관없이 어떤 모델에서도 잘 동작한다는 점이었습니다.
그리고, KD가 잘 작동하는 이유는 기존에 Label에서는 없던 정보들이, Teacher Model로부터 추가되었다는 점입니다.
그리고 이렇게 만들어진 정보를 통해서 학습된 Student 모델은 심지어 기존 모델보다 좋은 성능을 달성하기도 합니다.
이는 라벨이 더 많은 정보를 포함하고 있고, 라벨의 오류 또한 보정하며, 클래스들간의 상관관계 또한 포함하기 때문입니다.
이렇게 만들어진 KD는 augmentation에서 더 적합하며, label smoothing을 제거할 수 있고, 더 적은 training trick과, training epoch(maximum 300epoch), 심지어 regularization 성능마저 좋아지게 됩니다.
일반적으로 전이학습이 아닌, from scratch로부터의 학습은 일반적으로 더 어렵고, 더 높은 learning rate와, 강력한 정규화 및 더 많은 epoch을 통해 학습을 진행해야 합니다. 이는 모델들 간 학습과정에서 하이퍼파라미터가 상당부분 달라지는 원인이기도 합니다.
우선, KD의 장점을 설명하기 위해 아래 그림을 첨부합니다.
사진(a)를 보면, GT(빨간색)이 가리키고 있는것이, Teacher의 예측값과 동일합니다.
사진(b)를 보면, GT가 여객기이므로 Prediction도 여객기를 가장 높은 확률로 맞춥니다. 하지만, 날개 또한 11.3%로 예측을 하고 있습니다. 실제로 여객기는 날개를 포함하고 있고, 그 주변의 비행기들도 여러개의 날개를 가지고 있으므로, 이것은 잘못 예측된 게 아닙니다.
즉, Teahcer의 Prediction이 GT보다 많은 정보를 제공하는 하나의 예시라고 볼 수 있습니다.
이는, 이미지에 대한 보다 정확한 정보를 Teacher의 Prediction값이 가지고 있다는 것을 의미합니다.
그림(c)는 암탉을 나타내고 있습니다. 그러나 암탉은 매우 작고, 수탉이랑 헷갈립니다.
그러므로 Teacher는 55.5%로 낮은 확률로 정답을 맞췄지만, 사람도 헷갈릴 정도로 복잡한 문제이므로, 이는 논리적 오류에 해당한다고 볼 수 있습니다.
그림 (d)를 보시면, GT와 Prediction 결과가 동일하지는 않습니다.
GT는 아이스크림을 나타내지만, Prediction은 개가 이미지 안에 더 많은 부분을 포함하고 있으므로, 개를 main으로 아이스크림을 그 다음으로 많은 확률로 맞추고 있습니다.
이것은 사람에 따라 GT가 잘못되었다고 판단할수도 있는 부분입니다.
즉, 잘 학습된 모델이라면, 대부분의 경우 Prediction값이 GT보다 많은 정보를 포함하게 됩니다.
심지어는 라벨이 가진 오류를 보정하기도 합니다.
제안하는 학습 방법은 다음과 같습니다.
이 학습 방법은, 대부분의 backbone에서 잘 작동하는 것으로 확인되었습니다.
대부분의 모델들은 Parameter에 따라 batch_size가 달라지므로, 서로 성능을 비교하는 것이 어려웠습니다.
또한, batch_size가 커지게 되면, 더 큰 Learning rate나, 전용 옵티마이저를 사용해야 했습니다. batch_size는 성능에 영향을 주는 요인이기 떄문입니다.
본 논문에서 실험된 모델들은 한정된 자원 안에서 최대 112~504의 batch size들을 사용할 수 있지만, 제안된 방법은 batch size에 영향을 받지 않으므로, 공정한 실험 비교가 가능해집니다.
또한, 교사와 학생은 모델(CNN, Transformer)에 관계없이 Knowledge Distillation이 잘 적용되는 것을 관찰할 수 있었습니다.
보시게 되면, Teacher와 Student는 어떤 구조를 사용 하든지, 성능이 오르게 됩니다.
심지어, Teacher가 Student보다 작은 경우에도 성능은 오른다는 것을 알 수 있습니다.
USI를 적용해서 성능을 측정한 결과표는 다음과 같습니다.
그리고 추가로, vanilla KD에서 학습 시, hard label에 대한 정보를 제거해도 성능은 동일하게 유지 됩니다.
이는 teacher의 prediction value가 GT보다 좋다는 것을 의미합니다.
KD Temperature는 적용하지 않는게 성능이 가장 좋았으며, Drop-Path와 같은 Regularization Technique에도 매우 강인한것을 보여줍니다.
결론 : KD 짱짱
The text was updated successfully, but these errors were encountered:
궁금증 : Student는 Teacher를 모방한다고 알고 있었는데, Teacher가 성능이 더 낮음에도 불구하고 Student가 더 높은 성능을 보여줄 수 있는 것은 Noise를 넣어줬기 때문이라고 생각이 됩니다.
Noise를 넣어주지 않는다면, Student는 단순히 Teacher를 모방하게 되지만, Noise가 들어간 경우에도 동일한 추론값을 맞추도록 학습이 되기 때문에, Noise의 종류에 따라 성능이 차이가 나게 됩니다.
그러면, KD 자체는 성능을 복제하는 과정이고 이건 성능을 안정적으로 뽑기 위한 테크닉인데, 결국 성능을 올린건 Mixup-Cutmix Augmentation이라고 볼 수 있을거 같습니다.
그럼 여기서 궁금한건 왜 Augmentation이 성능 향상에 도움이 됬을까입니다.
classification 과정에서 영역별로 Teacher가 추론하고, Student가 영역별로 추론한 값을 따라가서 그런것인가?
그러니까 무슨 얘기냐면, 위에 강아지가 아이스크림을 먹고 있는 그림에서, 강아지만 Cut된 이미지가 들어오면 Teacher는 강아지에 대해서 높은 라벨 정보를 줄 것이고, Student도 마찬가지로 해당 부분은 강아지라고 학습을 할 수 있기 때문이라고 생각합니다. 라벨 정보를 더 refine했다고 볼 수 있겠네요.
즉. Augmentation이 Label 정보를 Refine했다가 성능 향상의 주요 쟁점인거 같습니다.
!!Batch size에 왜 둔감다하는 것인지? 이 부분은 도무지 이해가 안갑니다.!!
이미 학습된 Weight를 따라가는 것이기 때문에 둔감한 건가요?
Student가 Teacher의 결과값을 모방하는 것이다 보니, Sample의 갯수가 Contrastive Learning이나 다른 SSL 방법들보다 안정적이라는 것은 이해가 갑니다.
하지만 이건 vanilla KD라서, 최종 추론 결과만 주어지는 것이고, 중간에 있는 layer 각각의 feature들을 모방하는 것이 아니기 때문에, batch의 영향을 안 받기는 어려울 것이라는 생각이 듭니다.
그런데 왜 batch에 둔감한 걸까요?
Contrastive Learning의 경우에는 서로 다른 Augmentation을 줘서, 모델이 헷갈려서 그런건가?
KD는 같은 Augmentation이라서 그런 것이고?
흠..;; 잘 모르겠네요
일반적으로 hard label에 대해서 학습을 진행하게 되면, 이미 정해진 결과에 대해서 추론을 하는 것인데도 불구하고,
성능이 들쭉 날쭉한 것을 볼 수 있습니다. 이는 regnet에서 실험을 진행했듯이, 같은 모델에 대해서도 hard label로 학습을 진행 했을때
2~3%정도의 성능차이가 나게 됩니다.
KD의 경우에도 Teacher의 출력값은 Hard Label이 Teacher에 따라 Adaptive 하게 Smoothing된 Label 정보지만, 간단하게 Teacher가 Refine한 prediction value라고 부르게 되면, 같은 모델 내에서도 하이퍼 파라미터가 변경됨에 따라서 성능 차이가 발생하는 것이 원론적으로는 맞아야 할 것입니다.
그런데 왜 KD 방법론이 안정적인가?에 대해서는 연구들이 좀 더 필요할 거 같습니다.
이는 KD로 학습된 얘들이 더 Loss Surface를 안정적으로 만들어서 그런게 아닐까? 싶습니다.
학습과정에서 KD를 사용하는 것이 성능 향상과, 학습 안정성에 큰 도움이 된다는 것을 밝히 논문입니다.
대부분의 모델에서 동일한 하이퍼파라미터를 사용해서 SOTA에 근접한 성능을 달성하였습니다.
한줄 평을 하자면 'KD는 학습 안정성에 도움을 주고, Augmentation이 부정확한 라벨을 정제해서 성능 향상에 도움이 됬다.'라고 볼 수 있을거 같습니다.
Training scheme은 매우 중요한데, 모델별로 3.3%에서 최대 6%까지도 성능 차이가 발생할 수 있습니다.
예를 들어서 ResNet50의 하이퍼파라미터로 EfficientNetV2를 학습하면 3.3%의 성능하락을 발견할 수 있고,
ViT-L의 경우에는 사람에따라 6%까지도 차이가 날 수 있습니다.
모델에 따른 학습 방법들은 backbone에 따라 다음과 같이 4가지의 분류로 나눌 수 있습니다.
최근 대부분의 학습 방법들은 ImageNet-21K에서 ImageNet-1K로 transfer learning을 진행했습니다.
하지만 이는 사전에 Imagenet-21K에서 학습된 모델이 있어야 하고, 모델의 구조가 동일해야 하며, 학습에 오랜 시간이 걸리는 문제들이 있었습니다.
이 논문에서는 USI(Unified Scheme for Imagenet)이라는 모델에 관계없이 좋은 성능을 얻을 수 있는, 학습 방법을 제안합니다.
아래 그림에서 볼 수 있듯, USI는 대부분의 모델에서 높은 성능 향상을 보여줍니다.
처음에 저자들이 이 논문을 작성할 때, 주목한 점은 KD가 backbone에 상관없이 어떤 모델에서도 잘 동작한다는 점이었습니다.
그리고, KD가 잘 작동하는 이유는 기존에 Label에서는 없던 정보들이, Teacher Model로부터 추가되었다는 점입니다.
그리고 이렇게 만들어진 정보를 통해서 학습된 Student 모델은 심지어 기존 모델보다 좋은 성능을 달성하기도 합니다.
이는 라벨이 더 많은 정보를 포함하고 있고, 라벨의 오류 또한 보정하며, 클래스들간의 상관관계 또한 포함하기 때문입니다.
이렇게 만들어진 KD는 augmentation에서 더 적합하며, label smoothing을 제거할 수 있고, 더 적은 training trick과, training epoch(maximum 300epoch), 심지어 regularization 성능마저 좋아지게 됩니다.
일반적으로 전이학습이 아닌, from scratch로부터의 학습은 일반적으로 더 어렵고, 더 높은 learning rate와, 강력한 정규화 및 더 많은 epoch을 통해 학습을 진행해야 합니다. 이는 모델들 간 학습과정에서 하이퍼파라미터가 상당부분 달라지는 원인이기도 합니다.
우선, KD의 장점을 설명하기 위해 아래 그림을 첨부합니다.
사진(a)를 보면, GT(빨간색)이 가리키고 있는것이, Teacher의 예측값과 동일합니다.
사진(b)를 보면, GT가 여객기이므로 Prediction도 여객기를 가장 높은 확률로 맞춥니다. 하지만, 날개 또한 11.3%로 예측을 하고 있습니다. 실제로 여객기는 날개를 포함하고 있고, 그 주변의 비행기들도 여러개의 날개를 가지고 있으므로, 이것은 잘못 예측된 게 아닙니다.
즉, Teahcer의 Prediction이 GT보다 많은 정보를 제공하는 하나의 예시라고 볼 수 있습니다.
이는, 이미지에 대한 보다 정확한 정보를 Teacher의 Prediction값이 가지고 있다는 것을 의미합니다.
그림(c)는 암탉을 나타내고 있습니다. 그러나 암탉은 매우 작고, 수탉이랑 헷갈립니다.
그러므로 Teacher는 55.5%로 낮은 확률로 정답을 맞췄지만, 사람도 헷갈릴 정도로 복잡한 문제이므로, 이는 논리적 오류에 해당한다고 볼 수 있습니다.
그림 (d)를 보시면, GT와 Prediction 결과가 동일하지는 않습니다.
GT는 아이스크림을 나타내지만, Prediction은 개가 이미지 안에 더 많은 부분을 포함하고 있으므로, 개를 main으로 아이스크림을 그 다음으로 많은 확률로 맞추고 있습니다.
이것은 사람에 따라 GT가 잘못되었다고 판단할수도 있는 부분입니다.
즉, 잘 학습된 모델이라면, 대부분의 경우 Prediction값이 GT보다 많은 정보를 포함하게 됩니다.
심지어는 라벨이 가진 오류를 보정하기도 합니다.
제안하는 학습 방법은 다음과 같습니다.
이 학습 방법은, 대부분의 backbone에서 잘 작동하는 것으로 확인되었습니다.
대부분의 모델들은 Parameter에 따라 batch_size가 달라지므로, 서로 성능을 비교하는 것이 어려웠습니다.
또한, batch_size가 커지게 되면, 더 큰 Learning rate나, 전용 옵티마이저를 사용해야 했습니다. batch_size는 성능에 영향을 주는 요인이기 떄문입니다.
본 논문에서 실험된 모델들은 한정된 자원 안에서 최대 112~504의 batch size들을 사용할 수 있지만, 제안된 방법은 batch size에 영향을 받지 않으므로, 공정한 실험 비교가 가능해집니다.
또한, 교사와 학생은 모델(CNN, Transformer)에 관계없이 Knowledge Distillation이 잘 적용되는 것을 관찰할 수 있었습니다.
보시게 되면, Teacher와 Student는 어떤 구조를 사용 하든지, 성능이 오르게 됩니다.
심지어, Teacher가 Student보다 작은 경우에도 성능은 오른다는 것을 알 수 있습니다.
USI를 적용해서 성능을 측정한 결과표는 다음과 같습니다.
그리고 추가로, vanilla KD에서 학습 시, hard label에 대한 정보를 제거해도 성능은 동일하게 유지 됩니다.
이는 teacher의 prediction value가 GT보다 좋다는 것을 의미합니다.
KD Temperature는 적용하지 않는게 성능이 가장 좋았으며, Drop-Path와 같은 Regularization Technique에도 매우 강인한것을 보여줍니다.
결론 : KD 짱짱
The text was updated successfully, but these errors were encountered: