Skip to content

torchspec-project/TorchSpec

Repository files navigation

TorchSpec

TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via Mooncake (RDMA/TCP) store, allowing each side to scale independently.

TorchSpec Architecture

Setup

Choose Your Backend

TorchSpec supports two inference backends:

Backend Best For Installation
vLLM Flexibility, easier deployment ./tools/build_conda.sh 1 vllm
SGLang Production workloads, high throughput ./tools/build_conda.sh 1 sglang
Both Development, comparison testing ./tools/build_conda.sh 1 both

Quick Setup

# Install with vLLM
./tools/build_conda.sh 1 vllm
micromamba activate torchspec

# Or install with SGLang
./tools/build_conda.sh
micromamba activate torchspec

To install into your current environment instead:

./tools/build_conda.sh current sglang  # or 'vllm' or 'both'

Optional — install Flash Attention:

pip install -e ".[fa]"

Backend-Specific Usage

vLLM:

./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml

SGLang:

./examples/qwen3-8b-single-node/run.sh

TorchSpec uses vLLM's Worker Extension mechanism to hook into the model's forward pass and capture hidden states directly in the worker processes. This avoids RPC serialization issues and enables reliable hidden states extraction.

Quick Start

Train an Eagle3 draft model for Qwen3-8B using inference engine (4 GPUs: 2 training + 2 inference):

./examples/qwen3-8b-single-node/run.sh

Override any config value via CLI:

./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500

Examples

Example Backend Model
hf-quickstart HuggingFace Qwen3-8B
qwen3-8b-single-node Inference Engine Qwen3-8B
kimi-k25-2node-h200 Inference Engine Kimi-K2.5
kimi-k25-3node-h100 Inference Engine Kimi-K2.5

See examples/README.md for details.

Checkpoint Conversion

Convert an FSDP checkpoint to HuggingFace format:

python tools/convert_to_hf.py --input-dir ./outputs/my_experiment/iter_0010000/

Vocabulary pruning — reducing the draft model's lm_head to a smaller token set and emitting d2t/t2d mappings — can be applied either during training (pre-pruning) or at conversion time (post-pruning):

  • Pre-pruning: set draft_vocab_size in your training config. The checkpoint already contains the pruned lm_head and d2t/t2d buffers. Use the basic conversion command above — no extra flags needed.

  • Post-pruning: train with the full vocabulary, then prune at conversion time by passing --prune-vocab along with a representative dataset to compute token frequencies:

python tools/convert_to_hf.py \
    --input-dir ./outputs/my_experiment/iter_0010000/ \
    --prune-vocab \
    --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
    --draft-vocab-size 32000 \
    --tokenizer Qwen/Qwen3-8B \
    --chat-template qwen \
    --prompt-key conversations

Pass --cache-dir ./cache to reuse the tokenized dataset cache from training.

Metrics Reporting

W&B logging is disabled by default (report_to: none). To enable it, set report_to: wandb in your config and supply your API key.

Troubleshooting

Set TORCHSPEC_LOG_LEVEL=DEBUG for verbose logging when diagnosing issues:

TORCHSPEC_LOG_LEVEL=DEBUG ./examples/qwen3-8b-single-node/run.sh

Per-Rank File Logging

Set TORCHSPEC_LOG_DIR to an absolute path on a shared filesystem (NFS) to enable per-rank log files for every Ray actor (training and inference):

export TORCHSPEC_LOG_DIR=/my_project/running_logs

This creates a structured directory with one file per actor, organized by role and node:

running_logs/
  training/
    10.0.0.1/
      training_g0_rank0_20260301_080012.log
      training_g0_rank1_20260301_080012.log
    10.0.0.2/
      training_g0_rank2_20260301_080013.log
  inference/
    10.0.0.1/
      inference_g0_rank0_20260301_080014.log
    10.0.0.2/
      inference_g0_rank1_20260301_080015.log

The path must be an absolute path on a shared filesystem (NFS) accessible from all nodes. If TORCHSPEC_LOG_DIR is not set or the path is not writable, per-rank file logging is disabled and only Ray's default stdout/stderr capture is used.

Issue Reference
Stuck or failing distributed runs, Ray actor errors docs/debugging_ray_jobs.md
Ray cluster setup, actor hierarchy, placement groups docs/ray.md
Pipeline bottlenecks, slow steps, throughput analysis docs/performance_metrics.md

About

A PyTorch native library for training speculative decoding models

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages