Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在 SML 中实现预处理算法 #382

Closed
Candicepan opened this issue Nov 2, 2023 · 20 comments · Fixed by #470 or #575
Closed

在 SML 中实现预处理算法 #382

Candicepan opened this issue Nov 2, 2023 · 20 comments · Fixed by #470 or #575
Assignees
Labels
enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan

Comments

@Candicepan
Copy link
Contributor

Candicepan commented Nov 2, 2023

此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)第三期任务 ISSUE,欢迎社区开发者参与共建~
若有感兴趣想要认领的任务,但还未报名,辛苦先完成报名进行哈~

任务介绍

  • 任务名称:在 SML 中实现预处理算法
  • 技术方向:SPU/SML
  • 任务难度:进阶🌟🌟

详细要求

  • 为 SML 增加一些简单的预处理算法,包括:
    a. labelBinarizer
    b. Binarizer
    c. Normalizer
    d. MinMaxScaler
    e. MaxAbsScaler
    f. KBinsDiscretizer
    具体功能可参考sklearn中同名class
  • 正确性:请确保提交的代码内容为可以直接运行的
  • 代码规范:Python 代码需要使用 black+isort 进行格式化(流水线包含代码规范检查卡点); bazel需要使用buildifier格式化
  • 一次认领需要实现不少于3个算法(不足3个时即所有算法)

若有其他建议实现的算法,也可在本 ISSUE 下回复

能力要求

  • 熟悉经典的机器学习算法
  • 熟悉 JAX 或 NumPy,可以使用 NumPy 实现算法

操作说明

认领说明

  • 请在认领任务后,在该 issue 下 comment 你的具体设计思路
  • 设计思路说明:简单说明计划使用什么算法、什么技术方式进行实现
@Candicepan Candicepan added enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan labels Nov 2, 2023
@Candicepan Candicepan moved this to Needs Triage in OSCP Phase 3 Nov 3, 2023
@winnylyc
Copy link
Contributor

winnylyc commented Jan 3, 2024

winnylyc Give it to me

@winnylyc
Copy link
Contributor

winnylyc commented Jan 3, 2024

先尝试实现labelBinarizer,Binarizer和Normalizer

@Candicepan
Copy link
Contributor Author

winnylyc Give it to me

感谢您的认领~辛苦先完成活动报名呀~ https://www.wjx.top/vm/QKio0dq.aspx#

@winnylyc
Copy link
Contributor

winnylyc commented Jan 3, 2024

winnylyc Give it to me

感谢您的认领~辛苦先完成活动报名呀~ https://www.wjx.top/vm/QKio0dq.aspx#

抱歉,我应该是在12月31号已经完成报名了,请问我的报名是有问题吗

@Candicepan
Copy link
Contributor Author

winnylyc Give it to me

感谢您的认领~辛苦先完成活动报名呀~ https://www.wjx.top/vm/QKio0dq.aspx#

抱歉,我应该是在12月31号已经完成报名了,请问我的报名是有问题吗

sorry~报名时候填写的 GitHub id 跟目前的不一致 所以我们没有识别出来呢 sorry~ 小助手来联系你跟你确认哈~

@deadlywing
Copy link
Contributor

先尝试实现labelBinarizer,Binarizer和Normalizer

hello,感谢认领,您可以直接实现,完成后发起PR哈~

@deadlywing
Copy link
Contributor

BTW,可以在SML目录下新建一个preprocessing目录,然后将代码放进去即可,可以参考其他目录的组织方式,同时实现对应功能的test和emulation,并以bazel进行编排哈~

@winnylyc
Copy link
Contributor

winnylyc commented Jan 3, 2024

好的,感谢指导

@Candicepan Candicepan moved this from Needs Triage to In Progress in OSCP Phase 3 Jan 3, 2024
@winnylyc
Copy link
Contributor

winnylyc commented Jan 5, 2024

目前在实现labelBinarizer的过程中,发现难点在于生成dynamic shape的矩阵,具体来说要生成一个shape为(n_samples, n_classes)的矩阵,而这个n_classes是一个根据输入矩阵中元素的值计算出的变量。
我想请问一下,现在SPU对这一块的支持是不是比较困难。 #309

@deadlywing
Copy link
Contributor

目前在实现labelBinarizer的过程中,发现难点在于生成dynamic shape的矩阵,具体来说要生成一个shape为(n_samples, n_classes)的矩阵,而这个n_classes是一个根据输入矩阵中元素的值计算出的变量。
我想请问一下,现在SPU对这一块的支持是不是比较困难。 #309

