Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]: The number of convolutional multiplication decreases but the communication cost increases in SPU #678

Closed
warpoons opened this issue May 10, 2024 · 4 comments
Assignees

Comments

@warpoons
Copy link

Issue Type

Performance

Modules Involved

SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0.dev20240311

OS Platform and Distribution

Ubuntu 18.04.6 LTS by WSL

Python Version

3.10

Compiler Version

GCC 11.3.0

Current Behavior?

Not a bug. Just an abnormal question: I have tested the Comm. cost to evaluate the first and individual conv layer Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) of ResNet18 on CIFAR10. It costed 759296 byte Comm. and 0.015497988s latency.

Since Conv is multiplication-intensive, a solution of reducing the Comm./Latency cost is to reduce the number of multiplications using Winograd algorithm. Winograd uses some pre-defined matrices to transform the weight and input to Winograd-domain counterparts and implement element-wise matrix multiplication (EWMM) between the transformed weight&input. The output of EWMM after an additional transformation is equivalent to that of standard Conv. On average, the number of multiplications can be reduced by 2.25 times using Winograd without any accuracy loss.

I have tested the Comm./Latency cost of the standard/Winograd Conv. But curiously, the cost of Winograd conv is significantly increased: 6291456 byte Comm. and 0.0487127s latency.

Theoretically, for the first layer of ResNet18 on CIFAR10, the standard conv has 1,769,472 multiplications, and Winograd conv has 786432 multiplications (2.25x reduction), but the Comm. increases by 8.2859 times.

May I ask if you understand the underlying reasons, or if there are some potential convolution-specific optimizations that I am not aware of?

Thanks a lot.

Standalone code to reproduce the issue

N/A

Relevant log output

