This a an unofficial PyTorch (Lightning) implementation of EDM Elucidating the Design Space of Diffusion-Based Generative Models and Analyzing and Improving the Training Dynamics of Diffusion Models.
- Config G.
- Post-hoc EMA.
git clone https://github.com/YichengDWu/tinyedm.git cd tinyedm && pip install . python experiments/train.py --config-name=mnist python experiments/train.py --config-name=cifar10To download the ImageNet dataset, follow these steps:
- Visit the ImageNet website: http://www.image-net.org/
- Register for an account and request access for the dataset.
- Once approved, follow the instructions provided by ImageNet to download the dataset.
After downloading the ImageNet dataset, extract the files to a directory. When running the feature extraction script, use the --data-dir option to specify the path to this directory.
For example:
python src/tinyedm/datamodules/extract_latents.py --data-dir ./datasets/imagenet/train --out-dir ./datasets/imagenet/latents/trainpython src/tinyedm/generate.py \ --ckpt_path /path/to/checkpoint.ckpt \ --load_ema \ --output_dir /path/to/output \ --num_samples 50000 \ --image_size 32 \ --num_classes 10 \ --batch_size 128 \ --num_workers 16 \ --num_steps 32| Dataset | Params | type | epochs | FID |
|---|---|---|---|---|
| CIFAR-10 | 35.6 M | unconditional | 1700 | 4.0 |
- Using FP16 mixed precision training on the CIFAR-10 dataset sometimes leads to overflow, so we have adopted bf16 mixed precision, which may result in a loss of accuracy for the model.
- For the scale factors of skip connections, this implementation uses a small network to learn them, inspired by ScaleLong: Towards More Stable Training of Diffusion Model via Scaling Network Long Skip Connection . The experiment shows that this improves the results.
- The use of multi-task learning in the paper did not observe any improvement, or it may be more effective in long-term training. However, I do not have the compute power to verify this.