Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Rfc2898DeriveBytes to support spans #71888

Merged
merged 2 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,12 @@ private static void FillManaged(
Span<byte> destination)
{
using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(
password.ToArray(),
salt.ToArray(),
password,
salt,
iterations,
hashAlgorithmName,
clearPassword: true))
hashAlgorithmName))
{
byte[] result = deriveBytes.GetBytes(destination.Length);
result.AsSpan().CopyTo(destination);
deriveBytes.GetBytes(destination);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ public static unsafe void Fill(
Debug.Assert(hashAlgorithmName.Name is not null);

using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(
password.ToArray(),
salt.ToArray(),
password,
salt,
iterations,
hashAlgorithmName,
clearPassword: true))
hashAlgorithmName))
{
byte[] result = deriveBytes.GetBytes(destination.Length);
result.AsSpan().CopyTo(destination);
deriveBytes.GetBytes(destination);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public partial class Rfc2898DeriveBytes : DeriveBytes
{
private byte[] _salt;
private uint _iterations;
private HMAC _hmac;
private IncrementalHash _hmac;
private readonly int _blockSize;

private byte[] _buffer;
Expand Down Expand Up @@ -84,34 +84,36 @@ public Rfc2898DeriveBytes(string password, int saltSize, int iterations, HashAlg
HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac(passwordBytes);
CryptographicOperations.ZeroMemory(passwordBytes);
// _blockSize is in bytes, HashSize is in bits.
_blockSize = _hmac.HashSize >> 3;
_blockSize = _hmac.HashLengthInBytes;

Initialize();
}

internal Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm, bool clearPassword)
internal Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm, bool clearPassword) :
this(
new ReadOnlySpan<byte>(password ?? throw new NullReferenceException()), // This "should" be ArgumentNullException but for compat, we throw NullReferenceException.
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
new ReadOnlySpan<byte>(salt ?? throw new ArgumentNullException(nameof(salt))),
iterations,
hashAlgorithm)
{
ArgumentNullException.ThrowIfNull(salt);
if (clearPassword)
{
CryptographicOperations.ZeroMemory(password);
}
}

internal Rfc2898DeriveBytes(ReadOnlySpan<byte> password, ReadOnlySpan<byte> salt, int iterations, HashAlgorithmName hashAlgorithm)
{
if (iterations <= 0)
throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum);
if (password is null)
throw new NullReferenceException(); // This "should" be ArgumentNullException but for compat, we throw NullReferenceException.

_salt = new byte[salt.Length + sizeof(uint)];
salt.AsSpan().CopyTo(_salt);
salt.CopyTo(_salt);
_iterations = (uint)iterations;
HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac(password);

if (clearPassword)
{
CryptographicOperations.ZeroMemory(password);
}

// _blockSize is in bytes, HashSize is in bits.
_blockSize = _hmac.HashSize >> 3;
_blockSize = _hmac.HashLengthInBytes;
Initialize();
}

Expand Down Expand Up @@ -167,27 +169,35 @@ protected override void Dispose(bool disposing)

public override byte[] GetBytes(int cb)
{
Debug.Assert(_blockSize > 0);

if (cb <= 0)
throw new ArgumentOutOfRangeException(nameof(cb), SR.ArgumentOutOfRange_NeedPosNum);
byte[] password = new byte[cb];

byte[] ret = new byte[cb];
GetBytes(ret);
return ret;
}

internal void GetBytes(Span<byte> destination)
{
Debug.Assert(_blockSize > 0);
int cb = destination.Length;
int offset = 0;
int size = _endIndex - _startIndex;
ReadOnlySpan<byte> bufferSpan = _buffer;

if (size > 0)
{
if (cb >= size)
{
Buffer.BlockCopy(_buffer, _startIndex, password, 0, size);
bufferSpan.Slice(_startIndex, size).CopyTo(destination);
_startIndex = _endIndex = 0;
offset += size;
}
else
{
Buffer.BlockCopy(_buffer, _startIndex, password, 0, cb);
bufferSpan.Slice(_startIndex, cb).CopyTo(destination);
_startIndex += cb;
return password;
return;
}
}

Expand All @@ -199,18 +209,17 @@ public override byte[] GetBytes(int cb)
int remainder = cb - offset;
if (remainder >= _blockSize)
{
Buffer.BlockCopy(_buffer, 0, password, offset, _blockSize);
bufferSpan.Slice(0, _blockSize).CopyTo(destination.Slice(offset));
offset += _blockSize;
}
else
{
Buffer.BlockCopy(_buffer, 0, password, offset, remainder);
bufferSpan.Slice(0, remainder).CopyTo(destination.Slice(offset));
_startIndex = remainder;
_endIndex = _buffer.Length;
return password;
return;
}
}
return password;
}

[Obsolete(Obsoletions.Rfc2898CryptDeriveKeyMessage, DiagnosticId = Obsoletions.Rfc2898CryptDeriveKeyDiagId, UrlFormat = Obsoletions.SharedUrlFormat)]
Expand All @@ -230,26 +239,25 @@ public override void Reset()
Initialize();
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA5350", Justification = "HMACSHA1 is needed for compat. (https://github.com/dotnet/runtime/issues/17618)")]
private HMAC OpenHmac(byte[] password)
private IncrementalHash OpenHmac(ReadOnlySpan<byte> password)
{
Debug.Assert(password != null);

HashAlgorithmName hashAlgorithm = HashAlgorithm;

if (string.IsNullOrEmpty(hashAlgorithm.Name))
{
throw new CryptographicException(SR.Cryptography_HashAlgorithmNameNullOrEmpty);
}

if (hashAlgorithm == HashAlgorithmName.SHA1)
return new HMACSHA1(password);
if (hashAlgorithm == HashAlgorithmName.SHA256)
return new HMACSHA256(password);
if (hashAlgorithm == HashAlgorithmName.SHA384)
return new HMACSHA384(password);
if (hashAlgorithm == HashAlgorithmName.SHA512)
return new HMACSHA512(password);
// Restrict the HashAlgorithmName to known hashes, particularly excluding MD5.
if (hashAlgorithm != HashAlgorithmName.SHA1 &&
hashAlgorithm != HashAlgorithmName.SHA256 &&
hashAlgorithm != HashAlgorithmName.SHA384 &&
hashAlgorithm != HashAlgorithmName.SHA512)
{
throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithm.Name));
}

throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithm.Name));
return IncrementalHash.CreateHMAC(hashAlgorithm, password);
}

[MemberNotNull(nameof(_buffer))]
Expand Down Expand Up @@ -281,20 +289,17 @@ private void Func()
//
Span<byte> uiSpan = stackalloc byte[64];
uiSpan = uiSpan.Slice(0, _blockSize);

if (!_hmac.TryComputeHash(_salt, uiSpan, out int bytesWritten) || bytesWritten != _blockSize)
{
throw new CryptographicException();
}
_hmac.AppendData(_salt);
int bytesWritten = _hmac.GetHashAndReset(uiSpan);
Debug.Assert(bytesWritten == _blockSize);

uiSpan.CopyTo(_buffer);

for (int i = 2; i <= _iterations; i++)
{
if (!_hmac.TryComputeHash(uiSpan, uiSpan, out bytesWritten) || bytesWritten != _blockSize)
{
throw new CryptographicException();
}
_hmac.AppendData(uiSpan);
bytesWritten = _hmac.GetHashAndReset(uiSpan);
Debug.Assert(bytesWritten == _blockSize);

for (int j = _buffer.Length - 1; j >= 0; j--)
{
Expand Down