Here I report the SPU logs relevant to standard and Winograd conv evaluation:
Standard conv:
[2024-05-10 18:26:50,734] [Process-1] Starting grpc server at 127.0.0.1:61320
[2024-05-10 18:26:50,734] [Process-2] Starting grpc server at 127.0.0.1:61321
[2024-05-10 18:26:59,661] [Process-2] Run : builtin_spu_init at node:1
[2024-05-10 18:26:59,661] [Process-1] Run : builtin_spu_init at node:0
I0510 18:26:59.665524 12394 external/com_github_brpc_brpc/src/brpc/server.cpp:1158] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61331.
W0510 18:26:59.665545 12394 external/com_github_brpc_brpc/src/brpc/server.cpp:1164] Builtin services are disabled according to ServerOptions.has_builtin_services
I0510 18:26:59.666197 12396 external/com_github_brpc_brpc/src/brpc/server.cpp:1158] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61330.
W0510 18:26:59.666213 12396 external/com_github_brpc_brpc/src/brpc/server.cpp:1164] Builtin services are disabled according to ServerOptions.has_builtin_services
[2024-05-10 18:26:59,667] [Process-2] spu-runtime (SPU) initialized
[2024-05-10 18:26:59,668] [Process-1] spu-runtime (SPU) initialized
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2024-05-10 18:26:59,986] [Process-1] Run : <lambda> at node:0
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2024-05-10 18:27:00,003] [Process-2] Run : <lambda> at node:1
[2024-05-10 18:27:00,005] [Process-2] Run : make_shares at node:1
[2024-05-10 18:27:00,006] [Process-2] RunR: builtin_fetch_meta at node:1
[2024-05-10 18:27:00,008] [Process-2] Run : make_shares at node:1
[2024-05-10 18:27:00,010] [Process-2] RunR: builtin_fetch_meta at node:1
[2024-05-10 18:27:00,011] [Process-1] Run : make_shares at node:0
[2024-05-10 18:27:00,012] [Process-1] RunR: builtin_fetch_meta at node:0
[2024-05-10 18:27:00,021] [Process-2] Run : builtin_spu_run at node:1
[2024-05-10 18:27:00,023] [Process-1] RunR: builtin_fetch_object at node:0
[2024-05-10 18:27:00,023] [Process-1] Run : builtin_spu_run at node:0
[2024-05-10 18:27:00,025] [Process-2] RunR: builtin_fetch_object at node:1
[2024-05-10 18:27:00,026] [Process-2] RunR: builtin_fetch_object at node:1
[2024-05-10 18:27:00.031] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 18:27:00.044] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 18:27:00.049] [info] [api.cc:158] [Profiling] SPU execution infer completed, input processing took 8.61e-07s, execution took 0.022561069s, output processing took 1.745e-06s, total time 0.022563675s.
[2024-05-10 18:27:00.049] [info] [api.cc:191] HLO profiling: total time 0.022231636
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.add, executed 1 times, duration 0.001296596s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.broadcast, executed 1 times, duration 6.47e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.constant, executed 1 times, duration 6.612e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.convert, executed 1 times, duration 2.1625e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.convolution, executed 1 times, duration 0.020818828s, send bytes 759296
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.free, executed 5 times, duration 4.6146e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.pad, executed 1 times, duration 3.5359e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:191] HAL profiling: total time 0.018824205
[2024-05-10 18:27:00.049] [info] [api.cc:194] - f_add, executed 1 times, duration 0.001292103s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - f_tensordot, executed 1 times, duration 0.017514063s, send bytes 759296
[2024-05-10 18:27:00.049] [info] [api.cc:194] - seal, executed 1 times, duration 1.8039e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:191] MPC profiling: total time 0.021005868000000004
[2024-05-10 18:27:00.049] [info] [api.cc:194] - add_aa, executed 1 times, duration 0.001286959s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - broadcast, executed 1 times, duration 2.828e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - concatenate, executed 1 times, duration 0.000931655s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - extract_slice, executed 1024 times, duration 0.000790953s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - mmul_aa, executed 1 times, duration 0.011660626s, send bytes 235008
[2024-05-10 18:27:00.049] [info] [api.cc:194] - p2a, executed 1 times, duration 1.3275e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pad, executed 1 times, duration 3.3511e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - reshape, executed 1029 times, duration 0.000453675s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - transpose, executed 2 times, duration 3.981e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - trunc_a, executed 1 times, duration 0.005828405s, send bytes 524288
[2024-05-10 18:27:00.049] [info] [api.cc:204] Link details: total send bytes 759296, send actions 2

