Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
yiduyton committed Sep 10, 2023
1 parent 22b1a9d commit 314beff
Show file tree
Hide file tree
Showing 7 changed files with 513 additions and 5 deletions.
37 changes: 36 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,37 @@
# CodeGen4Libs
We will release the complete project soon.

### Benchmark Format
Benchmark has been meticulously structured and saved in the DatasetDict format, accessible at [Dataset and Models of CodeGen4Libs](https://zenodo.org/record/7920906#.ZFyPm-xByDV). The specific data fields for each task are delineated as follows:

- id
- method
- clean_method
- doc
- comment
- method_name
- extra
- license
- path
- repo_name
- size
- imports_info
- libraries_info

- input_str
- attention_mask
- input_ids
- tokenized_input_str
- input_token_length
- labels
- tokenized_labels_str
- labels_token_length

- retrieved_imports_info
- generated_imports_info
- union_gen_ret_imports_info
- intersection_gen_ret_imports_info
- similar_code

- decoded_labels
- predictions
- decoded_preds
1 change: 1 addition & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://zenodo.org/record/7920906#.ZFyPm-xByDV
1 change: 0 additions & 1 deletion dataset/README.md

This file was deleted.

4 changes: 1 addition & 3 deletions build_corpus.py → generation/corpus_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import jsonlines
import pickle
import string
from collections import defaultdict
import random

import re
from collections import defaultdict
from tqdm import tqdm

from script.extract_method import ProjectMethodExtractor

from sckg.util.path_util import PathUtil
from sckg.util.log_util import LogUtil

Expand Down
134 changes: 134 additions & 0 deletions generation/dataset_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pickle
from tqdm import tqdm
from datasets import Dataset, DatasetDict
from project.util.path_util import PathUtil
from project.util.logs_util import LogsUtil

logger = LogsUtil.get_logs_util()


def filter_with_upper(version: str, upper: int):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"]
data = []
other_data = []
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f:
count_lib = pickle.load(f)
chose_libs = count_lib.keys()
count_lib = {_: 0 for _ in count_lib.copy().keys()}
for row in tqdm(dataset):
lib_size = len(row["libraries"])
is_append = False
if all(
any(lib.startswith(_) for _ in ("java.", "javax.", "android.", "androidx.")) for lib in row["libraries"]
):
other_data.append(row)
continue
if any(lib not in chose_libs for lib in row["libraries"]):
continue
for lib in row["libraries"]:
if count_lib[lib] >= upper:
continue
# 按优先级过滤JDK&SDK
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
continue
if not any(lib.startswith(_) for _ in ("java.", "javax.", "android.", "androidx.")):
count_lib[lib] += 1
is_append = True
if is_append:
data.append(row)
else:
other_data.append(row)
dataset = DatasetDict()
dataset["train"] = Dataset.from_list(data)
dataset.save_to_disk(PathUtil.datasets(f"{version}_{upper}/filter-github-code-java-libs"))

other_dataset = DatasetDict()
other_dataset["train"] = Dataset.from_list(other_data)
other_dataset.save_to_disk(PathUtil.datasets(f"{version}_{upper}/other-github-code-java-libs"))
with open(PathUtil.datasets(f"{version}_{upper}/train-github-code-java-libs.txt"), "w") as file:
for lib, count in count_lib.items():
file.write(lib + ", " + str(count) + "\n")
with open(PathUtil.datasets(f"{version}_{upper}/count_lib.bin"), "wb") as file:
pickle.dump({lib: count for lib, count in count_lib.items()}, file)


def split_data(version: str, ration: float = 0.02, test_size: int = None):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"]
if test_size is None:
test_size = len(dataset) * ration
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f:
count_lib = pickle.load(f)
validation_dataset, train_dataset, test_dataset = [], [], []
lib_count_4_validation = {_: 0 for _ in count_lib.copy().keys()}
lib_count_4_test = lib_count_4_validation.copy()
nl_set_4_validation, nl_set_4_test = set(), set()
for row in tqdm(dataset):
lib_size = len(row["libraries"])
# 按库划分数据集
is_append_validation = False
for lib in row["libraries"]:
if lib not in count_lib:
continue
if lib_count_4_validation[lib] >= count_lib[lib] * ration:
break
# 按优先级过滤JDK&SDK
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
continue
lib_count_4_validation[lib] += 1
is_append_validation = True
if is_append_validation:
validation_dataset.append(row)
nl_set_4_validation.add(row["comment"] + row["libraries_info"])
continue
is_append_test = False
for lib in row["libraries"]:
if lib not in count_lib:
continue
if lib_count_4_test[lib] >= count_lib[lib] * ration:
break
# 按优先级过滤JDK&SDK
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
continue
lib_count_4_test[lib] += 1
is_append_test = True
if is_append_test:
test_dataset.append(row)
nl_set_4_test.add(row["comment"] + row["libraries_info"])
continue
# 同NL采集
if (
row["comment"] + row["libraries_info"] in nl_set_4_validation
or row["comment"] + row["libraries_info"] in nl_set_4_test
):
logger.info(row["comment"] + row["libraries_info"])
continue
train_dataset.append(row)
dataset = DatasetDict()
dataset["train"] = Dataset.from_list(train_dataset)
dataset["validation"] = Dataset.from_list(validation_dataset)
dataset["test"] = Dataset.from_list(test_dataset)
dataset.save_to_disk(PathUtil.datasets(f"{version}/github-code-java-libs"))


def slim_data(version: str):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/github-code-java-libs"))
def chunk_examples(examples):
return {
"input_ids": examples["input_ids"],
"attention_mask": examples["attention_mask"],
"labels": examples["labels"],
}

dataset = dataset.map(chunk_examples, batched=True)
dataset = dataset.map(chunk_examples, batched=True, remove_columns=dataset["train"].column_names)
dataset.save_to_disk(PathUtil.datasets(f"{version}/slim-github-code-java-libs"))


if __name__ == "__main__":
# with open(PathUtil.datasets(f"latest_0,800000_5000/count_lib.bin"), "rb") as file:
# data = pickle.load(file)
# print(data)

# filter_with_upper("latest_400000,600000", 5000)

split_data("latest_0,400000_5000", ration=0.02)
187 changes: 187 additions & 0 deletions generation/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import argparse
import pickle
from tqdm import tqdm
from collections import defaultdict
from datasets import Dataset, DatasetDict, load_dataset
from project.dataset.process import DataProcessUtil
from project.util.path_util import PathUtil
from project.util.logs_util import LogsUtil

logger = LogsUtil.get_logs_util()

DATA_TYPE_TRAIN = "train"
DATA_TYPE_VALID = "validation"
DATA_TYPE_TEST = "test"


def convert_data(version: str):
dataset = load_dataset("json", data_files=PathUtil.datasets(f"github-code-java-libs-{version}.json"))
dataset.save_to_disk(PathUtil.datasets(f"{version}/raw-github-code-java-libs"))


def process_data(version: str, input_r: str, label_r: str):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/raw-github-code-java-libs"))
tokenized_dataset = dataset.map(
lambda x: DataProcessUtil.preprocess_function_with_connect(x, input_r=input_r, label_r=label_r),
batched=True,
load_from_cache_file=False,
)
tokenized_dataset.save_to_disk(PathUtil.datasets(f"{version}/processed-github-code-java-libs"))


def filter_with_token_length(version: str, max_length: int = 384):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/processed-github-code-java-libs"))
filter_dataset = dataset.filter(
lambda x: x["labels_token_length"] <= max_length and x["input_token_length"] <= max_length
)
filter_dataset.save_to_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))
analyse_data(version, "train", filter_dataset, is_limit=True)


