Skip to content

Commit

Permalink
Handle range lock sector
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy-visionaid committed Dec 9, 2024
1 parent 2aca570 commit 48d5b21
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 20 deletions.
24 changes: 24 additions & 0 deletions OpenMcdf.Tests/RootStorageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,28 @@ public void V3ThrowsIOExceptionAt2GB()

Assert.ThrowsException<IOException>(() => stream.Write(buffer, 0, buffer.Length));
}

[TestMethod]
[DoNotParallelize] // High memory usage
public void ValidateRangeLockSector()
{
RecyclableMemoryStreamManager manager = new();
using RecyclableMemoryStream baseStream = new(manager);
baseStream.Capacity64 = RootContext.RangeLockSectorOffset;

using var rootStorage = RootStorage.Create(baseStream, Version.V4);
using (CfbStream stream = rootStorage.CreateStream("Test"))
{
byte[] buffer = TestData.CreateByteArray(4096);
while (baseStream.Length <= RootContext.RangeLockSectorOffset)
stream.Write(buffer, 0, buffer.Length);
}

Assert.IsTrue(rootStorage.Validate());

rootStorage.Delete("Test");
rootStorage.Flush();

Assert.IsTrue(rootStorage.Validate());
}
}
36 changes: 29 additions & 7 deletions OpenMcdf/Fat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@ internal sealed class Fat : ContextBase, IEnumerable<FatEntry>, IDisposable
Sector cachedSector = Sector.EndOfChain;
private bool isDirty;

public Func<FatEntry, bool> IsUsed { get; }

public Fat(RootContextSite rootContextSite)
: base(rootContextSite)
{
fatSectorEnumerator = new(rootContextSite);
cachedSectorBuffer = new byte[Context.SectorSize];

if (Context.Version == Version.V3)
IsUsed = entry => entry.Value is not SectorType.Free;
else if (Context.Version == Version.V4)
IsUsed = entry => entry.Value is not SectorType.Free && entry.Index is not RootContext.RangeLockSectorId;
else
throw new NotSupportedException($"Unsupported major version: {Context.Version}.");
}

public void Dispose()
Expand Down Expand Up @@ -143,14 +152,14 @@ public uint Add(FatEnumerator fatEnumerator, uint startIndex)

public Sector GetLastUsedSector()
{
uint lastUsedSectorIndex = uint.MaxValue;
FatEntry lastUsedSectorIndex = new(uint.MaxValue, uint.MaxValue);
foreach (FatEntry entry in this)
{
if (!entry.IsFree)
lastUsedSectorIndex = entry.Index;
if (IsUsed(entry))
lastUsedSectorIndex = entry;
}

return new(lastUsedSectorIndex, Context.SectorSize);
return new(lastUsedSectorIndex.Index, Context.SectorSize);
}

public IEnumerator<FatEntry> GetEnumerator() => new FatEnumerator(Context.Fat);
Expand All @@ -172,7 +181,7 @@ internal void WriteTrace(TextWriter writer)
foreach (FatEntry entry in this)
{
Sector sector = new(entry.Index, Context.SectorSize);
if (entry.IsFree)
if (entry.Value is SectorType.Free)
{
freeCount++;
writer.WriteLine($"{entry}");
Expand All @@ -194,7 +203,7 @@ internal void WriteTrace(TextWriter writer)
}

[ExcludeFromCodeCoverage]
internal void Validate()
internal bool Validate()
{
long fatSectorCount = 0;
long difatSectorCount = 0;
Expand All @@ -213,8 +222,21 @@ internal void Validate()
throw new FileFormatException($"FAT sector count mismatch. Expected: {Context.Header.FatSectorCount} Actual: {fatSectorCount}.");
if (Context.Header.DifatSectorCount != difatSectorCount)
throw new FileFormatException($"DIFAT sector count mismatch: Expected: {Context.Header.DifatSectorCount} Actual: {difatSectorCount}.");

if (Context.Length < RootContext.RangeLockSectorOffset)
{
if (this.TryGetValue(RootContext.RangeLockSectorId, out uint value) && value != SectorType.Free)
throw new FileFormatException($"Range lock FAT entry is not free.");
}
else
{
if (this[RootContext.RangeLockSectorId] != SectorType.EndOfChain)
throw new FileFormatException($"Range lock sector is not at the end of the chain.");
}

return true;
}

[ExcludeFromCodeCoverage]
internal long GetFreeSectorCount() => this.Count(entry => entry.IsFree);
internal long GetFreeSectorCount() => this.Count(entry => entry.Value == SectorType.Free);
}
2 changes: 0 additions & 2 deletions OpenMcdf/FatEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ namespace OpenMcdf;
/// </summary>
internal record struct FatEntry(uint Index, uint Value)
{
public readonly bool IsFree => Value == SectorType.Free;

[ExcludeFromCodeCoverage]
public override readonly string ToString() => $"#{Index}: {Value}";
}
2 changes: 1 addition & 1 deletion OpenMcdf/FatEnumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public bool MoveNextFreeEntry()
{
while (MoveNext())
{
if (value == SectorType.Free)
if (value is SectorType.Free)
return true;
}

Expand Down
5 changes: 4 additions & 1 deletion OpenMcdf/FatStream.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace OpenMcdf;
using System.Diagnostics;

namespace OpenMcdf;

/// <summary>
/// Provides a <inheritdoc cref="Stream"/> for a stream object in a compound file./>
Expand Down Expand Up @@ -185,6 +187,7 @@ public override void Write(byte[] buffer, int offset, int count)
long writeLength = Math.Min(remaining, sector.Length - sectorOffset);
writer.Write(buffer, localOffset, (int)writeLength);
Context.ExtendStreamLength(sector.EndPosition);
Debug.Assert(Context.Length >= Context.Stream.Length);
position += writeLength;
writeCount += (int)writeLength;
sectorOffset = 0;
Expand Down
6 changes: 4 additions & 2 deletions OpenMcdf/MiniFat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public uint Add(MiniFatEnumerator miniFatEnumerator, uint startIndex)
FatEntry entry = miniFatEnumerator.Current;
this[entry.Index] = SectorType.EndOfChain;

Debug.Assert(entry.IsFree);
Debug.Assert(entry.Value is SectorType.Free);
MiniSector miniSector = new(entry.Index, Context.MiniSectorSize);
if (Context.MiniStream.Length < miniSector.EndPosition)
Context.MiniStream.SetLength(miniSector.EndPosition);
Expand All @@ -153,7 +153,7 @@ internal void WriteTrace(TextWriter writer)
}

[ExcludeFromCodeCoverage]
internal void Validate()
internal bool Validate()
{
using MiniFatEnumerator miniFatEnumerator = new(ContextSite);

Expand All @@ -165,5 +165,7 @@ internal void Validate()
throw new FileFormatException($"Mini FAT entry {current} is beyond the end of the mini stream.");
}
}

return true;
}
}
16 changes: 12 additions & 4 deletions OpenMcdf/RootContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ enum IOContextFlags
/// </summary>
internal sealed class RootContext : ContextBase, IDisposable
{
const long MaximumV3StreamLength = 2147483648;
internal const long MaximumV3StreamLength = 2147483648;
internal const uint RangeLockSectorOffset = 0x7FFFFF00;
internal const uint RangeLockSectorId = RangeLockSectorOffset / (1 << Header.SectorShiftV4) - 1;

readonly IOContextFlags contextFlags;
readonly CfbBinaryWriter? writer;
Expand Down Expand Up @@ -187,11 +189,14 @@ public void Flush()

public void ExtendStreamLength(long length)
{
if (Length >= length)
return;

if (Version is Version.V3 && length > MaximumV3StreamLength)
throw new IOException("V3 compound files are limited to 2 GB.");

if (Length < length)
Length = length;
else if (Version is Version.V4 && Length < RangeLockSectorOffset && length >= RangeLockSectorOffset)
Fat[RangeLockSectorId] = SectorType.EndOfChain;
Length = length;
}

void TrimBaseStream()
Expand All @@ -200,6 +205,9 @@ void TrimBaseStream()
if (!lastUsedSector.IsValid)
throw new FileFormatException("Last used sector is invalid");

if (Version is Version.V4 && lastUsedSector.EndPosition < RangeLockSectorOffset)
Fat.TrySetValue(RangeLockSectorId, SectorType.Free);

Length = lastUsedSector.EndPosition;
BaseStream.SetLength(Length);
}
Expand Down
7 changes: 4 additions & 3 deletions OpenMcdf/RootStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,11 @@ internal void Trace(TextWriter writer)
Context.MiniFat.WriteTrace(writer);
}

// TODO: Move checks to Tests project as Asserts
[ExcludeFromCodeCoverage]
internal void Validate()
internal bool Validate()
{
Context.Fat.Validate();
Context.MiniFat.Validate();
return Context.Fat.Validate()
&& Context.MiniFat.Validate();
}
}

0 comments on commit 48d5b21

Please sign in to comment.