-
Notifications
You must be signed in to change notification settings - Fork 1
/
generalization_plot.py
91 lines (59 loc) · 2.23 KB
/
generalization_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import cPickle as pickle
import matplotlib.pyplot as plt
import numpy as np
import sys
# Check whether the name of the pickle file
# was provided as an argument
if len(sys.argv) > 1:
fname = sys.argv[1]
else:
print "Require a file containing data to generate the plot"
exit(0)
def pg_barplot(pg_list, title):
"""
Creates a histogram to visualize the probability
of generalization and saves it in a directory
Parameters
----------
pg_list: list
A list containing the probabilities of generalization for an
example 'y' from the subordinate, basic and superordiante categories
title: string
The title fo the plot (also used as filename to save the plot)
"""
ind = np.arange(len(pg_list))
width = 1
plt.figure(figsize=(5, 6), facecolor='white')
plt.bar(ind, pg_list, width, color='grey') #, yerr=menStd)
plt.xlabel('Categories', fontsize=11)
plt.ylabel('Probability of generalization', fontsize=12)
plt.title(title)
plt.xticks(ind + width/2., ['sub', 'basic', 'super'])
#plt.show()
plt.savefig('./generalization_plots/' + title, fontsize=12)
plt.close()
return plt
# Read the data from the pickle file
data = pickle.load(open(fname, 'rb'))
vegetable_1sub = data['vegetable_1sub']
vegetable_3sub = data['vegetable_3sub']
vegetable_3basic = data['vegetable_3basic']
vegetable_3sup = data['vegetable_3sup']
vehicle_1sub = data['vehicle_1sub']
vehicle_3sub = data['vehicle_3sub']
vehicle_3basic = data['vehicle_3basic']
vehicle_3sup = data['vehicle_3sup']
animal_1sub = data['animal_1sub']
animal_3sub = data['animal_3sub']
animal_3basic = data['animal_3basic']
animal_3sup = data['animal_3sup']
# Average the data from three clusters
sub1 = [sum(x)/3.0 for x in zip(vegetable_1sub, vehicle_1sub, animal_1sub)]
sub3 = [sum(x)/3.0 for x in zip(vegetable_3sub, vehicle_3sub, animal_3sub)]
basic3 = [sum(x)/3.0 for x in zip(vegetable_3basic, vehicle_3basic, animal_3basic)]
sup3 = [sum(x)/3.0 for x in zip(vegetable_3sup, vehicle_3sup, animal_3sup)]
# plot the averaged generalization results
pg_barplot(sub1, '1 sub')
pg_barplot(sub3, '3 sub')
pg_barplot(basic3, '3 basic')
pg_barplot(sup3, '3 sup')