・Amortization model 償却モデル \(\hat{y}_θ\) (section 2.1)
→ Fully-amortized 完全償却(section 2.1.1): no objective access 目的関数へのアクセスなし
→ Semi-amortized 半償却(section 2.1.2): accesses objective 目的関数へのアクセスあり
・Amortization loss 償却損失 \(\mathcal{L}\) (section 2.2)
→ Regression 回帰(section 2.2.1): \(\mathbb{E}_{p(x)}||\hat{y}_θ(x)−y^⋆(x)||_2^2\)
→ Objective 目的関数(section 2.2.1): \(\mathbb{E}_{p(x)}f(\hat{y}_θ(x))\)
Figure 2.1: Overview of amortized optimization modeling and loss choices.
図 2.1: 償却最適化モデリングと損失の選択肢の概要
The machine learning, statistics, and optimization communities are exploring methods of learning to optimize to obtain fast solvers for eq. (1.1). I will refer to these methods as amortized optimization as they amortize the cost of solving the optimization problems across many contexts to approximate the solution mapping y⋆. Amortized optimization is promising because in many applications, there are significant correlations and structure between the solutions which show up in \(y^⋆\) that a model can learn. This tutorial follows Shu [2017] for defining the core foundation of amortized optimization.
機械学習、統計学、最適化のコミュニティは、式(1.1)の高速な解を得るための最適化学習手法を模索しています。これらの手法は、最適化問題を解くコストを多くのコンテキストにわたって償却し、解の写像y⋆を近似することから、償却最適化と呼びます。償却最適化は、多くの応用において、モデルが学習できる\(y^⋆\)に現れる解の間に有意な相関と構造が存在するため、有望です。このチュートリアルでは、償却最適化の中核となる基礎を定義するために、Shu [2017]に従います。
Definition 1 An amortized optimization method to solve eq. (1.1) can be represented by \(\mathcal{A}:= (f,\mathcal{Y,X},p(x),\hat{y}_θ,\mathcal{L})\), where \(f: \mathcal{Y}× \mathcal{X}→ \mathbb{R}\) is the unconstrained objective to optimize, \(\mathcal{Y}\) is the domain, \(\mathcal{X}\) is the context space, \(p(x)\) is the probability distribution over contexts to optimize, \(\hat{y}_θ: \mathcal{X} → \mathcal{Y}\) is the amortization model parameterized by θ which is learned by optimizing a loss defined on all the components \(\mathcal{L}(f,\mathcal{Y,X},p(x),\hat{y}_θ)\).
定義 1 式 (1.1) を解く償却最適化法は、\(\mathcal{A}:= (f,\mathcal{Y,X},p(x),\hat{y}_θ,\mathcal{L})\) で表すことができます。ここで、\(f: \mathcal{Y}× \mathcal{X}→ \mathbb{R}\) は最適化する制約のない目的関数、\(\mathcal{Y}\) は定義域、\(\mathcal{X}\) はコンテキスト空間、\(p(x)\) は最適化するコンテキスト上の確率分布、\(\hat{y}_θ: \mathcal{X} → \mathcal{Y}\) は、すべてのコンポーネントで定義された損失を最適化することによって学習される、θ でパラメータ化された償却モデルです。 \(\mathcal{L}(f,\mathcal{Y,X},p(x),\hat{y}_θ)\)。
The objective \(f\) and domain \(\mathcal{Y}\) arise from the problem setting along with the context space \(\mathcal{X}\) and distribution over it \(p(x)\), and the remaining definitions of the model \(\hat{y}_θ\) and loss \(\mathcal{L}\) are application-specific design decisions that sections 2.1 and 2.2
opens up. These sections present the modeling and loss foundations for the core problem in definition 1 agnostic of specific downstream applications that will use them. The key choices highlighted in chapter 2 are how much information 1) the model \(\hat{y}_θ\) has about the objective \(f\) (fully- vs. semi-amortized), and 2) the loss has about the true solution \(y^⋆\) (regression- vs. objective-based). Figure 1.2 instantiates these components for amortizing the control of a robotic system. The model \(\hat{y}_θ\) solves the solution mapping \(y^⋆\) simultaneously for all contexts. The methods here usually assume the solution mapping \(y^⋆\) to be almost-everywhere smooth and well-behaved. The best modeling approach is an open research topic as there are many tradeoffs, and many specialized insights from the application domain can significantly improve the performance. The generalization capacity along with the model’s convergence guarantees are challenging topics which section 5.2 covers in more detail.
目的関数 \(f\) と定義域 \(\mathcal{Y}\) は、問題設定、コンテキスト空間 \(\mathcal{X}\)、およびその上の分布 \(p(x)\) から生じ、モデル \(\hat{y}_θ\) と損失 \(\mathcal{L}\) の残りの定義は、セクション 2.1 と 2.2 で定義されるアプリケーション固有の設計決定です。これらのセクションでは、定義 1 の中核問題に対するモデリングと損失の基礎を、それらを使用する特定の下流アプリケーションに依存せずに示します。第 2 章で強調されている重要な選択は、1) モデル \(\hat{y}_θ\) が目的関数 \(f\) についてどれだけの情報を持っているか (完全償却 vs. 半償却)、および 2) 損失が真の解 \(y^⋆\) についてどれだけの情報を持っているか (回帰ベース vs. 目的ベース) です。図1.2は、ロボットシステムの制御を償却するためのこれらのコンポーネントをインスタンス化したものです。モデル\(\hat{y}_θ\)は、すべてのコンテキストにおいて解写像\(y^⋆\)を同時に解きます。ここでの手法では、解写像\(y^⋆\)がほぼすべての点で滑らかで良好な挙動を示すと仮定しています。最適なモデリング手法は、多くのトレードオフが存在するため、未解決の研究課題であり、応用分野における多くの専門的な知見によって性能を大幅に向上させることができます。汎化能力とモデルの収束保証は難しいトピックであり、5.2節で詳細に説明します。
Origins of the term “amortization” for optimization. The word “amortiza-tion” generally means to spread out costs and thus “amortized optimization” usually means to spread out computational costs of the optimization process. The term originated in the variational inference community for inference optimization [Kingma and Welling, 2014, Rezende et al., 2014, Stuhlmüller et al., 2013, Gershman and Goodman, 2014, Webb et al., 2018, Ravi and Beatson, 2019, Cremer et al., 2018, Wu et al., 2020], and is used more generally in Xue et al. [2020], Sercu et al. [2021], Xiao et al. [2021]. Marino [2021, p. 28] give further background on the origins and uses of amortization. Concurrent to these developments, other communities have independently developed amortization methods without referring to them by the same terminology and analysis, such as in reinforcement learning, policy optimization, and sparse coding — chapter 3 connects all of these under definition 1.
最適化における「償却」という用語の由来:「償却」という言葉は一般的にコストを分散させるという意味で、「償却最適化」は通常、最適化プロセスの計算コストを分散させることを意味します。この用語は、変分推論のコミュニティにおける推論最適化のために生まれた用語 [Kingma and Welling, 2014, Rezende et al., 2014, Stuhlmüller et al., 2013, Gershman and Goodman, 2014, Webb et al., 2018, Ravi and Beatson, 2019, Cremer et al., 2018, Wu et al., 2020] でより一般的に使用され、Xue et al. [2020]、Sercu et al. [2021]、Xiao et al. [2021]、Marino [2021, p. 28]では、償却の起源と用途についてさらに詳細な背景が説明されています。これらの発展と並行して、強化学習、方策最適化、スパース符号化など、他のコミュニティでは、同じ用語や分析法を用いることなく、独自に償却手法を開発してきました。第3章では、これらすべてを定義1で結び付けています。
Conventions and notation. The context space \(\mathcal{X}\) represents the sample space of a probability space that the distribution \(p(x\)) is defined on, assuming it is Borel if not otherwise specified. For a function \(f: \mathbb{R}^n → \mathbb{R}\) in standard Euclidean space, \(∇_xf(\overline{x})∈ \mathbb{R}^n\) denotes the gradient at a point \(\overline{x}\) and \(∇_x^2f(\overline{x}) ∈ \mathbb{R}^{n×n}\) denotes the Hessian. For \(f : \mathbb{R}^n → \mathbb{R}^m, D_xf(\overline{x}) ∈ \mathbb{R}^{m×n}\) represents the Jacobian at \8\overline{x}\) with entries
\([D_xf(\overline{x})]_{ij}:= \frac{∂fi_i}{∂x_j}(\overline{x})\). I abbreviate the loss to \(\mathcal{L}(\hat{y})\) when the other components can be inferred from the surrounding text and prefer the term “context” for \(x\) instead of “parameterization” to make the distinction between the \(x\)-parameterized optimization problem and the \(θ\)-parameterized model clear. I use “;” as separation in \(f(y;x)\) to emphasize the separation between the domain variables y that eq. (1.1) optimizes over from the context ones x that remain fixed. A model’s parameters \(θ\) are usually subscripts as \(h_θ(x)\) but I will equivalently write \(h(x;θ)\) sometimes.
規則と表記法。文脈空間\(\mathcal{X}\)は、分布\(p(x\))が定義されている確率空間の標本空間を表します。特に指定がない場合は、ボレル分布であると仮定します。標準ユークリッド空間の関数\(f: \mathbb{R}^n → \mathbb{R}\)の場合、\(∇_xf(\overline{x})∈ \mathbb{R}^n\)は点\(\overline{x}\)における勾配を表し、\(∇_x^2f(\overline{x}) ∈ \mathbb{R}^{n×n}\)はヘッセ行列を表します。 \(f : \mathbb{R}^n → \mathbb{R}^m に対して、D_xf(\overline{x}) ∈ \mathbb{R}^{m×n}\) は、\8\overline{x}\) におけるヤコビアンを表し、その要素は
\([D_xf(\overline{x})]_{ij}:= \frac{∂fi_i}{∂x_j}(\overline{x})\) です。他の要素が周囲の文章から推測できる場合は、損失を \(\mathcal{L}(\hat{y})\) と略記し、\(x\) を「パラメータ化」ではなく「コンテキスト」という用語で表すことで、\(x\) パラメータ化最適化問題と \(θ\) パラメータ化モデルとの区別を明確にしています。また、\(x\) パラメータ化モデルでは、先頭に「;」を使用します。 \(f(y;x)\) における分離として、式 (1.1) が最適化するドメイン変数 y と、固定されたコンテキスト変数 x との分離を強調します。モデルのパラメータ \(θ\) は通常、\(h_θ(x)\) のように添え字で表されますが、ここでは \(h(x;θ)\) と表記することもあります。
The model \(\hat{y}_θ(x): \mathcal{X}×Θ → \mathcal{Y}\) predicts a solution to eq. (1.1). In many applications, the best model design is an active area of research that is searching for models that are expressive and more computationally eficient than the algorithms classically used to solve the optimization problem. Section 2.1.1 starts simple with fully-amortized models that approximate the entire solution to the optimization problem with a single black-box model. Then section 2.1.2 shows how to open up the model to include more information about the optimization problem that can leverage domain knowledge with semi-amortized models.
モデル \(\hat{y}_θ(x): \mathcal{X}×Θ → \mathcal{Y}\) は式 (1.1) の解を予測します。多くの応用において、最適なモデル設計は、最適化問題を解くために従来用いられてきたアルゴリズムよりも表現力豊かで計算効率の高いモデルを探索する活発な研究分野です。セクション 2.1.1 では、最適化問題の全解を単一のブラックボックスモデルで近似する、シンプルな完全償却モデルから始めます。次に、セクション 2.1.2 では、モデルを拡張し、半償却モデルを用いてドメイン知識を活用できる最適化問題に関するより多くの情報を含める方法を示します。
Definition 2 A fully-amortized model \(\hat{y}_θ: \mathcal{X} → \mathcal{Y}\) maps the context to the solution of eq. (1.1) and does not access the objective \(f\).
定義2 完全償却モデル\(\hat{y}_θ: \mathcal{X} → \mathcal{Y}\)はコンテキストを式(1.1)の解にマッピングし、目的関数\(f\)にはアクセスしません。
I use the prefix “fully” to emphasize that the entire computation of the solution to the optimization problem is absorbed into a black-box model that does not access the objective f. The prefix “fully” can be omitted when the context is clear because most amortization is fully amortized. These are standard in amortized variational inference (section 3.1) and policy learning (section 3.6), that typically use feedforward neural networks to map from the context space \(\mathcal{X}\) to the solution of the optimization problem living in \(\mathcal{Y}\). Fully-amortized models are remarkable because they are often successfully able to predict the solution to the optimization problem in eq. (1.1) without ever accessing the objective of the optimization problem after being trained.
最適化問題の解の計算全体が、目的関数 f にアクセスしないブラックボックス モデルに吸収されることを強調するために、接頭辞「完全に」を使用しています。ほとんどの償却は完全に償却されるため、コンテキストが明確な場合は接頭辞「完全に」を省略できます。これらは、償却変分推論 (セクション 3.1) とポリシー学習 (セクション 3.6) の標準であり、通常、フィードフォワード ニューラル ネットワークを使用して、コンテキスト空間 \(\mathcal{X}\) から \(\mathcal{Y}\) に存在する最適化問題の解にマッピングします。完全に償却されたモデルは、トレーニング後に最適化問題の目的関数にアクセスすることなく、式 (1.1) の最適化問題の解を予測できることが多いため、注目に値します。
Fully-amortized models are the most useful for attaining approximate solutions that are computationally eficient. They tend to work the best when the solution mappings \(y^⋆(x)\) are predictable, the domain \(\mathcal{Y}\) is relatively small, usually hundreds or thousands of dimensions, and the context distribution isn’t too large. When fully-amortized models don’t work well, semi-amortized models help open up the black box and use information about the objective.
完全償却モデルは、計算効率の高い近似解を得るのに最も有用です。解の写像 \(y^⋆(x)\) が予測可能で、定義域 \(\mathcal{Y}\) が比較的小さく(通常は数百から数千次元)、コンテキスト分布がそれほど大きくない場合に、完全償却モデルは最も効果的に機能します。完全償却モデルがうまく機能しない場合は、半償却モデルがブラックボックスを開き、目的関数に関する情報を利用するのに役立ちます。
Definition 3 A semi-amortized model \(\hat{y}_θ: \mathcal{X}→\mathcal{Y}\) maps the context to the solution of the optimization problem and accesses the objective f of eq. (1.1), typically iteratively.
定義3 半償却モデル\(\hat{y}_θ: \mathcal{X}→\mathcal{Y}\)はコンテキストを最適化問題の解にマッピングし、通常は反復的に式(1.1)の目的関数fにアクセスします。
Kim et al. [2018], Marino et al. [2018b] proposed semi-amortized models for variational inference that add back domain knowledge of the optimization problem to the model yˆ that the fully-amortized models do not use. These are brilliant ways of integrating the optimization-based domain knowledge into the learning process. The model can now internally integrate solvers to improve the prediction. Semi-amortized methods are typically iterative and update iterates in the domain \(\mathcal{Y}\) or in an auxiliary or latent space \(\mathcal{Z}\). I refer to the space the semi-amortization iterates over as the amortization space and denote iterate \(t\) in these spaces, respectively, as \(\hat{y}_θˆt\) and \(z_θ^t\). While the iterates and final prediction \(\hat{y}_θ\) can now query the objective \(f\) and gradient \(∇_yf\), I notationally leave this dependence implicit for brevity and only reference these queries in the relevant definitions.
Kim et al. [2018]、Marino et al. [2018b] は、変分推論のための半償却モデルを提案しました。これは、完全償却モデルでは使用されない最適化問題のドメイン知識をモデル yˆ に追加します。これらは、最適化ベースのドメイン知識を学習プロセスに統合する優れた方法です。モデルは、予測を向上させるためにソルバーを内部的に統合できるようになりました。半償却法は通常、反復的で、ドメイン \(\mathcal{Y}\) または補助空間または潜在空間 \(\mathcal{Z}\) 内で反復を更新します。半償却が反復する空間を償却空間と呼び、これらの空間での反復 \(t\) をそれぞれ \(\hat{y}_θˆt\) および \(z_θ^t\) と表します。反復処理と最終予測 \(\hat{y}_θ\) は目的関数 \(f\) と勾配 \(∇_yf\) を照会できるようになりましたが、簡潔にするためにこの依存関係を暗黙的に残し、関連する定義でのみこれらのクエリを参照します。
Semi-amortized models over the domain \(\mathcal{Y}\) 領域\(\mathcal{Y}\)上の半償却モデル
\[
\hat{y}_θ^0 → \hat{y}_θ^1 → \cdots → \hat{y}_θ^K =: \hat{y}_θ(x)
\]
One of the most common semi-amortized model is to parameterize and integrate an optimization procedure used to solve eq. (1.1) into the model \(\hat{y}_θ\), such as gradient descent [Andrychowicz et al., 2016, Finn et al., 2017, Kim et al., 2018]. This optimization procedure is an internal part of the amortization model \(\hat{y}_θ\), often referred to as the inner-level optimization problem in the bi-level setting that arises for learning.
最も一般的な半償却モデルの一つは、式(1.1)を解くために用いられる最適化手順(例えば勾配降下法[Andrychowicz et al., 2016, Finn et al., 2017, Kim et al., 2018])をパラメータ化し、モデル\(\hat{y}_θ\)に統合するものである。この最適化手順は償却モデル\(\hat{y}_θ\)の内部部分であり、学習に用いられる二層設定における内部レベル最適化問題と呼ばれることが多い。
Examples. This section instantiates a canonical semi-amortized model based gradient descent that learns the initialization as in model-agnostic meta-learning (MAML) by Finn et al. [2017], structured prediction energy networks (SPENs) by Belanger et al. [2017], and semi-amortized variational auto-encoders (SAVAEs) by Kim et al. [2018]. The initial iterate \(\hat{y}_θˆ0(x):= θ\) is parameterized by \(θ ∈ \mathcal{X}\) for all contexts. Iteratively updating \(\hat{y}_θˆt\) for \(K\) gradient steps with a learning rate or step size \(α ∈ \mathbb{R}_+\) on the objective \(f(y;x)\) gives
θ
θ
yˆt := yˆt−1 − α∇yf(yˆt−1;x) t ∈ {1...,K}, (2.1)
θ
θ
θ
where model’s output is defined as yˆ := yˆK.
θ
Semi-amortized models over the domain can go significantly beyond gradient-based models and in theory, any algorithm to solve the original optimization problem in eq. (1.1) can be integrated into the model. Section 2.2.2 further discusses the learning of semi-amortized models by unrolling that are instantiated later:
• Section 3.2 discusses how Gregor and LeCun [2010] integrate ISTA iterates [Daubechies et al., 2004, Beck and Teboulle, 2009] into a semi-amortized model.
• Section 3.4.1 discusses models that integrate fixed-point computations into semi-amortized models. Venkataraman and Amos [2021] amortize convex cone programs by differentiating through the splitting cone solver [O’donoghue et al., 2016] and Bai et al. [2022] amortize deep equilibrium models [Bai et al., 2019, 2020].
• Section 3.4.5 discusses RLQP by Ichnowski et al. [2021] that uses the OSQP solver [Stellato et al., 2018] inside of a semi-amortized model.
Semi-amortized models over a latent space Z
zˆ0 zˆ1 ... zˆK yˆ (x)
θ
θ
θ
θ
In addition to only updating iterates over the domain Y, a natural generalization is to introduce a latent space Z that is iteratively optimized over inside of the
9
amortization model. This is usually done to give the semi-amortized model more capacity to learn about the structure of the optimization problems that are being solved. The latent space can also be interpreted as a representation of the optimal solution space. This is useful for learning an optimizer that only searches over the optimal region of the solution space rather than the entire solution space.
Examples. The iterative gradient updates in eq. (2.1) can be replaced with a learned update function as in Ravi and Larochelle [2017], Li and Malik [2017a], Andrychowicz et al. [2016], Li and Malik [2017b]. These model the past sequence of iterates and learn how to best-predict the next iterate, pushing them towards optimality. This can be done with a recurrent cell g such as an LSTM [Hochreiter and Schmidhuber, 1997] or GRU [Cho et al., 2014] and leads to updates of the form
zθ,yˆt := gθ(zt−1,xt−1,∇yf(yˆt−1;x)) t ∈ {1...,K} (2.2)
t
θ
θ θ θ
where each call to the recurrent cell g takes a hidden state z along with an iterate and the derivative of the objective. This endows g with the capacity to learn significant updates leveraging the problem structure that a traditional optimization method would not be able to make. In theory, traditional update rules can also be fallen back on as the gradient step in eq. (2.1) is captured by removing the hidden state z and setting
g(x,∇yf(y;x)) := x − α∇yf(y;x). (2.3)
Latent semi-amortized models are a budding topic and can excitingly learn many other latent representations that go beyond iterative gradient updates in the original space. Luo et al. [2018], Amos and Yarats [2020] learn a latent domain connected to the original domain where the latent domain captures hidden structures and redundancies present in the original high-dimensional domain Y. Luo et al. [2018] consider gradient updates in the latent domain and Amos and Yarats [2020] show that the cross-entropy method [De Boer et al., 2005] can be made differentiable and learned as an alternative to gradient updates. Amos et al. [2017] unrolls and differentiates through the bundle method [Smola et al., 2007] in a convex setting as an alternative to gradient steps. The latent optimization could also be done over a learned parameter space as in POPLIN [Wang and Ba, 2020], which lifts the domain of the optimization problem eq. (1.1) from Y to the parameter space of a fully-amortized neural network. This leverages the insight that the parameter space of over-parameterized neural networks can induce easier non-convex optimization problems than in the original space, which is also studied in Hoyer et al. [2019].
Comparing semi-amortized models with warm-starting
Semi-amortized models are conceptually similar to learning a fully-amortized model to warm-start an existing optimization procedure that fine-tunes the solution. The crucial difference is that semi-amortized learning often end-to-end learns through the final prediction while warm-starting and fine-tuning only learns the initial prediction
10
and does not integrate the knowledge of the fine-tuning procedure into the learning procedure. Choosing between these is an active research topic and while this tutorial will mostly focus on semi-amortized models, learning a fully-amortized warm-starting model brings promising results to some fields too, such as Zhang et al. [2019b], Baker [2019], Chen et al. [2022b]. In variational inference, Kim et al. [2018, Table 2] compare semi-amortized models (SA-VAE) to warm-starting and fine-tuning (VAE+SVI) and demonstrate that the end-to-end learning signal is helpful. In other words, amortization finds an initialization that is helpful for gradient-based optimization. Arbel and Mairal [2022] further study fully-amortized warm-started solvers that arise in bi-level optimization problems for hyper-parameter optimization and use the theoretical framework from singularly perturbed systems [Habets, 2010] to analyze properties of the approximate solutions.
On second-order derivatives of the objective
Training a semi-amortized model is usually more computationally challenging than training a fully-amortized model. This section looks at how second-order derivatives of the objective may come up when unrolling and create a computational bottleneck when learning a semi-amortized model. The next derivation follows Nichol et al. [2018, §5] and Weng [2018] and shows the model derivatives that arise when composing a semi-amortized model with a loss.
Starting with a single-step model. This section instantiates a single-step model similar to eq. (2.1) that parameterizes the initial iterate yˆ0(x) := θ and takes one gradient step:
θ
yˆ (x) := yˆ0(x) − α∇yf(yˆ0(x);x) (2.4)
θ
θ θ
Interpreting yˆ (x) as a model is non-standard in contrast to other parametric models because it makes the optimization step internally part of the model. Gradient-based optimization of losses with respect to the model’s parameters, such as eqs. (2.9) and (2.10) requires the Jacobian of yˆ (x) w.r.t. the parameters, i.e. Dθ[yˆ (x)] (or Jacobian-vector products with it). Because yˆ (x) is an optimization step, the derivative of the model requires differentiating through the optimization step, which for eq. (2.4) is
θ
θ θ
θ
Dθ[yˆ (x)] = I − α∇2f(y0(x);x) (2.5)
θ
y
θ
and requires the Hessian of the objective. In Finn et al. [2017], ∇yf is the Hessian of the model’s parameters on the training loss (!) and is compute- and memory-expensive to instantiate for large models. In practice, the Hessian in eq. (2.5) is often never explicitly instantiated as optimizing the loss only requires Hessian-vector products. The Hessian-vector product can be computed exactly or estimated without fully instantiating the Hessian, similar to how computing the derivative of a neural network with backprop does not instantiate the intermediate Jacobians and only computes the Jacobian-vector product. More information about eficiently computing Hessian-vector products is available in Pearlmutter [1994], Domke [2012]. Jax’s
2
11
autodiff cookbook [Bradbury et al., 2020] further describes eficient Hessian-vector products. Before discussing alternatives, the next portion derives similar results for a K-step model.
Multi-step models. Eq. (2.4) can be extended to the K-step setting with
yˆK(x) := yˆK−1(x) − α∇yf(yˆK−1(x);x), (2.6)
θ
θ θ
where the base yˆ0(x) := θ as before. Similar to eq. (2.5), the derivative of a single
θ
step is
Dθ[yˆK(x)] = Dθ[yˆK−1(x)] I − α∇2f(yK−1(x);x) , (2.7) and composing the derivatives down to yˆ0 yields the product structure
y
θ
θ
θ
θ
Dθ[yˆK(x)] = K−1 I − α∇2f(yk(x);x), (2.8) k=0
Y
y
θ θ
where Dθ[yˆ0(x)] = I at the base case. Computing eq. (2.8) is now K times more challenging as it requires the Hessian ∇yf at every iteration of the model. While using Hessian-vector products can alleviate some computational burden of this term, it often still requires significantly more operations than most other derivatives.
θ
2
Computationally cheaper alternatives. The first-order MAML baseline in Finn et al. [2017] suggests to simply not use the second-order terms ∇yf here, approximating the model derivative as the identity, i.e. Dθ[yˆK(x)] ≈ I, and relying on only information from the outer loss to update the parameters. They use the intuition from Goodfellow et al. [2015] that neural networks are locally linear and therefore these second-order terms of f are not too important. They show that this approximation works well in some cases, such as MiniImagenet [Ravi and Larochelle, 2017]. The MAML++ extension by Antoniou et al. [2019] proposes to use first-order MAML during the early phases of training, but to later add back this second-order information. Nichol et al. [2018] further analyze first-order approximations to MAML and propose another approximation called Reptile that also doesn’t use this second-order information. These higher-order terms also come up when unrolling in the different bi-level optimization setting for hyper-parameter optimization, and Lorraine et al. [2020, Table 1] gives a particularly good overview of approximations to these. Furthermore, memory-eficient methods for training neural networks and recurrent models with backpropagation and unrolling such as Gruslys et al. [2016], Chen et al. [2016] can also help improve the memory utilization in amortization models.
2
θ
Parameterizing and learning the objective. While this section has mostly not considered the setting when the objective f is also learned, the second-order derivatives appearing in eq. (2.8) also cause issues in when the objective is parameterized and learned. In addition to learning an initial iterate, Belanger et al. [2017] learn the objective f representing an energy function. They parameterize f as a neural network and use softplus activation functions rather than ReLUs to ensure the objective’s second-order derivatives are non-zero.
12
2.1.3 Models based on differentiable optimization
As discussed in section 2.2, the model typically needs to be (sub-)differentiable with respect to the parameters to attain the Jacobian Dθ[yˆ ] (or compute Jacobian-vector products with it) necessary to optimize the loss. These derivatives are standard backprop when the model is, for example, a full-amortized neural network, but in the semi-amortized case, the model itself is often an optimization process that needs to be differentiated through. When the model updates are objective-based as in eq. (2.1) and eq. (2.2), the derivatives with respect to θ through the sequence of gradient updates in the domain can be attained by seeing the updates as a sequence of computations that are differentiated through, resulting in second-order derivatives. When more general optimization methods are used for the amortization model that may not have a closed-form solution, the tools of differentiable optimization [Domke, 2012, Gould et al., 2016, Amos and Kolter, 2017, Amos, 2019, Agrawal et al., 2019a] enable end-to-end learning.
θ
2.1.4 Practically choosing a model
This section has taxonomized how to instantiate an amortization model in an application-agnostic way. As in most machine learning settings in practice, the modeling choice is often application-specific and needs to take into consideration many factors. This may include 1) the speed and expressibility of the model, 2) adapting the model to specific context space X. An MLP may be good for fixed-dimensional real-valued spaces but a convolutional neural network is likely to perform better for image-based spaces. 3) taking the solution space Y into consideration. For example, if the solution space is an image space, then a standard vision model capable of predicting high-dimensional images is reasonable, such as a U-net [Ron-neberger et al., 2015], dilated convolutional network [Yu and Koltun, 2016] or fully convolutional network [Long et al., 2015]. 4) the model also may need to adapt to a variable-length context or solution space. This arises in VeLO [Metz et al., 2022] for learning to optimize machine learning models where the model needs to predict the parameters of different models that may have different numbers of parameters. Their solution is to decompose the structure of the parameter space and to formulate the semi-amortized model as a sequence model that predicts smaller MLPs that operate on smaller groups of parameters.
2.2 Learning the model’s parameters θ
After specifying the amortization model yˆ , the other major design choice is how to learn the parameters θ so that the model best-solves eq. (1.1). Learning is often a bi-level optimization problem where the outer level is the parameter learning problem for a model yˆ (x) that solves the inner-level optimization problem in eq. (1.1) over the domain Y. While defining the best loss is application-specific,
θ
θ
13
θ
Regression-Based Objective-Based
f(y;x)
y y^ (x) y?(x) y
x x
Figure 2.2: Overview of key losses for optimizing the parameters θ of the amortization model yˆ . Regression-based losses optimize a distance between the model’s prediction yˆ (x) and the ground-truth y⋆(x). Objective-based methods update yˆ using local information of the objective f and without access to the ground-truth solutions y⋆.
θ
θ θ
most approaches can be roughly categorized as 1) regressing a ground-truth solution (section 2.2.1), or 2) minimizing the objective (sections 2.2.1 and 2.2.3), which fig. 2.2 illustrates. Optimizing the model parameters here can in theory be done with most parameter learning methods that incorporate zeroth-, first-, and higher-order information about the loss being optimized, and this section mostly focuses on methods where θ is learned with a first-order gradient-based method such as Nesterov [1983], Duchi et al. [2010], Zeiler [2012], Kingma and Ba [2015]. The rest of this section discusses approaches for designing the loss and optimizing the parameters with first-order methods (section 2.2.1) when differentiation is easy or zeroth-order methods (section 2.2.3) otherwise, e.g., in non-differentiable settings.
2.2.1 Choosing the objective for learning Regression-based learning
Learning can be done by regressing the model’s prediction yˆ (x) onto a ground-truth solution y⋆(x). These minimize some distance between the predictions and ground-truth so that the expectation over the context distribution p(x) is minimal. With a Euclidean distance, for example, regression-based learning solves
θ
argminLreg(yˆ ) θ
L (yˆ ) := E ∥y⋆(x) − yˆ (x)∥2. (2.9)
x∼p(x)
reg
θ θ
2
Lreg is typically optimized with an adaptive first-order gradient-based method that is able to directly differentiate the loss with respect to the model’s parameters.
Regression-based learning works the best for distilling known solutions into a faster model that can be deployed at a much lower cost, but can otherwise start failing to work. In RL and control, regression-based amortization methods are referred to as behavioral cloning and is a widely-used way of recovering a policy using trajectories observed from an expert policy. Using regression is also advantageous
14
when evaluating the objective f(y;x) incurs a computationally intensive or otherwise complex procedure, such as an evaluation of the environment and dynamics in RL, or for computing the base model gradients when learning parameter optimizers. These methods work well when the ground-truth solutions are unique and semi-tractable, but can fail otherwise, i.e. if there are many possible ground-truth solutions for a context x or if computing them is too intractable. After all, solving eq. (1.1) from scratch may be computationally expensive and amortization methods should improve the computation time.
Remark 1 Eq. (2.9) can be extended to other distances defined on the domain, such as non-Euclidean distances or the likelihood of a probabilistic model that predicts a distribution of possible candidate solutions. Adler et al. [2017] propose to use the Wasserstein distance for learning to predict the solutions to inverse imaging problems.
Objective-based learning
Instead of regressing onto the ground-truth solution, objective-based learning methods seek for the model’s prediction to be minimal under the objective f with:
argminLobj(yˆ )
θ
θ
L (yˆ ) := E f(yˆ (x);x). (2.10)
x∼p(x)
obj θ θ
These methods use local information of the objective to provide a descent direction for the model’s parameters θ. A first-order method optimizing eq. (2.10) uses updates based on the gradient
∇θLobj(yˆ ) = ∇θ x Ex) f(yˆ (x);x)
= E D [yˆ (x)]⊤ ∇ [f(yˆ (x);x)],
∼p(
θ θ
y
θ θ θ
x∼p(x)
(2.11)
where the last step is obtained by the chain rule. This has the interpretation that the model’s parameters θ are updated by combining the gradient information around the prediction ∇y [f(yˆ (x);x)] shown in fig. 2.2 along with how θ impacts the model’s predictions with the derivative Dθ [yˆ (x)]. While this tutorial mostly focuses on optimizing eq. (2.11) with first-order methods that explicitly differentiate the objective, section 2.2.3 discusses alternatives to optimizing it with reinforcement learning and zeroth-order methods.
θ
θ
Objective-based methods thrive when the gradient information is informative and the objective and models are easily differentiable. Amortized variational inference methods and actor-critic methods both make extensive use of objective-based learning.
Remark 2 A standard gradient-based optimizer for eq. (1.1) (without amortization) can be recovered from Lobj by setting the model to the identity of the parameters, i.e. yˆ (x) := θ, and p(x) to be a Dirac delta distribution.
θ
15
f(y;x)
y1 y?(x)
y0
Figure 2.3: Contours of the regression-based amortization loss Lreg (in black) alongside the contours of the objective (in purple where darker colors indicate higher values). This shows the inaccuracies of the regression-based loss, e.g. along a level set, may impact the overall objective differently.
This can be seen by taking Dθ[yˆ (x)] = I in eq. (2.11), resulting in ∇θLobj(yˆ ) = ∇yf(yˆ (x);x). Thus optimizing θ of this parameter-identity model with gradient descent is identical to solving eq. (1.1) with gradient descent. Remark 2 shows a connection between a model trained with gradients of an objective-based loss and a non-amortized gradient-based solver for eq. (1.1). The gradient update that would originally have been applied to an iterate y ∈ Y of the domain is now transferred into the model’s parameters that are shared across all problem instances. This also leads to a hypothesis that objective-based amortization works best when a gradient-based optimizer is able to successfully solve eq. (1.1) from scratch. However, there may be settings where a gradient-based optimizer performs poorly but an amortized optimizer excels because it is able to use information from the other problem instances.
θ θ
θ
Remark 3 The objective-based loss in section 2.2.1 provides a starting point for amortizing with other optimality conditions or reformulations of the optimization problem. This is done when amortizing for fixed-point computations and convex optimization in section 3.4, as well as in optimal transport section 3.5.
Comparing the regression- and objective-based losses
Choosing between the regression- and objective-based losses is challenging as they measure the solution quality in different ways and have different convergence and locality properties. Liu et al. [2022] experimentally compare these losses for learning to optimize with fully-amortized set-based models. Figure 2.3 illustrates that the ℓ2-regression loss (the black contours) ignores the objective values (the purple contours) and thus gives the same loss to solutions that result in significantly different objective values. This could be potentially addressed by normalizing or re-weighting the dimensions for regression to be more aware of the curvature of the objective, but
16
θ
θ
this is often not done. Another idea is to combine both the objective and regression losses. Combining the losses could work especially well when only a few contexts are labeled, such as the regression and residual terms in the physics-informed neural operator paper [Li et al., 2021b]. The following summarizes some other advantages (+) and disadvantages (−):
Regression-based losses Lreg
− Often does not have access to f(y;x) + If f(y;x) is computationally expensive,
does not need to compute it
+ Uses global information with y⋆(x)
− It may be expensive to compute y⋆(x) + Does not need to compute ∇yf(y;x) − May be hard when y⋆(x) is not unique
Objective-based losses Lobj
+ Uses objective information of f(y;x)
− Can get stuck in local optima of f(y;x) + Faster, does not require y⋆(x)
− Often requires computing ∇yf(y;x) + Easily learns non-unique y⋆(x)
2.2.2 Learning iterative semi-amortized models
Fully-amortized or semi-amortized models can be learned with the regression- and objective-based losses. This section discusses how the loss can be further opened up and crafted to learn iterative semi-amortized methods. For example, if the model produces intermediate predictions yˆi in every iteration i, then instead of optimizing the loss of just the final prediction, i.e. L(yˆK), a more general loss LΣ may consider the impact of every iteration of the model’s prediction
θ
θ
K
argminLΣ(yˆ ) θ
θ
LΣ(yˆ ) := XwiL(yˆi), (2.12) i=0
where wi ∈ R+ are weights in every iteration i that give a design choice of how important it is for the earlier iterations to produce reasonable solutions. For example, setting wi = 1 encourages every iterate to be low.
Learning iterative semi-amortized methods also has (loose) connections to sequence learning models that arise in, e.g. text, audio, and language processing. Given the context x, an iterative semi-amortized model seeks to produce a sequence of predictions that ultimately result in the intermediate and final predictions, which can be analogous to a language model predicting future text given the previous text as context. One difference is that semi-amortized models do not necessarily attempt to model the probabilistic dependencies of a structured output space (such as language) and instead only needs to predict intermediate computation steps for solving an optimization problem. The next section discusses concepts that arise when computing the derivatives of a loss with respect to the model’s parameters.
17
Unrolled optimization and backpropagation through time
zˆ0 zˆ1 ... zˆK ... yˆ (x) L
θ
θ
θ
θ
The parameterization of every iterate zθ can influence the final prediction yˆ
i
θ
and thus losses on top of yˆ need to consider the entire chain of computations. Differentiating through this kind of iterative procedure is referred to as backpropagation through time in sequence models and unrolled optimization [Pearlmutter and Siskind, 2008, Zhang and Lesser, 2010, Maclaurin et al., 2015b, Belanger and McCallum, 2016, Metz et al., 2017, Finn et al., 2017, Han et al., 2017, Belanger et al., 2017, Belanger, 2017, Foerster et al., 2017, Bhardwaj et al., 2020, Monga et al., 2021] when the iterates are solving an optimization problem. The term “unrolling” arises because the model computation is iterative and computing Dθ[yˆ (x)] requires saving and differentiating the “unrolled” intermediate iterations, as in section 2.1.2. The terminology “unrolling” here emphasizes that the iterative computation produces a compute graph of operations and is likely inspired from loop unrolling in compiler optimization [Aho et al., 1986, Davidson and Jinturkar, 1995] where loop operations are inlined for eficiency and written as a single chain of repeated operations rather than an iterative computation of a single operation.
θ
θ
Even though Dθyˆ through unrolled optimization is well-defined, in practice it can be unstable because of exploding gradients [Pearlmutter, 1996, Pascanu et al., 2013, Maclaurin, 2016, Parmas et al., 2018] and ineficient for compute and memory resources because every iterate needs to be stored, as in section 2.1.2. This is why most methods using unrolled optimization for learning often only unroll through tens of iterations [Metz et al., 2017, Belanger et al., 2017, Foerster et al., 2017, Finn et al., 2017] while solving the problems from scratch may require 100k-1M+ iterations. This causes the predictions to be extremely inaccurate solutions to the optimization process and has sparked the research directions that the next section turns to that seek to make unrolled optimization more tractable.
θ
Truncated backpropagation through time and biased gradients
zˆ0 zˆ1 ... zˆK−H ... zˆK ...yˆ (x) L
θ
θ
θ
θ
θ
Truncated backpropagation through time (TBPTT) [Werbos, 1990, Jaeger, 2002] is a crucial idea that has enabled the training of sequence models over long sequences. TBPTT’s idea is that not every iteration needs to be differentiated through and that the derivative can be computed using smaller subsequences from the full sequence of model predictions by truncating the history of iterates. For example, the derivative of a model running for K iterations with a truncation length of H can be approximated by considering the influence of the last H iterates zθ i=K−H on the loss L.
H
i
18
y^0
y
Figure 2.4: Illustration of the penalty used in the Implicit MAML by Rajeswaran et al. [2019] in eq. (2.13). The original loss f(y;x) is shown in black for a fixed context x and the lighter grey colors show the impact of varying λ. This shows that the quadratic term of the penalization eventually overtakes the original loss and makes an optimum appear close to yˆ0
θ
Truncation significantly helps improve the computational and memory eficiency of unrolled optimization procedure but results in harmful biased gradients as these approximate derivatives do not contain the full amount of information that the model used to compute the prediction. This is especially damaging in approaches such as MAML [Finn et al., 2017] that only parameterize the first iterate and is why MAML-based approaches often don’t use TBPTT. Tallec and Ollivier [2017], Wu et al. [2018], Liao et al. [2018], Shaban et al. [2019], Vicol et al. [2021] seek to further theoretically understand the properties of TBPTT, including the bias of the estimator and how to unbias it.
Other gradient estimators for sequential models
In addition to truncating the iterations, other approaches attempt to improve the eficiency of learning through unrolled iterations with other approximations that retain the influence of the entire sequence of predictions on the loss [Finn et al., 2017, Nichol et al., 2018, Lorraine et al., 2020] which will be further discussed in section 2.1.2. Some optimization procedures, such as gradient descent with momentum, can also be “reversed” without needing to retain the intermediate states [Maclaurin et al., 2015b, Franceschi et al., 2017]. Real-Time Recurrent Learning (RTRL) by Williams and Zipser [1989] uses forward-mode automatic differentiation to compute unbiased gradient estimates in an online fashion. Unbiased Online Recurrent (UORO) by Tallec and Ollivier [2018] improves upon RTRL with a rank-1 approximation of the gradient of the hidden state with respect to the parameters. Silver et al. [2022] considers the directional derivative of a recurrent model along a candidate direction, which can be eficiently computed to construct a descent direction.
19
Semi-amortized learning with shrinkage and implicit differentiation
A huge issue arising in semi-amortized models is that adapting to long time horizons is computationally and memory ineficient and even if it wasn’t, causes exploding, vanishing, or otherwise unstable gradients. An active direction of research seeks to solve these issues by solving a smaller, local problem with the semi-amortized model, such as in Chen et al. [2020], Rajeswaran et al. [2019]. Implicit differentiation is an alternative to unrolling through the iterations of a semi-amortized model in settings where the model is able to successfully solve an optimization problem.
This section briefly summarizes Implicit MAML (iMAML) by Rajeswaran et al. [2019], which notably brings this insight to MAML. MAML methods usually only take a few gradient steps and are usually not enough to globally solve eq. (1.1), especially at the beginning of training. Rajeswaran et al. [2019] observe that adding a penalty to the objective around the initial iterate yˆ0 makes it easy for the model to globally (!) solve the problem
θ
yˆ (x) ∈ argminf(y;x) + λ∥y − yˆ0∥2, (2.13) y
θ
2
θ
2
where the parameter λ encourages the solution to stay close to some initial iterate. Figure 2.4 visualizes a function f(y;x) in black and add penalties in grey with λ ∈ [0,12] and see that a global minimum is dificult to find without adding a penalty around the initial iterate. This global solution can then be implicitly differentiated to obtain a derivative of the loss with respect to the model’s parameters without needing to unroll, as it requires significantly less computational and memory resources. Huszár [2019] further analyzes and discuses iMAML. They compare it to a Bayesian approach and observe that the insights from iMAML can transfer from gradient-based meta-learning to other amortized optimization settings.
Warning. Implicit differentiation is only useful when optimization problems are exactly solved and satisfy the conditions of the implicit function theorem in theorem 1. This is why Rajeswaran et al. [2019] needed to add a penalty to MAML’s inner optimization problem in eq. (2.13) to make the problem exactly solvable. While they showed that this works and results in significant improvements for differentiation, it comes at the expense of changing the objective to penalize the distance from the previous iterate. In other words, iMAML modifies MAML’s semi-amortized model and in general is not helpful for estimating the derivative through the original formulation of MAML. Furthermore, computing the implicit derivative by solving the linear system with the Jacobian in eq. (2.30) may be memory and compute expensive to form and estimate exactly. In practice, some methods such as Bai et al. [2019] successfully use indirect and approximation methods to solve for the system in eq. (2.30).
20
Figure 2.5: Illustration of perturbing yˆ . A zeroth-order optimizers may make perturbations like this to search for an improved parameterization
θ
2.2.3 Learning with zeroth-order methods and RL
Computing the derivatives to learn yˆ with a first-order method may be impossible or unstable. These problems typically arise when learning components of the model that are truly non-differentiable, or when attempting to unroll a semi-amortized model for a lot of steps. In these settings, research has successfully explored other optimizers that do not need the gradient information. These methods often consider settings that improve an objective-based loss with small local perturbations rather than differentiation. Figure 2.5 illustrates that most of these methods can be interpreted as locally perturbing the model’s prediction and updating the parameters to move towards the best perturbations.
θ
Reinforcement learning
Li and Malik [2017a,b], Ichnowski et al. [2021] view their semi-amortized models as a Markov decision process (MDP) that they solve with reinforcement learning. The MDP interpretation uses the insight that the iterations xi are the actions, the context and previous iterations or losses are typically the states, the associated losses L(xi) are the rewards, and yˆi(x) is a (deterministic) policy, and transitions given by updating the current iterate, either with a quantity defined by the policy or by running one or more iterations from an existing optimizer. Once this viewpoint is taken, then the optimal amortized model can be found by using standard reinforcement learning methods, e.g. Li and Malik [2017a,b] uses Guided Policy Search [Levine and Koltun, 2013] and Ichnowski et al. [2021] uses TD3 [Fujimoto et al., 2018]. The notation LRL indicates that a loss is optimized with reinforcement learning, typically on the objective-based loss.
θ
Loss smoothing and optimization with zeroth-order methods
Objective-based losses can have a high-frequency structure with many poor local minimum. Metz et al. [2019a] overcome this by smoothing the loss with a Gaussian
21
Figure 2.6: Gaussian smoothing of a loss using eq. (2.14). The colors show different values of the variance σ2 of the Gaussian. Selecting a high enough variance results in smoothing out most of the suboptimal minima.
over the parameter space, i.e.,
Lsmooth(yˆ ) := E [L(yˆ )],
θ θ+ϵ
ϵ∼N(0,σ2I)
(2.14)
where σ2 is a fixed variance. Figure 2.6 illustrates a loss function L in black and shows smoothed versions in color. They consider learning the loss with reparameterization gradients and zeroth-order evolutionary methods. Merchant et al. [2021] further build upon this for learned optimization in atomic structural optimization and study 1) clipping the values of the gradient estimator, and 2) parameter optimization with genetic algorithms.
Remark 4 While smoothing can help reduce suboptimal local minima, it may also undesirably change the location of the global minimum. One potential solution to this is to decay the smoothing throughout training, as done in Amos et al. [2021, Appendix A.1].
Connection to smoothing in reinforcement learning. The Gaussian smooth-ing of the objective L in eq. (2.14) is conceptually similar to Gaussian smoothing of the objective in reinforcement learning, i.e. the −Q-value, by a Gaussian policy. This happens in eq. (3.39) and is further discussed in section 3.6. The policy’s variance is typically controlled to match a target entropy Haarnoja et al. [2018] and the learning typically starts with a high variance so the policy has a broad view of the objective landscape and is then able to focus in on a optimal region of the value distribution. Amos et al. [2021] uses a fixed entropy decay schedule to explicitly control this behavior. In contrast, Metz et al. [2019a], Merchant et al. [2021] do not turn the loss into a distribution and more directly smooth the loss with a Gaussian with a fixed variance σ2 that is not optimized over.
2.3 Extensions
I have intentionally scoped definition 1 to optimization problems over deterministic, unconstrained, finite-dimensional, Euclidean domains Y where the context distribution
22
p(x) remains fixed the entire training time to provide a simple mental model that allows us to focus on the core amortization principles that consistently show up between applications. This section cover extensions from this setting that may come up in practice.
2.3.1 Extensions of the domain Y Deterministic → stochastic optimization
A crucial extension is from optimization over deterministic vector spaces Y to stochastic optimization where Y represents a space of distributions, turning y ∈ Y from a vector in Euclidean space into a distribution. This comes up in section 3.6 for control, for example..
Transforming parameterized stochastic problems back into determin-istic ones. This portion will mostly focus on settings that optimize over the parametric distributions. This may arise in stochastic domains for variational in-ference in section 3.1 and stochastic control in section 3.6. These settings optimize over a constrained parametric family of distributions parameterized by some λ, for example over a multivariate normal N(µ,Σ) parameterized by λ := (µ,Σ). Here, problems can be transformed back to eq. (1.1) by optimizing over the parameters with
λ⋆(x) ∈ argminf(λ;x), (2.15) λ
where λ induces a distribution that the objective f may use. When λ is not an unconstrained real space, the differentiable projections discussed in section 2.3.1 could be used to transform λ back into this form for amortization.
Optimizing over distributions and densities. More general stochastic optimization settings involve optimizing over spaces representing distributions, such as the functional space of all continuous densities. Many standard probability distributions can be obtained and characterized as the solution to a maximum entropy optimization problem and is explored, e.g., in Cover and Thomas [2006, Ch. 12], Guiasu and Shenitzer [1985, p. 47], and Pennec [2006, §6.2]. For example, a Gaussian distribution N(µ,Σ) solves the following constrained maximum entropy optimization problem over the space of continuous densities P:
p⋆(µ,Σ) ∈ argp∈mPaxHp[X] subject to E[X] = µ and Varp[X] = Σ, (2.16)
p
where Hp[X] := −R p(x)logp(x)dx is the differential entropy and the constraints are on the mean Ep[X] and covariance Varp[X]. Cover and Thomas [2006, Theorem 8.6.5 and Example 12.2.8] prove that the closed-form solution of p⋆ is the Gaussian density. This Gaussian setting therefore does not need amortization as the closed-form solution is known and easily computed, but more general optimization problems over densities do not necessarily have closed-form solutions and could benefit from amortization.
23
Figure 2.7: The Gaussian distribution can be characterized as the result of the optimization problem in eq. (2.16): constrained to the space of continuous distribu-tions with a given mean and variance, the Gaussian distribution has the maximum entropy in comparison to every other distribution. This example parameterizes a non-Gaussian density (shown in grey) and optimizes over it using gradient steps of eq. (2.16), eventually converging to a Gaussian. An animated version is available in the repository associated with this tutorial. While the Gaussian is the known closed-form solution to this optimization problem and analytically known, more general optimization problems over densities without known solutions can also be amortized.
While this tutorial does not study amortizing these problems, in some cases it may be possible to again transform them back into deterministic optimization problems over Euclidean space for amortization by approximating the density gθ with an expressive family of densities parameterized by θ.
Unconstrained → constrained optimization
Amortized constrained optimization problems may naturally arise, for example in the convex optimization settings in section 3.4 and for optimization over the sphere in section 4.2. Constrained optimization problems for amortization can often be represented as an extension of eq. (1.1) with
y⋆(x) ∈ argminf(y;x), (2.17) y∈C
where the constraints C may also depend on the context x. Remark 3 suggests one way of amortizing eq. (2.17) by amortizing the objective-based loss associated with the optimality conditions of the constrained problem. A budding research area studies how to more generally include constraints into the formulation. Baker [2019], Dong et al. [2020], Zamzam and Baker [2020], Pan et al. [2020], Klamkin et al. [2025], Hentenryck [2025] predict solutions to optimal power flow problems. Misra et al. [2021] learn active sets for constrained optimization. Kriváchy et al. [2020] solves constrained feasibility semi-definite programs with a fully-amortized neural network model using an objective-based loss. Donti et al. [2021] learns a fully-amortized model
24
x
πC(x)
C
Figure 2.8: Illustration of definition 4 showing a Euclidean projection πC(x) of a point x onto a set C.
and optimizes an objective-based loss with additional completion and correction terms to ensure the prediction satisfies the constraints of the original problem.
Differentiable projections. When the constraints are relatively simple, a differentiable projection can transform a constrained optimization problem into an unconstrained one, e.g., in reinforcement learning constrained action spaces can be transformed from the box [−1,1]n to the reals \mathbb{R}^n by using the tanh to project from \mathbb{R}^n to [−1,1]n. Section 4.2 also uses a differentiable projection from \mathbb{R}^n onto the sphere Sn−1. These are illustrated in section 2.3.1 and defined as:
Definition 4 A projection from \mathbb{R}^n onto a set C ⊆ \mathbb{R}^n is
πC : \mathbb{R}^n → C πC(x) ∈ argy∈mC inD(x,y) + Ω(y), (2.18)
where D : \mathbb{R}^n ×\mathbb{R}^n → R is a distance and Ω : \mathbb{R}^n → R is a regularizer that can ensure invertibility or help spread \mathbb{R}^n more uniformly throughout C. A (sub)differentiable projection has (sub)derivatives ∇xπC(x). I sometimes omit the dependence of π on the choice of D, Ω, and C when they are given by the surrounding context.
Lack of idempotency. In linear algebra, a projection is defined to be idem-potent, i.e. applying the projection twice gives the same result so that π ◦ π = π. Unfortunately, projections as defined in definition 4, such as Bregman projections, are not idempotent in general and often πC ◦ πC = πC as the regularizer Ω may cause points that are already on C to move to a different position on C.
Differentiable projections for constrained amortization. These can be used to cast Eq. (2.17) as the unconstrained problem eq. (1.1) by composing the objective with a projection f ◦ πC. (Sub)differentiable projections enable gradient-based learning through the projection and is the most easily attainable when the projection has an explicit closed-form solution. For intuition, the ReLU, sigmoid, and softargmax can be interpreted as differentiable projections that solve convex optimization problems in the form of eq. (2.18). Amos [2019, §2.4.4] further discusses these and proves them using the KKT conditions:
25
Figure 2.9: Illustration of the second-order cone in eq. (2.24).
• The standard Euclidean projection onto the non-negative orthant R+ is defined
n
by
π(x) ∈ argmin 1∥x − y∥2 s.t. y ≥ 0, (2.19) y
2
2
and has a closed-form solution given by the ReLU, i.e. π(x) := max{0,x}.
• The interior of the unit hypercube [0,1]n can be projected onto with the entropy-regularized optimization problem
π(x) ∈ argmin 0