-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdialogueGenerator.py
executable file
·112 lines (92 loc) · 3.5 KB
/
dialogueGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from dataModule import DataStorage,DataLoader
from entities import Dialogue
from templateConstructor import TemplateConstructor
import config as cfg
import random
from math import gcd
import numpy as np
import copy
# random.seed(1250)
# np.random.seed(1250)
class DataGenerator():
def __init__(self,args):
self.args = args
self.dataLoader = DataLoader(args)
self.dataStorer = DataStorage(args)
self.prepareData()
def prepareData(self):
self.templateList = self.dataLoader.loadTemplates()
self.slotDict = self.dataLoader.loadSlotValues()
self.questions = self.dataLoader.loadQuestions()
def generate(self):
dialogues = {}
fail = 0
count = 0
while fail< self.args["max_number_tries_dial"] and count < self.args["max_dialogue_count"]:
if(count%1000 ==0):
print(count)
dialogue = Dialogue()
#Choose a template to fill and extract queries
queryDict, template = self.chooseTemplate()
#Assign values to the queries from the data
self.fillQueryDict(queryDict, dialogue)
#Create dialogue from template and assigned values
dialogue.form(template,queryDict)
if dialogue not in dialogues:
dialogues[dialogue] = 1
count +=1
else:
fail =+1
dialogue.terminateDialogue()
for dialogue in list(dialogues.keys()):
self.dataStorer.addToTree(dialogue)
self.dataStorer.store()
return dialogues
def fillQueryDict(self, queryDict, dialouge):
#Pop values used so that they are used once in each dialogue.
slotDict = copy.deepcopy(self.slotDict)
for key in list(queryDict.keys()):
phrase = self.questions.returnRandomPhrase(key)
randomValue = random.choice(slotDict[key.suffix])
slotDict[key.suffix].remove(randomValue)
queryDict[key] = randomValue
# Saved for later use, where we need to return a query dictionary with the dialogue
dialouge.addQueryItem(key,phrase,randomValue)
def chooseTemplate(self):
template = self.chooseFairly(self.templateList)
template.selectedCount += 1
tempLateQueries = template.queryDict
queryDict = {}
for value in list(tempLateQueries.values()):
queryDict[value] = ""
return queryDict, template
def chooseFairly(self,templateList):
choosenTemplate = None
lcm = templateList[0].selectedCount
for temp in templateList:
lcm = lcm * temp.selectedCount // (gcd(lcm,temp.selectedCount))
sumWeights = 0
for temp in templateList:
temp.probWeight =lcm / temp.selectedCount
sumWeights += temp.probWeight
for temp in templateList:
temp.probWeight /= sumWeights
rand_num = random.random()
for temp in templateList:
if rand_num < temp.probWeight:
choosenTemplate = temp
break
rand_num = rand_num - temp.probWeight
return choosenTemplate
def main():
args = cfg.get_args()
if args["new_templates"]:
tc = TemplateConstructor(args)
tc.initFill()
tc.construct()
tc.makeTree()
tc.store()
dg = DataGenerator(args)
dg.generate()
if __name__ == "__main__":
main()