Hi,确实,dynamic shape无法实现。
一般会引入些额外参数,如让用户去指定n_classes,且要求label取值为0,1,2...

最后您在代码的注释里讲清楚这些信息即可。

@winnylyc
Copy link
Contributor

winnylyc commented Jan 5, 2024

好的,感谢您的回复。

@Candicepan Candicepan linked a pull request Jan 15, 2024 that will close this issue
@github-project-automation github-project-automation bot moved this from In Progress to Done in OSCP Phase 3 Jan 15, 2024
@deadlywing deadlywing reopened this Jan 15, 2024
@deadlywing
Copy link
Contributor

[UPDATE]:
当前只剩余:
d. MinMaxScaler
e. MaxAbsScaler
f. KBinsDiscretizer

想认领的同学需一次性实现上述3个功能~

@Candicepan Candicepan moved this from Done to Needs Triage in OSCP Phase 3 Jan 24, 2024
@Candicepan Candicepan moved this from Needs Triage to In Review in OSCP Phase 3 Jan 29, 2024
@Candicepan Candicepan moved this from In Review to Done in OSCP Phase 3 Jan 29, 2024
@Candicepan Candicepan moved this from Done to Needs Triage in OSCP Phase 3 Jan 29, 2024
@winnylyc
Copy link
Contributor

winnylyc Give it to me

@winnylyc
Copy link
Contributor

实现MinMaxScaler,MaxAbsScaler,KBinsDiscretizer

@deadlywing
Copy link
Contributor

hello,,感谢认领,,代码也放在之前的位置即可~

@winnylyc
Copy link
Contributor

好的

@Candicepan Candicepan moved this from Needs Triage to In Review in OSCP Phase 3 Feb 1, 2024
@Candicepan Candicepan moved this from In Review to In Progress in OSCP Phase 3 Feb 1, 2024
@winnylyc
Copy link
Contributor

winnylyc commented Feb 1, 2024