-------------------------------------------------
Winograd conv:
[2024-05-10 17:50:15,814] [ForkServerProcess-2] Starting grpc server at 127.0.0.1:61321
[2024-05-10 17:50:15,814] [ForkServerProcess-1] Starting grpc server at 127.0.0.1:61320
[2024-05-10 17:50:21,376] [ForkServerProcess-1] Run : builtin_spu_init at node:0
[2024-05-10 17:50:21,377] [ForkServerProcess-2] Run : builtin_spu_init at node:1
I0510 17:50:21.562828  8755 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61331.
W0510 17:50:21.562851  8755 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
I0510 17:50:21.563300  8753 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61330.
W0510 17:50:21.563313  8753 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
[2024-05-10 17:50:21,564] [ForkServerProcess-2] spu-runtime (SPU) initialized
[2024-05-10 17:50:21,564] [ForkServerProcess-1] spu-runtime (SPU) initialized
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2024-05-10 17:50:21,614] [ForkServerProcess-1] Run : <lambda> at node:0
[2024-05-10 17:50:21,616] [ForkServerProcess-2] Run : <lambda> at node:1
[2024-05-10 17:50:21,617] [ForkServerProcess-1] Run : make_shares at node:0
[2024-05-10 17:50:21.617] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 17:50:21,636] [ForkServerProcess-1] RunR: builtin_fetch_meta at node:0
[2024-05-10 17:50:21,645] [ForkServerProcess-1] Run : builtin_spu_run at node:0
[2024-05-10 17:50:21,646] [ForkServerProcess-2] Run : builtin_spu_run at node:1
[2024-05-10 17:50:21,647] [ForkServerProcess-1] RunR: builtin_fetch_object at node:0
[2024-05-10 17:50:21.671] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 17:50:21.695] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 6.96e-07s, execution took 0.0487127s, output processing took 1.804e-06s, total time 0.0487152s.
[2024-05-10 17:50:21.696] [info] [api.cc:209] HLO profiling: total time 1.952e-06
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.constant, executed 1 times, duration 1.77e-06s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.free, executed 2 times, duration 9.2e-08s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.multiply, executed 1 times, duration 4.7e-08s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.broadcast, executed 1 times, duration 4.3e-08s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:209] HAL profiling: total time 0.048321089
[2024-05-10 17:50:21.696] [info] [api.cc:212] - f_mul, executed 1 times, duration 0.048321089s, send bytes 6291456 recv bytes 6291456
[2024-05-10 17:50:21.696] [info] [api.cc:209] MPC profiling: total time 0.048304739
[2024-05-10 17:50:21.696] [info] [api.cc:212] - trunc_a, executed 1 times, duration 0.044309967s, send bytes 6291456 recv bytes 6291456
[2024-05-10 17:50:21.696] [info] [api.cc:212] - mul_ap, executed 1 times, duration 0.003989605s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - broadcast, executed 1 times, duration 5.167e-06s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:222] Link details: total send bytes 6291456, recv bytes 6291456, send actions 1
@fionser
Copy link
Collaborator

fionser commented May 10, 2024

You can see, the number of truncations increases a lot. That should be normal.
When doing matmul, the number of truncations should be quadratic to the matrix size, but for Winograd, I think you need to write SPU in C++ to reduce the number of truncations. For example, to add some multiplication results then perform one truncation, instead of calling many truncations on them.

@llCurious
Copy link

hi, @warpoons . Interesting idea. As pointed out by @fionser , the problem is due to the increasing amont of truncations.

According to the Winograd algorithm, the matmul is separated into several parts (currently, each part shall incur additional truncations), which I believe is not friendly in SPU.

In my opinion, to maximize the performance of Winograd, you may need to add a backend op for Winograd, and implement the algorithm in C++.

@warpoons
Copy link
Author

warpoons commented May 14, 2024

Hi @llCurious @fionser ! Thanks for your response!

As pointed out by @fionser , when doing matmul, the number of truncations should be quadratic to the matrix size. In Winograd, the input feature map is separated into several overlapped parts (or called tiles) and do element-wise matmul separately. Reasonably, there will be additional EWMMs among all the tiles than the standard conv. Is this understanding correct?

I have another question, is there a method to estimate the theoretical communication cost of standard conv and Winograd conv (considering only the EWMMs in Winograd and do the transformation of weights offline) in SPU?

Thanks!

@warpoons
Copy link
Author

warpoons commented May 31, 2024

Hi @llCurious @fionser ! In this week, I have further tested the Winograd convolution for reducing multiplications in SPU.

As I previously described in this ISSUE, Winograd converts standard conv into EWMM with fewer multiplications, coming at the cost of low parallelism in EWMM.

Here is another way to convert the Winograd's EWMM into general matmul (GEMM) by transposing the Winograd weights and inputs.

As suggested in a NeruIPS 2023 paper Copriv: Network/protocol co-optimization for communication-efficient private inference as below, the communication increases after using the Winograd with multiplication reduction. To reach the expected comm improvement, we should consider the EWMM->GEMM conversion.
image
This finding somewhat confirms that why the comm size abnormally increases by 8x after using EWMM-based Winograd.

Hence, I have further tested the GEMM-based Winograd to observe that if there is an expected 2.25x comm reduction, but the answer is NO. The profiling is ("SEMI2K", "FM64"):

  • jnp.dtype = jnp.float32
[2024-05-31 15:56:54.026] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.297e-06s, execution took 0.059690046s, output processing took 1.735e-06s, total time 0.059693078s.
[2024-05-31 15:56:54.026] [info] [api.cc:209] HLO profiling: total time 5.8410000000000005e-06
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.constant, executed 6 times, duration 2.155e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.free, executed 50 times, duration 1.814e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.reshape, executed 18 times, duration 7.56e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.transpose, executed 7 times, duration 3.05e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.broadcast, executed 6 times, duration 2.09e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.dot, executed 4 times, duration 1.95e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.iota, executed 2 times, duration 8.2e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.convolution, executed 1 times, duration 4e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:209] HAL profiling: total time 0.054427574
[2024-05-31 15:56:54.026] [info] [api.cc:212] - f_mmul, executed 20 times, duration 0.041460792s, send bytes 3866624 recv bytes 3866624
[2024-05-31 15:56:54.026] [info] [api.cc:212] - f_tensordot, executed 1 times, duration 0.012745663s, send bytes 98304 recv bytes 98304
[2024-05-31 15:56:54.026] [info] [api.cc:212] - i_equal, executed 2 times, duration 0.000147861s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - mixed_mul, executed 1 times, duration 4.8814e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - seal, executed 1 times, duration 1.9777e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - int2fxp, executed 1 times, duration 4.667e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:209] MPC profiling: total time 0.056513282
[2024-05-31 15:56:54.026] [info] [api.cc:212] - trunc_a, executed 21 times, duration 0.044545624s, send bytes 3964928 recv bytes 3964928
[2024-05-31 15:56:54.026] [info] [api.cc:212] - mmul_ap, executed 55 times, duration 0.007790367s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - reshape, executed 332 times, duration 0.001793458s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - concatenate, executed 2 times, duration 0.001615727s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - extract_slice, executed 360 times, duration 0.00036179s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - transpose, executed 132 times, duration 0.00019234s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - equal_pp, executed 2 times, duration 0.000101518s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - mul_pp, executed 1 times, duration 4.7267e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pad, executed 1 times, duration 3.4435e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - p2a, executed 1 times, duration 1.8228e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - broadcast, executed 6 times, duration 9.442e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - lshift_p, executed 1 times, duration 3.086e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:222] Link details: total send bytes 3964928, recv bytes 3964928, send actions 21
[2024-05-31 15:59:36.265] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 2.39e-07s, execution took 0.069512618s, output processing took 2.288e-06s, total time 0.069515145s.
[2024-05-31 15:59:36.265] [info] [api.cc:209] HLO profiling: total time 5.878e-06
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.constant, executed 6 times, duration 1.957e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.free, executed 50 times, duration 1.8e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.reshape, executed 18 times, duration 6.85e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.dot, executed 4 times, duration 5.09e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.transpose, executed 7 times, duration 2.92e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.broadcast, executed 6 times, duration 2.19e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 7.9e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.iota, executed 2 times, duration 7.8e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.equal, executed 2 times, duration 7.5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.dot_general, executed 1 times, duration 5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.convolution, executed 1 times, duration 4.5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.pad, executed 1 times, duration 4.5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.multiply, executed 1 times, duration 4.4e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:209] HAL profiling: total time 0.064209222
[2024-05-31 15:59:36.265] [info] [api.cc:212] - f_mmul, executed 2 times, duration 0.05382723s, send bytes 1572864 recv bytes 1572864
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_tensordot, executed 1 times, duration 0.009371561s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - mixed_mmul, executed 16 times, duration 0.000773674s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_equal, executed 2 times, duration 0.000100369s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_mmul, executed 2 times, duration 6.5911e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_mul, executed 1 times, duration 4.9618e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - seal, executed 1 times, duration 2.0859e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:209] MPC profiling: total time 0.066035378
[2024-05-31 15:59:36.265] [info] [api.cc:212] - trunc_a, executed 2 times, duration 0.05136436s, send bytes 1572864 recv bytes 1572864
[2024-05-31 15:59:36.265] [info] [api.cc:212] - mmul_ap, executed 55 times, duration 0.010664855s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - concatenate, executed 2 times, duration 0.001737654s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - reshape, executed 332 times, duration 0.001582893s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - extract_slice, executed 360 times, duration 0.000322834s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - transpose, executed 132 times, duration 0.000154002s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - equal_pp, executed 2 times, duration 9.7572e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - mul_pp, executed 1 times, duration 4.8151e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pad, executed 1 times, duration 3.5026e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - p2a, executed 1 times, duration 1.9398e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - broadcast, executed 6 times, duration 8.633e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:222] Link details: total send bytes 1572864, recv bytes 1572864, send actions 2

