TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via Mooncake (RDMA/TCP) store, allowing each side to scale independently.
TorchSpec supports two inference backends:
| Backend | Best For | Installation |
|---|---|---|
| vLLM | Flexibility, easier deployment | ./tools/build_conda.sh 1 vllm |
| SGLang | Production workloads, high throughput | ./tools/build_conda.sh 1 sglang |
| Both | Development, comparison testing | ./tools/build_conda.sh 1 both |
# Install with vLLM
./tools/build_conda.sh 1 vllm
micromamba activate torchspec
# Or install with SGLang
./tools/build_conda.sh
micromamba activate torchspecTo install into your current environment instead:
./tools/build_conda.sh current sglang # or 'vllm' or 'both'Optional — install Flash Attention:
pip install -e ".[fa]"vLLM:
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yamlSGLang:
./examples/qwen3-8b-single-node/run.shTorchSpec uses vLLM's Worker Extension mechanism to hook into the model's forward pass and capture hidden states directly in the worker processes. This avoids RPC serialization issues and enables reliable hidden states extraction.
Train an Eagle3 draft model for Qwen3-8B using inference engine (4 GPUs: 2 training + 2 inference):
./examples/qwen3-8b-single-node/run.shOverride any config value via CLI:
./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500| Example | Backend | Model |
|---|---|---|
| hf-quickstart | HuggingFace | Qwen3-8B |
| qwen3-8b-single-node | Inference Engine | Qwen3-8B |
| kimi-k25-2node-h200 | Inference Engine | Kimi-K2.5 |
| kimi-k25-3node-h100 | Inference Engine | Kimi-K2.5 |
See examples/README.md for details.
Convert an FSDP checkpoint to HuggingFace format:
python tools/convert_to_hf.py --input-dir ./outputs/my_experiment/iter_0010000/Vocabulary pruning — reducing the draft model's lm_head to a smaller token set and emitting d2t/t2d mappings — can be applied either during training (pre-pruning) or at conversion time (post-pruning):
-
Pre-pruning: set
draft_vocab_sizein your training config. The checkpoint already contains the prunedlm_headandd2t/t2dbuffers. Use the basic conversion command above — no extra flags needed. -
Post-pruning: train with the full vocabulary, then prune at conversion time by passing
--prune-vocabalong with a representative dataset to compute token frequencies:
python tools/convert_to_hf.py \
--input-dir ./outputs/my_experiment/iter_0010000/ \
--prune-vocab \
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
--draft-vocab-size 32000 \
--tokenizer Qwen/Qwen3-8B \
--chat-template qwen \
--prompt-key conversationsPass --cache-dir ./cache to reuse the tokenized dataset cache from training.
W&B logging is disabled by default (report_to: none). To enable it, set report_to: wandb in your config and supply your API key.
Set TORCHSPEC_LOG_LEVEL=DEBUG for verbose logging when diagnosing issues:
TORCHSPEC_LOG_LEVEL=DEBUG ./examples/qwen3-8b-single-node/run.shSet TORCHSPEC_LOG_DIR to an absolute path on a shared filesystem (NFS) to enable per-rank log files for every Ray actor (training and inference):
export TORCHSPEC_LOG_DIR=/my_project/running_logsThis creates a structured directory with one file per actor, organized by role and node:
running_logs/
training/
10.0.0.1/
training_g0_rank0_20260301_080012.log
training_g0_rank1_20260301_080012.log
10.0.0.2/
training_g0_rank2_20260301_080013.log
inference/
10.0.0.1/
inference_g0_rank0_20260301_080014.log
10.0.0.2/
inference_g0_rank1_20260301_080015.log
The path must be an absolute path on a shared filesystem (NFS) accessible from all nodes. If TORCHSPEC_LOG_DIR is not set or the path is not writable, per-rank file logging is disabled and only Ray's default stdout/stderr capture is used.
| Issue | Reference |
|---|---|
| Stuck or failing distributed runs, Ray actor errors | docs/debugging_ray_jobs.md |
| Ray cluster setup, actor hierarchy, placement groups | docs/ray.md |
| Pipeline bottlenecks, slow steps, throughput analysis | docs/performance_metrics.md |
