Getting Started with Training Models with Diode

This comprehensive guide will walk you through the complete process of creating a machine learning model from scratch using the Diode toolkit. You’ll learn how to generate a dataset of matrix multiplication performance data and train a model to predict optimal configurations.

Overview

The Diode workflow involves four main steps:

  1. Data Collection: Generate matrix multiplication performance data using PyTorch’s autotuning capabilities

  2. Model Training: Train a deep learning model on the collected data

  3. Validation Dataset Creation: Create a separate validation dataset from predefined operation shapes

  4. Model Validation: Evaluate the trained model’s performance on the validation dataset

Prerequisites

Before starting, ensure you have:

  • Access to your target hardware

  • PyTorch nightlies

  • The Diode toolkit

Step 1: Data Collection

The first step is to generate a training dataset by collecting matrix multiplication performance data. Diode uses PyTorch’s feedback saver interface to automatically collect timing information for different matrix multiplication configurations.

Setting Up the Data Collector

The MatmulDatasetCollector class provides flexible data collection capabilities:

from torch_diode.collection.matmul_dataset_collector import MatmulDatasetCollector, CollectionMode

# Initialize the collector with log-normal distribution mode
collector = MatmulDatasetCollector(
    hardware_name="your_gpu_name",
    mode=CollectionMode.LOG_NORMAL,
    operations=["mm", "addmm", "bmm"],
    num_shapes=1000,
    seed=50,
)

Collection Modes

Diode supports three collection modes:

  1. LOG_NORMAL: Uses log-normal distributions to generate realistic matrix sizes based on production workloads

  2. RANDOM: Generates uniformly random matrix sizes within specified bounds

  3. OPERATION_SHAPE_SET: Uses predefined shapes from a configuration file

Running Data Collection

Use the matmul_toolkit.py script to collect training data:

python matmul_toolkit.py \
    --format msgpack \
    --seed 50 \
    collect \
    --output train_dataset.msgpack \
    --num-shapes 1000 \
    --log-normal \
    --search-space EXHAUSTIVE \
    --search-mode max-autotune \
    --chunk-size 5

Key parameters:

  • --format msgpack: Use MessagePack format for efficient serialization

  • --seed 50: Set random seed for reproducibility

  • --num-shapes 1000: Generate 1000 different matrix configurations

  • --log-normal: Use log-normal distribution for realistic sizes

  • --search-space EXHAUSTIVE: Use exhaustive search for optimal configurations

  • --search-mode max-autotune: Use PyTorch’s max-autotune mode

  • --chunk-size 5: Write data every 5 operations to prevent data during collection

Understanding the Collection Process

The data collection process works by:

  1. Generating matrix shapes based on the selected mode

  2. Creating random tensors with the specified dimensions and data types

  3. Compiling matrix multiplication operations with PyTorch’s autotuning

  4. Capturing timing data for different Triton GEMM configurations through the feedback saver interface

  5. Storing the results in a structured dataset format

Step 2: Model Training

Once you have collected training data, train a deep learning model to predict optimal GEMM configurations:

python matmul_toolkit.py \
    --seed 50 \
    train \
    --data-dir ./data \
    --model matmul_model.pt \
    --model-type deep \
    --batch-size 64 \
    --num-epochs 1000 \
    --learning-rate 0.001 \
    --log-dir ./logs

Training parameters:

  • --model-type deep: Use a deep neural network architecture

  • --batch-size 64: Process 64 samples per batch

  • --num-epochs 1000: Train for 1000 epochs

  • --learning-rate 0.001: Set the learning rate

  • --log-dir: Directory to save training logs and metrics

The model learns to predict optimal Triton GEMM configurations based on matrix dimensions, data types, and hardware characteristics.

Model Architecture

Diode provides two simple neural network architectures for timing prediction. These are not meant to be state-of-the-art models, but rather serve as a starting point for further experimentation and development:

Standard Model (MatmulTimingModel)

The standard model uses a feedforward neural network with the following architecture:

class MatmulTimingModel(nn.Module):
    def __init__(
        self,
        problem_feature_dim: int,
        config_feature_dim: int,
        hidden_dims: List[int] = [256, 512, 256, 128, 64],
        dropout_rate: float = 0.2,
    ):

Architecture components:

  • Input Layer: Concatenates problem features (matrix dimensions, data types) and configuration features (Triton GEMM parameters)

  • Hidden Layers: Multiple fully connected layers with ReLU activation, batch normalization, and dropout

  • Output Layer: Single neuron predicting log execution time

  • Regularization: Dropout and batch normalization to prevent overfitting

Deep Model (DeepMatmulTimingModel)

The deep model uses residual connections for training deeper networks:

class DeepMatmulTimingModel(nn.Module):
    def __init__(
        self,
        problem_feature_dim: int,
        config_feature_dim: int,
        hidden_dim: int = 128,
        num_layers: int = 10,
        dropout_rate: float = 0.2,
    ):

