Upgrade to Pro — share decks privately, control downloads, hide ads and more …

On the principle of Invariant Risk Minimization

On the principle of Invariant Risk Minimization

Masanari Kimura

May 27, 2023
Tweet

More Decks by Masanari Kimura

Other Decks in Research

Transcript

  1. On the principle of Invariant Risk Minimization Created by: Masanari

    Kimura Institute: The Graduate University for Advanced Studies, SOKENDAI Dept: Department of Statistical Science, School of Multidisciplinary Sciences E-mail: [email protected] X E L ATEXed on 2023/05/26
  2. Table of contents Invariant Risk Minimization and its variants Invariant

    Risk Minimization Definitions of IRM and IRMv1 Connection to causality Learning theory of IRM Variants of IRM Limitations of Invariant Risk Minimization The difficulties of IRM The optimization dilemma in IRM Conclusion output.tex 1 ʢ 24
  3. Propblem setting: Out-Of-Distribution Generalization (OOD) [6] ⊚ We consider datasets

    𝐷𝑒 ≔ {(𝑥𝑒 𝑖 , 𝑦𝑒 𝑖 )}𝑛𝑒 𝑖=1 collected under multiple training environments 𝑒 ∈ E𝑡𝑟 . ⊚ The dataset 𝐷𝑒 generated from some 𝑝𝑒(𝑥, 𝑦) under i.i.d. assumption. ⊚ Our goal is to obtain 𝑓 (𝑥) ∶ X → Y which minimizes 𝑅OOD(𝑓 ) ≔ max 𝑒∈E𝑎𝑙𝑙 𝑅𝑒(𝑓 ) ≔ max 𝑒∈E𝑎𝑙𝑙 𝔼𝑝𝑒 (𝑥,𝑦) [ℓ(𝑓 (𝑥), 𝑦)], (1) for E𝑎𝑙𝑙 ⊃ E𝑡𝑟. output.tex 3 ʢ 24
  4. Definition (Invariant predictor[1]) We say that a data representation Φ

    ∶ X → H elicits an invariant predictor ̂ 𝛽 ∘ Φ across environments E if there is a classifier ̂ 𝛽 ∶ H → Y simultaneously optimal for all environments, that is, ̂ 𝛽 ∈ argmin 𝛽∶H→Y 𝑅𝑒(𝛽 ∘ Φ) for all 𝑒 ∈ E. Goal Learning invariant representation Φ such that optimal classifier ̂ 𝛽 is identical for all environments 𝑒 ∈ E. 𝔼𝑝𝑒 (𝑥,𝑦) [𝑦|Φ(𝑥) = ℎ] = 𝔼𝑝𝑒′ (𝑥,𝑦) [𝑦|Φ(𝑥) = ℎ], ∀𝑒, 𝑒′ ∈ E. (2) output.tex 4 ʢ 24
  5. Invariant Risk Minimization (IRM) Definition (IRM [1]) min Φ∶X→H ̂

    𝛽∶H→Y ∑ 𝑒∈E𝑡𝑟 𝑅𝑒( ̂ 𝛽 ∘ Φ), (3) 𝑠.𝑡. ̂ 𝛽 ∈ argmin 𝛽∶H→Y 𝑅𝑒(𝛽 ∘ Φ), ∀𝑒 ∈ E𝑡𝑟 . (4) ⊚ This bilevel program is highly non-convex and difficult to solve. ⊚ To find an approximate solution, we can consider a Langrangian form, whereby the sub-optimality w.r.t. the constraint is expressed as the squared norm of the gradients of each of the inner optimization problems. output.tex 5 ʢ 24
  6. Definition (IRMv1 [1]) min Φ∶X→Y ∑ 𝑒∈E𝑡𝑟 𝑅𝑒(Φ) + 𝜆

    ⋅ ‖∇ ̂ 𝛽 𝑅𝑒( ̂ 𝛽 ⋅ Φ)‖2 2 . (5) ⊚ Assuming the inner optimization problem is convex, achieving feasibility is equivalent to the penalty term being equal to 0. ⊚ For 𝜆 = ∞, IRMv1 is equivalent to IRM. output.tex 6 ʢ 24
  7. Connection to causality Definition (Structural Equation Model (SEM) [9, 5])

    A Structural Equation Model (SEM) C ≔ (S, N) governing the random vector 𝑋 = (𝑋1 , … , 𝑋𝑑 ) is a set of structural equations: S𝑖 ∶ 𝑋𝑖 ← 𝑓𝑖 (Pa(𝑋𝑖 ), 𝑁𝑖 ), (6) where Pa(𝑋𝑖 ) ⊆ {𝑋1 , … , 𝑋𝑑 } ⧵ {𝑋𝑖 } are called the parents of 𝑋𝑖 , and the 𝑁𝑖 are independent noise random variables. ⊚ We say that ”𝑋𝑖 causes 𝑋𝑗 if 𝑋𝑖 ∈ Pa(𝑋𝑗)”. ⊚ We call causal graph of 𝑋 to the graph obtained i) one node for each 𝑋𝑖 , ii) one edge from 𝑋𝑖 to 𝑋𝑗 if 𝑋𝑖 ∈ Pa(𝑋𝑗 ). ⊚ We assume acyclic causal graphs. output.tex 8 ʢ 24
  8. ⊚ From SEM C according to the topological ordering of

    its causal graph, we can draw samples from the observational distribution 𝑃(𝑋). ⊚ We can intervene an unique SEM in different ways, indexed by 𝑒, to obtain different but related SEMs C𝑒. Definition Consider a SEM C = (S, N). An intervention 𝑒 on C consists of replacing one or several of its structural equations to obtain an intervened SEM C𝑒 = (S𝑒, N𝑒), with structural equations S𝑒 𝑖 ∶ 𝑋𝑒 𝑖 ← 𝑓 𝑒 𝑖 (Pa(𝑋𝑒 𝑖 ), 𝑁𝑒 𝑖 ), (7) where the variable 𝑋𝑒 is intervened if S𝑖 ≠ S𝑒 𝑖 or 𝑁𝑖 ≠ 𝑁𝑒 𝑖 . output.tex 9 ʢ 24
  9. Definition Consider a SEM C governing the random vector (𝑋1

    , … , 𝑋𝑑 , 𝑌), and the learning goal of predicting 𝑌 from 𝑋. Then, the set of all environments E𝑎𝑙𝑙 (C) indexes all the interventional distributions 𝑃𝑒 (𝑋, 𝑌) = 𝑃(𝑋𝑒, 𝑌𝑒) obtainable by valid interventions 𝑒. An intervention 𝑒 ∈ E𝑎𝑙𝑙 (C) is valid as long as i) the causal graph remains acyclic; ii) 𝔼𝑃𝑒 (𝑋,𝑌) [𝑌|Pa(𝑌)] = 𝔼[𝑌|Pa(𝑌)]; iii) 𝕍[𝑌𝑒|Pa(𝑌)] remains within a finite range. ⊚ The previous definitions relate causality and invariance. ⊚ One can show that a predictor 𝛽 ∶ X → Y is invariant across E𝑎𝑙𝑙(C) iff. it attains optimal 𝑅OOD, and iff. it uses only the direct causal parents of 𝑌 to predict. output.tex 10 ʢ 24
  10. Learning theory of IRM Goal Low error and invariance across

    E𝑡𝑟 lead low error across E𝑎𝑙𝑙 . Intuition: Invariant Causal Prediction (ICP) [6] ICP recovers the target invariance as long as the i) data is Gaussian; ii) data satisfies a linear SEM; iii) data is obtained by certain types of interventions. output.tex 11 ʢ 24
  11. Assumption A set of training environments E𝑡𝑟 lie in linear

    general position of degree 𝑟 if |E𝑡𝑟 | > 𝑑 − 𝑟 + 𝑑 𝑟 for some 𝑟 ∈ ℕ, and for all non-zero 𝑥 ∈ ℝ𝑑, dim (span ({𝔼 [𝑋𝑒𝑋𝑒⊤] 𝑥 − 𝔼 [𝑋𝑒𝜖𝑒]} 𝑒∈E𝑡𝑟 )) > 𝑑 − 𝑟. (8) Theorem Assume that 𝑌𝑒 = 𝑍𝑒 1 ⋅ 𝛾 + 𝜖𝑒, 𝑍𝑒 1 ⟂ 𝜖𝑒, 𝔼[𝜖𝑒] = 0, 𝑋𝑒 = S(𝑍𝑒 1 , 𝑍𝑒 2 ). Here, 𝛾 ∈ ℝ𝑐. Assume that the 𝑍1 component of S is invertible. Let Φ ∈ ℝ𝑑×𝑑 have rank 𝑟 > 0. Then, if at least 𝑑 − 𝑟 + 𝑑 𝑟 training environments E𝑡𝑟 ⊆ E lie in linear general position of degree 𝑟, we have that Φ𝔼 [𝑋𝑒𝑋𝑒⊤] Φ⊤ ̂ 𝛽 = Φ𝔼[𝑋𝑒𝑌𝑒] (9) holds for all 𝑒 ∈ E𝑡𝑟 iff. Φ elicits the invariant predictor Φ⊤ ̂ 𝛽 for all 𝑒 ∈ E𝑎𝑙𝑙 . output.tex 12 ʢ 24
  12. Variants of IRM ⊚ Risk Extrapolation (REx) [4]; ⊚ Risk

    Variance Penalization (RVP) [10]; ⊚ Sparse Invariant Risk Minimization (SparseIRM) [11]; ⊚ Derivative Invariant Risk Minimization (DIRM) [2]; ⊚ Domain Extrapolation via Regret Minimization (RGM); ⊚ Domain Generalization using Causal Matching (MatchDG); ⊚ etc. [8], output.tex 13 ʢ 24
  13. Risk Extrapolation (REx) [4] REx For 𝛾 ∈ [0, ∞),

    𝑅V−REx (𝑓 ) ≔ 𝛾Var({𝑅1 (𝑓 ), … , 𝑅𝑚 (𝑓 )}) + ∑ 𝑒∈E𝑡𝑟 𝑅𝑒 (𝑓 ), (10) output.tex 14 ʢ 24
  14. Risk Variance Penalization (RVP) [10] RVP For 𝜆 ∈ [0,

    ∞), 𝑅RVP (𝑓 ) ≔ 𝜆√Var({𝑅1 (𝑓 ), … , 𝑅𝑚 (𝑓 )}) + ∑ 𝑒∈E𝑡𝑟 𝑅𝑒 (𝑓 ). (11) By the Slutsky’s theorem, for 𝑚 = |E|, ℙ (𝔼𝑒 [𝑅𝑒 (𝑓 )] − 𝑅RVP(𝑓 ) ≤ 0) → Φ(√𝑚𝜆). (12) Then, we can have 𝜆 = Φ−1(1 − 𝛾)/√𝑚 for some confidence interval 1 − 𝛾. output.tex 15 ʢ 24
  15. Sparse Invariant Risk Minimization (SparseIRM) [11] SparseIRM For 𝐾 ∈

    ℕ, min 𝛽,Φ,𝑚 𝑅(𝛽, 𝑚 ∘ Φ), 𝑠.𝑡. 𝑚 ∈ {0, 1}𝑑 Φ , ‖𝑚‖1 ≤ 𝐾. output.tex 16 ʢ 24
  16. Limitations of Invariant Risk Minimization ⊚ IRM fundamentally does not

    improve over ERM; ⊚ The Optimization Dilemma in IRM; output.tex 17 ʢ 24
  17. Theorem ı The Failure of IRM in the Non-Linear Regime

    [7] Suppose we observe 𝐸 environments E = {𝑒1 , … , 𝑒𝐸 }, where 𝜎2 𝐸 = 1, ∀𝑒 ∈ [1, 𝐸]. Then, for any 𝜖 > 1, there exists a featurizer Φ𝜖 which, combined with the ERM-optimal classifier ̂ 𝛽 = [𝛽𝑐 , 𝛽𝑒;𝐸𝑅𝑀 , 𝛽0 ]⊤, satisfies the following 1. The regularization term of Φ𝜖 , ̂ 𝛽 is bounded as 1 𝐸 ∑ 𝑒∈E ‖∇ ̂ 𝛽 𝑅𝑒(Φ𝜖 , ̂ 𝛽)‖ 2 2 ∈ O (𝑝2 𝜖 (𝑐𝜖 𝑑𝑒 + 1 𝐸 ∑ 𝑒∈E ‖𝜇𝑒 ‖2 2 )) , (13) for some constants 𝑐𝜖 and 𝑝𝜖 ≔ exp{−𝑑𝑒 min(𝜖 − 1, (𝜖 − 1)2/8)}. 2. Φ𝜖 , ̂ 𝛽 is equivalent to the ERM -optimal predicter on at least 1 − 𝑞 fraction of the test distribution, where 𝑞 ≔ 2𝑅 √𝜋𝛿 exp{−𝛿2}. output.tex 18 ʢ 24
  18. Here, we suppose that, for any test distribution, its environmental

    mean 𝜇𝐸+1 is sufficiently far from the training mean: ∀𝑒 ∈ E, min 𝑦∈{+1,−1} ‖𝜇𝐸+1 − 𝑦 ⋅ 𝜇𝑒‖2 ≥ (√𝜖 + 𝛿)/√𝑑𝑒 (14) for some 𝛿 > 0. This predictor we constructed will completely fail to use invariant prediction on most environments: ⊚ when large 𝛿, IRM fails to use invariant prediction on any environment that is slightly outside the high probability region of the prior. ⊚ when small 𝛿, ERM already guarantees reasonable performance at test-time; thus, IRM fundamentally does not improve over ERM in this regime. output.tex 19 ʢ 24
  19. The optimization dilemma in IRM ⊚ OOD objectives such as

    IRM usually require several relaxations for the ease of optimization, which however introduces huge gaps. ⊚ The gradient conflicts between ERM and OOD objectives generally exist for different objectives at different penalty weights. ⊚ The typically used linear weighting scheme to combine ERM and OOD objectives requires careful tuning of the weights to approach the solution. output.tex 20 ʢ 24
  20. Pareto Invariant Risk Minimization [3] When given a robust OOD

    objective 𝑅OOD, Pareto IRM aims to solve the following multi-objective optimization problem: min 𝑓 {𝑅ERM(𝑓 ), 𝑅OOD(𝑓 )}. (15) output.tex 21 ʢ 24
  21. Conclusion ⊚ IRM aims to learn invariant predictor to achieve

    OOD generalization. ⊚ There are many variants of IRM. ⊚ Several negative results for IRM are observed. output.tex 22 ʢ 24
  22. References [1] Martin Arjovsky et al. “Invariant risk minimization”. In:

    arXiv preprint arXiv:1907.02893 (2019). [2] Alexis Bellot and Mihaela van der Schaar. “Accounting for unobserved confounding in domain generalization”. In: arXiv preprint arXiv:2007.10653 (2020). [3] Yongqiang Chen et al. “Pareto Invariant Risk Minimization: Towards Mitigating the Optimization Dilemma in Out-of-Distribution Generalization”. In: The Eleventh International Conference on Learning Representations. 2023. [4] David Krueger et al. “Out-of-distribution generalization via risk extrapolation (rex)”. In: International Conference on Machine Learning. PMLR. 2021, pp. 5815–5826. [5] Judea Pearl. Causality: models, reasoning, and inference. 1980. [6] J Peters, Peter Buhlmann, and N Meinshausen. “Causal inference using invariant prediction: identification and confidence intervals. arXiv”. In: Methodology (2015). output.tex 23 ʢ 24
  23. References [7] Elan Rosenfeld, Pradeep Kumar Ravikumar, and Andrej Risteski.

    “The Risks of Invariant Risk Minimization”. In: International Conference on Learning Representations. 2021. url: https://openreview.net/forum?id=BbNIbVPJ-42. [8] Zheyan Shen et al. “Towards out-of-distribution generalization: A survey”. In: arXiv preprint arXiv:2108.13624 (2021). [9] Sewall Wright. “Correlation and causation”. In: (1921). [10] Chuanlong Xie et al. “Risk variance penalization”. In: arXiv preprint arXiv:2006.07544 (2020). [11] Xiao Zhou et al. “Sparse invariant risk minimization”. In: International Conference on Machine Learning. PMLR. 2022, pp. 27222–27244. output.tex 24 ʢ 24