Upgrade to Pro — share decks privately, control downloads, hide ads and more …

ICLR2024: Reading "Training Unbiased Diffusion ...

htakagi
July 31, 2024

ICLR2024: Reading "Training Unbiased Diffusion Models From Biased Dataset"

This is my trial post using SpeakerDeck:)
I recently read an intriguing paper on domain adaptation in diffusion models, and I created a slide deck to share my personal interpretation of the paper!

The original paper can be found at the following URL:
https://openreview.net/forum?id=39cPKijBed

htakagi

July 31, 2024
Tweet

More Decks by htakagi

Other Decks in Research

Transcript

  1. What & Why ‣What ‣ Adaptation of Denoising Score Matching

    (DSM) to limited unbiased data ‣ Reweighting approach with theoretical guarantee ‣Why ‣ Dealing with the recently emerged risk that generative AI can be biased by training dataset 2
  2. Background ‣ Dataset could be biased due to social, geographical,

    and physical factors ‣ Generative models trained on the biased dataset converge to the biased distribution ‣ They could produce outputs that have potential risk → Mitigating latent bias is a key factor in improving sample quality and proportion 3 CelebFaces Attributes Dataset
  3. Problem Setup 4 ‣ True data distributionɹɹɹɹɹɹɹɹɹɹɹɹɹ ‣ Unknown biased

    distribution ɹɹɹɹɹɹɹɹɹɹ ‣ Accessible data consists of two sets: ‣ Each element in is i.i.d. samples from ‣ Each element in is i.i.d. samples from ‣ is relatively smaller than This weak supervision setting was proposed in Choi et al. (2020) for GANs; the explicit bias is not provided while the origin of instances is known to be either or pdata : 𝒳 → ℝ≥0 pbias : 𝒳 → ℝ≥0 𝒟 obs ⊂ 𝒳 n 𝒟 obs = 𝒟 bias ∪ 𝒟 ref 𝒟 bias pbias 𝒟 ref pdata | 𝒟 ref | | 𝒟 bias | 𝒟 bias 𝒟 ref
  4. Problem Setup 5 ‣ True data distributionɹɹɹɹɹɹɹɹɹɹɹɹɹ ‣ Unknown biased

    distribution ɹɹɹɹɹɹɹɹɹɹ ‣ Accessible data consists of two sets: ‣ Each element in is i.i.d. samples from ‣ Each element in is i.i.d. samples from ‣ is relatively smaller than This weak supervision setting was proposed in Choi et al. (2020) for GANs; the explicit bias is not provided while the origin of instances is known to be either or pdata : 𝒳 → ℝ≥0 pbias : 𝒳 → ℝ≥0 𝒟 obs ⊂ 𝒳 n 𝒟 obs = 𝒟 bias ∪ 𝒟 ref 𝒟 bias pbias 𝒟 ref pdata | 𝒟 ref | | 𝒟 bias | 𝒟 bias 𝒟 ref ↑ Collect as many data sets as you can, even if there is bias ↑ Check by human power ↑ The underlying bias factor is left unannotated
  5. Di ff usion model ‣ Forward process
 perturbs into ‣

    Reverse process
 
 transform random noise to ‣ Score-matching (Song & Ermon, 2019) uses score network 
 trained by minimizing the objective function ‣ Both are equivalent with respect to the model parameter ‣ DSM is tractable and widely used x0 ∼ pdata xT xT x0 sθ (xt , t) ≈ ∇log pt data (xt ) θ 6 drift term 
 di ff usion term 
 standard Wiener process 
 when time fl ows backward 
 p.d.f. of f( ⋅ , t) : 𝒳 → 𝒳 g( ⋅ ) : ℝ → ℝ wt ¯ wt xt pt data (xt ) temporal weighting function 
 (noise scheduling) λ : [0,T] → ℝ+ ℓdsm(θ, x0 )
  6. ‣ We can obtain unbiased di ff usion model by

    common reweighting technique if we know density ratio. For , we denote the density ratio as Then, by minimizing the objective function, where the expectation can be empirically calculated by the abundant and the known , the unbiased DSM on can be obtained. x0 ∼ 𝒟 obs 𝒟 bias w(x0 ) pdata 7 Importance reweighting for Unbiased Model w(x0 ) := pdata(x0 ) pbias(x0 ) ℒDSM(θ; pdata) = 1 2 ∫ T 0 𝔼 pdata(x0 ) [ℓdsm(θ, x0 )]dt = 1 2 ∫ T 0 𝔼 pbias(x0 ) [wϕ (x0 )ℓdsm(θ, x0 )]dt
  7. ‣ Density Ratio Estimation (DRE) is calculated through discriminative training

    We set pseudo label
 if the sample comes from ,
 if the sample comes from , and train discriminator 
 by minimizing binary cross entropy (BCE) loss, The true density ratio can be calculated y = 0 𝒟 bias y = 1 𝒟 ref dϕ : 𝒳 → [0,1] 8 Importance reweighting for Unbiased Model ϕ * = arg min ϕ [ 𝔼 pdata(x0 ) [−log dϕ (x0 )] + 𝔼 pbias(x0 ) [−log(1 − dϕ (x0 ))]] wϕ* (x0 ) = pdata(x0 ) pbias(x0 ) = p(x0 |y = 1) p(x0 |y = 0) = p(y = 0)p(y = 1|x0 ) p(y = 1)p(y = 0|x0 ) = dϕ* (x0 ) 1 − dϕ* (x0 )
  8. Algorithm of Importance reWeighting DSM 1. Train discriminator by empirically

    minimizing BCE loss 2. De fi ne Calculate density ratio 3. Train score network on , by minimizing sθ (xt , t) 𝒟 obs = 𝒟 bias ∪ 𝒟 ref 9 ℒIW_DSM(θ) = 1 2 ∫ T 0 𝔼 pobs(x0 ) [ ˜ wϕ* (x0 )ℓdsm(θ, x0 )]dt ϕ * = arg min ϕ [ ̂ 𝔼 pdata(x0 ) [−log dϕ (x0 )] + ̂ 𝔼 pbias(x0 ) [−log(1 − dϕ (x0 ))]] IW-DSM ˜ wϕ* (x0 ) = pdata(x0 ) pobs(x0 ) = 2wϕ* (x0 ) 1 + wϕ* (x0 ) = 2dϕ* (x0 ) pobs(x0 ) = 1 2 pbias(x0 ) + 1 2 pdata(x0 ) In each mini batch, 
 the half is sampled from 
 while the others is from 𝒟 bias 𝒟 ref
  9. Algorithm of Importance reWeighting DSM 1. Train discriminator by empirically

    minimizing BCE loss 2. De fi ne Calculate density ratio 3. Train score network on , by minimizing sθ (xt , t) 𝒟 obs = 𝒟 bias ∪ 𝒟 ref 10 ϕ * = arg min ϕ [ ̂ 𝔼 pdata(x0 ) [−log dϕ (x0 )] + ̂ 𝔼 pbias(x0 ) [−log(1 − dϕ (x0 ))]] IW-DSM pobs(x0 ) = 1 2 pbias(x0 ) + 1 2 pdata(x0 ) In each mini batch, 
 the half is sampled from 
 while the others is from 𝒟 bias 𝒟 ref ✓In many cases, DRE su ff ers from estimation Error ℒIW_DSM(θ) = 1 2 ∫ T 0 𝔼 pobs(x0 ) [ ˜ wϕ* (x0 )ℓdsm(θ, x0 )]dt ˜ wϕ* (x0 ) = pdata(x0 ) pobs(x0 ) = 2wϕ* (x0 ) 1 + wϕ* (x0 ) = 2dϕ* (x0 )
  10. Time-depend Importance Reweighting ‣ DRE su ff ers from estimation

    errors due to the density-chasm problem ‣ The distance between two distributions is far ‣ The number of samples from two distributions is small
 ✓When the KL-divergence between two distributions is tens of nats,
 it’s easy to train discriminator which separates samples well
 → Makes BCE much small even if the estimated ratio is still inaccurate 11 (Rhodes et al., 2020)
  11. Time-depend Importance Reweighting When we train di ff usion model…

    ‣ DRE su ff ers from estimation errors due to the density-chasm problem ‣ The distance between two distributions is far ‣ Real-world image datasets are in high dimensions ‣ The number of samples from two distributions is small ‣ The number of reference data would be small | 𝒟 ref | 12 (Rhodes et al., 2020)
  12. Time-depend Importance Reweighting Regarding perturbed and , as becomes larger…

    ‣ DRE su ff ers from estimation errors due to the density-chasm problem ‣ The distance between two distributions is far ‣ Two distributions becomes closer ‣ The number of samples from two distributions is small ‣ Monte Carlo error decreases by sampling 
 from simpler each distribution ✓True density ratio converges to 1, which is easy to estimate, mitigating the problem 13 (Rhodes et al., 2020) pt bias(xt ) pt data(xt ) t ↓ + Perturbation
  13. Time-depend Importance Reweighting ‣ Time-dependent density ratio ‣ Represented by

    a time-dependent discriminator ‣ Trained by temporally weighted binary cross-entropy (T-BCE) 14 wt ϕ (x0 ) dϕ : 𝒳 × [0,T] → [0,1] ✓While (time-indipendent) density ratio on su ff ers from the estimation error, the MSE decreases rapidly as increases. In this example, the accumulated error area integrated by is reduced by about 60% for the time-dependent estimation, which we want to take advantage of. t t x0
  14. TIW-DSM ‣ Time-depend reweighting by time-depend density ‣ Score on

    is still di ff i cult to estimate ‣ Modi fi cation with both reweighting and score correction is tractable and
 has mathematical validity, converging to unbiased distribution pdata 15
  15. TIW-DSM Algorithm of Time-depend Importance reWeighting DSM 1. Train time-depend

    discriminator by empirically minimizing T-BCE loss 2. De fi ne Calculate time-depend density ratio 3. Train score network on , by minimizing sθ (xt , t) 𝒟 obs = 𝒟 bias ∪ 𝒟 ref 16 ˜ wt ϕ* (xt ) = pdata(xt ) pobs(xt ) = 2wt ϕ* (xt ) 1 + wt ϕ* (xt ) = 2dϕ* (xt , t) pobs(x0 ) = 1 2 pbias(x0 ) + 1 2 pdata(x0 ) In each mini batch, 
 the half is sampled from 
 while the others is from 𝒟 bias 𝒟 ref ϕ * = arg min ϕ ℒT-BCE(ϕ; pdata, pbias)
  16. Experiments ‣ Dataset ( ) ‣ CIFAR-10-LT / CIFAR-10 ‣

    CIFAR-100-LT / CIFAR-100 ‣ Reference size: ‣ Metric: FID ‣ Smaller is better ‣ Using VGG-13, DenseNet pre-trained on ‣ IW-DSM sometimes performs worse than plain DSM trained on ,
 while TIW-DSM always performs well 𝒟 bias/ 𝒟 ref 𝒟 ref 𝒟 obs 18 | 𝒟 ref | | 𝒟 bias |
  17. Experiments ‣ Dataset ( ) ‣ FFHQ sampled with the

    portion of females as 80% or 90% / FFHQ ‣ CelebA / CelebA sampled with the unbiased portion of gender and hair color ‣ TIW-SDM is better regarding FID, and generates more minorities in ‣ Comparing images generated from the same seed in DSM(obs) and TIW-DSM, 
 they have changed to people with di ff erent attributes 𝒟 bias/ 𝒟 ref 𝒟 obs 19 F: Female M: Male NB: Non-Black Hair B: Black-Hair DSM(obs) ↓ TIW-DSM
  18. Experiments ‣ Observations on the e ff ects of reweighting

    ‣ Weight can be scaled by using : ‣ : no reweighting, equivalent to DSM(obs) ‣ Plain DSM and IW-DSM su ff er trade-o ff between bias and quality, while TIW- DSM training with proper weighting can improve both ‣ α ∈ ℝ≥0 w → wα α = 0 20
  19. Conclusion ‣ In di ff usion model, utilizing time-depend distribution

    can mitigate estimation error of density ratio, avoiding density-chasm problem ‣ Proposed Method, TIW-DSM, using the time-dependent density ratio for reweighting as well as score correction, guarantees convergence to unbiased distribution ‣ TIW-DSM can de-bias generated samples from di ff usion model trained on imbalanced dataset, experimentally resulting in better performance ‣ (Not because of its implications as a social or political issue, but because the evaluation measure of the generative model uses latent variables of discriminators trained on balanced data) 21
  20. Reference Kim, Y., Na, B., Park, M., Jang, J., Kim,

    D., Kang, W., & Moon. Training unbiased di ff usion models from biased dataset. In The Twelfth International Conference on Learning Representations. 2024. Kristy Choi, Aditya Grover, Trisha Singh, Rui Shu, and Stefano Ermon. Fair generative modeling via weak supervision. In International Conference on Machine Learning, pp. 1887–1898. PMLR, 2020. Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32, 2019. Benjamin Rhodes, Kai Xu, and Michael U Gutmann. Telescoping density-ratio estimation. Advances in neural information processing systems, 33:4905–4916, 2020. 22