Key features:

  • Residual Blocks: Each block contains two linear layers with skip connections

  • Deeper Architecture: 10+ layers with consistent hidden dimensions

  • Better Gradient Flow: Residual connections help train deeper networks effectively

Residual Block Implementation

class ResidualBlock(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.block(x)
        out += identity  # Skip connection
        out = self.relu(out)
        return out

The residual blocks enable training much deeper networks while maintaining stable gradients throughout the network depth.

Step 3: Creating a Validation Dataset

Create a separate validation dataset using predefined operation shapes to evaluate model performance:

python matmul_toolkit.py \
    --format msgpack \
    --seed 50 \
    create-validation \
    --output validation_dataset.msgpack \
    --shapeset operation_shapeset.json \
    --operations mm addmm bmm \
    --search-space EXHAUSTIVE \
    --search-mode max-autotune

This step:

  • Loads predefined matrix shapes from operation_shapeset.json

  • Runs autotuning to find optimal configurations for these shapes

  • Creates a validation dataset with known ground truth performance data

Step 4: Model Validation

Finally, evaluate your trained model against the validation dataset:

python matmul_toolkit.py \
    --seed 50 \
    validate-model \
    --model matmul_model.pt \
    --dataset validation_dataset.msgpack \
    --batch-size 64 \
    --top-n-worst 10

This validation step:

  • Loads the trained model and validation dataset

  • Makes predictions for each validation sample

  • Compares predictions against ground truth timing data

  • Reports accuracy metrics and identifies the worst-performing predictions

Complete Workflow Script

Here’s a complete bash script that orchestrates the entire process:

#!/bin/bash

set -e  # Exit on any error

# Configuration
SEED=50
DATA_DIR="./data"
TRAIN_DATASET="${DATA_DIR}/seed_${SEED}_train_dataset.msgpack"
VALIDATION_DATASET="${DATA_DIR}/validation/validation_dataset.msgpack"
MODEL_PATH="${DATA_DIR}/matmul_model.pt"
LOG_DIR="${DATA_DIR}/logs"
NUM_SHAPES=1000
NUM_EPOCHS=1000
PYTHON_CMD="python"
TOOLKIT_PATH="matmul_toolkit.py"
OPERATION_SHAPESET_PATH="operation_shapeset.json"

echo "Starting Diode workflow..."

# Step 1: Create data directory
mkdir -p "${DATA_DIR}"
mkdir -p "${DATA_DIR}/validation"

# Step 2: Generate training dataset
echo "Collecting training data..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
    --format msgpack \
    --seed "${SEED}" \
    collect \
    --output "${TRAIN_DATASET}" \
    --num-shapes ${NUM_SHAPES} \
    --log-normal \
    --search-space EXHAUSTIVE \
    --search-mode max-autotune \
    --chunk-size 5

# Step 3: Train model
echo "Training model..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
    --seed "${SEED}" \
    train \
    --data-dir "${DATA_DIR}" \
    --model "${MODEL_PATH}" \
    --model-type deep \
    --batch-size 64 \
    --num-epochs ${NUM_EPOCHS} \
    --learning-rate 0.001 \
    --log-dir "${LOG_DIR}"

# Step 4: Create validation dataset
echo "Creating validation dataset..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
    --format msgpack \
    --seed "${SEED}" \
    create-validation \
    --output "${VALIDATION_DATASET}" \
    --shapeset "${OPERATION_SHAPESET_PATH}" \
    --operations mm addmm bmm \
    --search-space EXHAUSTIVE \
    --search-mode max-autotune

# Step 5: Validate model
echo "Validating model..."
${PYTHON_CMD} "${TOOLKIT_PATH}" \
    --seed "${SEED}" \
    validate-model \
    --model "${MODEL_PATH}" \
    --dataset "${VALIDATION_DATASET}" \
    --batch-size 64 \
    --top-n-worst 10

echo "Workflow completed successfully!"

Advanced Configuration

Custom Collection Parameters

For more control over data collection, you can customize the log-normal distribution parameters:

# Custom parameters for different workload characteristics
collector = MatmulDatasetCollector(
    mode=CollectionMode.LOG_NORMAL,
    # Larger matrices (shift mean higher)
    log_normal_m_mean=7.0,
    log_normal_n_mean=6.5,
    log_normal_k_mean=6.8,
    # Smaller variance for more consistent sizes
    log_normal_m_std=1.5,
    log_normal_n_std=1.2,
    log_normal_k_std=1.8,
)

Tips

  1. Start Small: Begin with a smaller number of shapes (100-200) to validate your setup

  2. Monitor Memory: Keep an eye on GPU memory usage during collection

  3. Save Frequently: Use the --chunk-size parameter to save data periodically

  4. Reproducibility: Always set a random seed for consistent results

  5. Hardware Consistency: Collect training and validation data on the same hardware

Next Steps

After completing this workflow, you can:

  • Experiment with different model architectures

  • Collect data for specific workloads using OPERATION_SHAPE_SET mode

  • Integrate the trained model into your own applications

  • Analyze the collected data to understand performance patterns

For more advanced usage, see the API documentation and examples in the repository.