Reparameterized allocation#

We’ve now seen two approaches to portfolio allocation:

Here we explore a third option: reparameterization. We transform the problem so constraints are automatically satisfied.

Instead of optimizing weights \(w\) subject to \(w_i \geq 0\) and \(\sum w_i = 1\), we optimize unconstrained logits \(\ell \in \mathbb{R}^n\) and compute weights as \(w = \text{softmax}(\ell)\).

Features used:

  • Var for unconstrained logit variables

  • Softmax reparameterization for automatic constraint satisfaction

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Softmax reparameterization#

The softmax function maps unconstrained reals to the probability simplex:

\[w_i = \frac{\exp(\ell_i)}{\sum_j \exp(\ell_j)}\]

This automatically ensures:

  • All weights are positive (\(\exp\) is always positive)

  • Weights sum to 1 (by construction)

No constraints or custom retractions needed, just standard unconstrained optimization.

Logit variable type#

We define a variable for unconstrained logits. The default is zeros, which maps to uniform weights via softmax:

n_assets = 3


class LogitsVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(n_assets)):
    """Unconstrained logits. Apply softmax to get portfolio weights."""


# Verify softmax of zeros gives uniform weights.
test_logits = jnp.zeros(n_assets)
print(f"Logits: {test_logits}")
print(f"Weights (softmax): {jax.nn.softmax(test_logits)}")
Logits: [0. 0. 0.]
Weights (softmax): [0.33333334 0.33333334 0.33333334]

Asset data#

We use the same historical stock data as the other portfolio notebooks.

stock_names = ["IBM", "WMT", "SEHI"]

# Monthly prices (13 months: Nov 2000 - Nov 2001).
prices = jnp.array(
    [
        [93.043, 51.826, 1.063],
        [84.585, 52.823, 0.938],
        [111.453, 56.477, 1.0],
        [99.525, 49.805, 0.938],
        [95.819, 50.287, 1.438],
        [114.708, 51.521, 1.7],
        [111.515, 51.531, 2.54],
        [113.211, 48.664, 2.39],
        [104.942, 55.744, 3.12],
        [99.827, 47.916, 2.98],
        [91.607, 49.438, 1.9],
        [107.937, 51.336, 1.75],
        [115.59, 55.081, 1.8],
    ]
)

# Compute returns and covariance.
returns = jnp.diff(prices, axis=0) / prices[:-1]
expected_returns = jnp.mean(returns, axis=0)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)
cov_chol = jnp.linalg.cholesky(covariance)

print("Expected monthly returns:")
for name, r in zip(stock_names, expected_returns):
    print(f"  {name}: {float(r) * 100:+.2f}%")
Expected monthly returns:
  IBM: +2.60%
  WMT: +0.81%
  SEHI: +7.37%

Problem formulation#

The key difference: our cost functions take logits as input and apply softmax internally to get weights. No constraints needed.

logits_var = LogitsVar(id=0)


@jaxls.Cost.factory
def variance_cost(
    vals: jaxls.VarValues, var: LogitsVar, cov_chol: jax.Array
) -> jax.Array:
    """Minimize portfolio variance."""
    weights = jax.nn.softmax(vals[var])  # Convert logits to weights.
    return cov_chol.T @ weights


@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
    vals: jaxls.VarValues, var: LogitsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
    """Expected return must meet target."""
    weights = jax.nn.softmax(vals[var])  # Convert logits to weights.
    return jnp.array([jnp.dot(weights, exp_ret) - target])

Efficient frontier#

We compute the efficient frontier by solving for different target returns. Only one constraint (return target) is needed; budget and non-negativity are handled by softmax.

Using jax.lax.scan, we solve sequentially while using each solution as the initial guess for the next (warm-starting).

min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return, max_return, 50)


def solve_for_target(
    current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
    """Solve portfolio optimization for a given target return.

    Args:
        current_vals: Solution from previous target (used as initial guess).
        target: Target return for this solve.

    Returns:
        Tuple of (solution values, optimal weights via softmax).
    """
    costs = [
        variance_cost(logits_var, cov_chol),
        return_constraint(logits_var, expected_returns, target),
    ]
    problem = jaxls.LeastSquaresProblem(costs, [logits_var]).analyze()
    solution = problem.solve(
        current_vals,
        verbose=False,
        linear_solver="dense_cholesky",
        termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
    )
    # Return weights (via softmax), not logits.
    return solution, jax.nn.softmax(solution[logits_var])


# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([logits_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
variances = jax.vmap(lambda w: w @ covariance @ w)(all_weights)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)

print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO     | Building optimization problem with 2 terms and 1 variables: 1 costs, 0 eq_zero, 0 leq_zero, 1 geq_zero
INFO     | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
Computed 50 points on the efficient frontier
# Verify all solutions satisfy simplex constraints.
weight_sums = jnp.sum(all_weights, axis=1)
min_weights = jnp.min(all_weights, axis=1)

print(
    f"Weight sums: min={float(weight_sums.min()):.6f}, max={float(weight_sums.max()):.6f}"
)
print(f"Min weight across all solutions: {float(min_weights.min()):.6f}")
print("Constraints satisfied by softmax construction!")
Weight sums: min=1.000000, max=1.000000
Min weight across all solutions: 0.000000
Constraints satisfied by softmax construction!

Results#

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML
import math

# Annualize returns and risk.
annual_factor = 12
risk_free_rate = 0.05  # 5% annual risk-free rate.

std_devs_annual = [float(jnp.sqrt(v)) * math.sqrt(annual_factor) for v in variances]
returns_annual = [float(r) * annual_factor for r in returns_achieved]
sharpe_ratios = [
    (r - risk_free_rate) / s if s > 0 else 0.0
    for r, s in zip(returns_annual, std_devs_annual)
]

colors = ["#2196F3", "#4CAF50", "#FF9800"]

fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Efficient Frontier", "Sharpe Ratio", "Asset Allocation"),
    column_widths=[0.33, 0.33, 0.34],
)

# Scale to $1000 investment.
std_dev_dollars = [s * 1000 for s in std_devs_annual]
return_dollars = [r * 1000 for r in returns_annual]
target_dollars = [float(t) * annual_factor * 1000 for t in target_returns]

# Left plot: Efficient frontier.
fig.add_trace(
    go.Scatter(
        x=std_dev_dollars,
        y=return_dollars,
        mode="lines+markers",
        marker=dict(size=6, color=sharpe_ratios, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Std Dev: $%{x:.0f}<br>Return: $%{y:.0f}<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Middle plot: Sharpe ratio vs return.
fig.add_trace(
    go.Scatter(
        x=return_dollars,
        y=sharpe_ratios,
        mode="lines+markers",
        marker=dict(size=6, color=sharpe_ratios, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Return: $%{x:.0f}<br>Sharpe: %{y:.2f}<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=2,
)

# Right plot: Asset allocation.
for i, (name, color) in enumerate(zip(stock_names, colors)):
    fig.add_trace(
        go.Bar(
            x=target_dollars,
            y=[float(w) * 1000 for w in all_weights[:, i]],
            name=name,
            marker_color=color,
            hovertemplate=f"{name}: $%{{y:.0f}}<extra></extra>",
        ),
        row=1,
        col=3,
    )

fig.update_xaxes(title_text="Std Dev ($)", row=1, col=1)
fig.update_yaxes(title_text="Return ($)", row=1, col=1)
fig.update_xaxes(title_text="Return ($)", row=1, col=2)
fig.update_yaxes(title_text="Sharpe Ratio", row=1, col=2)
fig.update_xaxes(title_text="Target Return ($)", row=1, col=3)
fig.update_yaxes(title_text="Investment ($)", range=[0, 1050], row=1, col=3)

fig.update_layout(
    barmode="stack",
    height=400,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.83),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Comparison of approaches#

All three approaches produce the same efficient frontier:

Approach

Variables

Constraints

Implementation

Constrained

Weights

Budget, non-negativity, return

Augmented Lagrangian

Manifold

Weights

Return only

Log-space retraction

Reparameterized

Logits

Return only

Softmax transformation