不好意思,打扰一下。我在实现KBinsDiscretizer的uniform模式的时候implement了一个能够大大减少开销的方法,但是因为精度问题无法得出预期的结果。
原版的算法为使用linspace计算出各个bin的范围,再用searchsorted将输入的值与各个bin匹配。
以下是我用spu实现后,用emulator跑出来的结果。
[2024-02-01 07:09:14.064] [info] [api.cc:158] [Profiling] SPU execution kbinsdiscretize completed, input processing took 1.367e-06s, execution took 0.38528338s, output processing took 2.151e-06s, total time 0.385286898s.
[2024-02-01 07:09:14.064] [info] [api.cc:191] HLO profiling: total time 0.38133484
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.add, executed 1 times, duration 1.1177e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.broadcast, executed 4 times, duration 3.0525e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.concatenate, executed 1 times, duration 1.7554e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.constant, executed 13 times, duration 0.000200352s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.convert, executed 4 times, duration 0.002004551s, send bytes 272
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.free, executed 26 times, duration 1.7038e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.iota, executed 1 times, duration 1.0227e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.multiply, executed 3 times, duration 0.000308453s, send bytes 768
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.reduce, executed 2 times, duration 0.018541286s, send bytes 1920
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.reshape, executed 1 times, duration 5.356e-06s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.slice, executed 1 times, duration 6.72e-06s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.subtract, executed 1 times, duration 4.189e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.transpose, executed 4 times, duration 3.784e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pphlo.while, executed 1 times, duration 0.360101871s, send bytes 56768
[2024-02-01 07:09:14.065] [info] [api.cc:191] HAL profiling: total time 0.297433468
[2024-02-01 07:09:14.065] [info] [api.cc:194] - _and, executed 4 times, duration 0.004100559s, send bytes 3616
[2024-02-01 07:09:14.065] [info] [api.cc:194] - _mux, executed 86 times, duration 0.04633184s, send bytes 18304
[2024-02-01 07:09:14.065] [info] [api.cc:194] - _rshift, executed 2 times, duration 0.003237852s, send bytes 3584
[2024-02-01 07:09:14.065] [info] [api.cc:194] - _sign, executed 2 times, duration 0.003509741s, send bytes 1664
[2024-02-01 07:09:14.065] [info] [api.cc:194] - f_add, executed 1 times, duration 7.263e-06s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - f_less, executed 8 times, duration 0.016467231s, send bytes 2304
[2024-02-01 07:09:14.065] [info] [api.cc:194] - f_mul, executed 3 times, duration 0.000207802s, send bytes 768
[2024-02-01 07:09:14.065] [info] [api.cc:194] - f_sub, executed 1 times, duration 3.5923e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - i_add, executed 44 times, duration 0.001007859s, send bytes 768
[2024-02-01 07:09:14.065] [info] [api.cc:194] - i_equal, executed 70 times, duration 0.072963777s, send bytes 6000
[2024-02-01 07:09:14.065] [info] [api.cc:194] - i_less, executed 105 times, duration 0.098162461s, send bytes 6912
[2024-02-01 07:09:14.065] [info] [api.cc:194] - i_negate, executed 6 times, duration 0.003421302s, send bytes 4352
[2024-02-01 07:09:14.065] [info] [api.cc:194] - logical_not, executed 4 times, duration 0.000117004s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - mixed_mmul, executed 64 times, duration 0.036492644s, send bytes 2992
[2024-02-01 07:09:14.065] [info] [api.cc:194] - seal, executed 72 times, duration 0.01137021s, send bytes 1296
[2024-02-01 07:09:14.065] [info] [api.cc:191] MPC profiling: total time 0.2913409269999999
[2024-02-01 07:09:14.065] [info] [api.cc:194] - a2b, executed 68 times, duration 0.076462737s, send bytes 14336
[2024-02-01 07:09:14.065] [info] [api.cc:194] - add_aa, executed 85 times, duration 0.000338282s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - add_ap, executed 246 times, duration 0.000839773s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - add_pp, executed 210 times, duration 0.00085503s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - and_bb, executed 2 times, duration 0.000137758s, send bytes 32
[2024-02-01 07:09:14.065] [info] [api.cc:194] - and_bp, executed 2 times, duration 1.6616e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - b2a, executed 86 times, duration 0.024341501s, send bytes 19760
[2024-02-01 07:09:14.065] [info] [api.cc:194] - broadcast, executed 115 times, duration 0.000396947s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - cast_type_b, executed 64 times, duration 0.000239684s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - common_type_b, executed 64 times, duration 0.000122411s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - concatenate, executed 35 times, duration 0.000429897s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - extract_slice, executed 233 times, duration 0.00083118s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - make_p, executed 217 times, duration 0.000785983s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - mmul_aa, executed 64 times, duration 0.024191822s, send bytes 768
[2024-02-01 07:09:14.065] [info] [api.cc:194] - msb_a2b, executed 78 times, duration 0.113729456s, send bytes 10368
[2024-02-01 07:09:14.065] [info] [api.cc:194] - msb_p, executed 37 times, duration 8.5856e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - mul_a1b, executed 72 times, duration 0.030121619s, send bytes 4608
[2024-02-01 07:09:14.065] [info] [api.cc:194] - mul_aa, executed 14 times, duration 0.00318294s, send bytes 1792
[2024-02-01 07:09:14.065] [info] [api.cc:194] - mul_ap, executed 4 times, duration 3.2011e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - mul_pp, executed 1 times, duration 1.0044e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - not_a, executed 70 times, duration 0.000267505s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - not_p, executed 138 times, duration 0.000592593s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - p2a, executed 72 times, duration 0.011004042s, send bytes 1296
[2024-02-01 07:09:14.065] [info] [api.cc:194] - pad, executed 64 times, duration 0.000650356s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - reshape, executed 397 times, duration 0.001137509s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - rshift_b, executed 2 times, duration 7.264e-06s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - transpose, executed 6 times, duration 5.4502e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - trunc_a, executed 2 times, duration 0.000128553s, send bytes 768
[2024-02-01 07:09:14.065] [info] [api.cc:194] - trunc_p, executed 1 times, duration 1.5236e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - update_slice, executed 32 times, duration 0.000286748s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:194] - xor_bp, executed 4 times, duration 4.5072e-05s, send bytes 0
[2024-02-01 07:09:14.065] [info] [api.cc:204] Link details: total send bytes 59728, send actions 2221
这种实现的最终结果是正确的

