Source code for KGE.models.base_model.BaseModel

"""
base module for Knowledge Graph Embedding Model
"""

import os
import datetime
import logging
import numpy as np
import tensorflow as tf
import multiprocessing as mp
from tqdm import tqdm, trange
from tensorboard.plugins import projector
from KGE.ns_strategy import TypedStrategy, UniformStrategy
from KGE.data_utils import calculate_data_size, set_tf_iterator
from KGE.metrics import mean_reciprocal_rank, mean_rank, hits_at_k, median_rank, geometric_mean_rank, harmonic_mean_rank, std_rank

logging.getLogger().setLevel(logging.INFO)

gpus = tf.config.list_physical_devices("GPU")
if len(gpus) > 0:
    tf.config.experimental.set_memory_growth(gpus[0], True)

[docs]class KGEModel: """A base module for Knowledge Graph Embedding Model. Subclass of :class:`KGEModel` can have thier own translation and interation model. """
[docs] def __init__(self, embedding_params, negative_ratio, corrupt_side, loss_fn, ns_strategy, n_workers): """Initialize KGEModel. Parameters ---------- embedding_params : dict embedding dimension parameters negative_ratio : int number of negative sample corrupt_side : str corrupt from which side while trainging, can be :code:`'h'`, :code:`'t'`, or :code:`'h+t'` loss_fn : class loss function class :py:mod:`KGE.loss.Loss` ns_strategy : function negative sampling strategy n_workers : int number of workers for negative sampling """ assert corrupt_side in ['h+t', 'h', 't'], "Invalid corrupt_side, valid options: 'h+t', 'h', 't'" self.embedding_params = embedding_params self.negative_ratio = negative_ratio self.corrupt_side = corrupt_side self.loss_fn = loss_fn self.ns_strategy = ns_strategy self.__n_workers = n_workers
[docs] def train(self, train_X, val_X, metadata, epochs, batch_size, early_stopping_rounds=None, model_weights_initial=None, restore_best_weight=True, optimizer="Adam", seed=None, log_path="./logs", log_projector=False): """Train the Knowledge Graph Embedding Model. Parameters ---------- train_X : np.ndarray or str training triplets. \n If :code:`np.ndarray`, shape should be :code:`(n,3)` for :math:`(h,r,t)` respectively. \n If :code:`str`, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for :math:`(h,r,t)` respectively. val_X : np.ndarray or str validation triplets. \n If :code:`np.ndarray`, shape should be :code:`(n,3)` for :math:`(h,r,t)` respectively. \n If :code:`str`, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for :math:`(h,r,t)` respectively. metadata : dict metadata for kg data. should have following keys: \n :code:`'ent2ind'`: dict, dictionay that mapping entity to index. \n :code:`'ind2ent'`: list, list that mapping index to entity. \n :code:`'rel2ind'`: dict, dictionay that mapping relation to index. \n :code:`'ind2rel'`: list, list that mapping index to relation. \n can use KGE.data_utils.index_kg to index and get metadata. epochs : int number of epochs batch_size : int batch_size early_stopping_rounds : int, optional number of rounds that trigger early stopping, by default None (no early stopping) model_weights_initial : dict, optional initial model wieghts with specific value, by default None restore_best_weight : bool, optional restore weight to the best iteration if early stopping rounds is not None, by default True optimizer : str or tensorflow.keras.optimizers, optional optimizer that apply in training, by default :code:`'Adam'`, use the default setting of `tf.keras.optimizers.Adam <https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam>`_ seed : int, optional random seed for shuffling data & embedding initialzation, by default None log_path : str, optional path for tensorboard logging, by default "./logs" log_projector : bool, optional project the embbedings in the tensorboard projector tab, setting this True will write the metadata and embedding tsv files in :code:`log_path` and project this data on tensorboard projector tab, by default False """ self.metadata = metadata self.batch_size = batch_size self._model_weights_initial = model_weights_initial self.__optimizer = optimizer self.seed = seed self.log_path = log_path # Create a Summary Writer to log the metrics to Tensorboard summary_writer = tf.summary.create_file_writer(log_path) train_logger = tf.summary.create_file_writer(log_path + '/scalar/train') if val_X is not None: val_logger = tf.summary.create_file_writer(log_path + '/scalar/validation') logging.info("[%s] Preparing for training..." % str(datetime.datetime.now())) train_iter, val_iter = self.__prepare_for_train(train_X=train_X, val_X=val_X) train_loss_history = [] val_loss_history = [] patience_count = 0 # Start Training logging.info("[%s] Start Training..." % str(datetime.datetime.now())) epoch_bar = trange(epochs, desc="Epoch", leave=True) for i in epoch_bar: train_loss = 0 val_loss = 0 batch_bar = trange(self.__batch_count_train, desc=" Batch", leave=False) for b in batch_bar: train_batch_X = next(train_iter) train_batch_loss = self.__run_single_batch(batch_data=train_batch_X, is_train=True) train_loss += train_batch_loss if val_iter is not None: if b < self.__batch_count_val: val_batch_X = next(val_iter) val_batch_loss = self.__run_single_batch(batch_data=val_batch_X, is_train=False) val_loss += val_batch_loss train_loss /= self.__batch_count_train val_loss /= self.__batch_count_val train_loss_history = self.__append_history_and_log( loss = train_loss, loss_history=train_loss_history, summary_writer=train_logger, step=i ) if val_X is not None: val_loss_history = self.__append_history_and_log( loss = val_loss, loss_history=val_loss_history, summary_writer=val_logger, step=i ) epoch_bar.set_description("epoch: %i, train loss: %f, valid loss: %f" % (i, train_loss_history[i], val_loss_history[i])) else: epoch_bar.set_description("epoch: %i, train loss: %f" % (i, train_loss_history[i])) epoch_bar.refresh() self.__log_embeddings_histogram(summary_writer=summary_writer, step=i) if early_stopping_rounds is not None: assert val_X is not None, "val_X should be given if want to check early stopping." early_stop, patience_count = self.__check_early_stopping( metric_history=val_loss_history, magnitude="larger", patience_now=patience_count, patience_max=early_stopping_rounds, step=i, restore_best_weight=restore_best_weight ) if early_stop: logging.info("[%s] Val loss does not improve within %i iterations, trigger early stopping." % (str(datetime.datetime.now()), early_stop)) if restore_best_weight: logging.info("[%s] Restore best weight from %i to %i step." % (str(datetime.datetime.now()), i, self.best_step)) break else: self.ckpt_manager.save() if log_projector: logging.info("[%s] Logging final embeddings into tensorboard projector..." % str(datetime.datetime.now())) self.__log_embeddings_projector(log_path=log_path) logging.info("[%s] Finished training!" % str(datetime.datetime.now()))
# if hasattr(self.ns_strategy, "pool"): # if self.ns_strategy.pool is not None: # self.ns_strategy.pool.close() # self.ns_strategy.pool.join() def __prepare_for_train(self, train_X, val_X): """Prepartion before training. Do the following steps: - calculate number of batch - set tensorflow dataset iterator - initilized embedding, optimizer & checkpoint manager - create "type2inds" metadata if using typed_strategy - create pool for multiprocessing negative sampling if n_workers > 1 Parameters ---------- train_X : np.ndarray or str training triplets. If `np.ndarray`, shape should be (n,3) for (h,r,t) respectively. If `str`, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for (h,r,t) respectively. val_X : np.ndarray or str validation triplets. If `np.ndarray`, shape should be (n,3) for (h,r,t) respectively. If `str`, training triplets should be save under this folder path with csv format, every csv files should have 3 columns without header for (h,r,t) respectively. Returns ------- iterator, iterator training and validation data iterator. """ # calculate number of batch logging.info("[%s] - Calculating number of batch..." % str(datetime.datetime.now())) n_train = calculate_data_size(train_X) self.__batch_count_train = int(np.ceil(n_train / self.batch_size)) if val_X is not None: n_val = calculate_data_size(val_X) self.__batch_count_val = int(np.ceil(n_val / self.batch_size)) # set tensorflow dataset iterator logging.info("[%s] - Setting data iterator..." % str(datetime.datetime.now())) train_iter = set_tf_iterator(data=train_X, batch_size=self.batch_size, shuffle=True, buffer_size=n_train, seed=self.seed) if val_X is not None: val_iter = set_tf_iterator(data=val_X, batch_size=self.batch_size, shuffle=False, buffer_size=None, seed=None) else: val_iter = None # initilized embedding, optimizer & checkpoint manager logging.info("[%s] - Initialized embedding..." % str(datetime.datetime.now())) self._init_embeddings(seed=self.seed) logging.info("[%s] - Initialized optimizer..." % str(datetime.datetime.now())) if self.__optimizer == "Adam": self.__optimizer = tf.optimizers.Adam() else: self.__optimizer = self.__optimizer logging.info("[%s] - Initialized checkpoint manager..." % str(datetime.datetime.now())) self.best_ckpt = tf.train.Checkpoint() self.best_ckpt.listed = [] for w in list(self.model_weights.values()): self.best_ckpt.listed.append(w) self.best_ckpt.mapped = {k: v for k, v in zip(list(self.model_weights.keys()), list(self.model_weights.values()))} self.ckpt_manager = tf.train.CheckpointManager(checkpoint=self.best_ckpt, directory=self.log_path, max_to_keep=1) # Create type2inds metadata if using typed strategy if self.ns_strategy == UniformStrategy: self.ns_strategy = UniformStrategy(sample_pool=tf.range(len(self.metadata["ind2ent"]))) elif self.ns_strategy == TypedStrategy: self.metadata["type2inds"] = {} all_type = np.unique(self.metadata["ind2type"]) for t in all_type: indices = [i for (i, ti) in enumerate(self.metadata["ind2type"]) if ti == t] self.metadata["type2inds"][t] = np.array(indices) # Create pool for multiprocessing negative sampling if self.__n_workers > 1: pool = mp.Pool(self.__n_workers) else: pool = None self.ns_strategy = TypedStrategy( pool=pool, metadata={ "type2inds": self.metadata["type2inds"], "ind2type": self.metadata["ind2type"] } ) return train_iter, val_iter def _init_embeddings(self): """Initialized embeddings. Should be implemented in subclass for their own embedding parameters. Raises ------ NotImplementedError subclass doesnt not implement _init_embeddings(). """ raise NotImplementedError("subclass of KGEModel should implement _init_embeddings()") def __run_single_batch(self, batch_data, is_train): """Run training procedure on one single batch. whole training procedure: - perform negative sampling - calculate contraint term or contraint embedding - calculate positive & negative score - calculate loss - backpropagation & update model weights (if is_train) Parameters ---------- batch_data : tf.Tensor batch data to be processed, shape should be (n,3) is_train : bool whether to calculate gradients and update model weights Returns ------- float loss of this batch. """ neg_triplet = self.__negative_sampling(batch_data, strategy=self.ns_strategy) with tf.GradientTape() as g: constraint_term = self._constraint_loss(batch_data) pos_score = self.score_hrt(batch_data[:, 0], batch_data[:, 1], batch_data[:, 2]) neg_score = self.score_hrt(neg_triplet[:, 0], neg_triplet[:, 1], neg_triplet[:, 2]) batch_loss = self.loss_fn(pos_score, neg_score) batch_loss += constraint_term if is_train: gradients = g.gradient(batch_loss, list(self.model_weights.values())) gradients = [tf.clip_by_norm(grad, clip_norm=5.0) for grad in gradients] self.__optimizer.apply_gradients(zip(gradients, list(self.model_weights.values()))) return batch_loss.numpy() def __negative_sampling(self, X, strategy): """Perfoem negative sampling Parameters ---------- X : tf.Tensor triplets to be corrupt with shape (n,3) strategy : function negative sampling strategy function in KGE.ns_strategy Returns ------- tf.Tensor corrupted triplets with shaep (n*self.negative_ratio, 3) """ # combine hrt: if self.corrupt_side == 'h': neg_triplet = self.__corrupt_h(X, self.negative_ratio, strategy) elif self.corrupt_side == "t": neg_triplet = self.__corrupt_t(X, self.negative_ratio, strategy) elif self.corrupt_side == "h+t": neg_triplet_h = self.__corrupt_h(X, self.negative_ratio // 2, strategy) neg_triplet_t = self.__corrupt_t(X, self.negative_ratio // 2, strategy) neg_triplet = tf.reshape(tf.concat([neg_triplet_h, neg_triplet_t], axis=-1), [-1, 3]) return neg_triplet def __corrupt_h(self, X, negative_ratio, strategy): """Corrupt triplets from head side Parameters ---------- X : tf.Tensor triplets to be corrupt with shape (n,3) negative_ratio : int number of negative triplets to be generated for each triplet strategy : function negative sampling strategy function in KGE.ns_strategy Returns ------- tf.Tensor corrupted triplets with shaep (n*negative_ratio, 3) """ sample_entities = strategy(X, negative_ratio=negative_ratio, side="h") h = sample_entities r = tf.repeat(X[:, 1], negative_ratio) t = tf.repeat(X[:, 2], negative_ratio) return tf.stack([h, r, t], axis = 1) def __corrupt_t(self, X, negative_ratio, strategy): """Corrupt triplets from tail side Parameters ---------- X : tf.Tensor triplets to be corrupt with shape (n,3) negative_ratio : int number of negative triplets to be generated for each triplet strategy : function negative sampling strategy function in KGE.ns_strategy Returns ------- tf.Tensor corrupted triplets with shaep (n*negative_ratio, 3) """ sample_entities = strategy(X, negative_ratio=negative_ratio, side="t") h = tf.repeat(X[:, 0], negative_ratio) r = tf.repeat(X[:, 1], negative_ratio) t = sample_entities return tf.stack([h, r, t], axis = 1)
[docs] def score_hrt(self, h, r, t): """Scoring the triplets. Should be implemented in subclass for their own scoring function. Raises ------ NotImplementedError subclass doesnt not implement score_hrt(). """ assert ~(h is None and t is None), "h and t should not be None simultaneously" if h is None: assert len(r.shape) == 0 assert len(t.shape) == 0 h = np.arange(len(self.metadata["ind2ent"])) if t is None: assert len(h.shape) == 0 assert len(r.shape) == 0 t = np.arange(len(self.metadata["ind2ent"])) return h, r, t
def _constraint_loss(self): """Perform penalty on loss or constraint on model weights. Should be implemented in subclass for their own constraint. Raises ------ NotImplementedError subclass doesnt not implement _constraint_loss(). """ raise NotImplementedError("subclass of KGEModel should implement _constraint_loss()") def __append_history_and_log(self, loss, loss_history, summary_writer, step): """Append current loss history and log into tensorboard. Parameters ---------- loss : float current loss should be appended loss_history loss_history : list of flost loss history to be appended summary_writer : tensorflow.python.ops.summary_ops_v2.ResourceSummaryWriter tensorboard summary writer step : int current step Returns ------- list of float appended loss history """ loss_history.append(loss) with summary_writer.as_default(): tf.summary.scalar("loss", loss, step=step) return loss_history def __log_embeddings_histogram(self, summary_writer, step): """Logging all model weights to tensorboard histogram. Parameters ---------- summary_writer : tensorflow.python.ops.summary_ops_v2.ResourceSummaryWriter tensorboard summary writer step : int current step """ with summary_writer.as_default(): for w in list(self.model_weights.keys()): tf.summary.histogram(w, self.model_weights[w], step=step) def __check_early_stopping(self, metric_history, magnitude, patience_now, patience_max, step, restore_best_weight=True): """Check early stopping and restore the weights to the best step if early stopping criteria is match. Parameters ---------- metric_history : list of float metric history to be check for early stopping. magnitude : str overfitting metric magnitude, can be 'larger' or 'smaller' for example, if metric is loss, the loss becomes larger and larger when overfitting occur. patience_now : int how many times that metric does not improve. patience_max : int maximum patience that metrics does not improve. when patience_now == patience_max, trigger early stopping. step : int current step restore_best_weight : bool, optional whether to restore model weights to the best step, by default True Returns ------- bool, int whther trigger early stopping, updated patience_now """ if step == 0: self.ckpt_manager.save() self.best_step = step return False, patience_now assert magnitude in ["larger", "smaller"], "magnitude must be 'larger' or 'smaller'" if self.best_step is None: self.best_step = step if magnitude == "larger": flag = metric_history[step] >= metric_history[self.best_step] elif magnitude == "smaller": flag = metric_history[step] <= metric_history[self.best_step] if flag: patience_now += 1 else: patience_now = 0 self.best_step = step self.ckpt_manager.save() if patience_now == patience_max: if restore_best_weight: self.best_ckpt.restore(self.ckpt_manager.latest_checkpoint) return True, patience_now return False, patience_now def __log_embeddings_projector(self, log_path): """Log embedding to TensorBoard projector tab. Parameters ---------- log_path : str path for tensorboard logging """ def write_metadata_file(path, obj): with open(path, "w") as f: for x in obj: f.write("{}\n".format(x)) write_metadata_file(path=os.path.join(log_path, "ent_metadata.tsv"), obj=self.metadata["ind2ent"]) if self.model_weights.get("rel_emb") is not None: write_metadata_file(path=os.path.join(log_path, "rel_metadata.tsv"), obj=self.metadata["ind2rel"]) ckpt = tf.train.Checkpoint(ent_emb=self.model_weights["ent_emb"], rel_emb=self.model_weights["rel_emb"]) else: ckpt = tf.train.Checkpoint(ent_emb=self.model_weights["ent_emb"]) ckpt.save(os.path.join(log_path, "embedding.ckpt")) config = projector.ProjectorConfig() ent_embedding = config.embeddings.add() ent_embedding.tensor_name = "ent_emb/.ATTRIBUTES/VARIABLE_VALUE" ent_embedding.metadata_path = "ent_metadata.tsv" if self.model_weights.get("rel_emb") is not None: rel_embedding = config.embeddings.add() rel_embedding.tensor_name = "rel_emb/.ATTRIBUTES/VARIABLE_VALUE" rel_embedding.metadata_path = "rel_metadata.tsv" projector.visualize_embeddings(log_path, config)
[docs] def evaluate(self, eval_X, corrupt_side, positive_X=None): """Evaluate triplets. Parameters ---------- eval_X : tf.Tensor or np.array triplets to be evaluated corrupt_side : str corrupt triplets from which side, can be :code:`'h'` and :code:`'t'` positive_X : tf.Tensor or np.array, optional positive triplets that should be filtered while generating corrupted triplets, by default None (no filter applied) Returns ------- dict evaluation result """ n_eval = calculate_data_size(eval_X) eval_iter = set_tf_iterator(data=eval_X, batch_size=1, shuffle=False) ranks = [] for _ in tqdm(range(n_eval)): eval_x = next(eval_iter) ranks.append(self.get_rank(eval_x, positive_X, corrupt_side)) eval_result = { "mean_rank": mean_rank(ranks), "mean_reciprocal_rank": mean_reciprocal_rank(ranks), "median_rank": median_rank(ranks), "geometric_mean_rank": geometric_mean_rank(ranks), "harmonic_mean_rank": harmonic_mean_rank(ranks), "std_rank": std_rank(ranks), "hit@1": hits_at_k(ranks, k=1), "hit@3": hits_at_k(ranks, k=3), "hit@10": hits_at_k(ranks, k=10) } return eval_result
[docs] def get_rank(self, x, positive_X, corrupt_side): """Get rank for specific one triplet. Parameters ---------- x : tf.Tensor or np.array rank this triplet positive_X : tf.Tensor or np.array, optional positive triplets that should bt filtered while generating corrupted triplets, if :code:`None`, no filter applied corrupt_side : str corrupt triplets from which side, can be :code:`'h'` and :code:`'t'` Returns ------- int ranking result """ x = tf.squeeze(x) if corrupt_side == "h": filter_side, corrupt_side = 2, 0 scores = self.score_hrt(h=None, r=x[1], t=x[2]) elif corrupt_side == "t": filter_side, corrupt_side = 0, 2 scores = self.score_hrt(h=x[0], r=x[1], t=None) if positive_X is not None: r_mask = positive_X[:, 1] == x[1] e_mask = positive_X[:, filter_side] == x[filter_side] positive_e = positive_X[r_mask & e_mask, corrupt_side] scores = tf.tensor_scatter_nd_update(scores, tf.expand_dims(positive_e, -1), [-np.inf] * len(positive_e)) pos_score = self.score_hrt(x[0], x[1], x[2]) return tf.reduce_sum(tf.cast(scores > pos_score, tf.int16)).numpy() + 1
[docs] def restore_model_weights(self, model_weights): """Restore the model weights. Parameters ---------- model_weights : dict dictionary of model weights to be restored """ self._check_model_weights() self.model_weights = model_weights
def _check_model_weights(self): '''Check modle weights have required keys. Should be implemented in subclass for their own key checking. Raises ------ NotImplementedError subclass doesnt not implement _check_model_weights(). ''' raise NotImplementedError("subclass of KGEModel should implement _check_model_weights()")