CVaR allocation#

In this notebook, we solve a CVaR portfolio optimization problem: minimizing expected losses in the worst-case scenarios rather than overall variance.

Unlike mean-variance optimization which penalizes all volatility equally, CVaR focuses on worst-case scenarios: the expected loss in the worst α% of outcomes. This makes it useful for risk-averse investors concerned about extreme market downturns.

This example is based on the formulation from PyPortfolioOpt and Rockafellar & Uryasev (2000).

Features used:

  • Var with vector-valued and scalar defaults

  • Inequality constraints (constraint_geq_zero): CVaR auxiliary constraints, budget, no short-selling

  • Equality constraints (constraint_eq_zero): budget constraint

  • Augmented Lagrangian solver for constrained optimization

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Historical stock data#

We use the same dataset as the Mean-variance allocation example: monthly stock prices from November 2000 to November 2001 for IBM, Walmart (WMT), and Southern Electric (SEHI).

stock_names = ["IBM", "WMT", "SEHI"]

# Monthly prices (13 months: Nov 2000 - Nov 2001)
prices = jnp.array(
    [
        [93.043, 51.826, 1.063],
        [84.585, 52.823, 0.938],
        [111.453, 56.477, 1.0],
        [99.525, 49.805, 0.938],
        [95.819, 50.287, 1.438],
        [114.708, 51.521, 1.7],
        [111.515, 51.531, 2.54],
        [113.211, 48.664, 2.39],
        [104.942, 55.744, 3.12],
        [99.827, 47.916, 2.98],
        [91.607, 49.438, 1.9],
        [107.937, 51.336, 1.75],
        [115.59, 55.081, 1.8],
    ]
)

# Monthly returns: (P[t+1] - P[t]) / P[t]
returns = jnp.diff(prices, axis=0) / prices[:-1]
num_scenarios, num_assets = returns.shape

print(
    f"Returns shape: {returns.shape} ({num_scenarios} scenarios x {num_assets} assets)"
)
print(f"\nMonthly returns (%):\n{returns * 100}")
Returns shape: (12, 3) (12 scenarios x 3 assets)

Monthly returns (%):
[[-9.09042072e+00  1.92374790e+00 -1.17591667e+01]
 [ 3.17645016e+01  6.91743946e+00  6.60980558e+00]
 [-1.07022705e+01 -1.18136597e+01 -6.19999790e+00]
 [-3.72368884e+00  9.67771173e-01  5.33048973e+01]
 [ 1.97132092e+01  2.45391679e+00  1.82197552e+01]
 [-2.78359032e+00  1.94063038e-02  4.94117584e+01]
 [ 1.52087092e+00 -5.56363535e+00 -5.90550661e+00]
 [-7.30405807e+00  1.45487385e+01  3.05439224e+01]
 [-4.87411880e+00 -1.40427647e+01 -4.48717546e+00]
 [-8.23424625e+00  3.17639065e+00 -3.62416115e+01]
 [ 1.78261414e+01  3.83914971e+00 -7.89473581e+00]
 [ 7.09024763e+00  7.29508114e+00  2.85714006e+00]]

CVaR vs variance#

Variance measures average deviation from the mean – it penalizes upside and downside equally.

CVaR (Conditional Value at Risk) measures the expected loss in the worst \(\alpha\%\) of scenarios. For \(\alpha = 0.05\) (95% confidence), CVaR answers: “What’s my average loss on the worst 5% of days?”

Key advantages of CVaR:

  • Focuses on tail risk (extreme losses) rather than general volatility

  • Coherent risk measure (subadditive, convex)

  • Does not assume normally distributed returns

  • More robust to outliers than variance

CVaR formulation#

The CVaR optimization uses the Rockafellar-Uryasev formulation:

\[\text{CVaR}_\alpha = \min_{\zeta} \left[ \zeta + \frac{1}{\alpha T} \sum_{t=1}^T \max(-w^\top r_t - \zeta, 0) \right]\]

where:

  • \(w\) = portfolio weights

  • \(r_t\) = returns in scenario \(t\)

  • \(\zeta\) = VaR threshold (auxiliary variable)

  • \(\alpha\) = tail probability (e.g., 0.05 for 95% CVaR)

  • \(T\) = number of scenarios

To handle the \(\max(\cdot, 0)\) term, we introduce slack variables \(u_t \geq 0\):

\[\min_{w, \zeta, u} \quad \zeta + \frac{1}{\alpha T} \sum_{t=1}^T u_t\]

subject to:

  • \(u_t \geq -w^\top r_t - \zeta\) (loss exceeds VaR)

  • \(u_t \geq 0\) (slack non-negativity)

  • \(\sum_i w_i = 1\) (budget constraint)

  • \(w_i \geq 0\) (no short-selling)

# CVaR confidence level.
alpha = 0.05  # 95% CVaR (worst 5% of scenarios)


class WeightsVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.ones(3) / 3):
    """Portfolio weights (3D vector)."""


class VaRVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(1)):
    """Value-at-Risk threshold (scalar)."""


class SlackVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(12)):
    """Slack variables for max(loss - VaR, 0) per scenario."""


weights_var = WeightsVar(id=0)
var_var = VaRVar(id=0)
slack_var = SlackVar(id=0)
@jaxls.Cost.factory
def cvar_objective(
    vals: jaxls.VarValues,
    var_v: VaRVar,
    slack_v: SlackVar,
    alpha: float,
    num_scenarios: int,
) -> jax.Array:
    """CVaR objective: VaR + (1/alpha) * mean(slack).

    Since this is the only cost term, the solver minimizes CVaR^2. For a
    non-negative scalar, min(CVaR^2) has the same minimizer as min(CVaR).
    """
    var_threshold = vals[var_v]
    slack = vals[slack_v]
    return var_threshold + jnp.sum(slack) / (alpha * num_scenarios)


@jaxls.Cost.factory(kind="constraint_geq_zero")
def slack_lower_bound(
    vals: jaxls.VarValues,
    weights_v: WeightsVar,
    var_v: VaRVar,
    slack_v: SlackVar,
    scenario_returns: jax.Array,
) -> jax.Array:
    """Constraint: u_t >= -w'r_t - zeta (slack captures excess loss)."""
    weights = vals[weights_v]
    var_threshold = vals[var_v]
    slack = vals[slack_v]
    # Portfolio return for each scenario.
    portfolio_returns = scenario_returns @ weights
    # Loss = negative return.
    losses = -portfolio_returns
    # u_t >= loss_t - VaR.
    return slack - (losses - var_threshold)


@jaxls.Cost.factory(kind="constraint_geq_zero")
def slack_nonneg(vals: jaxls.VarValues, slack_v: SlackVar) -> jax.Array:
    """Constraint: u_t >= 0."""
    return vals[slack_v]


@jaxls.Cost.factory(kind="constraint_eq_zero")
def budget_constraint(vals: jaxls.VarValues, weights_v: WeightsVar) -> jax.Array:
    """Weights must sum to 1 (fully invested)."""
    return jnp.sum(vals[weights_v]) - 1.0


@jaxls.Cost.factory(kind="constraint_geq_zero")
def no_short_constraint(vals: jaxls.VarValues, weights_v: WeightsVar) -> jax.Array:
    """No short-selling: weights >= 0."""
    return vals[weights_v]

Solving#

costs = [
    cvar_objective(var_var, slack_var, alpha, num_scenarios),
    slack_lower_bound(weights_var, var_var, slack_var, returns),
    slack_nonneg(slack_var),
    budget_constraint(weights_var),
    no_short_constraint(weights_var),
]

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [weights_var, var_var, slack_var])

# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()