我优化后的算法为使用MinMaxScale将X中的值映射到[0, n_bins]的范围中,然后floor,当然这里会有最大值被映射为n_bins的问题(应该最终取值为n_bins-1),所以这里还需要一个clip操作(比用where的开销小)。
以下是我用spu实现后,用emulator跑出来的结果。
[2024-02-01 07:25:56.408] [info] [api.cc:158] [Profiling] SPU execution kbinsdiscretize completed, input processing took 6.41e-07s, execution took 0.052413957s, output processing took 1.924e-06s, total time 0.052416522s.
[2024-02-01 07:25:56.408] [info] [api.cc:191] HLO profiling: total time 0.050960056999999996
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.add, executed 1 times, duration 2.4537e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.broadcast, executed 2 times, duration 3.08e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.clamp, executed 1 times, duration 0.005674089s, send bytes 1920
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.constant, executed 9 times, duration 9.1639e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.convert, executed 3 times, duration 0.000512452s, send bytes 24
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.divide, executed 1 times, duration 0.014864071s, send bytes 2592
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.equal, executed 1 times, duration 0.004818007s, send bytes 288
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.floor, executed 1 times, duration 0.001749568s, send bytes 1792
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.free, executed 32 times, duration 3.2869e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.multiply, executed 2 times, duration 0.000997252s, send bytes 544
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.negate, executed 1 times, duration 5.0246e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.not, executed 1 times, duration 4.5164e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.prefer_a, executed 2 times, duration 0.000648598s, send bytes 1536
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.reduce, executed 3 times, duration 0.019555057s, send bytes 2384
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.select, executed 4 times, duration 0.001823398s, send bytes 320
[2024-02-01 07:25:56.408] [info] [api.cc:194] - pphlo.subtract, executed 1 times, duration 4.231e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:191] HAL profiling: total time 0.049926874999999996
[2024-02-01 07:25:56.408] [info] [api.cc:194] - _and, executed 3 times, duration 0.002828277s, send bytes 464
[2024-02-01 07:25:56.408] [info] [api.cc:194] - _mux, executed 12 times, duration 0.007180927s, send bytes 1856
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_add, executed 1 times, duration 1.8107e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_div, executed 1 times, duration 0.014856038s, send bytes 2592
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_equal, executed 1 times, duration 0.004794003s, send bytes 288
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_floor, executed 1 times, duration 0.001744575s, send bytes 1792
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_less, executed 8 times, duration 0.016307413s, send bytes 2304
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_mul, executed 2 times, duration 0.000978185s, send bytes 544
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_negate, executed 1 times, duration 4.4026e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - f_sub, executed 1 times, duration 3.536e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - i_add, executed 2 times, duration 0.000608019s, send bytes 1536
[2024-02-01 07:25:56.408] [info] [api.cc:194] - logical_not, executed 1 times, duration 4.3038e-05s, send bytes 0
[2024-02-01 07:25:56.408] [info] [api.cc:194] - seal, executed 3 times, duration 0.000488907s, send bytes 24
[2024-02-01 07:25:56.408] [info] [api.cc:191] MPC profiling: total time 0.043778508
[2024-02-01 07:25:56.408] [info] [api.cc:194] - a2b, executed 3 times, duration 0.005431767s, send bytes 2688
[2024-02-01 07:25:56.408] [info] [api.cc:194] - add_aa, executed 27 times, duration 0.000112406s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - add_ap, executed 41 times, duration 0.000186499s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - add_pp, executed 5 times, duration 4.1066e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - and_bb, executed 9 times, duration 0.002433834s, send bytes 192
[2024-02-01 07:25:56.409] [info] [api.cc:194] - arshift_b, executed 1 times, duration 5.24e-06s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - b2a, executed 5 times, duration 0.005130928s, send bytes 2528
[2024-02-01 07:25:56.409] [info] [api.cc:194] - bitrev_b, executed 1 times, duration 7.557e-06s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - broadcast, executed 13 times, duration 5.4757e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - extract_slice, executed 12 times, duration 7.9287e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - lshift_b, executed 1 times, duration 3.193e-06s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - make_p, executed 32 times, duration 0.000148395s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - msb_a2b, executed 9 times, duration 0.01740561s, send bytes 2448
[2024-02-01 07:25:56.409] [info] [api.cc:194] - mul_a1b, executed 8 times, duration 0.004856491s, send bytes 1536
[2024-02-01 07:25:56.409] [info] [api.cc:194] - mul_aa, executed 14 times, duration 0.003835836s, send bytes 736
[2024-02-01 07:25:56.409] [info] [api.cc:194] - mul_ap, executed 2 times, duration 9.687e-06s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - not_a, executed 23 times, duration 0.000110248s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - not_p, executed 5 times, duration 4.955e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - p2a, executed 3 times, duration 0.00047469s, send bytes 24
[2024-02-01 07:25:56.409] [info] [api.cc:194] - reshape, executed 9 times, duration 4.3041e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - rshift_b, executed 7 times, duration 2.3574e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - transpose, executed 3 times, duration 4.3463e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - trunc_a, executed 9 times, duration 0.00320778s, send bytes 960
[2024-02-01 07:25:56.409] [info] [api.cc:194] - xor_bb, executed 13 times, duration 5.268e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:194] - xor_bp, executed 2 times, duration 3.0929e-05s, send bytes 0
[2024-02-01 07:25:56.409] [info] [api.cc:204] Link details: total send bytes 11400, send actions 193
可以看到开销小了非常多,但是问题是结果因为精度的原因不符合预期。下面是一个例子:
输入X为:
[[-2, 1, -4, -1],
[-1, 2, -3, -0.5],
[ 0, 3, -2, 0.5],
[ 1, 4, -1, 2]]
最终预期的结果为:
[[0. 0. 0. 0.]
[1. 1. 1. 0.]
[2. 2. 2. 1.]
[2. 2. 2. 2.]]
而我这个算法算出来的结果为:
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[2., 2., 2., 2.]
看上去误差很大,实际上在做clip和floor操作前的结果的误差很小:
[[0. , 0. , 0. , 0. ],
[0.999989, 0.999985, 0.999981, 0.499996],
[1.999977, 1.999969, 1.999962, 1.499981],
[2.999966, 2.999954, 2.999943, 2.999966]]
可以看到就是因为<0.0001的误差,导致了最终结果误差很大。

