-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] add model script of cait #547
Conversation
'first_conv': '', 'classifier': '', | ||
**kwargs | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缺少default_cfg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
mindcv/models/cait.py
Outdated
} | ||
|
||
|
||
class Class_Attention(nn.Cell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个类名在原作者的实现中就是大写加下划线吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
mindcv/models/cait.py
Outdated
return x_cls | ||
|
||
|
||
class LayerScale_Block_CA(nn.Cell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
mindcv/models/cait.py
Outdated
|
||
attn = self.q_matmul_k(q, k) | ||
|
||
"talking head trick" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是什么意思?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a trick used in cait. I have deleted this comment line.
mindcv/models/cait.py
Outdated
self.gamma_2 = Parameter(init_values * ops.ones((dim), ms.float32), requires_grad=True) | ||
|
||
def construct(self, x: Tensor) -> Tensor: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去除不必要空行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted
mindcv/models/cait.py
Outdated
|
||
self._init_weights() | ||
|
||
def _init_weights(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名与其他文件统一
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
mindcv/models/cait.py
Outdated
|
||
|
||
@register_model | ||
def cait_XXS24_224(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数名也是
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原作者实现中写的就是这样的函数名
Thank you for your contribution to the MindCV repo.
Before submitting this PR, please make sure:
Motivation
The model script of cait is added.
Test Plan
Please use 'create_model' part for testing models in cait.
Related Issues and PRs
Related issue: #424