-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathenv.lua
55 lines (51 loc) · 1.34 KB
/
env.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
local env = {}
-- user configurable function
function env.istype(obj, typename)
local mt = getmetatable(obj)
if type(mt) == 'table' then
local objtype = rawget(mt, '__typename')
if objtype then
return objtype == typename
end
end
return type(obj) == typename
end
function env.type(obj)
local mt = getmetatable(obj)
if type(mt) == 'table' then
local objtype = rawget(mt, '__typename')
if objtype then
return objtype
end
end
return type(obj)
end
-- torch specific
if pcall(require, 'torch') then
function env.istype(obj, typename)
local thname = torch.typename(obj)
if thname then
-- __typename (see below) might be absent
local match = thname:match(typename)
if match and (match ~= typename or match == thname) then
return true
end
local mt = torch.getmetatable(thname)
while mt do
if mt.__typename then
match = mt.__typename:match(typename)
if match and (match ~= typename or match == mt.__typename) then
return true
end
end
mt = getmetatable(mt)
end
return false
end
return type(obj) == typename
end
function env.type(obj)
return torch.type(obj)
end
end
return env