我想请教的有几个问题:

  1. 为什么这个误差都是计算结果小于预期结果?
  2. 有什么办法缓解这个问题吗?
  3. 目前看上去误差很大,但实际只有特定几个值会产生这样的误差,在大部分的实际应用场景中,这个误差应该会很少遇到,这是在可接受范围内吗?(换句话说,这个算法就算不再进一步优化,是否可以保留作为一个可选项)

@deadlywing
Copy link
Contributor

关于您尝试的优化:我理解误差的主要原因是因为涉及到了定点数的乘除法,至于结果小于预期,可能是因为truncation。至于缓解,我理解除非计算精度极高,否则不可避免有误差问题。

关于3. 误差产生的原因是在bin的边界可能随机误差不稳定,,我个人感觉这不是一个很好的算法,因为不能对数据分布做太严格的假设。

最后,其实方法1性能较差主要是因为binary search,我理解可以通过简单的比较每个端点然后累加的方式,e.g.
Bin = [0,1,2], x=0.5
计算 x<= 0, x<=1 , x<=2, x<=inf (这个恒为1) ,可以得到:[0,1,1,1],右移1,可得到[0,0,1,1],二者做逐bit XOR,得到[0,1,0,0],这就是x所处的bin的index;

这个的性能应该会好一些,你可以尝试一下~

@winnylyc
Copy link
Contributor

winnylyc commented Feb 1, 2024

感谢您的思路,给了我非常大的启发。目前已经将binary search优化,并且已经比我之前提出的算法开销更小。

您的使用右移和Xor的思路似乎比较难在SPU中实现(主要是左边补0的右移似乎没有很好的原生操作支持)
我目前的实现是这样的

def compute_row(bin, x):
      def compute_element(x):
          encoding = jnp.concatenate((jnp.array([0]), jnp.append(jnp.where(x <= bin, 1, 0), 1)))
          encoding_r = jnp.roll(encoding, 1)[1:]
          index = jnp.argmax(jnp.bitwise_xor(encoding[1:], encoding_r))
          return jnp.where(index == 0, index, index - 1)
      row = jax.vmap(compute_element)(x)
      return row
compute_rows_vmap = jax.vmap(compute_row, in_axes=(1, 1), out_axes=1)(bin_edges, X)

这种实现的开销还是比较大

我就基于您的思路再做了一下优化,现在的算法是在element-wise comparison得到[0,1,1]后, 对这个矩阵做sum得到2, 再用n_bins(这里是3)- 2得到index。下面是实现

def compute_row(bin, x):
      def compute_element(x):
          encoding = jnp.where(x <= bin[1:], 1, 0)
          return self.n_bins - jnp.sum(encoding)
      row = jax.vmap(compute_element)(x)
      return row
compute_rows_vmap = jax.vmap(compute_row, in_axes=(1, 1), out_axes=1)(bin_edges, X)

@deadlywing
Copy link
Contributor

hello,我原本想的是 index = dot([0,1,0], [1,2,3]),这个cost应该比argmax小;不过直接sum确实更好,连内积都省了。

@Candicepan Candicepan moved this from In Progress to In Review in OSCP Phase 3 Feb 26, 2024
@github-project-automation github-project-automation bot moved this from In Review to Done in OSCP Phase 3 Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan
Projects
Status: Done
3 participants