Skip to content

UniPC remove order short-circuit for solve(R, b) #11612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

Beinsezii
Copy link
Contributor

@Beinsezii Beinsezii commented May 26, 2025

When unifying skrample's UniP and UniC solvers I noticed the rks matrix solve short circuits when it likely shouldn't.

I think it's because it appends to RKS

but only checks by order

# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)

Reading the paper https://arxiv.org/abs/2302.04867 I don't see any 0.5 short circuits so I'm assuming this was added to stop a 1x1 solve from erring. The comments would imply the results should be equivalent, however because you're skipping actual 2x2 solves the results are quite different for high orders

main
grid_00005
beinsezii/unipc_matrix_solve
grid_00009

cc @yiyixuxu

I'm not a mathematician so it might be prudent to double-check the paper yourself to make sure my homework is correct.

@Beinsezii
Copy link
Contributor Author

we might not even need the check at all let me see.

@Beinsezii
Copy link
Contributor Author

Yeah solve() is fine because it does append(1.0)

@Beinsezii Beinsezii changed the title UniPC adjust solve() short circuit to len(rks)==1 UniPC remove order short-circuit for solve(R, b) May 26, 2025
@chaObserv
Copy link

chaObserv commented May 28, 2025

It's in the implementation details of the paper. They use an approximate value 0.5 for UniP-2 and UniC-1 to avoid solving the equation.

@Beinsezii
Copy link
Contributor Author

It's in the implementation details of the paper. They use an approximate value 0.5 for UniP-2 and UniC-1 to avoid solving the equation.

Hm, it's strange to me that it's off by something like a double digit % vs linalg.solve(). The perf impact of such an op in a denoise is basically immeasurable even for sd15.

@Beinsezii Beinsezii closed this Jun 1, 2025
@Beinsezii Beinsezii deleted the beinsezii/unipc_matrix_solve branch June 1, 2025 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants