| | --- |
| | license: mit |
| | pipeline_tag: image-feature-extraction |
| | --- |
| | # Masked Autoencoder (MAE) for Medical Imaging |
| |
|
| | A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset. |
| |
|
| | ## π Overview |
| |
|
| | This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data. |
| |
|
| | ### Key Features |
| |
|
| | - **Vision Transformer Architecture**: Encoder-decoder transformer architecture with positional encodings |
| | - **Self-Supervised Learning**: Pre-training through masked image reconstruction |
| | - **Optimized for Medical Imaging**: Designed specifically for chest X-ray analysis |
| | - **Production-Ready Training Pipeline**: |
| | - Mixed precision training (FP16) with gradient scaling |
| | - Gradient accumulation support |
| | - Learning rate warmup and cosine annealing |
| | - Automatic checkpointing and resumption |
| | - **Efficient Data Loading**: |
| | - Optimized ZIP file reader with LRU caching |
| | - Class-balanced sampling with weighted random sampler |
| | - Multi-worker data loading with persistent workers |
| | - **Comprehensive Logging**: Training/validation metrics tracking and visualization |
| |
|
| | ## ποΈ Architecture |
| |
|
| | ### Masked Autoencoder Structure |
| |
|
| | ``` |
| | Input Image (384Γ384) |
| | β |
| | Patchify (16Γ16 patches β 576 patches) |
| | β |
| | Random Masking (75% masked, 25% visible) |
| | β |
| | βββββββββββββββββββββββββββββββββββββββ |
| | β MAE ENCODER β |
| | β - Linear patch embedding β |
| | β - Positional encoding (visible) β |
| | β - 12 Transformer blocks β |
| | β - 8 attention heads, 768 hidden β |
| | βββββββββββββββββββββββββββββββββββββββ |
| | β |
| | βββββββββββββββββββββββββββββββββββββββ |
| | β MAE DECODER β |
| | β - Learnable mask tokens β |
| | β - Positional encoding (all) β |
| | β - 8 Transformer blocks β |
| | β - 8 attention heads, 512 hidden β |
| | β - Pixel reconstruction head β |
| | βββββββββββββββββββββββββββββββββββββββ |
| | β |
| | Reconstructed Image |
| | β |
| | MSE Loss (on masked patches only) |
| | ``` |
| |
|
| | ### Model Configuration |
| |
|
| | | Parameter | Default Value | Description | |
| | |-----------|---------------|-------------| |
| | | Image Size | 384Γ384 | Input image resolution | |
| | | Patch Size | 16Γ16 | Size of each patch | |
| | | Mask Ratio | 0.75 | Fraction of patches to mask | |
| | | Encoder Depth | 12 layers | Number of transformer blocks | |
| | | Encoder Dim | 768 | Hidden dimension | |
| | | Encoder Heads | 8 | Number of attention heads | |
| | | Decoder Depth | 8 layers | Number of transformer blocks | |
| | | Decoder Dim | 512 | Hidden dimension | |
| | | Decoder Heads | 8 | Number of attention heads | |
| | | MLP Ratio | 4Γ | MLP expansion ratio (3072) | |
| | | Dropout | 0.25 | Dropout rate | |
| |
|
| | ## π Getting Started |
| |
|
| | ### Prerequisites |
| |
|
| | - Python >= 3.8 |
| | - CUDA-capable GPU (recommended) |
| | - 16GB+ RAM |
| |
|
| | ### Installation |
| |
|
| | 1. Clone the repository: |
| | ```bash |
| | git clone https://github.com/adelelsayed/mae.git |
| | cd mae |
| | ``` |
| |
|
| | 2. Install dependencies: |
| | ```bash |
| | pip install -r requirements.txt |
| | ``` |
| |
|
| | ### Dataset Preparation |
| |
|
| | This project is configured for the **CheXpert dataset**. To use it: |
| |
|
| | 1. Download CheXpert-v1.0-small from [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/) |
| | 2. Update paths in `configs/configs.py`: |
| | - `root`: Base directory for your data |
| | - `zip_path`: Path to zipped dataset (optional, for faster loading) |
| | - `csv`: Path to training CSV |
| | - `train_csv`, `val_csv`, `test_csv`: Split CSV files |
| |
|
| | ## π Usage |
| |
|
| | ### Training |
| |
|
| | Start training from scratch: |
| | ```bash |
| | python trainer/trainer.py |
| | ``` |
| |
|
| | The trainer will: |
| | - Automatically create checkpoint and log directories |
| | - Resume from the last checkpoint if available |
| | - Log training/validation metrics to text files |
| | - Save plots every 10 epochs |
| | - Save best model based on validation loss |
| |
|
| | ### Training Configuration |
| |
|
| | Edit `configs/configs.py` to customize training: |
| |
|
| | ```python |
| | mae_config = { |
| | # Training hyperparameters |
| | "lr": 1e-4, # Learning rate |
| | "warmup": 5, # Warmup epochs |
| | "weight_decay": 5e-4, # AdamW weight decay |
| | "num_epochs": 200, # Total training epochs |
| | "batch_size": 96, # Batch size |
| | "accumulation": 1, # Gradient accumulation steps |
| | |
| | # Model architecture |
| | "mask_ratio": 0.75, # Masking ratio |
| | "encoder_depth": 12, # Encoder layers |
| | "decoder_depth": 8, # Decoder layers |
| | |
| | # Paths |
| | "checkpoints": "/path/to/checkpoints", |
| | "logdir": "/path/to/logs", |
| | ... |
| | } |
| | ``` |
| |
|
| | ### Monitoring Training |
| |
|
| | Training logs are saved in three files: |
| | - `training_log.txt`: Training metrics per epoch |
| | - `val_log.txt`: Validation metrics per epoch |
| | - `test_log.txt`: Test set evaluation results |
| |
|
| | Metrics plots are saved every 10 epochs in `{logdir}/{epoch}/metrics.png` |
| |
|
| | ### Evaluation |
| |
|
| | The project includes a test method in the trainer. To evaluate: |
| | ```python |
| | from trainer.utils import MAETrainer |
| | from configs.configs import mae_config |
| | |
| | trainer = MAETrainer(mae_config) |
| | trainer.test() |
| | ``` |
| |
|
| | ## π Project Structure |
| |
|
| | ``` |
| | mae/ |
| | βββ configs/ |
| | β βββ __init__.py |
| | β βββ configs.py # Training configuration |
| | βββ data/ |
| | β βββ __init__.py |
| | β βββ dataset.py # CheXpert dataset loader |
| | β βββ splitter.py # Dataset splitting utilities |
| | βββ loss/ |
| | β βββ __init__.py |
| | β βββ mae_loss.py # MAE reconstruction loss |
| | βββ models/ |
| | β βββ __init__.py |
| | β βββ mae.py # MAE architecture |
| | βββ trainer/ |
| | β βββ __init__.py |
| | β βββ trainer.py # Main training script |
| | β βββ utils.py # Training utilities |
| | βββ notebooks/ |
| | β βββ chexpert_mae.ipynb # Jupyter notebook for experiments |
| | βββ training logs/ # Logged metrics and plots |
| | βββ weights/ # Model checkpoints |
| | βββ results/ # Evaluation results |
| | βββ requirements.txt # Python dependencies |
| | βββ LICENSE # Project license |
| | βββ README.md # This file |
| | ``` |
| |
|
| | ## π§ Components |
| |
|
| | ### Dataset (`data/dataset.py`) |
| |
|
| | - **OptimizedZipReader**: Fast ZIP file reading with LRU caching |
| | - **CheXpertDataset**: PyTorch dataset for CheXpert chest X-rays |
| | - 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc. |
| | - Albumentations-based augmentation pipeline |
| | - Class-balanced sampling support |
| | - Frontal/lateral view filtering |
| |
|
| | ### Model (`models/mae.py`) |
| |
|
| | - **Patchify/Unpatchify**: Image-to-patch conversion utilities |
| | - **Random Masking**: Stochastic patch masking with restore indices |
| | - **PositionalEncoding**: Learnable position embeddings |
| | - **TransformerBlock**: Multi-head self-attention + MLP |
| | - **MAEEncoder**: Processes visible patches only |
| | - **MAEDecoder**: Reconstructs full image with mask tokens |
| | - **MaskedAutoEncoder**: Complete MAE model |
| |
|
| | ### Loss (`loss/mae_loss.py`) |
| | |
| | Mean Squared Error (MSE) computed only on masked patches: |
| | ```python |
| | loss = ((pred - target) ** 2 * mask).sum() / mask.sum() |
| | ``` |
| | |
| | ### Trainer (`trainer/utils.py`) |
| | |
| | - **MAETrainer**: Complete training pipeline |
| | - Mixed precision training (AMP) |
| | - Gradient clipping and accumulation |
| | - Learning rate scheduling (warmup β cosine) |
| | - Automatic checkpointing |
| | - Multi-file logging (train/val/test) |
| | - Live metric monitoring with tqdm |
| | - Periodic metric visualization |
| | |
| | ## π― CheXpert Pathologies |
| | |
| | The dataset includes 14 chest X-ray findings: |
| | |
| | 1. No Finding |
| | 2. Enlarged Cardiomediastinum |
| | 3. Cardiomegaly |
| | 4. Lung Opacity |
| | 5. Lung Lesion |
| | 6. Edema |
| | 7. Consolidation |
| | 8. Pneumonia |
| | 9. Atelectasis |
| | 10. Pneumothorax |
| | 11. Pleural Effusion |
| | 12. Pleural Other |
| | 13. Fracture |
| | 14. Support Devices |
| | |
| | ## π Training Tips |
| | |
| | 1. **Learning Rate**: Start with 1e-4, use warmup for stability |
| | 2. **Batch Size**: Maximize based on GPU memory (96 works well on 40GB GPUs) |
| | 3. **Gradient Accumulation**: Use if batch size is limited by memory |
| | 4. **Mixed Precision**: Enabled by default for faster training |
| | 5. **Masking Ratio**: 75% is standard, higher ratios increase difficulty |
| | 6. **Resume Training**: Model automatically resumes from last checkpoint |
| | |
| | ## π¬ Use Cases |
| | |
| | ### Pre-training for Downstream Tasks |
| | Use the trained encoder as a feature extractor: |
| | ```python |
| | from models.mae import MaskedAutoEncoder |
| | |
| | # Load pre-trained model |
| | mae = MaskedAutoEncoder() |
| | mae.load_state_dict(torch.load("best_mae.pth")["model"]) |
| |
|
| | # Use encoder for feature extraction |
| | encoder = mae.encoder |
| | features, _, _, _ = encoder(images) |
| | ``` |
| | |
| | ### Fine-tuning on Classification |
| | Add a classification head to the encoder for supervised tasks. |
| | |
| | ### Anomaly Detection |
| | Reconstruction error can indicate abnormalities in medical images. |
| | |
| | ## π Performance Optimization |
| | |
| | This implementation includes several optimizations: |
| | |
| | - **Efficient ZIP Reading**: Avoids extracting files to disk |
| | - **LRU Cache**: Keeps frequently accessed images in memory |
| | - **Persistent Workers**: Reduces data loading overhead |
| | - **Mixed Precision**: 2Γ faster training with minimal quality loss |
| | - **Gradient Checkpointing**: Reduces memory usage (if enabled) |
| | - **CUDA Memory Management**: Proper cache clearing and synchronization |
| | |
| | ## π€ Contributing |
| | |
| | Contributions are welcome! Please feel free to submit a Pull Request. |
| | |
| | ## π License |
| | |
| | This project is licensed under the terms specified in the LICENSE file. |
| | |
| | ## π References |
| | |
| | 1. **Masked Autoencoders Are Scalable Vision Learners** |
| | He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022) |
| | [arXiv:2111.06377](https://arxiv.org/abs/2111.06377) |
| | |
| | 2. **CheXpert: A Large Chest Radiograph Dataset** |
| | Irvin, J., et al. (2019) |
| | [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/) |
| | |
| | ## π Acknowledgments |
| | |
| | - Original MAE paper by Meta AI Research |
| | - CheXpert dataset by Stanford ML Group |
| | - PyTorch and Albumentations communities |
| | |
| | ## π§ Contact |
| | |
| | For questions or issues, please open an issue on GitHub or contact the maintainer. |
| | |
| | --- |
| | |
| | **Note**: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance. |