Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

q4_matmul_cuda kernel does not yield reproducible results #153

Closed
fxmarty opened this issue Jul 13, 2023 · 2 comments
Closed

q4_matmul_cuda kernel does not yield reproducible results #153

fxmarty opened this issue Jul 13, 2023 · 2 comments

Comments

@fxmarty
Copy link

fxmarty commented Jul 13, 2023

Hi,

I see a slight deviation in the output of q4_matmul_cuda between diffferents calls with the same input. Is it expected? If so, why?

The absolute deviation is in the order of 0.04%, and from what's I've seen it does not influence the generated output. Simply the logits differ.

The issue does not happen when calling instead q4_matmul_recons_cuda (just change inp = torch.rand(1, 1, hidden_size, dtype=torch.float16).to(device) to inp = torch.rand(1, 9, hidden_size, dtype=torch.float16).to(device) in the example below).

Related: #73

Reproduction:
Download https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GPTQ to a local repository, and then:

from safetensors import safe_open

from model import Ex4bitLinear, ExLlamaConfig
import torch
import torch.nn as nn
import cuda_ext
import hashlib
import copy

path = "/fsx/felix/WizardLM-7B-uncensored-GPTQ/WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors"
with safe_open(path, framework="pt", device=0) as f:
    for k in f.keys():
        if k.startswith("model.layers.0.self_attn.v_proj"):
            key = ".".join(k.split(".")[:-1])  # otherwise we get model.layers.0.self_attn.v_proj.g_idx.scales
            scales = f.get_tensor(key + ".scales")
            qzeros = f.get_tensor(key + ".qzeros")
            qweight = f.get_tensor(key + ".qweight")
            break

config = ExLlamaConfig("../WizardLM-7B-uncensored-GPTQ/config.json")
config.set_tuning_params()

device = torch.device("cuda:0")

max_dq_buffer_size = 45088768

temp_state = torch.zeros((config.max_seq_len, config.intermediate_size), dtype = torch.float16, device = device)
temp_mlp = torch.zeros((config.fused_mlp_thd * 2, config.intermediate_size), dtype = torch.float16, device = device)
temp_zeros_float = torch.zeros((1, 65536), dtype = torch.float32, device = device)
temp_dq = torch.zeros((1, max_dq_buffer_size), dtype = torch.float16, device = device)

cuda_ext.exllama_ext.prepare_buffers(device,
                                    temp_state,
                                    temp_mlp,
                                    temp_zeros_float,
                                    temp_dq)

tensors = {
    "v_proj.qweight": qweight.to(device),
    "v_proj.qzeros": qzeros.to(device),
    "v_proj.scales": scales.to(device),
}

hidden_size = qweight.shape[0] * 8

q4linear = Ex4bitLinear(config, hidden_size, qweight.shape[1], False, tensors, "v_proj")
inp = torch.rand(1, 1, hidden_size, dtype=torch.float16).to(device)

prec = None
with torch.no_grad():
    # warmup (whatever)
    _ = q4linear.forward(inp, None)

    for i in range(10):
        real_inp = copy.deepcopy(inp)  # make sure there's no fuzy in place op
        res = q4linear.forward(real_inp, None)

        if i > 0:
            print(f"Mean abs diff with previous result: {((res - prec) / (prec.abs() + 1e-5)).abs().mean() * 10:.4f} %")

        prec = copy.deepcopy(res)

        h = hashlib.new('sha256')
        h.update(str(res).encode())
        sha_hash = h.hexdigest()
        print("Hash:", sha_hash[:8])
        print("Argmax:", torch.argmax(res).item())
        print("Some res:", res[0, 0, 0:5])

Result:

Hash: df39e735
Argmax: 809
Some: tensor([-0.0050, -0.0982,  0.2544, -0.0842,  0.2179], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0233 %
Hash: bad7ae24
Argmax: 809
Some: tensor([-0.0050, -0.0984,  0.2544, -0.0845,  0.2180], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0243 %
Hash: c12facc3
Argmax: 809
Some: tensor([-0.0050, -0.0983,  0.2544, -0.0845,  0.2177], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.1443 %
Hash: 166472ea
Argmax: 809
Some: tensor([-0.0051, -0.0983,  0.2544, -0.0842,  0.2184], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0588 %
Hash: d48faa34
Argmax: 809
Some: tensor([-0.0052, -0.0984,  0.2544, -0.0842,  0.2180], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0303 %
Hash: 0e53c270
Argmax: 809
Some: tensor([-0.0051, -0.0983,  0.2544, -0.0841,  0.2180], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0214 %
Hash: a5656d6c
Argmax: 809
Some: tensor([-0.0050, -0.0984,  0.2542, -0.0844,  0.2180], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0486 %
Hash: bad7ae24
Argmax: 809
Some: tensor([-0.0050, -0.0984,  0.2544, -0.0839,  0.2179], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0430 %
Hash: 014b5485
Argmax: 809
Some: tensor([-0.0051, -0.0983,  0.2544, -0.0841,  0.2179], device='cuda:0',
       dtype=torch.float16)
Mean abs diff with previous result: 0.0381 %
Hash: 4917bbf1
Argmax: 809
Some: tensor([-0.0050, -0.0983,  0.2544, -0.0837,  0.2180], device='cuda:0',
       dtype=torch.float16)

Edit: could it be because of a different atomicAdd order?

@turboderp
Copy link
Owner

The q4 matmul kernel isn't strictly deterministic due to the non-associativity of floating-point addition and CUDA providing no guarantees about the order in which blocks in a grid are processed. It's essentially an artifact of relying on atomicAdd.

Ways to mitigate it would be either switching to a reduction method (as used by cuBLAS in the reconstruction version) or computing an intermediate result in FP32 before downcasting to FP16. Both of those methods add VRAM and compute overhead, though, so it wouldn't make sense without first establishing that a difference on the order of 0.04% actually matters. I would question the use of quantization at all for applications where it does, given there's a significantly greater loss going from FP16 to GPTQ in the first place.

@fxmarty
Copy link
Author

fxmarty commented Jul 13, 2023

Thanks, indeed it looks to be the atomicAdd: https://forums.developer.nvidia.com/t/get-different-results-for-every-running-with-atomicadd/229649/2

I haven't seen a case where such a small difference matters, I just catched it in my CI and wondered why my logits slightly differed. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants