全面理解PinSage

PinSage & GraphSage

Posted by fzr on April 22, 2020

全面理解PinSage

本文为原创文章,首发地址为[全面理解PinSage - 一块小蛋糕的文章 - 知乎] (https://zhuanlan.zhihu.com/p/133739758) img/2020-04-22/pinsage_content1.png img/2020-04-22/pinsage_content2.png img/2020-04-22/pinsage_content3.png

img/2020-04-22/pinsage_content4.png img/2020-04-22/pinsage_content5.png img/2020-04-22/pinsage_content6.png img/2020-04-22/pinsage_content7.png img/2020-04-22/pinsage_content8.png img/2020-04-22/pinsage_content9.png

import numpy as np
import networkx as nx
import tensorflow as tf
class Convolve(tf.keras.Model):
    def __init__(self, hidden_channels):
        super(Convolve, self).__init__()
        self.Q = tf.keras.layers.Dense(units=hidden_channels, activation=tf.keras.layers.LeakyReLU())
        self.W = tf.keras.layers.Dense(units=hidden_channels, activation=tf.keras.layers.LeakyReLU())
    def call(self, inputs):
        # embedding.shape = (batch, node number, in_channels)
        embeddings = inputs[0]   # 所有节点的Embedding
        # weight.shape = (node number, node number)
        weights = inputs[1]   # 所有边权重
        # neighbor_set.shape = (node number, neighbor number)  ==> (节点数,邻居数)
        neighbor_set = inputs[2]  # 针对每个节点采样的邻居节点id集合
        # neighbor_embeddings.shape = (batch, node number, neighbor number, in channels)
        # 所有邻居节点对应的Embedding
        neighbor_embeddings = tf.keras.layers.Lambda(lambda x, neighbor_set: tf.transpose(tf.gather_nd(tf.transpose(x, (1, 0, 2)),
                                                                                                       tf.expand_dims(neighbor_set, axis=-1)
                                                                                                       ),
                                                                                          (2, 0, 1, 3)
                                                                                          ),
                                                     arguments={'neighbor_set': neighbor_set})(embeddings)
        # neighbor_hiddens.shape = (batch, node number, neighbor number, hidden channels)
        neighbor_hiddens = self.Q(neighbor_embeddings)  # 所有的邻居Embedding经过第一层dense层

        # indices.shape = (node number, neighbor number, 2)
        node_nums = tf.keras.layers.Lambda(lambda x: tf.tile(tf.expand_dims(tf.range(tf.shape(x)[0]), axis=1),
                                                             (1, tf.shape(x)[1])))(neighbor_set)
        indices = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=-1))([node_nums, neighbor_set])  # 所有邻居节点及其对应的目标节点id

        # neighbor weights.shape = (node number, neighbor number)
        neighbor_weights = tf.keras.layers.Lambda(lambda x, indices: tf.gather_nd(x, indices),
                                                  arguments={"indices": indices})(weights)    # 提取所有要计算的邻居边的权重
        # neighbor weights.shape = (1, node number, neighbor number, 1)
        neighbor_weights = tf.keras.layers.Lambda(lambda x: tf.expand_dims(tf.expand_dims(x, 0), -1))(neighbor_weights)

        # weighted_sum_hidden.shape = (batch, node number, hidden channels)  # 对所有节点的邻居节点Embedding,根据其与目标节点的边的权重计算加权和
        weighted_sum_hidden = tf.keras.layers.Lambda(lambda x: tf.math.reduce_sum(x[0] * x[1], axis=2) / (tf.reduce_sum(x[1], axis=2)+1e-6))([neighbor_hiddens, neighbor_weights])

        # concated_hidden.shape = (batch, node number, in channels + hidden channels)  # 节点的原始Embedding与每个节点的邻居加权和Embedding拼接
        concated_hidden = tf.keras.layers.Concatenate(axis=-1)([embeddings, weighted_sum_hidden])

        # hidden_new shape = (batch, node number, hidden_channels)
        hidden_new = self.W(concated_hidden)   # 拼接后的Embedding经过第二层dense层

        # normalized.shape = (batch, node number, hidden_channels)  # 结果Embedding规范化
        normalized = tf.keras.layers.Lambda(lambda x: x / (tf.norm(x, axis=2, keep_dims=True) + 1e-6))(hidden_new)
        return normalized

然后是minibatch算法,graph使用networkx中的graph表示,

class PinSage(tf.keras.Model):
    def __init__(self, hidden_channels, graph=None, edge_weights=None):
        # hidden_channels用于保存每次卷积convolve操作的输出
        assert type(hidden_channels) is list
        if graph is not None: assert type(graph) is nx.classes.graph.Graph  # 原始图
        if edge_weights is not None : assert type(edge_weights) is list   # 边权重
        super(PinSage, self).__init__()
        # 创建卷积层
        self.convs = list()
        for i in range(len(hidden_channels)):
            self.convs.append(Convolve(hidden_channels=[i]))
        # 在原始图上计算PageRank权重
        self.edge_weights = self.pagerank(graph) if graph is not None else edge_weights
    def call(self, inputs):
        # embeddings.shape = (batch, node number, in channels)
        embeddings = inputs[0]   # 所有节点的Embedding
        # 邻居采样个数
        sample_neighbor_num = inputs[1]
        # 根据边的权重对邻居采样
        # neighbor_set.shape = (node num, neighbor num) ==> (节点数,邻居数)
        neighbor_set = tf.random.categorical(self.edge_weights, sample_neighbor_num)  #针对每个节点采样的邻居集合
        for conv in self.convs:   #经过K层卷积
            embeddings = conv([embeddings, self.edge_weights, neighbor_set])
        return embeddings
    def pagerank(self, graph, damp_rate=0.2): 
        # node id must from 0 to any nature number
        node_ids = sorted([id for id in graph.nodes])
        assert node_ids == list(range(len(node_ids)))
        # adjacent matrix
        weights = np.zeros((len(graph.nodes), len(graph.nodes), ), dtype=np.float32)
        for f in graph.nodes:
            for t in list(graph.adj[f]):
                weights[f, t] = 1
        weights = tf.constant(weights)
        # normalize adjacent matrix line by line
        line_sum = tf.math.reduce_sum(weights, axis=1, keep_dims=True)+1e-6
        normalized = weights / line_sum
        # dampping vector
        dampping = tf.ones((len(graph.nodes), ), dtype=tf.float32)
        dampping = dampping / tf.constant(len(graph.nodes), dtype=tf.float32)
        dampping = tf.expand_dims(dampping, 0) # line vector
        # learning pagerank
        v = dampping
        while True:
            v_updated = (1-damp_rate) * tf.linalg.matmul(v, normalized) + damp_rate * dampping
            d = tf.norm(v_updated - v)
            if tf.equal(tf.less(d, 1e-4), True): break
            v = v_updated
        # edge weight is pagerank
        weights = weights * tf.tile(v, (tf.shape(weights)[0], 1))
        line_sum = tf.reduce_sum(weights, axis=1, keepdims=True) + 1e-6
        normalized = weights / line_sum
        return normalized