This is the official repository for "Efficient Neuron Segmentation in Electron Microscopy by Affinity-Guided Queries", an ICLR 2025 paper presenting a novel approach for neuron segmentation in electron microscopy images.
Accurate segmentation of neurons in electron microscopy (EM) images plays a crucial role in understanding the intricate wiring patterns of the brain. Existing automatic neuron segmentation methods rely on traditional clustering algorithms, where affinities are predicted first, and then watershed and post-processing algorithms are applied to yield segmentation results. Due to the nature of watershed algorithm, this paradigm has deficiency in both prediction quality and speed.
Inspired by recent advances in natural image segmentation, we propose to use query-based methods to address the problem because they do not necessitate watershed algorithms. However, we find that directly applying existing query-based methods faces great challenges due to the large memory requirement of the 3D data and considerably different morphology of neurons.
To tackle these challenges, we introduce affinity-guided queries and integrate them into a lightweight query-based framework. Specifically:
- We first predict affinities with a lightweight branch, which provides coarse neuron structure information
- The affinities are then used to construct affinity-guided queries, facilitating segmentation with bottom-up cues
- These queries, along with additional learnable queries, interact with the image features to directly predict the final segmentation results
Experiments on benchmark datasets demonstrated that our method achieved better results over state-of-the-art methods with a 2∼3× speedup in inference.
# Clone the repository
git clone https://github.com/chenhang98/AGQ.git
cd AGQ
# Install dependencies
conda create -n agq Python==3.8
conda activate agq
# install pytorch (assuming CUDA 11.1, refer to https://pytorch.org/get-started/previous-versions for other cuda versions)
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
pip install -v -e ./submodules/cremi_python
pip install -v -e .This repository is based on pytorch_connectomics. Please refer to pytorch_connectomics for more details about the underlying framework.
Download the AC3-AC4 dataset from dataset and organize the files as follows:
datasets/
└── AC3-AC4/
├── AC3_inputs.h5 # Test input volume (100×1024×1024)
├── AC3_labels.h5 # Test ground truth labels
├── AC4_train_inputs.h5 # Training input volume
├── AC4_train_labels.h5 # Training ground truth labels
├── AC4_val_inputs.h5 # Validation input volume (not used)
└── AC4_val_labels.h5 # Validation ground truth labels (not used)- Training Data: AC4 dataset for model training
- Test Data: AC3 dataset for inference and evaluation
Run inference on the provided test data:
bash scripts/inference.shThe evaluation results will be printed and the final segmentation results will be saved as an int-type ndarray in outputs/inference/mask_int_merged.pkl.
Expected Output:
voi_split: 0.681 | voi_merge: 0.267 | voi_sum: 0.947 | adapted_RAND: 0.089Pre-trained Model: The model checkpoint is available at checkpoints/checkpoint_200000.pth.tar.
To train the model from scratch:
bash scripts/train.shNote: Training requires 8 GPUs by default. Modify NUM_GPUS in scripts/train.sh according to your hardware setup.
AGQ/
├── projects/AGQ/ # Main AGQ implementation
│ ├── configs/ # Configuration files
│ ├── model/ # Model architecture
│ ├── loss/ # Loss functions
│ └── main.py # Training and inference entry point
├── scripts/ # Utility scripts
│ ├── inference.sh # Inference script
│ ├── train.sh # Training script
│ ├── inference_mp.py # Multi-processing inference
│ └── concat_merge_eval.py # Result merging and evaluation
├── datasets/ # Dataset directory
├── outputs/ # Output directory
├── checkpoints/ # Model checkpoints
└── requirements.txt # Python dependencies
The model can be configured through YAML files in projects/AGQ/configs/:
AGQ.yaml: Main configuration for AGQ modelSNEMI-Base.yaml: Base configuration for SNEMI3D dataset
Key configuration parameters include:
MODEL.ARCHITECTURE: Model architecture (AGQ)SOLVER.BASE_LR: Learning rateINFERENCE.INPUT_SIZE: Input volume size for inference
Our method achieves state-of-the-art performance on benchmark datasets:
| Metric | AGQ | PEA |
|---|---|---|
| VOI Split | 0.681 | 0.852 |
| VOI Merge | 0.267 | 0.232 |
| VOI Sum | 0.947 | 1.084 |
| Adapted RAND | 0.089 | 0.094 |
| Inference Speed | 2-3× faster | Baseline |
We welcome contributions! Please see our CONTRIBUTING.md for guidelines on how to contribute to this project.
This project is licensed under the MIT License - see the LICENSE file for details.
We would like to thank pytorch_connectomics for providing the codebase that this project is built upon.
If you find this project useful in your research, please consider citing:
@inproceedings{AGQ,
author = {Hang Chen and Chufeng Tang and Xiao Li and Xiaolin Hu},
title = {Efficient Neuron Segmentation in Electron Microscopy by Affinity-Guided Queries},
booktitle = {The Thirteenth International Conference on Learning Representations, {ICLR}},
year = {2025}
}For questions and discussions about this work, please open an issue on GitHub or contact the authors.
