COAP: Memory-Efficient Training with Correlation-Aware Gradient Projection

1ByteDance, 2Rutgers University
*Work done during internship at ByteDance

Comparison between COAP and other low-rank-based methods. The X-axis shows additional training time, with lower values being better. The Y-axis shows quantitative (e.g., FID, PPL) changes compared to the original optimizer (e.g., Adam, Adafactor) with higher values indicating better performance.

Abstract

Training large-scale neural networks in vision, and multimodal domains demands substantial memory resources, primarily due to the storage of optimizer states. While LoRA, a popular parameter-efficient method, reduces memory usage, it often suffers from suboptimal performance due to the constraints of low-rank updates. Low-rank gradient projection methods (e.g., GaLore, Flora) reduce optimizer memory by projecting gradients and moment estimates into low-rank spaces via singular value decomposition or random projection. However, they fail to account for inter-projection correlation, causing performance degradation, and their projection strategies often incur high computational costs. In this paper, we present COAP (COrrelation-Aware Gradient Projection), a memory-efficient method that minimizes computational overhead while maintaining training performance. Evaluated across various vision, language, and multimodal tasks, COAP outperforms existing methods in both training speed and model performance. For LLaMA-1B, it reduces optimizer memory by 61% with only 2% additional time cost, achieving the same PPL as AdamW. With 8-bit quantization, COAP cuts optimizer memory by 81% and achieves 4x speedup over GaLore for LLaVA-v1.5-7B fine-tuning, while delivering higher accuracy.

Qualitative Comparisons

Quantitative Comparisons

BibTeX

@article{xiao2024coap,
    title   = {COAP: Memory-Efficient Training with Correlation-Aware Gradient Projection},
    author  = {Xiao, Jinqi and Sang, Shen and Zhi, Tiancheng and Liu, Jing and Yan, Qing and Luo, Linjie and Yuan, Bo},
    journal = {arXiv preprint arXiv:2412.00071},
    year    = {2024},
    url     = {https://arxiv.org/abs/2412.00071}
}