solution = problem.solve(
    verbose=True,
    linear_solver="dense_cholesky",
    termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
INFO     | Building optimization problem with 5 terms and 3 variables: 1 costs, 1 eq_zero, 0 leq_zero, 3 geq_zero
INFO     | Vectorizing group with 1 costs, 2 variables each: cvar_objective
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 3 variables each: augmented_slack_lower_bound
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_slack_nonneg
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_budget_constraint
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_no_short_constraint
INFO     | Augmented Lagrangian: initial snorm=1.3766e-01, csupn=1.3766e-01, max_rho=1.0000e+01, constraint_dim=28
INFO     |  step #0: cost=0.0000 lambd=0.0005
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.39280 (avg 0.03273)
INFO     |      - augmented_slack_nonneg(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.26e+00 cost_prev=0.3928 cost_new=0.0269
INFO     |  step #1: cost=0.0000 lambd=0.0003
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #2: cost=0.0000 lambd=0.0005
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #3: cost=0.0000 lambd=0.0010
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #4: cost=0.0000 lambd=0.0020
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #5: cost=0.0000 lambd=0.0040
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #6: cost=0.0000 lambd=0.0080
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #7: cost=0.0000 lambd=0.0160
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #8: cost=0.0000 lambd=0.0320
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #9: cost=0.0000 lambd=0.0640
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #10: cost=0.0000 lambd=0.1280
INFO     |      - cvar_objective(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02689 (avg 0.00224)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.65e-01 cost_prev=0.0269 cost_new=0.0264
INFO     |  step #11: cost=0.0000 lambd=0.0640
INFO     |      - cvar_objective(1): 0.00002 (avg 0.00002)
INFO     |      - augmented_slack_lower_bound(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.02642 (avg 0.00220)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.18e-01 cost_prev=0.0264 cost_new=0.0071
INFO     |  step #12: cost=0.0002 lambd=0.0320
INFO     |      - cvar_objective(1): 0.00021 (avg 0.00021)
INFO     |      - augmented_slack_lower_bound(1): 0.00012 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00674 (avg 0.00056)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #13: cost=0.0002 lambd=0.0640
INFO     |      - cvar_objective(1): 0.00021 (avg 0.00021)
INFO     |      - augmented_slack_lower_bound(1): 0.00012 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00674 (avg 0.00056)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.99e-01 cost_prev=0.0071 cost_new=0.0061
INFO     |  step #14: cost=0.0004 lambd=0.0320
INFO     |      - cvar_objective(1): 0.00039 (avg 0.00039)
INFO     |      - augmented_slack_lower_bound(1): 0.00010 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00558 (avg 0.00046)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #15: cost=0.0004 lambd=0.0640
INFO     |      - cvar_objective(1): 0.00039 (avg 0.00039)
INFO     |      - augmented_slack_lower_bound(1): 0.00010 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00558 (avg 0.00046)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #16: cost=0.0004 lambd=0.1280
INFO     |      - cvar_objective(1): 0.00039 (avg 0.00039)
INFO     |      - augmented_slack_lower_bound(1): 0.00010 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00558 (avg 0.00046)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #17: cost=0.0004 lambd=0.2560
INFO     |      - cvar_objective(1): 0.00039 (avg 0.00039)
INFO     |      - augmented_slack_lower_bound(1): 0.00010 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00558 (avg 0.00046)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.52e-01 cost_prev=0.0061 cost_new=0.0053
INFO     |  step #18: cost=0.0006 lambd=0.1280
INFO     |      - cvar_objective(1): 0.00056 (avg 0.00056)
INFO     |      - augmented_slack_lower_bound(1): 0.00297 (avg 0.00025)
INFO     |      - augmented_slack_nonneg(1): 0.00175 (avg 0.00015)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.92e-01 cost_prev=0.0053 cost_new=0.0025
INFO     |  step #19: cost=0.0006 lambd=0.0640
INFO     |      - cvar_objective(1): 0.00061 (avg 0.00061)
INFO     |      - augmented_slack_lower_bound(1): 0.00003 (avg 0.00000)
INFO     |      - augmented_slack_nonneg(1): 0.00185 (avg 0.00015)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=4.15e-03 cost_prev=0.0025 cost_new=0.0024
INFO     |  AL update: snorm=4.0531e-03, csupn=4.0531e-03, max_rho=4.0000e+01
INFO     |  step #20: cost=0.0006 lambd=0.0320
INFO     |      - cvar_objective(1): 0.00059 (avg 0.00059)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.01135 (avg 0.00095)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=7.17e-01 cost_prev=0.0121 cost_new=0.0080
INFO     |  step #21: cost=0.0043 lambd=0.0160
INFO     |      - cvar_objective(1): 0.00431 (avg 0.00431)
INFO     |      - augmented_slack_lower_bound(1): 0.00033 (avg 0.00003)
INFO     |      - augmented_slack_nonneg(1): 0.00332 (avg 0.00028)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.65e-04 cost_prev=0.0080 cost_new=0.0080
INFO     |  AL update: snorm=3.7433e-03, csupn=3.7433e-03, max_rho=4.0000e+01
INFO     |  step #22: cost=0.0043 lambd=0.0080
INFO     |      - cvar_objective(1): 0.00431 (avg 0.00431)
INFO     |      - augmented_slack_lower_bound(1): 0.00109 (avg 0.00009)
INFO     |      - augmented_slack_nonneg(1): 0.00880 (avg 0.00073)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=4.21e-01 cost_prev=0.0142 cost_new=0.0130
INFO     |  step #23: cost=0.0072 lambd=0.0040
INFO     |      - cvar_objective(1): 0.00725 (avg 0.00725)
INFO     |      - augmented_slack_lower_bound(1): 0.00014 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00558 (avg 0.00046)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.81e-04 cost_prev=0.0130 cost_new=0.0129
INFO     |  AL update: snorm=8.0387e-04, csupn=8.0387e-04, max_rho=4.0000e+01
INFO     |  step #24: cost=0.0072 lambd=0.0020
INFO     |      - cvar_objective(1): 0.00721 (avg 0.00721)
INFO     |      - augmented_slack_lower_bound(1): 0.00021 (avg 0.00002)
INFO     |      - augmented_slack_nonneg(1): 0.00836 (avg 0.00070)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #25: cost=0.0072 lambd=0.0040
INFO     |      - cvar_objective(1): 0.00721 (avg 0.00721)
INFO     |      - augmented_slack_lower_bound(1): 0.00021 (avg 0.00002)
INFO     |      - augmented_slack_nonneg(1): 0.00836 (avg 0.00070)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #26: cost=0.0072 lambd=0.0080
INFO     |      - cvar_objective(1): 0.00721 (avg 0.00721)
INFO     |      - augmented_slack_lower_bound(1): 0.00021 (avg 0.00002)
INFO     |      - augmented_slack_nonneg(1): 0.00836 (avg 0.00070)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #27: cost=0.0072 lambd=0.0160
INFO     |      - cvar_objective(1): 0.00721 (avg 0.00721)
INFO     |      - augmented_slack_lower_bound(1): 0.00021 (avg 0.00002)
INFO     |      - augmented_slack_nonneg(1): 0.00836 (avg 0.00070)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.48e-01 cost_prev=0.0158 cost_new=0.0156
INFO     |  step #28: cost=0.0087 lambd=0.0080
INFO     |      - cvar_objective(1): 0.00872 (avg 0.00872)
INFO     |      - augmented_slack_lower_bound(1): 0.00018 (avg 0.00002)
INFO     |      - augmented_slack_nonneg(1): 0.00671 (avg 0.00056)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.74e-02 cost_prev=0.0156 cost_new=0.0156
INFO     |  step #29: cost=0.0087 lambd=0.0040
INFO     |      - cvar_objective(1): 0.00874 (avg 0.00874)
INFO     |      - augmented_slack_lower_bound(1): 0.00015 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00671 (avg 0.00056)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=False ATb_norm=1.11e-05 cost_prev=0.0156 cost_new=0.0156
INFO     |  AL update: snorm=3.5812e-04, csupn=3.5812e-04, max_rho=4.0000e+01
INFO     |  step #30: cost=0.0087 lambd=0.0040
INFO     |      - cvar_objective(1): 0.00874 (avg 0.00874)
INFO     |      - augmented_slack_lower_bound(1): 0.00016 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00797 (avg 0.00066)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.58e-02 cost_prev=0.0169 cost_new=0.0168
INFO     |  step #31: cost=0.0094 lambd=0.0020
INFO     |      - cvar_objective(1): 0.00944 (avg 0.00944)
INFO     |      - augmented_slack_lower_bound(1): 0.00016 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00724 (avg 0.00060)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.97e-05 cost_prev=0.0168 cost_new=0.0168
INFO     |  AL update: snorm=1.5261e-04, csupn=1.5261e-04, max_rho=4.0000e+01
INFO     |  step #32: cost=0.0094 lambd=0.0010
INFO     |      - cvar_objective(1): 0.00944 (avg 0.00944)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00780 (avg 0.00065)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.80e-02 cost_prev=0.0174 cost_new=0.0174
INFO     |  step #33: cost=0.0098 lambd=0.0005
INFO     |      - cvar_objective(1): 0.00976 (avg 0.00976)
INFO     |      - augmented_slack_lower_bound(1): 0.00016 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00748 (avg 0.00062)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.71e-06 cost_prev=0.0174 cost_new=0.0174
INFO     |  AL update: snorm=6.6933e-05, csupn=6.6933e-05, max_rho=4.0000e+01
INFO     |  step #34: cost=0.0098 lambd=0.0003
INFO     |      - cvar_objective(1): 0.00976 (avg 0.00976)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00773 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.23e-02 cost_prev=0.0177 cost_new=0.0177
INFO     |  step #35: cost=0.0099 lambd=0.0001
INFO     |      - cvar_objective(1): 0.00990 (avg 0.00990)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00759 (avg 0.00063)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.94e-06 cost_prev=0.0177 cost_new=0.0177
INFO     |  AL update: snorm=2.9420e-05, csupn=2.9420e-05, max_rho=4.0000e+01
INFO     |  step #36: cost=0.0099 lambd=0.0001
INFO     |      - cvar_objective(1): 0.00990 (avg 0.00990)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00770 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.41e-03 cost_prev=0.0178 cost_new=0.0178
INFO     |  AL update: snorm=1.2926e-05, csupn=1.2926e-05, max_rho=4.0000e+01
INFO     |  step #37: cost=0.0100 lambd=0.0000
INFO     |      - cvar_objective(1): 0.00996 (avg 0.00996)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.38e-03 cost_prev=0.0178 cost_new=0.0178
INFO     |  AL update: snorm=5.6801e-06, csupn=5.6801e-06, max_rho=4.0000e+01
INFO     |  step #38: cost=0.0100 lambd=0.0000
INFO     |      - cvar_objective(1): 0.00999 (avg 0.00999)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.04e-03 cost_prev=0.0178 cost_new=0.0178
INFO     |  AL update: snorm=2.4961e-06, csupn=2.4961e-06, max_rho=4.0000e+01
INFO     |  step #39: cost=0.0100 lambd=0.0000
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #40: cost=0.0100 lambd=0.0000
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #41: cost=0.0100 lambd=0.0000
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #42: cost=0.0100 lambd=0.0001
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #43: cost=0.0100 lambd=0.0002
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |  step #44: cost=0.0100 lambd=0.0003
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00768 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=4.58e-04 cost_prev=0.0178 cost_new=0.0178
INFO     |  AL update: snorm=1.0965e-06, csupn=1.0965e-06, max_rho=1.6000e+02
INFO     |  step #45: cost=0.0100 lambd=0.0002
INFO     |      - cvar_objective(1): 0.01000 (avg 0.01000)
INFO     |      - augmented_slack_lower_bound(1): 0.00017 (avg 0.00001)
INFO     |      - augmented_slack_nonneg(1): 0.00767 (avg 0.00064)
INFO     |      - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO     |      - augmented_no_short_constraint(1): 0.00000 (avg 0.00000)
INFO     |      accepted=False ATb_norm=2.06e-04 cost_prev=0.0178 cost_new=0.0178
INFO     |  AL update: snorm=1.0965e-06, csupn=1.0965e-06, max_rho=6.4000e+02
INFO     | Terminated @ iteration #46: cost=0.0100 criteria=[0 1 0], term_deltas=2.3e-04,6.0e-05,1.6e-05
# Extract solution.
optimal_weights = solution[weights_var]
optimal_var = float(solution[var_var][0])

# Compute CVaR from the solution.
portfolio_returns = returns @ optimal_weights
losses = -portfolio_returns
# CVaR is the mean of losses exceeding VaR.
tail_losses = jnp.where(losses >= optimal_var, losses, 0.0)
cvar_value = optimal_var + jnp.sum(jnp.maximum(losses - optimal_var, 0)) / (
    alpha * num_scenarios
)

print("\n=== CVaR-Optimal Portfolio ===")
print(f"\nAlpha (tail probability): {alpha:.0%}")
print("\nOptimal weights:")
for name, w in zip(stock_names, optimal_weights):
    print(f"  {name}: {float(w) * 100:.1f}%")
print(f"\nVaR (95%): {optimal_var * 100:.2f}% monthly loss")
print(
    f"CVaR (95%): {float(cvar_value) * 100:.2f}% expected loss in worst {alpha:.0%} of scenarios"
)
=== CVaR-Optimal Portfolio ===

Alpha (tail probability): 5%

Optimal weights:
  IBM: 13.2%
  WMT: 57.2%
  SEHI: 29.6%

VaR (95%): 10.00% monthly loss
CVaR (95%): 10.00% expected loss in worst 5% of scenarios

Comparison: CVaR vs mean-variance#

Let’s compare the CVaR-optimal portfolio with a minimum-variance portfolio.

# Compute covariance matrix for mean-variance comparison.
expected_returns = jnp.mean(returns, axis=0)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)
cov_chol = jnp.linalg.cholesky(covariance)


class MVWeightsVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.ones(3) / 3):
    """Portfolio weights for mean-variance optimization."""


@jaxls.Cost.factory
def variance_cost(
    vals: jaxls.VarValues, var: MVWeightsVar, cov_chol: jax.Array
) -> jax.Array:
    """Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
    return cov_chol.T @ vals[var]


@jaxls.Cost.factory(kind="constraint_eq_zero")
def mv_budget_constraint(vals: jaxls.VarValues, var: MVWeightsVar) -> jax.Array:
    """Weights must sum to 1."""
    return jnp.sum(vals[var]) - 1.0


@jaxls.Cost.factory(kind="constraint_geq_zero")
def mv_no_short_constraint(vals: jaxls.VarValues, var: MVWeightsVar) -> jax.Array:
    """No short-selling."""
    return vals[var]


mv_weights_var = MVWeightsVar(id=0)
mv_costs = [
    variance_cost(mv_weights_var, cov_chol),
    mv_budget_constraint(mv_weights_var),
    mv_no_short_constraint(mv_weights_var),
]

mv_problem = jaxls.LeastSquaresProblem(mv_costs, [mv_weights_var]).analyze()
mv_solution = mv_problem.solve(
    verbose=False,
    linear_solver="dense_cholesky",
    termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)

mv_weights = mv_solution[mv_weights_var]
INFO     | Building optimization problem with 3 terms and 1 variables: 1 costs, 1 eq_zero, 0 leq_zero, 1 geq_zero
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_mv_budget_constraint
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_mv_no_short_constraint
INFO     | Vectorizing group with 1 costs, 1 variables each: variance_cost

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML


def compute_metrics(weights: jax.Array) -> dict:
    """Compute risk metrics for a portfolio."""
    port_returns = returns @ weights
    losses = -port_returns

    # Variance and std dev.
    variance = float(weights @ covariance @ weights)
    std_dev = float(jnp.sqrt(variance))

    # VaR (95%) - the 95th percentile of losses.
    sorted_losses = jnp.sort(losses)
    var_95 = float(sorted_losses[int(num_scenarios * (1 - alpha))])

    # CVaR (95%) - mean of losses exceeding VaR.
    cvar_95 = float(
        var_95 + jnp.sum(jnp.maximum(losses - var_95, 0)) / (alpha * num_scenarios)
    )

    # Expected return.
    exp_return = float(jnp.dot(weights, expected_returns))

    return {
        "std_dev": std_dev,
        "var_95": var_95,
        "cvar_95": cvar_95,
        "exp_return": exp_return,
    }


cvar_metrics = compute_metrics(optimal_weights)
mv_metrics = compute_metrics(mv_weights)

print("=" * 50)
print(f"{'Metric':<25} {'CVaR-Opt':>12} {'Min-Var':>12}")
print("=" * 50)
print(
    f"{'Expected Return (monthly)':<25} {cvar_metrics['exp_return'] * 100:>11.2f}% {mv_metrics['exp_return'] * 100:>11.2f}%"
)
print(
    f"{'Std Dev (monthly)':<25} {cvar_metrics['std_dev'] * 100:>11.2f}% {mv_metrics['std_dev'] * 100:>11.2f}%"
)
print(
    f"{'VaR 95% (monthly loss)':<25} {cvar_metrics['var_95'] * 100:>11.2f}% {mv_metrics['var_95'] * 100:>11.2f}%"
)
print(
    f"{'CVaR 95% (monthly loss)':<25} {cvar_metrics['cvar_95'] * 100:>11.2f}% {mv_metrics['cvar_95'] * 100:>11.2f}%"
)
print("=" * 50)
==================================================
Metric                        CVaR-Opt      Min-Var
==================================================
Expected Return (monthly)        2.99%        1.26%
Std Dev (monthly)               10.37%        7.71%
VaR 95% (monthly loss)          10.00%       12.33%
CVaR 95% (monthly loss)         10.00%       12.33%
==================================================

Hide code cell source

colors = ["#2196F3", "#4CAF50", "#FF9800"]

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Portfolio Weights", "Return Distribution"),
    column_widths=[0.4, 0.6],
)

# Left plot: Weight comparison.
x_labels = stock_names
fig.add_trace(
    go.Bar(
        x=x_labels,
        y=optimal_weights * 100,
        name="CVaR-Optimal",
        marker_color="#E91E63",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        x=x_labels,
        y=mv_weights * 100,
        name="Min-Variance",
        marker_color="#3F51B5",
    ),
    row=1,
    col=1,
)

# Right plot: Return distributions.
cvar_returns = returns @ optimal_weights
mv_returns = returns @ mv_weights

# Sort for visualization.
sorted_idx = jnp.argsort(cvar_returns)
scenario_labels = [f"Scenario {i + 1}" for i in range(num_scenarios)]

fig.add_trace(
    go.Bar(
        x=list(range(num_scenarios)),
        y=cvar_returns[sorted_idx] * 100,
        name="CVaR-Optimal Returns",
        marker_color="#E91E63",
        opacity=0.7,
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        x=list(range(num_scenarios)),
        y=mv_returns[sorted_idx] * 100,
        name="Min-Variance Returns",
        marker_color="#3F51B5",
        opacity=0.7,
    ),
    row=1,
    col=2,
)

# Add VaR threshold line.
fig.add_hline(
    y=-cvar_metrics["var_95"] * 100,
    line_dash="dash",
    line_color="#E91E63",
    annotation_text="CVaR VaR threshold",
    row=1,
    col=2,
)

fig.update_xaxes(title_text="Asset", row=1, col=1)
fig.update_yaxes(title_text="Weight (%)", row=1, col=1)
fig.update_xaxes(title_text="Scenario (sorted by return)", row=1, col=2)
fig.update_yaxes(title_text="Monthly Return (%)", row=1, col=2)

fig.update_layout(
    barmode="group",
    height=400,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Key observations#

The CVaR-optimal portfolio differs from the minimum-variance portfolio:

  1. Different risk focus: CVaR optimization targets tail risk, while minimum-variance treats all deviations equally.

  2. Asset allocation: CVaR may allocate more to assets that have better worst-case behavior, even if they have higher overall variance.

  3. Scenario-based: CVaR uses historical scenarios directly, making no normality assumptions about returns.

For risk-averse investors concerned about extreme losses (e.g., pension funds, insurance companies), CVaR optimization provides a more relevant risk measure than variance.

For more details on constrained optimization in jaxls, see jaxls.Cost and jaxls.LeastSquaresProblem.