A collection of Tensorflow implementations of embeddings for entities.
- Python 3
- Tensorflow >= 1.2
- Hyperopt
The generic abstract model is defined in model.py. All specific models are implemented in efe.py
Model | Implementations | Reference |
---|---|---|
TransE | TransE_L2; TransE_L1 | Bordes et al. (NIPS 2013) |
NTN | NTN | Socher et al. (NIPS 2013) |
DistMult | DistMult; DistMult_tanh | Yang et al. (ICLR 2015) |
ComplEx | Complex; Complex_tanh | Trouillon et al. (ICML 2016) |
python preprocess.py -d [data_name]
Add hyperparameters dict and its identifier in model_param_space.py.
python task.py -m [model_name] -d [data_name] -e [max_evals] -c [cv_runs]
model_name is the identifier defined in the model_param_space.py. data_name is either wn18 or fb15k. max_evals is the maximum runs to search the hyperparameters, default: 100. cv_runs is the number of runs for the cross validation, default: 3.
The search process and result are stored in log
folder.
python train.py -m [model_name] -d [data_name]
Train on the given hyperparameter setting and give the result for the test set.
Model | WN18 | FB15K | ||||||
---|---|---|---|---|---|---|---|---|
Filtered MRR | Hits@1 | Hits@3 | Hits@10 | Filtered MRR | Hits@1 | Hits@3 | Hits@10 | |
TransE | 0.454 | 0.089 | 0.814 | 0.954 | 0.407 | 0.272 | 0.480 | 0.657 |
DistMult | 0.868 | 0.786 | 0.948 | 0.970 | 0.761 | 0.691 | 0.815 | 0.875 |
ComplEx | 0.971 | 0.969 | 0.973 | 0.974 | 0.768 | 0.676 | 0.843 | 0.908 |
MIT