Statistical Rates of Diffusion Transformers (NeurIPS '24)
This project is the first major project I worked on at Prof. Han Liu's MAGICS Lab (now the Center for Foundation Models and Generative AI). I was very fortunate to have the mentorship of Northwestern PhD candidates Jerry Han and Weimin Wu, who helped me navigate the latent manifold of sanity in an otherwise terrifyingly new ambient space. This project was, for me, a cold plunge into the deep end of statistical learning theory - I spent many hours crossing out my confused scribbles and lamenting the lack of a 3b1b video on the Universal Approximation Theorem for transformers - but by the end of it I had a laundry list of new questions I didn't know how to solve and a stronger desire to pursue the answers. Maybe the ice bath people were onto something.
Abstract
We investigate the statistical and computational limits of latent Diffusion Transformers (DiTs) under a low-dimensional linear latent space assumption. Statistically, we derive approximation error bounds for the score network that are sub-linear in the latent space dimension, along with corresponding sample complexity bounds. Computationally, we characterize the hardness of forward inference and backward computation, identifying efficient criteria that enable almost-linear time inference and training by leveraging low-rank structure in gradient computations.
Under the low-dimensional assumption, both statistical rates and computational efficiency are dominated by the subspace dimension, suggesting that latent DiTs can bypass the challenges associated with high-dimensional data.