Getting Started with Training Models with Diode
This comprehensive guide will walk you through the complete process of creating a machine learning model from scratch using the Diode toolkit. You’ll learn how to generate a dataset of matrix multiplication performance data and train a model to predict optimal configurations.
Overview
The Diode workflow involves four main steps:
Data Collection: Generate matrix multiplication performance data using PyTorch’s autotuning capabilities
Model Training: Train a deep learning model on the collected data
Validation Dataset Creation: Create a separate validation dataset from predefined operation shapes
Model Validation: Evaluate the trained model’s performance on the validation dataset
Prerequisites
Before starting, ensure you have:
Access to your target hardware
PyTorch nightlies
The Diode toolkit
Step 1: Data Collection
The first step is to generate a training dataset by collecting matrix multiplication performance data. Diode uses PyTorch’s feedback saver interface to automatically collect timing information for different matrix multiplication configurations.
Setting Up the Data Collector
The MatmulDatasetCollector class provides flexible data collection capabilities:
from torch_diode.collection.matmul_dataset_collector import MatmulDatasetCollector, CollectionMode
# Initialize the collector with log-normal distribution mode
collector = MatmulDatasetCollector(
hardware_name="your_gpu_name",
mode=CollectionMode.LOG_NORMAL,
operations=["mm", "addmm", "bmm"],
num_shapes=1000,
seed=50,
)
Collection Modes
Diode supports three collection modes:
LOG_NORMAL: Uses log-normal distributions to generate realistic matrix sizes based on production workloads
RANDOM: Generates uniformly random matrix sizes within specified bounds
OPERATION_SHAPE_SET: Uses predefined shapes from a configuration file
Running Data Collection
Use the matmul_toolkit.py script to collect training data:
python matmul_toolkit.py \
--format msgpack \
--seed 50 \
collect \
--output train_dataset.msgpack \
--num-shapes 1000 \
--log-normal \
--search-space EXHAUSTIVE \
--search-mode max-autotune \
--chunk-size 5
Key parameters:
--format msgpack: Use MessagePack format for efficient serialization--seed 50: Set random seed for reproducibility--num-shapes 1000: Generate 1000 different matrix configurations--log-normal: Use log-normal distribution for realistic sizes--search-space EXHAUSTIVE: Use exhaustive search for optimal configurations--search-mode max-autotune: Use PyTorch’s max-autotune mode--chunk-size 5: Write data every 5 operations to prevent data during collection
Understanding the Collection Process
The data collection process works by:
Generating matrix shapes based on the selected mode
Creating random tensors with the specified dimensions and data types
Compiling matrix multiplication operations with PyTorch’s autotuning
Capturing timing data for different Triton GEMM configurations through the feedback saver interface
Storing the results in a structured dataset format
Step 2: Model Training
Once you have collected training data, train a deep learning model to predict optimal GEMM configurations:
python matmul_toolkit.py \
--seed 50 \
train \
--data-dir ./data \
--model matmul_model.pt \
--model-type deep \
--batch-size 64 \
--num-epochs 1000 \
--learning-rate 0.001 \
--log-dir ./logs
Training parameters:
--model-type deep: Use a deep neural network architecture--batch-size 64: Process 64 samples per batch--num-epochs 1000: Train for 1000 epochs--learning-rate 0.001: Set the learning rate--log-dir: Directory to save training logs and metrics
The model learns to predict optimal Triton GEMM configurations based on matrix dimensions, data types, and hardware characteristics.
Model Architecture
Diode provides two simple neural network architectures for timing prediction. These are not meant to be state-of-the-art models, but rather serve as a starting point for further experimentation and development:
Standard Model (MatmulTimingModel)
The standard model uses a feedforward neural network with the following architecture:
class MatmulTimingModel(nn.Module):
def __init__(
self,
problem_feature_dim: int,
config_feature_dim: int,
hidden_dims: List[int] = [256, 512, 256, 128, 64],
dropout_rate: float = 0.2,
):
Architecture components:
Input Layer: Concatenates problem features (matrix dimensions, data types) and configuration features (Triton GEMM parameters)
Hidden Layers: Multiple fully connected layers with ReLU activation, batch normalization, and dropout
Output Layer: Single neuron predicting log execution time
Regularization: Dropout and batch normalization to prevent overfitting
Deep Model (DeepMatmulTimingModel)
The deep model uses residual connections for training deeper networks:
class DeepMatmulTimingModel(nn.Module):
def __init__(
self,
problem_feature_dim: int,
config_feature_dim: int,
hidden_dim: int = 128,
num_layers: int = 10,
dropout_rate: float = 0.2,
):
Key features:
Residual Blocks: Each block contains two linear layers with skip connections
Deeper Architecture: 10+ layers with consistent hidden dimensions
Better Gradient Flow: Residual connections help train deeper networks effectively
Residual Block Implementation
class ResidualBlock(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.block(x)
out += identity # Skip connection
out = self.relu(out)
return out
The residual blocks enable training much deeper networks while maintaining stable gradients throughout the network depth.
Step 3: Creating a Validation Dataset
Create a separate validation dataset using predefined operation shapes to evaluate model performance:
python matmul_toolkit.py \
--format msgpack \
--seed 50 \
create-validation \
--output validation_dataset.msgpack \
--shapeset operation_shapeset.json \
--operations mm addmm bmm \
--search-space EXHAUSTIVE \
--search-mode max-autotune
This step:
Loads predefined matrix shapes from
operation_shapeset.jsonRuns autotuning to find optimal configurations for these shapes
Creates a validation dataset with known ground truth performance data
Step 4: Model Validation
Finally, evaluate your trained model against the validation dataset:
python matmul_toolkit.py \
--seed 50 \
validate-model \
--model matmul_model.pt \
--dataset validation_dataset.msgpack \
--batch-size 64 \
--top-n-worst 10
This validation step:
Loads the trained model and validation dataset
Makes predictions for each validation sample
Compares predictions against ground truth timing data
Reports accuracy metrics and identifies the worst-performing predictions
Complete Workflow Script
Here’s a complete bash script that orchestrates the entire process:
#!/bin/bash
set -e # Exit on any error
# Configuration
SEED=50
DATA_DIR="./data"
TRAIN_DATASET="${DATA_DIR}/seed_${SEED}_train_dataset.msgpack"
VALIDATION_DATASET="${DATA_DIR}/validation/validation_dataset.msgpack"
MODEL_PATH="${DATA_DIR}/matmul_model.pt"
LOG_DIR="${DATA_DIR}/logs"
NUM_SHAPES=1000
NUM_EPOCHS=1000
PYTHON_CMD="python"
TOOLKIT_PATH="matmul_toolkit.py"
OPERATION_SHAPESET_PATH="operation_shapeset.json"
echo "Starting Diode workflow..."
# Step 1: Create data directory
mkdir -p "${DATA_DIR}"
mkdir -p "${DATA_DIR}/validation"
# Step 2: Generate training dataset
echo "Collecting training data..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
--format msgpack \
--seed "${SEED}" \
collect \
--output "${TRAIN_DATASET}" \
--num-shapes ${NUM_SHAPES} \
--log-normal \
--search-space EXHAUSTIVE \
--search-mode max-autotune \
--chunk-size 5
# Step 3: Train model
echo "Training model..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
--seed "${SEED}" \
train \
--data-dir "${DATA_DIR}" \
--model "${MODEL_PATH}" \
--model-type deep \
--batch-size 64 \
--num-epochs ${NUM_EPOCHS} \
--learning-rate 0.001 \
--log-dir "${LOG_DIR}"
# Step 4: Create validation dataset
echo "Creating validation dataset..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
--format msgpack \
--seed "${SEED}" \
create-validation \
--output "${VALIDATION_DATASET}" \
--shapeset "${OPERATION_SHAPESET_PATH}" \
--operations mm addmm bmm \
--search-space EXHAUSTIVE \
--search-mode max-autotune
# Step 5: Validate model
echo "Validating model..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
--seed "${SEED}" \
validate-model \
--model "${MODEL_PATH}" \
--dataset "${VALIDATION_DATASET}" \
--batch-size 64 \
--top-n-worst 10
echo "Workflow completed successfully!"
Advanced Configuration
Custom Collection Parameters
For more control over data collection, you can customize the log-normal distribution parameters:
# Custom parameters for different workload characteristics
collector = MatmulDatasetCollector(
mode=CollectionMode.LOG_NORMAL,
# Larger matrices (shift mean higher)
log_normal_m_mean=7.0,
log_normal_n_mean=6.5,
log_normal_k_mean=6.8,
# Smaller variance for more consistent sizes
log_normal_m_std=1.5,
log_normal_n_std=1.2,
log_normal_k_std=1.8,
)
Tips
Start Small: Begin with a smaller number of shapes (100-200) to validate your setup
Monitor Memory: Keep an eye on GPU memory usage during collection
Save Frequently: Use the
--chunk-sizeparameter to save data periodicallyReproducibility: Always set a random seed for consistent results
Hardware Consistency: Collect training and validation data on the same hardware
Next Steps
After completing this workflow, you can:
Experiment with different model architectures
Collect data for specific workloads using OPERATION_SHAPE_SET mode
Integrate the trained model into your own applications
Analyze the collected data to understand performance patterns
For more advanced usage, see the API documentation and examples in the repository.