-
Notifications
You must be signed in to change notification settings - Fork 1
/
unify.py
125 lines (114 loc) · 3.2 KB
/
unify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession
from pyspark import SparkConf
def process(sc, input_cases, input_metadata, num_partitions):
df_cases = (
sc.read.parquet(input_cases)
.select(
F.col("doi_a"),
F.col("doi_b"),
F.col("begin_a"),
F.col("end_a"),
F.col("text_a"),
F.col("before_a"),
F.col("after_a"),
F.col("doc_length_a"),
F.col("begin_b"),
F.col("end_b"),
F.col("text_b"),
F.col("before_b"),
F.col("after_b"),
F.col("doc_length_b"),
)
)
df_metadata = (
sc.read.parquet(input_metadata)
.select(
F.col("doi"),
F.col("year"),
F.col("board"),
F.col("area"),
F.col("discipline")
)
)
df = (
df_cases.alias("cases")
.join(
df_metadata.select(
F.col("doi"),
F.col("year").alias("year_a"),
F.col("board").alias("board_a"),
F.col("area").alias("area_a"),
F.col("discipline").alias("discipline_a")
).alias("metadata"),
F.col("cases.doi_a") == F.col("metadata.doi"),
'left'
)
.drop("doi")
.join(
df_metadata.select(
F.col("doi"),
F.col("year").alias("year_b"),
F.col("board").alias("board_b"),
F.col("area").alias("area_b"),
F.col("discipline").alias("discipline_b")
).alias("metadata"),
F.col("cases.doi_b") == F.col("metadata.doi"),
'left'
)
.drop("doi")
)
df = (
df
.select(
F.col("doi_a"),
F.col("begin_a"),
F.col("end_a"),
F.col("text_a"),
F.col("before_a"),
F.col("after_a"),
F.col("doc_length_a"),
F.col("year_a"),
F.col("board_a").alias("field_a"),
F.col("area_a"),
F.col("discipline_a"),
F.col("doi_b"),
F.col("begin_b"),
F.col("end_b"),
F.col("text_b"),
F.col("before_b"),
F.col("after_b"),
F.col("doc_length_b"),
F.col("year_b"),
F.col("board_b").alias("field_b"),
F.col("area_b"),
F.col("discipline_b"),
)
)
return df.coalesce(num_partitions)
def run(sc, args):
case_path = args[0]
metadata_path = args[1]
process(sc, case_path, metadata_path, 1)
if __name__ == '__main__':
# spark session
spark = (
SparkSession
.builder
.config(conf=SparkConf())
.getOrCreate()
)
# argument
INPUT_CASES = "stereo-aligned.parquet/*/*"
INPUT_METADATA = "stereo-oag.parquet/*"
NUM_PARTITIONS = 50
OUTPUT = "stereo-corpus-final.jsonl"
# process
(
process(spark, INPUT_CASES, INPUT_METADATA, NUM_PARTITIONS)
.write
.mode("overwrite")
.option("compression", "gzip")
.json(OUTPUT)
)