Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OSCP] 使用 SPU 实现 AP(average_precision_score) 函数 #801

Merged
merged 8 commits into from
Aug 20, 2024

Conversation

z0gSh1u
Copy link
Contributor

@z0gSh1u z0gSh1u commented Aug 4, 2024

Pull Request

What problem does this PR solve?

Issue Number: Fixed #727

Implemented average_precision_score function for binary classification and multi-class classification with three average methods.

Copy link
Contributor

@deadlywing deadlywing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,,需要在classification_emul.py中增加ap的调用测试,具体方式可以参考里面的其他例子

sml/metrics/classification/BUILD.bazel Outdated Show resolved Hide resolved
sml/metrics/classification/classification.py Show resolved Hide resolved
sml/metrics/classification/classification.py Outdated Show resolved Hide resolved
sml/metrics/classification/classification.py Outdated Show resolved Hide resolved
sml/metrics/classification/classification.py Outdated Show resolved Hide resolved
sml/metrics/classification/classification.py Outdated Show resolved Hide resolved
sml/metrics/classification/classification.py Outdated Show resolved Hide resolved
@deadlywing
Copy link
Contributor

TODO @z0gSh1u

  1. 修改/测试 有 tied value的情况
  2. 增加emulation测试

@z0gSh1u
Copy link
Contributor Author

z0gSh1u commented Aug 13, 2024

TODO @z0gSh1u

  1. 修改/测试 有 tied value的情况
  2. 增加emulation测试

👌🏻 I'll ping you when I'm ready.

@z0gSh1u
Copy link
Contributor Author

z0gSh1u commented Aug 15, 2024

TODO @z0gSh1u

  1. 修改/测试 有 tied value的情况
  2. 增加emulation测试

已修改,可以再次评审~

@deadlywing
Copy link
Contributor

image

计算逻辑还是有一点问题,,
另外,发现test的tol有点太高了,可能导致你没发现问题,,建议把atol和rtol调整到1e-3

@deadlywing
Copy link
Contributor

提供一个比较直接的思路:

  1. 这里的难点其实是tied value会导致precision计算的末尾会出现nan(或者极大的值),recall的末尾出现0值 (这使得AP计算公式出错)
  2. 利用threshold>0其实可以得到tiled value有哪些的mask,则precision可以转化为[x,x,x,x,...,0,0,..] ,这就和recall的格式一致了,均为[y,y,y,y,...,0,0,..]
  3. AP本质就是计算积分,数值上就是 diff(recall) * precision,所以我们可以直接计算diff(recall) 这个array,但是需要注意实际需要的diff(recall) = [y1-y0, y2-y1, ..., yn-y_(n-1), 0,0,0] (后面还是会有0,这个可以通过rotate加一次mux完成)
  4. 同理,precision也需要根据mask得到[x0,x1,x2,..,0,0.](这里的逻辑可能会复杂一些,需要一些翻转之类的操作),同样,也会有尾0

anyway,核心还是围绕AP的计算公式,只不过要小心处理尾0,使得整个计算结果保持和明文一致

@z0gSh1u
Copy link
Contributor Author

z0gSh1u commented Aug 16, 2024

提供一个比较直接的思路:

  1. 这里的难点其实是tied value会导致precision计算的末尾会出现nan(或者极大的值),recall的末尾出现0值 (这使得AP计算公式出错)
  2. 利用threshold>0其实可以得到tiled value有哪些的mask,则precision可以转化为[x,x,x,x,...,0,0,..] ,这就和recall的格式一致了,均为[y,y,y,y,...,0,0,..]
  3. AP本质就是计算积分,数值上就是 diff(recall) * precision,所以我们可以直接计算diff(recall) 这个array,但是需要注意实际需要的diff(recall) = [y1-y0, y2-y1, ..., yn-y_(n-1), 0,0,0] (后面还是会有0,这个可以通过rotate加一次mux完成)
  4. 同理,precision也需要根据mask得到[x0,x1,x2,..,0,0.](这里的逻辑可能会复杂一些,需要一些翻转之类的操作),同样,也会有尾0

anyway,核心还是围绕AP的计算公式,只不过要小心处理尾0,使得整个计算结果保持和明文一致

调整完成,使用thresholds > 0作为mask调整了precision的计算区间,并改用了更严格的tol检查。由于y_score存在0值时thresholds > 0条件不准确(trailing 0 和 0 score无法区分),因此还添加了下界score_eps。

@deadlywing
Copy link
Contributor

hello,,正确性看上去应该没问题了,但是性能上仍然有巨大的优化空间

  1. 我在本地跑了个10000个样本测试,,通信
Link details: total send bytes 25116300315, recv bytes 25137340315, send actions 720907, recv actions 721170

可以发现这个cost还是很离谱,,主要的原因在于

sorted_pairs = pairs[jnp.argsort(pairs[:, 1], descending=True, stable=True)] 

这里会调用secret index,这个在MPC下是无比昂贵的。。
可以替换为:

sorted_pairs = create_sorted_label_score_pair(y_true, y_score)
  1. 同理,recall的计算也可替换成
max_tp = jnp.max(tp)
recalls = jnp.where(max_tp == 0, jnp.ones_like(tp), tp / max_tp)

替换后,可以发现通信量缩减在250x以上,耗时也减少100x左右。。

@deadlywing
Copy link
Contributor

另外,,我发现emulation的文件里其他两个函数在run之前也没加seal,,这也是有问题的,,麻烦你也顺便都加上吧🙏

@deadlywing
Copy link
Contributor

@z0gSh1u
hello,,代码已经ok了,但是我看bazel的format checker有问题,,我本地run了一下buildfier,似乎是这俩dep的顺序,,麻烦本地修改一下,,(最好本地再用buildifier check一下再push哈~)

image

@z0gSh1u
Copy link
Contributor Author

z0gSh1u commented Aug 19, 2024

@z0gSh1u hello,,代码已经ok了,但是我看bazel的format checker有问题,,我本地run了一下buildfier,似乎是这俩dep的顺序,,麻烦本地修改一下,,(最好本地再用buildifier check一下再push哈~)

好的,我晚些检查&修改下,顺便合一下master

Copy link
Contributor

@deadlywing deadlywing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@deadlywing deadlywing merged commit 28fef7d into secretflow:main Aug 20, 2024
8 of 9 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Aug 20, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

使用 SPU 实现 AP(average_precision_score) 函数
2 participants