-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscripts.py
21 lines (19 loc) · 945 Bytes
/
scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) 2018-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
def shredFacts(facts, hiss, raw_facts=None): #takes a batch of facts and shreds it into its columns
heads = torch.tensor(facts[:,0]).long().cuda()
rels = torch.tensor(facts[:,1]).long().cuda()
tails = torch.tensor(facts[:,2]).long().cuda()
years = torch.tensor(facts[:,3]).float().cuda()
months = torch.tensor(facts[:,4]).float().cuda()
days = torch.tensor(facts[:,5]).float().cuda()
yearsid = torch.tensor(raw_facts[:,6]).long().cuda()
monthsid = torch.tensor(raw_facts[:,7]).long().cuda()
daysid = torch.tensor(raw_facts[:,8]).long().cuda()
hiss = torch.tensor(hiss).long().cuda()
return heads, rels, tails, years, months, days, yearsid, monthsid, daysid, hiss