DistMult

class KGE.models.semantic_based.DistMult.DistMult[source]

Bases: KGE.models.base_model.SemanticModel.SemanticModel

An implementation of DistMult from [yang 2014].

DistMult simplified RESCAL by restricting \(\textbf{R}_i\) to diagonal matrix.

The score of \((h,r,t)\) is defined by a bilinear function:

\[f(h,r,t) = \textbf{e}_h^{T} \textbf{R}_{r} \textbf{e}_t = \textbf{e}_h^{T} diag(\textbf{R}_{r}) \textbf{e}_t = \sum_i (\textbf{e}_h^{T})_i diag(\textbf{R}_{r})_i (\textbf{e}_t)_i\]

where \(\textbf{e}_i \in \mathbb{R}^k\) are vector representations of the entities, and \(\textbf{R}_i \in \mathbb{R}^{k \times k}\) is a diagonal matrix associated with the relation.

If constraint=True given in __init__(), renormalized \(\left\| \textbf{e}_i \right\|_2 = 1\) to have unit length every iteration and conduct L2-regularization on \(\textbf{R}\) described in original DistMult paper:

\[regularization~term = \lambda \times \sum_{i} {\left\| \textbf{R}_i \right\|}_F^2 = \lambda \times \sum_i {\left\| diag(\textbf{R}_i) \right\|}_2^2\]

Methods Summary

evaluate(eval_X, corrupt_side[, positive_X])

Evaluate triplets.

get_rank(x, positive_X, corrupt_side)

Get rank for specific one triplet.

restore_model_weights(model_weights)

Restore the model weights.

score_hrt(h, r, t)

Score the triplets \((h,r,t)\).

train(train_X, val_X, metadata, epochs, ...)

Train the Knowledge Graph Embedding Model.

Methods Documentation

evaluate(eval_X, corrupt_side, positive_X=None)

Evaluate triplets.

Parameters
  • eval_X (tf.Tensor or np.array) – triplets to be evaluated

  • corrupt_side (str) – corrupt triplets from which side, can be 'h' and 't'

  • positive_X (tf.Tensor or np.array, optional) – positive triplets that should be filtered while generating corrupted triplets, by default None (no filter applied)

Returns

evaluation result

Return type

dict

get_rank(x, positive_X, corrupt_side)

Get rank for specific one triplet.

Parameters
  • x (tf.Tensor or np.array) – rank this triplet

  • positive_X (tf.Tensor or np.array, optional) – positive triplets that should bt filtered while generating corrupted triplets, if None, no filter applied

  • corrupt_side (str) – corrupt triplets from which side, can be 'h' and 't'

Returns

ranking result

Return type

int

restore_model_weights(model_weights)

Restore the model weights.

Parameters

model_weights (dict) – dictionary of model weights to be restored

score_hrt(h, r, t)[source]

Score the triplets \((h,r,t)\).

If h is None, score all entities: \((h_i, r, t)\).

If t is None, score all entities: \((h, r, t_i)\).

h and t should not be None simultaneously.

Parameters
  • h (tf.Tensor or np.ndarray or None) – index of heads with shape (n,)

  • r (tf.Tensor or np.ndarray) – index of relations with shape (n,)

  • t (tf.Tensor or np.ndarray or None) – index of tails with shape (n,)

Returns

triplets scores with shape (n,)

Return type

tf.Tensor

train(train_X, val_X, metadata, epochs, batch_size, early_stopping_rounds=None, model_weights_initial=None, restore_best_weight=True, optimizer='Adam', seed=None, log_path='./logs', log_projector=False)

Train the Knowledge Graph Embedding Model.

Parameters
  • train_X (np.ndarray or str) –

    training triplets.

    If np.ndarray, shape should be (n,3) for \((h,r,t)\) respectively.

    If str, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for \((h,r,t)\) respectively.

  • val_X (np.ndarray or str) –

    validation triplets.

    If np.ndarray, shape should be (n,3) for \((h,r,t)\) respectively.

    If str, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for \((h,r,t)\) respectively.

  • metadata (dict) –

    metadata for kg data. should have following keys:

    'ent2ind': dict, dictionay that mapping entity to index.

    'ind2ent': list, list that mapping index to entity.

    'rel2ind': dict, dictionay that mapping relation to index.

    'ind2rel': list, list that mapping index to relation.

    can use KGE.data_utils.index_kg to index and get metadata.

  • epochs (int) – number of epochs

  • batch_size (int) – batch_size

  • early_stopping_rounds (int, optional) – number of rounds that trigger early stopping, by default None (no early stopping)

  • model_weights_initial (dict, optional) – initial model wieghts with specific value, by default None

  • restore_best_weight (bool, optional) – restore weight to the best iteration if early stopping rounds is not None, by default True

  • optimizer (str or tensorflow.keras.optimizers, optional) – optimizer that apply in training, by default 'Adam', use the default setting of tf.keras.optimizers.Adam

  • seed (int, optional) – random seed for shuffling data & embedding initialzation, by default None

  • log_path (str, optional) – path for tensorboard logging, by default “./logs”

  • log_projector (bool, optional) – project the embbedings in the tensorboard projector tab, setting this True will write the metadata and embedding tsv files in log_path and project this data on tensorboard projector tab, by default False

__init__(embedding_params, negative_ratio, corrupt_side, loss_fn=<KGE.loss.PairwiseHingeLoss object>, ns_strategy=<class 'KGE.ns_strategy.UniformStrategy'>, constraint=True, constraint_weight=1.0, n_workers=1)[source]

Initialized DistMult

Parameters
  • embedding_params (dict) – embedding dimension parameters, should have key 'embedding_size' for embedding dimension \(k\)

  • negative_ratio (int) – number of negative sample

  • corrupt_side (str) – corrupt from which side while trainging, can be 'h', 't', or 'h+t'

  • loss_fn (class, optional) – loss function class KGE.loss.Loss, by default KGE.loss.PairwiseHingeLoss

  • ns_strategy (function, optional) – negative sampling strategy, by default KGE.ns_strategy.uniform_strategy()

  • constraint (bool, optional) – conduct constraint or not, by default True

  • constraint_weight (float, optional) – regularization weight \(\lambda\), by default 1.0

  • n_workers (int, optional) – number of workers for negative sampling, by default 1

__new__(*args, **kwargs)