Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transpose for complex64 type #6410

Merged
merged 7 commits into from
May 20, 2022
Merged

Implement transpose for complex64 type #6410

merged 7 commits into from
May 20, 2022

Conversation

DCtheTall
Copy link
Contributor

@DCtheTall DCtheTall commented May 15, 2022

For #4649

This PR implements tf.transpose for the complex64 type using higher level ops so that it does not need to be implemented separately on each backend


This change is Reviewable

@rthadur rthadur requested review from pyu10055 and lina128 May 16, 2022 15:22
Copy link
Collaborator

@lina128 lina128 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

Reviewable status: 0 of 1 approvals obtained (waiting on @DCtheTall and @pyu10055)


tfjs-core/src/ops/transpose.ts line 52 at r1 (raw file):

 * @doc {heading: 'Operations', subheading: 'Matrices'}
 */
function transpose_<T extends Tensor>(x: T|TensorLike, perm?: number[], conjugate?: boolean): T {

Add a doccomment about this parameter above.

Code quote:

conjugate?: boolean

tfjs-core/src/ops/transpose.ts line 81 at r1 (raw file):

    $real = ENGINE.runKernel(
        Transpose, {x: $real} as {} as NamedTensorMap,
        attrs as {} as NamedAttrMap);

Wrap the whole thing in tf.tidy(...), otherwise, there's memory leak.

Code quote:

    let $real = real($x);
    let $imag = imag($x);
    $real = ENGINE.runKernel(
        Transpose, {x: $real} as {} as NamedTensorMap,
        attrs as {} as NamedAttrMap);
    $imag = ENGINE.runKernel(
        Transpose, {x: $imag} as {} as NamedTensorMap,
        attrs as {} as NamedAttrMap);
    if (conjugate) $imag = neg($imag);
    return complex($real, $imag);

tfjs-core/src/ops/transpose.ts line 85 at r1 (raw file):

        Transpose, {x: $imag} as {} as NamedTensorMap,
        attrs as {} as NamedAttrMap);
    if (conjugate) $imag = neg($imag);

style sugg: Wrap in {} braces.

Code quote:

$imag = neg($imag);

@DCtheTall
Copy link
Contributor Author

@lina128 made your suggested changes in the most recent commit.

Copy link
Collaborator

@lina128 lina128 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @DCtheTall and @pyu10055)

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

Reviewed 1 of 2 files at r1, 1 of 1 files at r2, all commit messages.
Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @DCtheTall)

@DCtheTall
Copy link
Contributor Author

Thanks all, happy to make another contribution to TFJS 😄

@pyu10055 pyu10055 merged commit 42fcd6f into tensorflow:master May 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants