KGEModel
- class KGE.models.base_model.BaseModel.KGEModel[source]
Bases:
objectA base module for Knowledge Graph Embedding Model.
Subclass of
KGEModelcan have thier own translation and interation model.Methods Summary
evaluate(eval_X, corrupt_side[, positive_X])Evaluate triplets.
get_rank(x, positive_X, corrupt_side)Get rank for specific one triplet.
restore_model_weights(model_weights)Restore the model weights.
score_hrt(h, r, t)Scoring the triplets.
train(train_X, val_X, metadata, epochs, ...)Train the Knowledge Graph Embedding Model.
Methods Documentation
- evaluate(eval_X, corrupt_side, positive_X=None)[source]
Evaluate triplets.
- Parameters
eval_X (
tf.Tensorornp.array) – triplets to be evaluatedcorrupt_side (
str) – corrupt triplets from which side, can be'h'and't'positive_X (
tf.Tensorornp.array, optional) – positive triplets that should be filtered while generating corrupted triplets, by default None (no filter applied)
- Returns
evaluation result
- Return type
dict
- get_rank(x, positive_X, corrupt_side)[source]
Get rank for specific one triplet.
- Parameters
x (
tf.Tensorornp.array) – rank this tripletpositive_X (
tf.Tensorornp.array, optional) – positive triplets that should bt filtered while generating corrupted triplets, ifNone, no filter appliedcorrupt_side (
str) – corrupt triplets from which side, can be'h'and't'
- Returns
ranking result
- Return type
int
- restore_model_weights(model_weights)[source]
Restore the model weights.
- Parameters
model_weights (
dict) – dictionary of model weights to be restored
- score_hrt(h, r, t)[source]
Scoring the triplets.
Should be implemented in subclass for their own scoring function.
- Raises
NotImplementedError – subclass doesnt not implement score_hrt().
- train(train_X, val_X, metadata, epochs, batch_size, early_stopping_rounds=None, model_weights_initial=None, restore_best_weight=True, optimizer='Adam', seed=None, log_path='./logs', log_projector=False)[source]
Train the Knowledge Graph Embedding Model.
- Parameters
train_X (
np.ndarrayorstr) –training triplets.
If
np.ndarray, shape should be(n,3)for \((h,r,t)\) respectively.If
str, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for \((h,r,t)\) respectively.val_X (
np.ndarrayorstr) –validation triplets.
If
np.ndarray, shape should be(n,3)for \((h,r,t)\) respectively.If
str, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for \((h,r,t)\) respectively.metadata (
dict) –metadata for kg data. should have following keys:
'ent2ind': dict, dictionay that mapping entity to index.'ind2ent': list, list that mapping index to entity.'rel2ind': dict, dictionay that mapping relation to index.'ind2rel': list, list that mapping index to relation.can use KGE.data_utils.index_kg to index and get metadata.
epochs (
int) – number of epochsbatch_size (
int) – batch_sizeearly_stopping_rounds (
int, optional) – number of rounds that trigger early stopping, by default None (no early stopping)model_weights_initial (
dict, optional) – initial model wieghts with specific value, by default Nonerestore_best_weight (
bool, optional) – restore weight to the best iteration if early stopping rounds is not None, by default Trueoptimizer (
strortensorflow.keras.optimizers, optional) – optimizer that apply in training, by default'Adam', use the default setting of tf.keras.optimizers.Adamseed (
int, optional) – random seed for shuffling data & embedding initialzation, by default Nonelog_path (
str, optional) – path for tensorboard logging, by default “./logs”log_projector (
bool, optional) – project the embbedings in the tensorboard projector tab, setting this True will write the metadata and embedding tsv files inlog_pathand project this data on tensorboard projector tab, by default False
- __init__(embedding_params, negative_ratio, corrupt_side, loss_fn, ns_strategy, n_workers)[source]
Initialize KGEModel.
- Parameters
embedding_params (
dict) – embedding dimension parametersnegative_ratio (
int) – number of negative samplecorrupt_side (
str) – corrupt from which side while trainging, can be'h','t', or'h+t'loss_fn (
class) – loss function classKGE.loss.Lossns_strategy (
function) – negative sampling strategyn_workers (
int) – number of workers for negative sampling
- __new__(*args, **kwargs)