-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
StringSpanOrdinalKey.cs
200 lines (166 loc) · 7.45 KB
/
StringSpanOrdinalKey.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.ML.Tokenizers
{
/// <summary>Used as a key in a dictionary to enable querying with either a string or a span.</summary>
/// <remarks>
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
/// always be used with a string.
/// </remarks>
[JsonConverter(typeof(StringSpanOrdinalKeyConverter))]
internal readonly unsafe struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
{
public readonly char* Ptr;
public readonly int Length;
public readonly string? Data;
public StringSpanOrdinalKey(char* ptr, int length)
{
Ptr = ptr;
Length = length;
}
public StringSpanOrdinalKey(string data) =>
Data = data;
private ReadOnlySpan<char> Span => Ptr is not null ?
new ReadOnlySpan<char>(Ptr, Length) :
Data.AsSpan();
public override string ToString() => Data ?? Span.ToString();
public override bool Equals(object? obj) =>
obj is StringSpanOrdinalKey wrapper && Equals(wrapper);
public bool Equals(StringSpanOrdinalKey other) =>
Span.SequenceEqual(other.Span);
public override int GetHashCode() => Helpers.GetHashCode(Span);
}
internal readonly unsafe struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
{
private readonly StringSpanOrdinalKey _left;
private readonly StringSpanOrdinalKey _right;
public StringSpanOrdinalKeyPair(char* ptr1, int length1, char* ptr2, int length2)
{
_left = new StringSpanOrdinalKey(ptr1, length1);
_right = new StringSpanOrdinalKey(ptr2, length2);
}
public StringSpanOrdinalKeyPair(string data1, string data2)
{
_left = new StringSpanOrdinalKey(data1);
_right = new StringSpanOrdinalKey(data2);
}
public override bool Equals(object? obj) =>
obj is StringSpanOrdinalKeyPair wrapper && wrapper._left.Equals(_left) && wrapper._right.Equals(_right);
public bool Equals(StringSpanOrdinalKeyPair other) => other._left.Equals(_left) && other._right.Equals(_right);
public override int GetHashCode() => HashCode.Combine(_left.GetHashCode(), _right.GetHashCode());
}
internal sealed class StringSpanOrdinalKeyCache<TValue>
{
private readonly int _capacity;
private readonly Dictionary<StringSpanOrdinalKey, TValue> _map;
private object SyncObj => _map;
internal StringSpanOrdinalKeyCache() : this(BpeTokenizer.DefaultCacheCapacity) { }
internal StringSpanOrdinalKeyCache(int capacity)
{
_capacity = capacity;
_map = new Dictionary<StringSpanOrdinalKey, TValue>(capacity);
}
internal bool TryGetValue(string key, out TValue value)
{
lock (SyncObj)
{
return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
}
}
internal unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
{
lock (SyncObj)
{
fixed (char* ptr = key)
{
return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
}
}
}
internal void Remove(string key)
{
lock (SyncObj)
{
_map.Remove(new StringSpanOrdinalKey(key));
}
}
internal void Set(string k, TValue v)
{
lock (SyncObj)
{
if (_map.Count < _capacity)
{
_map[new StringSpanOrdinalKey(k)] = v;
}
}
}
}
[JsonConverter(typeof(VocabularyConverter))]
internal sealed class Vocabulary : Dictionary<StringSpanOrdinalKey, (int, string)>;
/// <summary>
/// Custom JSON converter for <see cref="StringSpanOrdinalKey"/>.
/// </summary>
internal sealed class StringSpanOrdinalKeyConverter : JsonConverter<StringSpanOrdinalKey>
{
public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
new StringSpanOrdinalKey(reader.GetString()!);
public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) =>
writer.WriteStringValue(value.Data!);
public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!);
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
}
internal class VocabularyConverter : JsonConverter<Vocabulary>
{
public override Vocabulary Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var dictionary = new Vocabulary();
while (reader.Read())
{
if (reader.TokenType == JsonTokenType.EndObject)
{
return dictionary;
}
if (reader.TokenType == JsonTokenType.PropertyName)
{
var key = reader.GetString();
reader.Read();
var value = reader.GetInt32();
dictionary.Add(new StringSpanOrdinalKey(key!), (value, key!));
}
}
throw new JsonException("Invalid JSON.");
}
public override void Write(Utf8JsonWriter writer, Vocabulary value, JsonSerializerOptions options) => throw new NotImplementedException();
}
/// <summary>
/// Extension methods for <see cref="StringSpanOrdinalKey"/>.
/// </summary>
internal static class StringSpanOrdinalKeyExtensions
{
public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
{
fixed (char* ptr = key)
{
return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
}
}
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
{
fixed (char* ptr1 = key1)
fixed (char* ptr2 = key2)
{
return map.TryGetValue(new StringSpanOrdinalKeyPair(ptr1, key1.Length, ptr2, key2.Length), out value!);
}
}
public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, string key1, string key2, out TValue value) =>
map.TryGetValue(new StringSpanOrdinalKeyPair(key1, key2), out value!);
}
}