We observe that the comm is reduced compared to the EWMM-based Winograd but still far from the expected improvement.

Another issue is that using jnp.integer still has trunc_a and comm in the profiling, and I cannot reach the reason behind it.

And also, jnp.dtype = jnp.float32 has f_tensordot with comm but jnp.dtype = jnp.integer has i_tensordot without comm.

To make it clear, this is my model.py to define Winograd Conv layer:

class FlaxConvWino(nn.Module):
    inCh: int
    outCh: int
    filterDim: int
    outTileDim: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, input):
        padding = int((self.filterDim - 1)/2)
        temp_padding = ZeroPad2dFlax(padding)
        input_ = temp_padding(input)
        number_tiling_positions = (input_.shape[3] - 2 * padding) / self.outTileDim
        if number_tiling_positions.is_integer():
            Pad_tiling = ZeroPad2dFlax(0)
        else:
            decimal_part = number_tiling_positions - int(number_tiling_positions)
            to_pad = round((1.0 - decimal_part) * self.outTileDim)
            to_pad_even = round(to_pad / 2)
            Pad_tiling = ZeroPad2dFlax(to_pad_even)

        expected_output_width = input_.shape[2] - 2 * padding
        input_ = Pad_tiling(input_)
        Tiler = winUtils.TilerFlax(self.outTileDim, self.filterDim)
        input_ = Tiler.tile(input_)
        weight = jnp.ones((1, self.outCh, self.inCh, self.filterDim, self.filterDim), dtype=self.dtype) ★★★

        A_t = params.A_T
        B_t = params.B_T
        G = params.G

        # Note that the PI communication increases by over 10x without Tile Transposition
        # Therefore, next we transpose the winograd input and weight for converting EWMM to GEMM
        
        # Weight/Input transformation
        weight_winograd = jnp.matmul(jnp.matmul(G, weight), jnp.transpose(G, (1, 0)))
        input_winograd = jnp.matmul(jnp.matmul(B_t, input_), jnp.transpose(B_t, (1, 0)))

        # Tile Transposition
        weight_winograd_TTrans = jnp.transpose(weight_winograd, (0, 3, 4, 1, 2))
        input_winograd_TTrans = jnp.transpose(input_winograd, (0, 3, 4, 2, 1))

        GEMM = jnp.matmul(weight_winograd_TTrans, input_winograd_TTrans)

        output = jnp.transpose(GEMM, (0, 4, 3, 1, 2))
        output = jnp.matmul(jnp.matmul(A_t, output), jnp.transpose(A_t, (1, 0)))
        output = Tiler.untile(output)

        if output.shape[3] is not expected_output_width:
            warnings.warn('output dim is not expected. Error may occur !!!')
            padding = Pad_tiling.padding
            output = output[:, :, padding[0]:-padding[1], padding[2]:-padding[3]]
        return output

Note that the line marked with ★★★ is used to initialize an all-ones weights inside the model definition since I think the specific parameters will not significantly affect the comm results. So the flax_model.init(jax.random.PRNGKey(1),jnp.ones(input_shape)) is an empty {}. I don't know if this will have an impact.

Sorry for taking your time. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants