Skip to content

Missed optimisation/UBSAN in chacha reference implementation #144964

Open
@nabijaczleweli

Description

@nabijaczleweli

This is a variant on https://cr.yp.to/streamciphers/timings/estreambench/submissions/salsa20/chacha8/merged/chacha.c; the same problem plagues both:

/*
chacha-merged.c version 20080118
D. J. Bernstein
Public domain.
*/

#include <memory.h>
#include <stdint.h>
#include <unistd.h>
#include <stdio.h>
#include <endian.h>


#if !defined(BYTE_ORDER) || !defined(LITTLE_ENDIAN) || !defined(BIG_ENDIAN)
static_assert(false, "BYTE_ORDER is undefined. Perhaps, GNU extensions are not enabled");
#endif

#define IDENT16(x) ((uint16_t) (x))
#define IDENT32(x) ((uint32_t) (x))
#define IDENT64(x) ((uint64_t) (x))

#define SWAP16(x) ((((uint16_t) (x) & 0x00ff) << 8) | \
  (((uint16_t) (x) & 0xff00) >> 8))
#define SWAP32(x) ((((uint32_t) (x) & 0x000000ff) << 24) | \
  (((uint32_t) (x) & 0x0000ff00) <<  8) | \
  (((uint32_t) (x) & 0x00ff0000) >>  8) | \
  (((uint32_t) (x) & 0xff000000) >> 24))
#define SWAP64(x) ((((uint64_t) (x) & 0x00000000000000ff) << 56) | \
  (((uint64_t) (x) & 0x000000000000ff00) << 40) | \
  (((uint64_t) (x) & 0x0000000000ff0000) << 24) | \
  (((uint64_t) (x) & 0x00000000ff000000) <<  8) | \
  (((uint64_t) (x) & 0x000000ff00000000) >>  8) | \
  (((uint64_t) (x) & 0x0000ff0000000000) >> 24) | \
  (((uint64_t) (x) & 0x00ff000000000000) >> 40) | \
  (((uint64_t) (x) & 0xff00000000000000) >> 56))

static inline uint32_t rol32(uint32_t x, int r) {
  return (x << (r & 31)) | (x >> (-r & 31));
}

#if BYTE_ORDER == LITTLE_ENDIAN
#define SWAP16LE IDENT16
#define SWAP16BE SWAP16
#define swap16le ident16
#define swap16be swap16
#define mem_inplace_swap16le mem_inplace_ident
#define mem_inplace_swap16be mem_inplace_swap16
#define memcpy_swap16le memcpy_ident16
#define memcpy_swap16be memcpy_swap16
#define SWAP32LE IDENT32
#define SWAP32BE SWAP32
#define swap32le ident32
#define swap32be swap32
#define mem_inplace_swap32le mem_inplace_ident
#define mem_inplace_swap32be mem_inplace_swap32
#define memcpy_swap32le memcpy_ident32
#define memcpy_swap32be memcpy_swap32
#define SWAP64LE IDENT64
#define SWAP64BE SWAP64
#define swap64le ident64
#define swap64be swap64
#define mem_inplace_swap64le mem_inplace_ident
#define mem_inplace_swap64be mem_inplace_swap64
#define memcpy_swap64le memcpy_ident64
#define memcpy_swap64be memcpy_swap64
#endif

#if BYTE_ORDER == BIG_ENDIAN
#define SWAP16BE IDENT16
#define SWAP16LE SWAP16
#define swap16be ident16
#define swap16le swap16
#define mem_inplace_swap16be mem_inplace_ident
#define mem_inplace_swap16le mem_inplace_swap16
#define memcpy_swap16be memcpy_ident16
#define memcpy_swap16le memcpy_swap16
#define SWAP32BE IDENT32
#define SWAP32LE SWAP32
#define swap32be ident32
#define swap32le swap32
#define mem_inplace_swap32be mem_inplace_ident
#define mem_inplace_swap32le mem_inplace_swap32
#define memcpy_swap32be memcpy_ident32
#define memcpy_swap32le memcpy_swap32
#define SWAP64BE IDENT64
#define SWAP64LE SWAP64
#define swap64be ident64
#define swap64le swap64
#define mem_inplace_swap64be mem_inplace_ident
#define mem_inplace_swap64le mem_inplace_swap64
#define memcpy_swap64be memcpy_ident64
#define memcpy_swap64le memcpy_swap64
#endif


/*
 * The following macros are used to obtain exact-width results.
 */
#define U8V(v) ((uint8_t)(v) & UINT8_C(0xFF))
#define U32V(v) ((uint32_t)(v) & UINT32_C(0xFFFFFFFF))

