-
Notifications
You must be signed in to change notification settings - Fork 25
/
vec.lua
264 lines (237 loc) · 6.22 KB
/
vec.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
local ffi = require 'ffi'
local tds = require 'tds.env'
local elem = require 'tds.elem'
local C = tds.C
-- vec-independent temporary buffers
local val__ = C.tds_elem_new()
ffi.gc(val__, C.tds_elem_free)
local vec = {}
local NULL = not jit and ffi.C.NULL or nil
local mt = {}
function mt:insert(...)
local lkey, lval
if select('#', ...) == 1 then
lkey, lval = #self+1, select(1, ...)
elseif select('#', ...) == 2 then
lkey, lval = select(1, ...), select(2, ...)
else
error('[key] value expected')
end
assert(self)
assert(type(lkey) == 'number' and lkey > 0, 'positive number expected as key')
if lval or type(lval) == 'boolean' then
elem.set(val__, lval)
else
C.tds_elem_set_nil(val__)
end
if C.tds_vec_insert(self, lkey-1, val__) == 1 then
error('out of memory')
end
end
function mt:remove(lkey)
lkey = lkey or #self
assert(self)
assert(type(lkey) == 'number' and lkey > 0, 'positive number expected as key')
C.tds_vec_remove(self, lkey-1)
end
function mt:resize(size)
assert(type(size) == 'number' and size >= 0, 'size must be a positive number')
C.tds_vec_resize(self, size)
end
if pcall(require, 'torch') then
function mt:concatstorage(sep, i, j)
i = i or 1
j = j or #self
local sepsize = 0
if sep then
sep = torch.CharStorage():string(sep)
sepsize = sep:size()
end
local buffer = torch.CharStorage()
local size = 0
for k=i,j do
local str = tostring(self[k])
assert(str, 'vector elements must return a non-nil tostring()')
str = torch.CharStorage():string(str)
local strsize = str:size()
if size+strsize+sepsize > buffer:size() then
buffer:resize(math.max(buffer:size()*1.5, size+strsize+sepsize))
end
if sep and size > 0 then
local view = torch.CharStorage(buffer, size+1, sepsize)
view:copy(sep)
size = size + sepsize
end
local view = torch.CharStorage(buffer, size+1, strsize)
view:copy(str)
size = size + strsize
end
buffer:resize(size)
return buffer
end
function mt:concat(sep, i, j)
return self:concatstorage(sep, i, j):string()
end
end
function mt:sort(compare)
if type(compare) == 'function' then
local function compare__(cval1, cval2)
local lval1, lval2
if C.tds_elem_isnil(cval1) == 0 then
lval1 = elem.get(cval1)
end
if C.tds_elem_isnil(cval2) == 0 then
lval2 = elem.get(cval2)
end
return compare(lval1, lval2) and -1 or 1
end
local cb_compare__ = ffi.cast('int (*)(tds_elem*, tds_elem*)', compare__)
C.tds_vec_sort(self, cb_compare__)
cb_compare__:free()
else -- you must know what you are doing
assert(compare ~= nil, 'compare function must be a lua or C function')
C.tds_vec_sort(self, compare)
end
end
local function isvec(tbl)
for k, v in pairs(tbl) do
if type(k) ~= 'number' then
return false
end
end
return true
end
local function fill(self, tbl)
assert(isvec(tbl), 'lua table with number keys expected')
for key, val in pairs(tbl) do
if type(val) == 'table' then
if isvec(val) then
self[key] = tds.Vec(val)
else
self[key] = tds.Hash(val)
end
else
self[key] = val
end
end
end
function vec:__new(...) -- beware of the :
local self = C.tds_vec_new()
if self == NULL then
error('unable to allocate vec')
end
self = ffi.cast('tds_vec&', self)
ffi.gc(self, C.tds_vec_free)
if select('#', ...) == 1 and type(select(1, ...)) == 'table' then
fill(self, select(1, ...))
elseif select('#', ...) > 0 then
fill(self, {...})
end
return self
end
function vec:__newindex(lkey, lval)
assert(self)
assert(type(lkey) == 'number' and lkey > 0, 'positive number expected as key')
if lval or type(lval) == 'boolean' then
elem.set(val__, lval)
else
C.tds_elem_set_nil(val__)
end
if C.tds_vec_set(self, lkey-1, val__) == 1 then
error('out of memory')
end
end
function vec:__index(lkey)
local lval
assert(self)
if type(lkey) == 'number' then
assert(lkey > 0, 'positive number expected as key')
C.tds_vec_get(self, lkey-1, val__)
if C.tds_elem_isnil(val__) == 0 then
lval = elem.get(val__)
end
else
local method = rawget(mt, lkey)
if method then
return method
else
error('invalid key (number) or method name')
end
end
return lval
end
function vec:__len()
assert(self)
return tonumber(C.tds_vec_size(self))
end
function vec:__ipairs()
assert(self)
local k = 0
return function()
k = k + 1
if k <= C.tds_vec_size(self) then
return k, self[k]
end
end
end
vec.__pairs = vec.__ipairs
ffi.metatype('tds_vec', vec)
if pcall(require, 'torch') and torch.metatype then
function vec:__write(f)
f:writeLong(#self)
for k,v in ipairs(self) do
f:writeObject(v)
end
end
function vec:__read(f)
local n = f:readLong()
for k=1,n do
local v = f:readObject()
self[k] = v
end
end
vec.__factory = vec.__new
vec.__version = 0
torch.metatype('tds.Vec', vec, 'tds_vec&')
end
function vec:__tostring()
local str = {}
table.insert(str, string.format('tds.Vec[%d]{', #self))
for k,v in ipairs(self) do
local kstr = string.format("%5d : ", tostring(k))
local vstr = tostring(v) or type(v)
local sp = string.rep(' ', #kstr)
local i = 0
vstr = vstr:gsub(
'([^\n]+)',
function(line)
i = i + 1
if i == 1 then
return kstr .. line
else
return sp .. line
end
end
)
table.insert(str, vstr)
if k == 20 then
table.insert(str, '...')
break
end
end
table.insert(str, '}')
return table.concat(str, '\n')
end
-- table constructor
local vec_ctr = {}
setmetatable(
vec_ctr,
{
__index = vec,
__newindex = vec,
__call = vec.__new
}
)
tds.vec = vec_ctr
tds.Vec = vec_ctr
return vec_ctr