def analyse_data(version: str, typ: str, dataset, is_limit=False):
analysis = defaultdict(int)
jdk_sdk_analysis = defaultdict(int)
for item in tqdm(dataset[typ]):
libs = item["libraries"]
for lib in libs:
if any(lib.startswith(_) for _ in ("java", "javax", "android", "androidx")):
jdk_sdk_analysis[lib] += 1
continue
analysis[lib] += 1
analysis = sorted(analysis.items(), key=lambda x: x[1], reverse=True)
jdk_sdk_analysis = sorted(jdk_sdk_analysis.items(), key=lambda x: x[1], reverse=True)
analysis += jdk_sdk_analysis
if is_limit:
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "wb") as file:
pickle.dump({lib: count for lib, count in analysis}, file)
with open(PathUtil.datasets(f"{version}/{typ}-github-code-java-libs.txt"), "w") as file:
for lib, count in analysis:
file.write(lib + ", " + str(count) + "\n")


def check_data(args, console_only: bool = False, do_analyse: bool = False):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{args.version}/{args.filename}"))
print(dataset)
if do_analyse:
# 数据集分析,三方库频次统计
analyse_data(args.version, DATA_TYPE_TRAIN, dataset)
analyse_data(args.version, DATA_TYPE_VALID, dataset)
analyse_data(args.version, DATA_TYPE_TEST, dataset)
for i in range(args.check_size):
if console_only:
continue
logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["comment"])
logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["decoded_labels"])
logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["decoded_preds"])


def split_data(version: str, ration: float = 0.02, test_size: int = None):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"]
# dataset = dataset.train_test_split(test_size=100)["test"]
if test_size is None:
test_size = len(dataset) * ration
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f:
count_lib = pickle.load(f)
validation_dataset, train_dataset, test_dataset = [], [], []
lib_count_4_validation = {_: 0 for _ in count_lib.copy().keys()}
lib_count_4_test = lib_count_4_validation.copy()
nl_set_4_validation, nl_set_4_test = set(), set()
for row in tqdm(dataset):
lib_size = len(row["libraries"])
# 按库划分数据集
is_append_validation = False
for lib in row["libraries"]:
if lib_count_4_validation[lib] >= count_lib[lib] * ration:
break
# 按优先级过滤JDK&SDK
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
continue
lib_count_4_validation[lib] += 1
is_append_validation = True
if is_append_validation:
validation_dataset.append(row)
continue
is_append_test = False
for lib in row["libraries"]:
if lib_count_4_test[lib] >= count_lib[lib] * ration:
break
# 按优先级过滤JDK&SDK
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
continue
lib_count_4_test[lib] += 1
is_append_test = True
if is_append_test:
test_dataset.append(row)
continue
# 同NL采集
if row["comment"] in nl_set_4_validation or row["comment"] in nl_set_4_test:
logger.info(row["comment"])
continue
train_dataset.append(row)
dataset = DatasetDict()
dataset["train"] = Dataset.from_list(train_dataset)
dataset["validation"] = Dataset.from_list(validation_dataset)
dataset["test"] = Dataset.from_list(test_dataset)
dataset.save_to_disk(PathUtil.datasets(f"{version}/github-code-java-libs"))


def postprocess_data(args, saved_version: str, input_r: str, label_r: str):
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{args.version}/{args.filename}"))
tokenized_dataset = dataset.map(
lambda x: DataProcessUtil.preprocess_function_with_connect(x, input_r=input_r, label_r=label_r), batched=True
)
tokenized_dataset.save_to_disk(PathUtil.datasets(f"{saved_version}/github-code-java-libs"))

def add_args(parser):
parser.add_argument(
"--task", type=str, required=True, choices=["convert", "process", "filter", "split", "check", "postprocess"]
)
parser.add_argument('--version', type=str, help="The version of datasets.")
parser.add_argument('--filename', type=str, help="The filename of dataset.")
parser.add_argument('--check_size', type=int, help="Size of checking.")
parser.add_argument('--input', default=None, type=str, help="Type of input.")
parser.add_argument('--label', default=None, type=str, help="Type of label.")

parser.add_argument('--upper', default=None, type=int, help="Max of the lib count in test split.")
parser.add_argument('--test_size', default=4000, type=int, help="Size of the test split.")
parser.add_argument('--saved_version', default=None, type=str, help="The version of postprocess datasets.")
args = parser.parse_args()
return args

if __name__ == "__main__":
versions = {
"l0": ("nl", "code"),
"l1": ("nl+libs", "code"),
"l2": ("nl+libs+codeRet", "code"),
"l3": ("nl+libs+importsGen", "code"),
"l4": ("nl+libs+importsGen+codeRet", "code"),
"l5": ("nl+libs", "imports"),
"l6": ("nl+libs+importsRet", "imports")
}

parser = argparse.ArgumentParser()
args = add_args(parser)
logger.info(args)

if args.task == "convert":
# step 1
convert_data(version=args.version)
elif args.task == "process":
# step 2
process_data(version=args.version, input_r=args.input, label_r=args.label)
elif args.task == "filter":
# step 3
filter_with_token_length(version=args.version)
elif args.task == "filter" and args.upper != None:
# step 4
filter_with_upper(version=args.version, upper=args.upper)
elif args.task == "split":
# step 5
split_data(version=args.version, test_size=args.test_size)
elif args.task == "check":
# step 6
check_data(args=args)
elif args.task == "postprocess":
# other step
input_label = versions[args.saved_version[:2]]
postprocess_data(args, saved_version=args.saved_version, input_r=input_label[0], label_r=input_label[1])
Loading

0 comments on commit 314beff

Please sign in to comment.