/*
 * The following macros load words from an array of bytes with
 * different types of endianness, and vice versa.
 */
#define U8TO32_LITTLE(p) SWAP32LE(((uint32_t*)(p))[0])
#define U32TO8_LITTLE(p, v) (((uint32_t*)(p))[0] = SWAP32LE(v))

#define ROTATE(v,c) (rol32(v,c))
#define XOR(v,w) ((v) ^ (w))
#define PLUS(v,w) (U32V((v) + (w)))
#define PLUSONE(v) (PLUS((v),1))

#define QUARTERROUND(a,b,c,d) \
  a = PLUS(a,b); d = ROTATE(XOR(d,a),16); \
  c = PLUS(c,d); b = ROTATE(XOR(b,c),12); \
  a = PLUS(a,b); d = ROTATE(XOR(d,a), 8); \
  c = PLUS(c,d); b = ROTATE(XOR(b,c), 7);

static const char sigma[] = "expand 32-byte k";

static void chacha(unsigned rounds, const void* data, size_t length, const uint8_t* key, const uint8_t* iv, char* cipher) {
  uint32_t x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
  uint32_t j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11, j12, j13, j14, j15;
  char* ctarget = 0;
  char tmp[64];
  int i;

  if (!length) return;

  j0  = U8TO32_LITTLE(sigma + 0);
  j1  = U8TO32_LITTLE(sigma + 4);
  j2  = U8TO32_LITTLE(sigma + 8);
  j3  = U8TO32_LITTLE(sigma + 12);
  j4  = U8TO32_LITTLE(key + 0);
  j5  = U8TO32_LITTLE(key + 4);
  j6  = U8TO32_LITTLE(key + 8);
  j7  = U8TO32_LITTLE(key + 12);
  j8  = U8TO32_LITTLE(key + 16);
  j9  = U8TO32_LITTLE(key + 20);
  j10 = U8TO32_LITTLE(key + 24);
  j11 = U8TO32_LITTLE(key + 28);
  j12 = 0;
  j13 = 0;
  j14 = U8TO32_LITTLE(iv + 0);
  j15 = U8TO32_LITTLE(iv + 4);

  for (;;) {
    if (length < 64) {
      memcpy(tmp, data, length);
      data = tmp;
      ctarget = cipher;
      cipher = tmp;
    }
    x0  = j0;
    x1  = j1;
    x2  = j2;
    x3  = j3;
    x4  = j4;
    x5  = j5;
    x6  = j6;
    x7  = j7;
    x8  = j8;
    x9  = j9;
    x10 = j10;
    x11 = j11;
    x12 = j12;
    x13 = j13;
    x14 = j14;
    x15 = j15;
    for (i = rounds;i > 0;i -= 2) {
      QUARTERROUND( x0, x4, x8,x12)
      QUARTERROUND( x1, x5, x9,x13)
      QUARTERROUND( x2, x6,x10,x14)
      QUARTERROUND( x3, x7,x11,x15)
      QUARTERROUND( x0, x5,x10,x15)
      QUARTERROUND( x1, x6,x11,x12)
      QUARTERROUND( x2, x7, x8,x13)
      QUARTERROUND( x3, x4, x9,x14)
    }
    x0  = PLUS( x0, j0);
    x1  = PLUS( x1, j1);
    x2  = PLUS( x2, j2);
    x3  = PLUS( x3, j3);
    x4  = PLUS( x4, j4);
    x5  = PLUS( x5, j5);
    x6  = PLUS( x6, j6);
    x7  = PLUS( x7, j7);
    x8  = PLUS( x8, j8);
    x9  = PLUS( x9, j9);
    x10 = PLUS(x10,j10);
    x11 = PLUS(x11,j11);
    x12 = PLUS(x12,j12);
    x13 = PLUS(x13,j13);
    x14 = PLUS(x14,j14);
    x15 = PLUS(x15,j15);

    x0  = XOR( x0,U8TO32_LITTLE((uint8_t*)data +  0));
    x1  = XOR( x1,U8TO32_LITTLE((uint8_t*)data +  4));
    x2  = XOR( x2,U8TO32_LITTLE((uint8_t*)data +  8));
    x3  = XOR( x3,U8TO32_LITTLE((uint8_t*)data + 12));
    x4  = XOR( x4,U8TO32_LITTLE((uint8_t*)data + 16));
    x5  = XOR( x5,U8TO32_LITTLE((uint8_t*)data + 20));
    x6  = XOR( x6,U8TO32_LITTLE((uint8_t*)data + 24));
    x7  = XOR( x7,U8TO32_LITTLE((uint8_t*)data + 28));
    x8  = XOR( x8,U8TO32_LITTLE((uint8_t*)data + 32));
    x9  = XOR( x9,U8TO32_LITTLE((uint8_t*)data + 36));
    x10 = XOR(x10,U8TO32_LITTLE((uint8_t*)data + 40));
    x11 = XOR(x11,U8TO32_LITTLE((uint8_t*)data + 44));
    x12 = XOR(x12,U8TO32_LITTLE((uint8_t*)data + 48));
    x13 = XOR(x13,U8TO32_LITTLE((uint8_t*)data + 52));
    x14 = XOR(x14,U8TO32_LITTLE((uint8_t*)data + 56));
    x15 = XOR(x15,U8TO32_LITTLE((uint8_t*)data + 60));

    j12 = PLUSONE(j12);
    if (!j12)
    {
      j13 = PLUSONE(j13);
      /* stopping at 2^70 bytes per iv is user's responsibility */
    }

    U32TO8_LITTLE(cipher +  0,x0);
    U32TO8_LITTLE(cipher +  4,x1);
    U32TO8_LITTLE(cipher +  8,x2);
    U32TO8_LITTLE(cipher + 12,x3);
    U32TO8_LITTLE(cipher + 16,x4);
    U32TO8_LITTLE(cipher + 20,x5);
    U32TO8_LITTLE(cipher + 24,x6);
    U32TO8_LITTLE(cipher + 28,x7);
    U32TO8_LITTLE(cipher + 32,x8);
    U32TO8_LITTLE(cipher + 36,x9);
    U32TO8_LITTLE(cipher + 40,x10);
    U32TO8_LITTLE(cipher + 44,x11);
    U32TO8_LITTLE(cipher + 48,x12);
    U32TO8_LITTLE(cipher + 52,x13);
    U32TO8_LITTLE(cipher + 56,x14);
    U32TO8_LITTLE(cipher + 60,x15);

    if (length <= 64) {
      if (length < 64) {
        memcpy(ctarget, cipher, length);
      }
      return;
    }
    length -= 64;
    cipher += 64;
    data = (uint8_t*)data + 64;
  }
}

