Upgrade to Pro — share decks privately, control downloads, hide ads and more …

DynaGraph: Dynamic Graph Neural Networks at Scale

DynaGraph: Dynamic Graph Neural Networks at Scale

Anand Iyer

June 12, 2022
Tweet

More Decks by Anand Iyer

Other Decks in Research

Transcript

  1. DynaGraph: Dynamic Graph
    Neural Networks at Scale
    Mingyu Guan*, Anand Iyer▴, Taesoo Kim*
    *Georgia Institute of Technology ▴Microsoft Research
    GRADES-NDA 2022

    View Slide

  2. Graph Neural Networks (GNNs)
    • The recent past has seen an increasing interest in GNNs.
    • Node embeddings are generated by combining graph structure and
    feature information.
    • Most GNN models can fit into the Message Passing Paradigm.
    A
    E
    C D
    B
    GNN
    A
    E
    C D
    B
    Output Features/Embeddings of Each Node
    Initial Features/Embeddings of Each Node

    View Slide

  3. Message Passing Paradigm
    C D
    B
    B C
    Current Neighbor States
    D
    Current Node State ℎ!
    "#$

    View Slide

  4. Message Passing Paradigm
    C D
    B
    B C
    Current Neighbor States
    D
    Current Node State ℎ!
    "#$
    B D
    C D
    Messages from Neighbors

    View Slide

  5. Message Passing Paradigm
    C D
    B
    B C
    Current Neighbor States
    D
    Current Node State ℎ!
    "#$
    B D
    C D
    Messages from Neighbors Aggregate and Reduce
    Received Messages
    𝑚!
    "

    View Slide

  6. Message Passing Paradigm
    C D
    B
    B C
    Current Neighbor States
    D
    Current Node State ℎ!
    "#$
    B D
    C D
    Messages from Neighbors Aggregate and Reduce
    Received Messages
    Update
    D
    Next Node State ℎ!
    "
    𝑚!
    "

    View Slide

  7. Dynamic GNNs
    • Most of existing GNN frameworks assume that the input graph is
    static.
    • Real-world graphs are often dynamic in nature.
    • Representation: a time series of snapshots of the graph.
    • Common approach: Combine GNNs and RNNs.
    oGNNs for encoding spatial information (graph structure)
    oRNNs for encoding temporal information

    View Slide

  8. 𝑊%&
    𝑥'
    𝑊(&
    ℎ'#$
    Gate 𝒊
    𝑊%)
    𝑥'
    𝑊()
    ℎ'#$
    Gate 𝒇
    𝑊
    %*
    𝑥'
    𝑊(*
    ℎ'#$
    Gate 𝒄
    𝑊
    %+
    𝑥'
    𝑊(+
    ℎ'#$
    Gate 𝒐
    + A
    + A
    + A
    + A
    E ℎ'
    𝑊
    %,
    𝑥'
    𝑊(,
    ℎ'#$
    Gate 𝒓
    + A
    𝑊
    %-
    𝑥'
    𝑊(-
    ℎ'#$
    Gate 𝒛
    + A
    𝑊%(
    𝑥'
    𝑊((
    ℎ′
    Gate 𝒉
    + A
    *
    ℎ'
    -1
    E
    ℎ'
    ℎ′
    LSTM GRU

    View Slide

  9. 𝑊%&
    𝑥'
    𝑊(&
    ℎ'#$
    Gate 𝒊
    𝑊%)
    𝑥'
    𝑊()
    ℎ'#$
    Gate 𝒇
    𝑊
    %*
    𝑥'
    𝑊(*
    ℎ'#$
    Gate 𝒄
    𝑊
    %+
    𝑥'
    𝑊(+
    ℎ'#$
    Gate 𝒐
    + A
    + A
    + A
    + A
    E ℎ'
    𝑊
    %,
    𝑥'
    𝑊(,
    ℎ'#$
    Gate 𝒓
    + A
    𝑊
    %-
    𝑥'
    𝑊(-
    ℎ'#$
    Gate 𝒛
    + A
    𝑊%(
    𝑥'
    𝑊((
    ℎ′
    Gate 𝒉
    + A
    *
    ℎ'
    -1
    E
    ℎ'
    ℎ′
    LSTM GRU Time-independent

    View Slide

  10. 𝑊%&
    𝑥'
    𝑊(&
    ℎ'#$
    Gate 𝒊
    𝑊%)
    𝑥'
    𝑊()
    ℎ'#$
    Gate 𝒇
    𝑊
    %*
    𝑥'
    𝑊(*
    ℎ'#$
    Gate 𝒄
    𝑊
    %+
    𝑥'
    𝑊(+
    ℎ'#$
    Gate 𝒐
    + A
    + A
    + A
    + A
    E ℎ'
    𝑊
    %,
    𝑥'
    𝑊(,
    ℎ'#$
    Gate 𝒓
    + A
    𝑊
    %-
    𝑥'
    𝑊(-
    ℎ'#$
    Gate 𝒛
    + A
    𝑊%(
    𝑥'
    𝑊((
    ℎ′
    Gate 𝒉
    + A
    *
    ℎ'
    -1
    E
    ℎ'
    ℎ′
    LSTM GRU Time-independent
    Time-dependent

    View Slide

  11. 𝑊%&
    𝑥'
    𝑊(&
    ℎ'#$
    Gate 𝒊
    𝑊%)
    𝑥'
    𝑊()
    ℎ'#$
    Gate 𝒇
    𝑊
    %*
    𝑥'
    𝑊(*
    ℎ'#$
    Gate 𝒄
    𝑊
    %+
    𝑥'
    𝑊(+
    ℎ'#$
    Gate 𝒐
    + A
    + A
    + A
    + A
    E ℎ'
    𝑊
    %,
    𝑥'
    𝑊(,
    ℎ'#$
    Gate 𝒓
    + A
    𝑊
    %-
    𝑥'
    𝑊(-
    ℎ'#$
    Gate 𝒛
    + A
    𝑊%(
    𝑥'
    𝑊((
    ℎ′
    Gate 𝒉
    + A
    *
    ℎ'
    -1
    E
    ℎ'
    ℎ′
    LSTM GRU Time-independent
    Time-dependent

    View Slide

  12. GraphLSTM
    𝐺*+.!
    (𝑥'
    , 𝑊%&
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(&
    )
    Gate 𝒊
    + A
    + A
    + A
    + A
    E ℎ'
    + A
    + A
    + A
    *
    ℎ'
    -1
    E
    ℎ'
    ℎ′
    GraphGRU
    𝐺*+.!
    (𝑥'
    , 𝑊%)
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊()
    )
    Gate 𝒇
    𝐺*+.!
    (𝑥'
    , 𝑊
    %*
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(*
    )
    Gate 𝒄
    𝐺*+.!
    (𝑥'
    , 𝑊
    %+
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(+
    )
    Gate 𝒐
    𝐺*+.!
    (𝑥'
    , 𝑊
    %,
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(,
    )
    Gate 𝒓
    𝐺*+.!
    (𝑥'
    , 𝑊
    %-
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(-
    )
    Gate 𝒛
    𝐺*+.!
    (𝑥'
    , 𝑊%(
    )
    𝐺*+.!
    (ℎ/, 𝑊(!(
    )
    Gate 𝒉
    Time-independent
    Time-dependent

    View Slide

  13. Challenge #1: Redundant Neighborhood Aggregation
    GraphLSTM
    • Two categories of graph convolutions.
    Ø Time-independent graph convolution
    depends on current representations of nodes.
    Ø Time-dependent graph convolution
    depends on previous hidden states.
    • Redundancy: Graph convolutions in the same
    category perform same neighborhood
    aggregation.
    𝐺*+.!
    (𝑥'
    , 𝑊%&
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(&
    )
    Gate 𝒊
    + A
    + A
    + A
    + A
    E ℎ'
    𝐺*+.!
    (𝑥'
    , 𝑊%)
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊()
    )
    Gate 𝒇
    𝐺*+.!
    (𝑥'
    , 𝑊
    %*
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(*
    )
    Gate 𝒄
    𝐺*+.!
    (𝑥'
    , 𝑊
    %+
    )
    𝐺*+.!
    (ℎ'#$
    , 𝑊(+
    )
    Gate 𝒐

    View Slide

  14. Challenge #2: Inefficient Distributed Training
    • No existing systems for training static GNNs, for example, DGL,
    support distributed dynamic GNN training in an efficient way.
    • Static GNN training:
    • Partitioning both the graph structure and node features across machines.
    • Using data parallelism to train a static GNN.
    • Can we partition each snapshot individually?
    § Partitioning and maintaining a large number of snapshots can be expensive.
    § The graph structure and the node features in each snapshot may vary.

    View Slide

  15. Cached Message Passing
    + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+
    Time-independent
    Time-dependent
    Typical Message Passing Paradigm of GNN:
    𝑚0→!
    " = 𝑀"(ℎ!
    "#$, ℎ0
    "#$, 𝑒0→!
    "#$ )
    𝑚!
    " = .
    0∈3(!)
    𝑚0→!
    "
    ℎ!
    " = 𝑈"(ℎ!
    "#$, 𝑚!
    " )

    View Slide

  16. Cached Message Passing
    Typical Message Passing Paradigm of GNN:
    The results after the message passing can be reused
    for all graph convolution in the same category.
    𝑚0→!
    " = 𝑀"(ℎ!
    "#$, ℎ0
    "#$, 𝑒0→!
    "#$ )
    𝑚!
    " = .
    0∈3(!)
    𝑚0→!
    "
    ℎ!
    " = 𝑈"(ℎ!
    "#$, 𝑚!
    " )
    + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+
    Time-independent
    Time-dependent

    View Slide

  17. Cached Message Passing
    • Dynamic graphs are often trained using sequence-to-sequence
    models in a sliding-window fashion.
    t=1
    GraphRNN
    t=2 t=3 t=4
    GraphRNN
    GraphRNN GraphRNN
    GraphRNN GraphRNN
    GraphRNN GraphRNN
    H
    H
    H
    Layer 1
    Layer 2
    H
    H
    H H
    H
    H
    H
    Teacher States (Ground Truth)
    Encoder Decoder
    Seq 1

    View Slide

  18. Cached Message Passing
    • Dynamic graphs are often trained using sequence-to-sequence
    models in a sliding-window fashion.
    GraphRNN
    t=2 t=3 t=4
    GraphRNN
    GraphRNN GraphRNN
    GraphRNN GraphRNN
    GraphRNN GraphRNN
    H
    H
    H
    Layer 1
    Layer 2
    H
    H
    H H
    H
    H
    H
    Encoder Decoder
    t=5
    Seq 2

    View Slide

  19. Cached Message Passing
    • Dynamic graphs are often trained using sequence-to-sequence
    models in a sliding-window fashion.
    GraphRNN
    t=2 t=3 t=4
    GraphRNN
    GraphRNN GraphRNN
    GraphRNN GraphRNN
    GraphRNN GraphRNN
    H
    H
    H
    Layer 1
    Layer 2
    H
    H
    H H
    H
    H
    H
    Encoder Decoder
    t=5
    Neighborhood aggregation has already been performed in previous sequence(s)!
    Seq 2

    View Slide

  20. Cached Message Passing
    + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+
    𝑥'
    ℎ'#$
    Cache Store
    Snapshot t

    Snapshot t-n
    𝑚%"#%
    𝑚("#%#$

    View Slide

  21. Cached Message Passing
    𝑥'
    ℎ'#$
    Cache Store
    Snapshot t

    Snapshot t-n
    𝑚%"#%
    𝑚("#%#$
    Msg. Passing
    Msg. Passing + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+

    View Slide

  22. Cached Message Passing
    𝑥'
    ℎ'#$
    Cache Store
    Snapshot t
    𝑚%"
    𝑚("#$

    Snapshot t-n
    𝑚%"#%
    𝑚("#%#$
    PUT
    PUT
    Msg. Passing
    Msg. Passing + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+

    View Slide

  23. Cached Message Passing
    𝑥'
    ℎ'#$
    Cache Store
    Snapshot t
    𝑚%"
    𝑚("#$

    Snapshot t-n
    𝑚%"#%
    𝑚("#%#$
    PUT
    PUT
    GET
    GET
    Msg. Passing
    Msg. Passing + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+

    View Slide

  24. Cached Message Passing
    𝑥'
    ℎ'#$
    Cache Store
    Snapshot t
    𝑚%"
    𝑚("#$

    Snapshot t-n
    𝑚%"#%
    𝑚("#%#$
    PUT
    PUT
    GET
    GET
    Msg. Passing
    GET
    GET
    GET
    GET
    + A
    + A
    + A
    + A
    E ℎ'
    GraphLSTM
    𝑔%"
    )
    𝑔("#$
    )
    Gate 𝒇
    𝑚%"
    𝑚("#$
    𝑊%)
    𝑊()
    𝑔%"
    &
    𝑔("#$
    &
    Gate 𝒊
    𝑚%"
    𝑚("#$
    𝑊%&
    𝑊(&
    𝑔%"
    *
    𝑔("#$
    *
    Gate 𝒄
    𝑚%"
    𝑚("#$
    𝑊
    %*
    𝑊(*
    𝑔%"
    +
    𝑔("#$
    +
    Gate 𝒐
    𝑚%"
    𝑚("#$
    𝑊
    %+
    𝑊(+
    Msg. Passing

    View Slide

  25. Distributed DGNN Training
    t=1
    t=2
    t=n

    Partitioned Snapshots & Input Features
    𝑀$
    Layer 1 Layer 2 Layer K
    t=1
    t=2
    t=n

    𝑀6
    t=1
    t=2
    t=n

    𝑀7
    t=1
    t=2
    t=n

    𝑀8
    Layer 1 Layer 2 Layer k
    Sliding Window

    View Slide

  26. DynaGraph API
    cache() Cache caller function outputs; do nothing if already cached.
    msg_pass() Computes intermediate message passing results.
    update() Computes output representation from intermediate message
    passing results.
    integrate() Integrates a GNN into a GraphRNN to create a dynamic GNN.
    stack_seq_model() Stacks dynamic GNN layers to an encoder-decoder structure.

    View Slide

  27. Implementation & Evaluation
    • Implemented on Deep Graph Library (DGL) v0.7
    • Evaluated using 8 machines, each with 2 NVIDIA Tesla V100 GPUs
    § METR-LA: 207 nodes/snapshots, |F|=2, |S|= 34K
    § PEMS-BAY: 325 nodes/snapshots, |F|=2, |S|= 52K
    § METR-LA-Large: 0.4m nodes/snapshots, |F|=128, |S|= 34k
    § PEMS-BAY-Large: 0.7m nodes/snapshots, |F|=128, |S|= 52k
    • Several Dynamic GNN architectures
    § GCRN-GRU, GCRN-LSTM [ICONIP ‘18]
    § DCRNN [ICLR ‘18]

    View Slide

  28. DynaGraph Single-Machine Performance
    0
    50
    100
    150
    200
    250
    DCRNN GCRN-GRU GCRN-LSTM DCRNN GCRN-GRU GCRN-LSTM
    META-LA PEMS-BAY
    Average Epoch Time(s)
    DGL DynaGraph
    Up to 2.31x Speedup

    View Slide

  29. DynaGraph Distributed Performance
    Up to 2.23x Speedup
    0
    1000
    2000
    3000
    4000
    5000
    6000
    DCRNN GCRN-GRU GCRN-LSTM DCRNN GCRN-GRU GCRN-LSTM
    META-LA-Large PEMS-BAY-Large
    Average Epoch Time(s)
    DGL DynaGraph

    View Slide

  30. DynaGraph Scaling
    0
    10
    20
    30
    40
    50
    60
    70
    80
    2(4) 4(8) 8(16)
    Throughput (snapshots/sec)
    0
    10
    20
    30
    40
    50
    60
    70
    80
    2(4) 4(8) 8(16)
    Throughput (snapshots/sec)
    # Machines (# GPUs)
    DGL DynaGraph
    GCRN-GRU
    GCRN-LSTM

    View Slide

  31. Summary
    • Supporting dynamic graphs is increasingly important for enabling
    many GNN applications.
    • Existing GNN systems mainly focus on static graphs and static GNNs.
    • Dynamic GNN architectures combine GNN techniques and temporal
    embedding techniques like RNNs.
    • DynaGraph enables dynamic GNN training at scale.
    • Several techniques to reuse intermediate results.
    • Efficient distributed training.
    • Outperforms state-of-the-art solutions.
    Thank you!
    Contact: [email protected]

    View Slide