Reparameterized allocation#
We’ve now seen two approaches to portfolio allocation:
Mean-Variance Allocation: explicit constraints via Augmented Lagrangian
Manifold Allocation: constraints built into
retract_fn
Here we explore a third option: reparameterization. We transform the problem so constraints are automatically satisfied.
Instead of optimizing weights \(w\) subject to \(w_i \geq 0\) and \(\sum w_i = 1\), we optimize unconstrained logits \(\ell \in \mathbb{R}^n\) and compute weights as \(w = \text{softmax}(\ell)\).
Features used:
Varfor unconstrained logit variablesSoftmax reparameterization for automatic constraint satisfaction
import jax
import jax.numpy as jnp
import jaxls
Softmax reparameterization#
The softmax function maps unconstrained reals to the probability simplex:
This automatically ensures:
All weights are positive (\(\exp\) is always positive)
Weights sum to 1 (by construction)
No constraints or custom retractions needed, just standard unconstrained optimization.
Logit variable type#
We define a variable for unconstrained logits. The default is zeros, which maps to uniform weights via softmax:
n_assets = 3
class LogitsVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(n_assets)):
"""Unconstrained logits. Apply softmax to get portfolio weights."""
# Verify softmax of zeros gives uniform weights.
test_logits = jnp.zeros(n_assets)
print(f"Logits: {test_logits}")
print(f"Weights (softmax): {jax.nn.softmax(test_logits)}")
Logits: [0. 0. 0.]
Weights (softmax): [0.33333334 0.33333334 0.33333334]
Asset data#
We use the same historical stock data as the other portfolio notebooks.
stock_names = ["IBM", "WMT", "SEHI"]
# Monthly prices (13 months: Nov 2000 - Nov 2001).
prices = jnp.array(
[
[93.043, 51.826, 1.063],
[84.585, 52.823, 0.938],
[111.453, 56.477, 1.0],
[99.525, 49.805, 0.938],
[95.819, 50.287, 1.438],
[114.708, 51.521, 1.7],
[111.515, 51.531, 2.54],
[113.211, 48.664, 2.39],
[104.942, 55.744, 3.12],
[99.827, 47.916, 2.98],
[91.607, 49.438, 1.9],
[107.937, 51.336, 1.75],
[115.59, 55.081, 1.8],
]
)
# Compute returns and covariance.
returns = jnp.diff(prices, axis=0) / prices[:-1]
expected_returns = jnp.mean(returns, axis=0)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)
cov_chol = jnp.linalg.cholesky(covariance)
print("Expected monthly returns:")
for name, r in zip(stock_names, expected_returns):
print(f" {name}: {float(r) * 100:+.2f}%")
Expected monthly returns:
IBM: +2.60%
WMT: +0.81%
SEHI: +7.37%
Problem formulation#
The key difference: our cost functions take logits as input and apply softmax internally to get weights. No constraints needed.
logits_var = LogitsVar(id=0)
@jaxls.Cost.factory
def variance_cost(
vals: jaxls.VarValues, var: LogitsVar, cov_chol: jax.Array
) -> jax.Array:
"""Minimize portfolio variance."""
weights = jax.nn.softmax(vals[var]) # Convert logits to weights.
return cov_chol.T @ weights
@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
vals: jaxls.VarValues, var: LogitsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
"""Expected return must meet target."""
weights = jax.nn.softmax(vals[var]) # Convert logits to weights.
return jnp.array([jnp.dot(weights, exp_ret) - target])
Efficient frontier#
We compute the efficient frontier by solving for different target returns. Only one constraint (return target) is needed; budget and non-negativity are handled by softmax.
Using jax.lax.scan, we solve sequentially while using each solution as the
initial guess for the next (warm-starting).
min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return, max_return, 50)
def solve_for_target(
current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
"""Solve portfolio optimization for a given target return.
Args:
current_vals: Solution from previous target (used as initial guess).
target: Target return for this solve.
Returns:
Tuple of (solution values, optimal weights via softmax).
"""
costs = [
variance_cost(logits_var, cov_chol),
return_constraint(logits_var, expected_returns, target),
]
problem = jaxls.LeastSquaresProblem(costs, [logits_var]).analyze()
solution = problem.solve(
current_vals,
verbose=False,
linear_solver="dense_cholesky",
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
# Return weights (via softmax), not logits.
return solution, jax.nn.softmax(solution[logits_var])
# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([logits_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
variances = jax.vmap(lambda w: w @ covariance @ w)(all_weights)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)
print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO | Building optimization problem with 2 terms and 1 variables: 1 costs, 0 eq_zero, 0 leq_zero, 1 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
Computed 50 points on the efficient frontier
# Verify all solutions satisfy simplex constraints.
weight_sums = jnp.sum(all_weights, axis=1)
min_weights = jnp.min(all_weights, axis=1)
print(
f"Weight sums: min={float(weight_sums.min()):.6f}, max={float(weight_sums.max()):.6f}"
)
print(f"Min weight across all solutions: {float(min_weights.min()):.6f}")
print("Constraints satisfied by softmax construction!")
Weight sums: min=1.000000, max=1.000000
Min weight across all solutions: 0.000000
Constraints satisfied by softmax construction!
Results#
Comparison of approaches#
All three approaches produce the same efficient frontier:
Approach |
Variables |
Constraints |
Implementation |
|---|---|---|---|
Constrained |
Weights |
Budget, non-negativity, return |
Augmented Lagrangian |
Manifold |
Weights |
Return only |
Log-space retraction |
Reparameterized |
Logits |
Return only |
Softmax transformation |