static void chacha20(const void* data, size_t length, const uint8_t* key, const uint8_t* iv, char* cipher)
{
  chacha(20, data, length, key, iv, cipher);
}

static const uint8_t data[65];
static const uint8_t key[32] = {
  (uint8_t)'\x25', (uint8_t)'\xe6', (uint8_t)'\xd1', (uint8_t)'\xe0', (uint8_t)'\x16', (uint8_t)'\x4f', (uint8_t)'\xfd', (uint8_t)'\x98', (uint8_t)'\x64', (uint8_t)'\x55', (uint8_t)'\xd8', (uint8_t)'\x46', (uint8_t)'\x67', (uint8_t)'\x7c', (uint8_t)'\x71', (uint8_t)'\xf1',
  (uint8_t)'\xaa', (uint8_t)'\x0b', (uint8_t)'\xef', (uint8_t)'\xa8', (uint8_t)'\x19', (uint8_t)'\x4f', (uint8_t)'\x84', (uint8_t)'\xca', (uint8_t)'\xd6', (uint8_t)'\x20', (uint8_t)'\x23', (uint8_t)'\xf4', (uint8_t)'\x1e', (uint8_t)'\x6a', (uint8_t)'\x1a', (uint8_t)'\x63',
};
static const uint8_t iv[8] = {
  (uint8_t)'\x50', (uint8_t)'\xe2', (uint8_t)'\x1c', (uint8_t)'\xe0', (uint8_t)'\x50', (uint8_t)'\x7a', (uint8_t)'\x03', (uint8_t)'\x48',
};

int main() {
  char outbuf[sizeof(data)];
  chacha20(data, sizeof(data), key, iv, outbuf);
  write(1, outbuf, sizeof(outbuf));
}

This compiles cleanly with cc -Wall -Wextra -fsanitize=undefined -O3 (21.0.0 (++20250519112653+d0ee35851bb9-1exp120250519112844.1459)).

It is relatively glaring that:

  1. when length < 64
  2. char tmp[64]; (uninitialised)
  3. memcpy(tmp, data, length); initialises tmp[0..length]
  4. x0 = XOR( x0,U8TO32_LITTLE((uint8_t*)data + 0)); &c. reads tmp[0..64] ‒ this is UB

therefore length < 64 can never hold so the ifs are dead. Clang misses this and compiles them both in.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions