-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement groups
for conv1d and conv2d
#637
Conversation
groups
for conv1d and conv2d
Thanks for putting this together! I'm just dropping a comment here to give you a heads up so that you can be on the look out for any updates and PRs to conv in the meantime |
@@ -34,6 +37,8 @@ def __init__( | |||
kernel_size: int, | |||
stride: int = 1, | |||
padding: int = 0, | |||
dilation: int = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
): | |||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the shape of the self.weight doesn't change?
I think the shape should be (out_channel, kernel_size, in_channels/groups) and (in_channel%groups ==0).
lambda x: (x, x) if isinstance(x, int) else x, | ||
(kernel_size, stride, padding), | ||
(kernel_size, stride, padding, dilation), | ||
) | ||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) | ||
self.weight = mx.random.uniform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here, in_channel/groups
@@ -2767,6 +2768,22 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) { | |||
<< " input: " << in.shape() << " and weight: " << wt.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the group>1, then the last dimension of weight * group == the last dimension of input shape
After I fixed the bugs, I had a try on depthwise conv1d, very slow. |
@jimexist any updates on this? |
What's the status of this PR? I think we could move it forward as no-one else is working on groups right now? Wdyt? Otherwise let's close it so someone else can pick this up. |
in that case let me see if i can move on and implement this, also rebasing after: |
Is that PR still active? I would like to work on it. |
@Stealeristaken feel free to take over since i am bit busy and engaged in March |
Proposed changes
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes