-
Notifications
You must be signed in to change notification settings - Fork 36
/
kg_link_tail_prediction.py
304 lines (273 loc) · 12.5 KB
/
kg_link_tail_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""A processor for the knowledge graph link tail prediction task."""
from __future__ import annotations
from collections.abc import Iterable
import json
from typing import Any
from explainaboard import TaskType
from explainaboard.analysis import feature
from explainaboard.analysis.analyses import Analysis, AnalysisLevel, BucketAnalysis
from explainaboard.analysis.case import AnalysisCase
from explainaboard.analysis.feature_funcs import count_tokens
from explainaboard.info import SysOutputInfo
from explainaboard.metrics.metric import MetricConfig, MetricStats
from explainaboard.metrics.ranking import (
HitsConfig,
MeanRankConfig,
MeanReciprocalRankConfig,
RankingMetric,
)
from explainaboard.processors.processor import Processor
from explainaboard.processors.processor_registry import register_processor
from explainaboard.utils import cache_api
from explainaboard.utils.logging import progress
from explainaboard.utils.typing_utils import narrow
@register_processor(TaskType.kg_link_tail_prediction)
class KGLinkTailPredictionProcessor(Processor):
"""A processor for the knowledge graph link tail prediction task."""
@classmethod
def task_type(cls) -> TaskType:
"""See Processor.task_type."""
return TaskType.kg_link_tail_prediction
def default_analysis_levels(self) -> list[AnalysisLevel]:
"""See Processor.default_analysis_levels."""
features = {
"true_head": feature.Value(dtype=feature.DataType.STRING),
"true_head_decipher": feature.Value(dtype=feature.DataType.STRING),
"true_link": feature.Value(
dtype=feature.DataType.STRING, description="the relation type"
),
"true_tail": feature.Value(dtype=feature.DataType.STRING),
"true_tail_decipher": feature.Value(dtype=feature.DataType.STRING),
"predict": feature.Value(dtype=feature.DataType.STRING),
"predictions": feature.Sequence(
feature=feature.Value(dtype=feature.DataType.STRING)
),
"tail_entity_length": feature.Value(
dtype=feature.DataType.FLOAT,
description="length of the tail entity in tokens",
func=lambda info, x, c: count_tokens(
info, x['true_tail_decipher'], side='target'
),
),
"head_entity_length": feature.Value(
dtype=feature.DataType.FLOAT,
description="length of the head entity in tokens",
func=lambda info, x, c: count_tokens(
info, x['true_head_decipher'], side='target'
),
),
"tail_fre": feature.Value(
dtype=feature.DataType.FLOAT,
description="average frequency of the tail entity",
require_training_set=True,
func=lambda info, x, c, stat: stat['tail_fre'].get(
x['true_tail_decipher'], 0
),
),
"link_fre": feature.Value(
dtype=feature.DataType.FLOAT,
description="frequency of relation in training set",
require_training_set=True,
func=lambda info, x, c, stat: stat['link_fre'].get(x['true_link'], 0),
),
"head_fre": feature.Value(
dtype=feature.DataType.FLOAT,
description="frequency of head entity in training set",
require_training_set=True,
func=lambda info, x, c, stat: stat['head_fre'].get(
x['true_head_decipher'], 0
),
),
"symmetry": feature.Value(
dtype=feature.DataType.STRING,
description="whether the relation is symmetric",
func=lambda info, x, c: 'symmetric'
if x['true_link'] in self._symmetric_relations
else 'asymmetric',
),
"entity_type_level": feature.Value(
dtype=feature.DataType.STRING,
description="most specific entity type level of the true tail entity",
func=lambda info, x, c: self._get_entity_type_level(x),
),
}
return [
AnalysisLevel(
name='example',
features=features,
metric_configs=self.default_metrics(),
)
]
def default_analyses(self) -> list[Analysis]:
"""See Processor.default_analyses."""
analysis_levels = self.default_analysis_levels()
features = analysis_levels[0].features
discrete_features = {'symmetry': 2, 'entity_type_level': 8, 'true_link': 15}
analyses: list[Analysis] = [
BucketAnalysis(
level=analysis_levels[0].name,
description=features[k].description,
feature=k,
method="discrete",
number=v,
)
for k, v in discrete_features.items()
]
analyses.extend(self.continuous_feature_analyses())
return analyses
@classmethod
def default_metrics(
cls, level='example', source_language=None, target_language=None
) -> list[MetricConfig]:
"""See Processor.default_metrics."""
return [
HitsConfig(name='Hits1', hits_k=1),
HitsConfig(name='Hits2', hits_k=2),
HitsConfig(name='Hits3', hits_k=3),
HitsConfig(name='Hits5', hits_k=5),
HitsConfig(name='Hits10', hits_k=10),
MeanReciprocalRankConfig(name='MRR'),
MeanRankConfig(name='MR'),
]
# TODO: is this the best place to put this?
_symmetric_relations = {
'/base/popstra/celebrity/breakup./base/popstra/breakup/participant',
'/base/popstra/celebrity/canoodled./base/popstra/canoodled/participant',
'/base/popstra/celebrity/dated./base/popstra/dated/participant',
'/base/popstra/celebrity/friendship./base/popstra/friendship/participant',
'/celebrities/celebrity/celebrity_friends./celebrities/friendship/friend',
'/celebrities/celebrity/sexual_relationships./celebrities/romantic_relationship/celebrity', # noqa: E501
'/influence/influence_node/peers./influence/peer_relationship/peers',
'/location/location/adjoin_s./location/adjoining_relationship/adjoins',
'/people/person/spouse_s./people/marriage/spouse',
'/people/person/sibling_s./people/sibling relationship/sibling',
}
def __init__(self) -> None:
"""Constructor."""
super().__init__()
file_path = cache_api.cache_online_file(
'https://storage.googleapis.com/inspired-public-data/'
'explainaboard/task_data/kg_link_tail_prediction/entity2wikidata.json',
'explainaboard/task_data/kg_link_tail_prediction/entity2wikidata.json',
)
with open(file_path, 'r') as file:
self.entity_type_level_map: dict = json.load(file)
def _statistics_func(self, samples: Iterable[Any], sys_info: SysOutputInfo) -> dict:
"""See Processor._statistics_func."""
dict_head: dict[str, int] = {}
dict_link: dict[str, int] = {}
dict_tail: dict[str, int] = {}
for sample in progress(samples):
tail = sample['true_tail_decipher']
dict_tail[tail] = dict_tail.get(tail, 0) + 1
head = sample['true_head_decipher']
dict_head[head] = dict_head.get(head, 0) + 1
link = sample['true_link']
dict_link[link] = dict_link.get(link, 0) + 1
return {
"head_fre": dict_head,
"link_fre": dict_link,
"tail_fre": dict_tail,
}
def _gen_cases_and_stats(
self,
sys_info: SysOutputInfo,
sys_output: list[dict],
statistics: Any,
analysis_level: AnalysisLevel,
) -> tuple[list[AnalysisCase], list[MetricStats]]:
# Note that this is overridden to calculate stats from rank
cases = []
true_data = [self._get_true_label(x) for x in sys_output]
pred_data = [self._get_predicted_label(x) for x in sys_output]
rank_data = [narrow(int, x.get('true_rank')) for x in sys_output]
if any(item is None for item in rank_data):
raise ValueError(
'Some data points do not have rank information; check system outputs.'
)
metric_stats = []
for metric in [x.to_metric() for x in analysis_level.metric_configs]:
if isinstance(metric, RankingMetric):
metric_stats.append(metric.calc_stats_from_rank(rank_data))
else:
metric_stats.append(metric.calc_stats_from_data(true_data, pred_data))
# Calculate features
for i, output in progress(
enumerate(sys_output), desc='calculating example-level features'
):
case = AnalysisCase(sample_id=i, features={})
for feat_name, feat_spec in analysis_level.features.items():
if feat_spec.func is None:
case.features[feat_name] = output[feat_name]
elif not feat_spec.require_training_set:
case.features[feat_name] = feat_spec.func(sys_info, output, case)
elif statistics is not None:
case.features[feat_name] = feat_spec.func(
sys_info, output, case, statistics
)
cases.append(case)
return cases, metric_stats
# TODO(gneubig): this needs replaced
# def _gen_metric_stats(
# self,
# sys_info: SysOutputInfo,
# sys_output: list[dict],
# cases: list[list[AnalysisCase]],
# ) -> list[list[MetricStats]]:
# """Generate sufficient statistics for scoring different metrics.
# :param sys_info: Information about the system outputs
# :param sys_output: The system output itself
# :return: Statistics sufficient for scoring
# """
# metrics = [
# x.to_metric() for x in unwrap(sys_info.analysis_levels)[0].metric_configs
# ]
# true_data = [self._get_true_label(x) for x in sys_output]
# pred_data = [self._get_predicted_label(x) for x in sys_output]
# rank_data = [
# x.get('true_rank') for x in sys_output
# ] # rank of true entity in predictions
# if any(item is None for item in rank_data):
# raise ValueError(
# 'Some data points do not have rank information; check system outputs.'
# )
# metric_stats = []
# for metric in metrics:
# if (
# isinstance(metric, MeanReciprocalRank)
# or isinstance(metric, MeanRank)
# or isinstance(metric, Hits)
# ):
# metric_stats.append(metric.calc_stats_from_rank(rank_data))
# else:
# metric_stats.append(metric.calc_stats_from_data(true_data, pred_data))
# return [metric_stats]
# --- Feature functions accessible by ExplainaboardBuilder._get_feature_func()
def _get_entity_type_level(self, existing_features: dict):
# entities not found in `entity_type_level_map` get bucketed to this value.
# in FB15k, "0" is the same as the most generic entity type, "Thing".
default_level = "0"
# entities not found in `entity_type_level_map` get bucketed to this value.
# in FB15k, "0" is the same as the most generic entity type, "Thing".
default_level = "0"
# list of entity types at each level:
# [type_level_0, type_level_1, ... type_level_6]
# e.g. ["Thing", "Agent", "Person", None, None, None, None]
tail_entity_type_levels = self.entity_type_level_map.get(
existing_features['true_tail'], None
)
if tail_entity_type_levels is None:
return default_level # entity types not found
# find the index of the first occurrence of None in the list
if None in tail_entity_type_levels:
most_specific_level = tail_entity_type_levels.index(None) - 1
else: # tail has entity types at every level
most_specific_level = len(tail_entity_type_levels) - 1
return str(most_specific_level)
# --- End feature functions
def _get_true_label(self, data_point: dict):
"""See processor._get_true_label."""
return data_point["true_" + data_point["predict"]]
def _get_predicted_label(self, data_point: dict):
"""See processor._get_predicted_label."""
return data_point["predictions"]