PyTorch Geometric (PyG)
Overview
PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.
When to Use This Skill
This skill should be used when working with:
- Graph-based machine learning: Node classification, graph classification, link prediction
- Molecular property prediction: Drug discovery, chemical property prediction
- Social network analysis: Community detection, influence prediction
- Citation networks: Paper classification, recommendation systems
- 3D geometric data: Point clouds, meshes, molecular structures
- Heterogeneous graphs: Multi-type nodes and edges (e.g., knowledge graphs)
- Large-scale graph learning: Neighbor sampling, distributed training
Quick Start
Installation
uv pip install torch_geometric
For additional dependencies (sparse operations, clustering):
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
Basic Graph Creation
import torch from torch_geometric.data import Data # Create a simple graph with 3 nodes edge_index = torch.tensor([[0, 1, 1, 2], # source nodes [1, 0, 2, 1]], dtype=torch.long) # target nodes x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # node features data = Data(x=x, edge_index=edge_index) print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
Loading a Benchmark Dataset
from torch_geometric.datasets import Planetoid # Load Cora citation network dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] # Get the first (and only) graph print(f"Dataset: {dataset}") print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
Core Concepts
Data Structure
PyG represents graphs using the torch_geometric.data.Data class with these key attributes:
data.x: Node feature matrix[num_nodes, num_node_features]data.edge_index: Graph connectivity in COO format[2, num_edges]data.edge_attr: Edge feature matrix[num_edges, num_edge_features](optional)data.y: Target labels for nodes or graphsdata.pos: Node spatial positions[num_nodes, num_dimensions](optional)- Custom attributes: Can add any attribute (e.g.,
data.train_mask,data.batch)
Important: These attributes are not mandatory—extend Data objects with custom attributes as needed.
Edge Index Format
Edges are stored in COO (coordinate) format as a [2, num_edges] tensor:
- First row: source node indices
- Second row: target node indices
# Edge list: (0→1), (1→0), (1→2), (2→1) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
Mini-Batch Processing
PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:
- Adjacency matrices are stacked diagonally
- Node features are concatenated along the node dimension
- A
batchvector maps each node to its source graph - No padding needed—computationally efficient
from torch_geometric.loader import DataLoader loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: print(f"Batch size: {batch.num_graphs}") print(f"Total nodes: {batch.num_nodes}") # batch.batch maps nodes to graphs
Building Graph Neural Networks
Message Passing Paradigm
GNNs in PyG follow a neighborhood aggregation scheme:
- Transform node features
- Propagate messages along edges
- Aggregate messages from neighbors
- Update node representations
Using Pre-Built Layers
PyG provides 40+ convolutional layers. Common ones include:
GCNConv (Graph Convolutional Network):
from torch_geometric.nn import GCNConv import torch.nn.functional as F class GCN(torch.nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.conv1 = GCNConv(num_features, 16) self.conv2 = GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
GATConv (Graph Attention Network):
from torch_geometric.nn import GATConv class GAT(torch.nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6) self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6) def forward(self, data): x, edge_index = data.x, data.edge_index x = F.dropout(x, p=0.6, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
GraphSAGE:
from torch_geometric.nn import SAGEConv class GraphSAGE(torch.nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.conv1 = SAGEConv(num_features, 64) self.conv2 = SAGEConv(64, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
Custom Message Passing Layers
For custom layers, inherit from MessagePassing:
from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class CustomConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') # "add", "mean", or "max" self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): # Add self-loops to adjacency matrix edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Transform node features x = self.lin(x) # Compute normalization row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # Propagate messages return self.propagate(edge_index, x=x, norm=norm) def message(self, x_j, norm): # x_j: features of source nodes return norm.view(-1, 1) * x_j
Key methods:
forward(): Main entry pointmessage(): Constructs messages from source to target nodesaggregate(): Aggregates messages (usually don't override—setaggrparameter)update(): Updates node embeddings after aggregation
Variable naming convention: Appending _i or _j to tensor names automatically maps them to target or source nodes.
Working with Datasets
Loading Built-in Datasets
PyG provides extensive benchmark datasets:
# Citation networks (node classification) from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') # or 'CiteSeer', 'PubMed' # Graph classification from torch_geometric.datasets import TUDataset dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') # Molecular datasets from torch_geometric.datasets import QM9 dataset = QM9(root='/tmp/QM9') # Large-scale datasets from torch_geometric.datasets import Reddit dataset = Reddit(root='/tmp/Reddit')
Check references/datasets_reference.md for a comprehensive list.
Creating Custom Datasets
For datasets that fit in memory, inherit from InMemoryDataset:
from torch_geometric.data import InMemoryDataset, Data import torch class MyOwnDataset(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None): super().__init__(root, transform, pre_transform) self.load(self.processed_paths[0]) @property def raw_file_names(self): return ['my_data.csv'] # Files needed in raw_dir @property def processed_file_names(self): return ['data.pt'] # Files in processed_dir def download(self): # Download raw data to self.raw_dir pass def process(self): # Read data, create Data objects data_list = [] # Example: Create a simple graph edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) x = torch.randn(2, 16) y = torch.tensor([0], dtype=torch.long) data = Data(x=x, edge_index=edge_index, y=y) data_list.append(data) # Apply pre_filter and pre_transform if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] # Save processed data self.save(data_list, self.processed_paths[0])
For large datasets that don't fit in memory, inherit from Dataset and implement len() and get(idx).
Loading Graphs from CSV
import pandas as pd import torch from torch_geometric.data import HeteroData # Load nodes nodes_df = pd.read_csv('nodes.csv') x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float) # Load edges edges_df = pd.read_csv('edges.csv') edge_index = torch.tensor([edges_df['source'].values, edges_df['target'].values], dtype=torch.long) data = Data(x=x, edge_index=edge_index)
Training Workflows
Node Classification (Single Graph)
import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid # Load dataset dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] # Create model model = GCN(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # Training model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() if epoch % 10 == 0: print(f'Epoch {epoch}, Loss: {loss.item():.4f}') # Evaluation model.eval() pred = model(data).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() acc = int(correct) / int(data.test_mask.sum()) print(f'Test Accuracy: {acc:.4f}')
Graph Classification (Multiple Graphs)
from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import global_mean_pool class GraphClassifier(torch.nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.conv1 = GCNConv(num_features, 64) self.conv2 = GCNConv(64, 64) self.lin = torch.nn.Linear(64, num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.conv1(x, edge_index) x = F.relu(x) x = self.conv2(x, edge_index) x = F.relu(x) # Global pooling (aggregate node features to graph-level) x = global_mean_pool(x, batch) x = self.lin(x) return F.log_softmax(x, dim=1) # Load dataset dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=32, shuffle=True) model = GraphClassifier(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Training model.train() for epoch in range(100): total_loss = 0 for batch in loader: optimizer.zero_grad() out = model(batch) loss = F.nll_loss(out, batch.y) loss.backward() optimizer.step() total_loss += loss.item() if epoch % 10 == 0: print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
Large-Scale Graphs with Neighbor Sampling
For large graphs, use NeighborLoader to sample subgraphs:
from torch_geometric.loader import NeighborLoader # Create a neighbor sampler train_loader = NeighborLoader( data, num_neighbors=[25, 10], # Sample 25 neighbors for 1st hop, 10 for 2nd hop batch_size=128, input_nodes=data.train_mask, ) # Training model.train() for batch in train_loader: optimizer.zero_grad() out = model(batch) # Only compute loss on seed nodes (first batch_size nodes) loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size]) loss.backward() optimizer.step()
Important:
- Output subgraphs are directed
- Node indices are relabeled (0 to batch.num_nodes - 1)
- Only use seed node predictions for loss computation
- Sampling beyond 2-3 hops is generally not feasible
Advanced Features
Heterogeneous Graphs
For graphs with multiple node and edge types, use HeteroData:
from torch_geometric.data import HeteroData data = HeteroData() # Add node features for different types data['paper'].x = torch.randn(100, 128) # 100 papers with 128 features data['author'].x = torch.randn(200, 64) # 200 authors with 64 features # Add edges for different types (source_type, edge_type, target_type) data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500)) data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300)) print(data)
Convert homogeneous models to heterogeneous:
from torch_geometric.nn import to_hetero # Define homogeneous model model = GNN(...) # Convert to heterogeneous model = to_hetero(model, data.metadata(), aggr='sum') # Use as normal out = model(data.x_dict, data.edge_index_dict)
Or use HeteroConv for custom edge-type-specific operations:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv class HeteroGNN(torch.nn.Module): def __init__(self, metadata): super().__init__() self.conv1 = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), }, aggr='sum') self.conv2 = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(64, 32), ('author', 'writes', 'paper'): SAGEConv((64, 64), 32), }, aggr='sum') def forward(self, x_dict, edge_index_dict): x_dict = self.conv1(x_dict, edge_index_dict) x_dict = {key: F.relu(x) for key, x in x_dict.items()} x_dict = self.conv2(x_dict, edge_index_dict) return x_dict
Transforms
Apply transforms to modify graph structure or features:
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose # Single transform transform = NormalizeFeatures() dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform) # Compose multiple transforms transform = Compose([ AddSelfLoops(), NormalizeFeatures(), ]) dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
Common transforms:
- Structure:
ToUndirected,AddSelfLoops,RemoveSelfLoops,KNNGraph,RadiusGraph - Features:
NormalizeFeatures,NormalizeScale,Center - Sampling:
RandomNodeSplit,RandomLinkSplit - Positional Encoding:
AddLaplacianEigenvectorPE,AddRandomWalkPE
See references/transforms_reference.md for the full list.
Model Explainability
PyG provides explainability tools to understand model predictions:
from torch_geometric.explain import Explainer, GNNExplainer # Create explainer explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', # or 'phenomenon' node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), ) # Generate explanation for a specific node node_idx = 10 explanation = explainer(data.x, data.edge_index, index=node_idx) # Visualize print(f'Node {node_idx} explanation:') print(f'Important edges: {explanation.edge_mask.topk(5).indices}') print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
Pooling Operations
For hierarchical graph representations:
from torch_geometric.nn import TopKPooling, global_mean_pool class HierarchicalGNN(torch.nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.conv1 = GCNConv(num_features, 64) self.pool1 = TopKPooling(64, ratio=0.8) self.conv2 = GCNConv(64, 64) self.pool2 = TopKPooling(64, ratio=0.8) self.lin = torch.nn.Linear(64, num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x = global_mean_pool(x, batch) x = self.lin(x) return F.log_softmax(x, dim=1)
Common Patterns and Best Practices
Check Graph Properties
# Undirected check from torch_geometric.utils import is_undirected print(f"Is undirected: {is_undirected(data.edge_index)}") # Connected components from torch_geometric.utils import connected_components print(f"Connected components: {connected_components(data.edge_index)}") # Contains self-loops from torch_geometric.utils import contains_self_loops print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
GPU Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) data = data.to(device) # For DataLoader for batch in loader: batch = batch.to(device) # Train...
Save and Load Models
# Save torch.save(model.state_dict(), 'model.pth') # Load model = GCN(num_features, num_classes) model.load_state_dict(torch.load('model.pth')) model.eval()
Layer Capabilities
When choosing layers, consider these capabilities:
- SparseTensor: Supports efficient sparse matrix operations
- edge_weight: Handles one-dimensional edge weights
- edge_attr: Processes multi-dimensional edge features
- Bipartite: Works with bipartite graphs (different source/target dimensions)
- Lazy: Enables initialization without specifying input dimensions
See the GNN cheatsheet at references/layer_capabilities.md.
Resources
Bundled References
This skill includes detailed reference documentation:
references/layers_reference.md: Complete listing of all 40+ GNN layers with descriptions and capabilitiesreferences/datasets_reference.md: Comprehensive dataset catalog organized by categoryreferences/transforms_reference.md: All available transforms and their use casesreferences/api_patterns.md: Common API patterns and coding examples
Scripts
Utility scripts are provided in scripts/:
scripts/visualize_graph.py: Visualize graph structure using networkx and matplotlibscripts/create_gnn_template.py: Generate boilerplate code for common GNN architecturesscripts/benchmark_model.py: Benchmark model performance on standard datasets
Execute scripts directly or read them for implementation patterns.