Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement groups for conv1d and conv2d #637

Closed
wants to merge 1 commit into from

Conversation

jimexist
Copy link

@jimexist jimexist commented Feb 6, 2024

Proposed changes

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jimexist jimexist changed the title implement groups for conv1d and conv2d implement groups for conv1d and conv2d Feb 6, 2024
@jimexist jimexist marked this pull request as draft February 6, 2024 09:13
@jagrit06
Copy link
Member

jagrit06 commented Feb 8, 2024

Thanks for putting this together!
I'm putting together an update to the convolutions that won't have groups handled in the backend right away - so this would be a good way to add that functionality till that specialization is ready!

I'm just dropping a comment here to give you a heads up so that you can be on the look out for any updates and PRs to conv in the meantime

@@ -34,6 +37,8 @@ def __init__(
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
Copy link

@tic-top tic-top Feb 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the shape of the self.weight doesn't change?
I think the shape should be (out_channel, kernel_size, in_channels/groups) and (in_channel%groups ==0).

lambda x: (x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
(kernel_size, stride, padding, dilation),
)
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
self.weight = mx.random.uniform(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, in_channel/groups

@@ -2767,6 +2768,22 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
<< " input: " << in.shape() << " and weight: " << wt.shape();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the group>1, then the last dimension of weight * group == the last dimension of input shape

@tic-top
Copy link

tic-top commented Feb 11, 2024

After I fixed the bugs, I had a try on depthwise conv1d, very slow.
I can't wait to see convolution update!

@a1eaiactaest
Copy link
Contributor

@jimexist any updates on this?

@jimexist
Copy link
Author

@jimexist any updates on this?

i paused working on this due to @jagrit06's message. not sure if our paths will cross and duplicate work

@awni
Copy link
Member

awni commented Mar 5, 2024

What's the status of this PR? I think we could move it forward as no-one else is working on groups right now? Wdyt?

Otherwise let's close it so someone else can pick this up.

@jimexist
Copy link
Author

jimexist commented Mar 7, 2024

What's the status of this PR? I think we could move it forward as no-one else is working on groups right now? Wdyt?

Otherwise let's close it so someone else can pick this up.

in that case let me see if i can move on and implement this, also rebasing after:

@Stealeristaken
Copy link
Contributor

Is that PR still active? I would like to work on it.

@jimexist
Copy link
Author

Is that PR still active? I would like to work on it.

@Stealeristaken feel free to take over since i am bit busy and engaged in March

@awni awni closed this Mar 19, 2024
@jimexist jimexist deleted the conv-groups branch March 20, 2024 03:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants