-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscanner.go
153 lines (121 loc) · 3.3 KB
/
scanner.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
//
// Copyright (C) 2024 Dmitry Kolesnikov
//
// This file may be modified and distributed under the terms
// of the MIT license. See the LICENSE file for details.
// https://github.com/kshard/wreck
//
package wreck
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
)
// basic vector decoder
type decoder[T any] struct {
fmap func([]uint8) []T
maxUniqueKey int
maxSortKey int
maxVector int
}
func newDecoder[T any]() decoder[T] {
codec := decoder[T]{}
switch any(*new(T)).(type) {
case float32:
codec.fmap = toFloat32
case uint8:
codec.fmap = toUint8
}
return codec
}
func (codec *decoder[T]) WithDecoder(fmap func([]uint8) []T) {
codec.fmap = fmap
}
func (codec *decoder[T]) WithMaxUniqueKey(v int) {
codec.maxUniqueKey = v
}
func (codec *decoder[T]) WithMaxSortKey(v int) {
codec.maxSortKey = v
}
func (codec *decoder[T]) WithMaxVector(v int) {
codec.maxVector = v
}
func (codec *decoder[T]) decode(chunk *Chunk) ([]uint8, []uint8, []T, error) {
uniqueKey := chunk.UniqueKey
sortKey := chunk.SortKey
vec := codec.fmap(chunk.Vector)
if codec.maxUniqueKey > 0 && len(uniqueKey) > codec.maxUniqueKey {
return nil, nil, nil, fmt.Errorf("length exceeded : uniqueKey (%d)", len(uniqueKey))
}
if codec.maxSortKey > 0 && len(sortKey) > codec.maxSortKey {
return nil, nil, nil, fmt.Errorf("length exceeded : sortKey (%d)", len(sortKey))
}
if codec.maxVector > 0 && len(vec) > codec.maxVector {
return nil, nil, nil, fmt.Errorf("length exceeded : vector (%d)", len(vec))
}
return uniqueKey, sortKey, vec, nil
}
//------------------------------------------------------------------------------
// Vector stream
type Scanner[T any] struct {
decoder[T]
r io.Reader
uniqueKey []uint8
sortKey []uint8
vec []T
err error
}
func NewScanner[T any](r io.Reader) *Scanner[T] {
return &Scanner[T]{
decoder: newDecoder[T](),
r: r,
}
}
func (codec *Scanner[T]) Err() error { return codec.err }
func (codec *Scanner[T]) UniqueKey() []uint8 { return codec.uniqueKey }
func (codec *Scanner[T]) SortKey() []uint8 { return codec.sortKey }
func (codec *Scanner[T]) Vector() []T { return codec.vec }
func (codec *Scanner[T]) Scan() bool {
var chunk Chunk
if err := Decode(codec.r, &chunk); err != nil {
if !errors.Is(err, io.EOF) {
codec.err = err
return false
}
return false
}
codec.uniqueKey, codec.sortKey, codec.vec, codec.err = codec.decode(&chunk)
return codec.err == nil
}
//------------------------------------------------------------------------------
// Decode binary packet to vector
type Decoder[T any] struct{ decoder[T] }
func NewDecoder[T any]() *Decoder[T] {
return &Decoder[T]{
decoder: newDecoder[T](),
}
}
func (codec *Decoder[T]) Decode(pack []byte, uniqueKey, sortKey *[]uint8, vec *[]T) (err error) {
var chunk Chunk
if err := Decode(bytes.NewBuffer(pack), &chunk); err != nil {
return err
}
*uniqueKey, *sortKey, *vec, err = codec.decode(&chunk)
return
}
//------------------------------------------------------------------------------
func toFloat32[T any](b []uint8) []T {
v := make([]float32, len(b)/4)
p := 0
for i := 0; i < len(b); i += 4 {
v[p] = math.Float32frombits(binary.LittleEndian.Uint32(b[i : i+4]))
p++
}
return any(v).([]T)
}
func toUint8[T any](b []uint8) []T {
return any(b).([]T)
}