Constraints#
How to handle constraints in nonlinear least squares problems using jaxls. This guide covers equality and inequality constraints, with a portfolio optimization example.
Features used:
@jaxls.Cost.factorywithkindparameter for constraintsAugmentedLagrangianConfigfor solver tuning
import jax
import jax.numpy as jnp
import jaxls
Types of constraints#
jaxls supports three constraint types, specified via the kind parameter in @jaxls.Cost.factory:
Constraint Type |
|
Mathematical Form |
|---|---|---|
Equality |
|
\(h(x) = 0\) |
Inequality (upper bound) |
|
\(g(x) \leq 0\) |
Inequality (lower bound) |
|
\(g(x) \geq 0\) |
The default kind="l2_squared" creates a standard least-squares cost term.
Example: portfolio optimization#
We’ll optimize a portfolio of 4 assets to minimize variance (risk) subject to:
Budget constraint: weights sum to 1 (equality)
Return target: expected return >= minimum threshold (inequality)
No short-selling: all weights >= 0 (inequality)
# Asset data: 4 assets with expected returns and covariance.
n_assets = 4
asset_names = ["Tech", "Healthcare", "Energy", "Bonds"]
# Expected annual returns.
expected_returns = jnp.array([0.12, 0.08, 0.10, 0.04])
# Covariance matrix (annual).
covariance = jnp.array(
[
[0.04, 0.006, 0.010, -0.002],
[0.006, 0.025, 0.004, 0.001],
[0.010, 0.004, 0.035, -0.001],
[-0.002, 0.001, -0.001, 0.005],
]
)
print(
"Expected returns:", {n: f"{r:.1%}" for n, r in zip(asset_names, expected_returns)}
)
Expected returns: {'Tech': '12.0%', 'Healthcare': '8.0%', 'Energy': '10.0%', 'Bonds': '4.0%'}
# Define the portfolio weights variable.
class WeightsVar(
jaxls.Var[jax.Array], default_factory=lambda: jnp.ones(n_assets) / n_assets
):
"""Portfolio weights (n_assets-dimensional vector)."""
weights_var = WeightsVar(id=0)
Defining costs and constraints#
Objective: minimize variance#
Portfolio variance is \(w^T \Sigma w\). We use the Cholesky factor \(L\) where \(\Sigma = LL^T\), so minimizing \(\|L^T w\|^2\) is equivalent to minimizing variance.
# Cholesky decomposition for the variance cost.
cov_chol = jnp.linalg.cholesky(covariance)
@jaxls.Cost.factory # Default kind="l2_squared".
def variance_cost(
vals: jaxls.VarValues, var: WeightsVar, cov_chol: jax.Array
) -> jax.Array:
"""Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
return cov_chol.T @ vals[var]
Budget constraint (equality)#
The weights must sum to 1 (fully invested portfolio): \(\sum_i w_i = 1\)
We write this as \(h(w) = \sum_i w_i - 1 = 0\).
@jaxls.Cost.factory(kind="constraint_eq_zero")
def budget_constraint(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
"""Weights must sum to 1 (fully invested)."""
weights = vals[var]
return jnp.array([jnp.sum(weights) - 1.0])
Return target (inequality >= 0)#
Expected portfolio return must meet a minimum target: \(w^T \mu \geq r_{\text{target}}\)
We write this as \(g(w) = w^T \mu - r_{\text{target}} \geq 0\).
@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
vals: jaxls.VarValues, var: WeightsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
"""Expected return must meet target: E[r] >= target."""
weights = vals[var]
return jnp.array([jnp.dot(weights, exp_ret) - target])
No short-selling (inequality >= 0)#
All weights must be non-negative: \(w_i \geq 0\) for all \(i\)
This returns the weights directly as the constraint output.
@jaxls.Cost.factory(kind="constraint_geq_zero")
def no_short_selling(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
"""No short-selling: weights >= 0."""
return vals[var]
Solving the problem#
We’ll solve for a target return of 8% (between the lowest-return Bonds at 4% and highest-return Tech at 12%).
target_return = 0.08
costs = [
variance_cost(weights_var, cov_chol),
budget_constraint(weights_var),
return_constraint(weights_var, expected_returns, target_return),
no_short_selling(weights_var),
]
# Build problem (before .analyze() to visualize the problem structure structure).
unanalyzed_problem = jaxls.LeastSquaresProblem(costs, [weights_var])
# Visualize problem structure structure showing costs, constraints, and variables.
unanalyzed_problem.show()
# Analyze and solve the problem.
problem = unanalyzed_problem.analyze()
solution = problem.solve(
linear_solver="dense_cholesky",
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
INFO | Building optimization problem with 4 terms and 1 variables: 1 costs, 1 eq_zero, 0 leq_zero, 2 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_budget_constraint
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_no_short_selling
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0088 lambd=0.0005
INFO | - variance_cost(1): 0.00881 (avg 0.00220)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | step #1: cost=0.0088 lambd=0.0010
INFO | - variance_cost(1): 0.00881 (avg 0.00220)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | step #2: cost=0.0088 lambd=0.0020
INFO | - variance_cost(1): 0.00881 (avg 0.00220)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=2.50e-02 cost_prev=0.0088 cost_new=0.0085
INFO | step #3: cost=0.0036 lambd=0.0010
INFO | - variance_cost(1): 0.00356 (avg 0.00089)
INFO | - augmented_return_constraint(1): 0.00498 (avg 0.00498)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=4.96e-02 cost_prev=0.0085 cost_new=0.0057
INFO | step #4: cost=0.0049 lambd=0.0005
INFO | - variance_cost(1): 0.00489 (avg 0.00122)
INFO | - augmented_return_constraint(1): 0.00086 (avg 0.00086)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=1.86e-04 cost_prev=0.0057 cost_new=0.0057
INFO | AL update: snorm=9.0866e-03, csupn=9.0866e-03, max_rho=4.0000e+01
INFO | step #5: cost=0.0049 lambd=0.0003
INFO | - variance_cost(1): 0.00492 (avg 0.00123)
INFO | - augmented_return_constraint(1): 0.00516 (avg 0.00516)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=6.70e-02 cost_prev=0.0101 cost_new=0.0072
INFO | step #6: cost=0.0067 lambd=0.0001
INFO | - variance_cost(1): 0.00673 (avg 0.00168)
INFO | - augmented_return_constraint(1): 0.00048 (avg 0.00048)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=2.92e-05 cost_prev=0.0072 cost_new=0.0072
INFO | AL update: snorm=1.1763e-03, csupn=1.1763e-03, max_rho=4.0000e+01
INFO | step #7: cost=0.0067 lambd=0.0001
INFO | - variance_cost(1): 0.00673 (avg 0.00168)
INFO | - augmented_return_constraint(1): 0.00086 (avg 0.00086)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=4.85e-03 cost_prev=0.0076 cost_new=0.0075
INFO | AL update: snorm=1.5295e-04, csupn=1.5295e-04, max_rho=4.0000e+01
INFO | step #8: cost=0.0070 lambd=0.0000
INFO | - variance_cost(1): 0.00702 (avg 0.00175)
INFO | - augmented_return_constraint(1): 0.00056 (avg 0.00056)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=6.48e-04 cost_prev=0.0076 cost_new=0.0076
INFO | AL update: snorm=1.9595e-05, csupn=1.9595e-05, max_rho=4.0000e+01
INFO | step #9: cost=0.0071 lambd=0.0000
INFO | - variance_cost(1): 0.00706 (avg 0.00176)
INFO | - augmented_return_constraint(1): 0.00053 (avg 0.00053)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=8.60e-05 cost_prev=0.0076 cost_new=0.0076
INFO | AL update: snorm=2.5183e-06, csupn=2.5183e-06, max_rho=4.0000e+01
INFO | Terminated @ iteration #10: cost=0.0071 criteria=[0 1 0], term_deltas=7.0e-04,6.6e-05,6.0e-04
# Extract results.
optimal_weights = solution[weights_var]
print("\nOptimal allocation:")
for name, w in zip(asset_names, optimal_weights):
print(f" {name}: {float(w):.1%}")
portfolio_return = float(jnp.dot(optimal_weights, expected_returns))
portfolio_std = float(jnp.sqrt(optimal_weights @ covariance @ optimal_weights))
print("\nPortfolio metrics:")
print(f" Expected return: {portfolio_return:.2%}")
print(f" Std deviation: {portfolio_std:.2%}")
print(f" Weights sum: {float(jnp.sum(optimal_weights)):.6f}")
Optimal allocation:
Tech: 26.4%
Healthcare: 17.3%
Energy: 20.0%
Bonds: 36.4%
Portfolio metrics:
Expected return: 8.00%
Std deviation: 8.40%
Weights sum: 1.000000
Efficient frontier#
By varying the target return, we can trace out the efficient frontier.
Using jax.lax.scan, we solve sequentially while using each solution as the
initial guess for the next (warm-starting).
# Range of target returns.
min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return + 0.005, max_return - 0.005, 15)
def solve_for_target(
current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
"""Solve portfolio optimization for a given target return.
Args:
current_vals: Solution from previous target (used as initial guess).
target: Target return for this solve.
Returns:
Tuple of (solution values, optimal weights).
"""
costs = [
variance_cost(weights_var, cov_chol),
budget_constraint(weights_var),
return_constraint(weights_var, expected_returns, target),
no_short_selling(weights_var),
]
problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
solution = problem.solve(
current_vals,
linear_solver="dense_cholesky",
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
return solution, solution[weights_var]
# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([weights_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)
std_devs = jax.vmap(lambda w: jnp.sqrt(w @ covariance @ w))(all_weights)
print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO | Building optimization problem with 4 terms and 1 variables: 1 costs, 1 eq_zero, 0 leq_zero, 2 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_budget_constraint
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_no_short_selling
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0088 lambd=0.0005
INFO | - variance_cost(1): 0.00881 (avg 0.00220)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=2.50e-02 cost_prev=0.0088 cost_new=0.0035
INFO | step #1: cost=0.0035 lambd=0.0003
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=2.26e-04 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=3.5286e-04, csupn=3.5286e-04, max_rho=4.0000e+01
INFO | step #2: cost=0.0035 lambd=0.0001
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.50e-02 cost_prev=0.0035 cost_new=0.0035
INFO | step #3: cost=0.0035 lambd=0.0001
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=4.10e-06 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=0.0000e+00, csupn=0.0000e+00, max_rho=4.0000e+01
INFO | step #4: cost=0.0035 lambd=0.0000
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=False ATb_norm=1.78e-06 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=0.0000e+00, csupn=0.0000e+00, max_rho=4.0000e+01
INFO | Terminated @ iteration #5: cost=0.0035 criteria=[1 0 1], term_deltas=0.0e+00,8.9e-07,2.6e-08
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0035 lambd=0.0005
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=8.75e-03 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=3.5274e-04, csupn=3.5274e-04, max_rho=4.0000e+01
INFO | step #1: cost=0.0035 lambd=0.0003
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.50e-02 cost_prev=0.0035 cost_new=0.0035
INFO | step #2: cost=0.0035 lambd=0.0001
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=7.05e-06 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=0.0000e+00, csupn=0.0000e+00, max_rho=4.0000e+01
INFO | step #3: cost=0.0035 lambd=0.0001
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=False ATb_norm=4.73e-06 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=0.0000e+00, csupn=0.0000e+00, max_rho=4.0000e+01
INFO | Terminated @ iteration #4: cost=0.0035 criteria=[0 0 1], term_deltas=1.3e-07,2.4e-06,5.8e-08
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0035 lambd=0.0005
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=8.75e-03 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=3.5280e-04, csupn=3.5280e-04, max_rho=4.0000e+01
INFO | step #1: cost=0.0035 lambd=0.0003
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.50e-02 cost_prev=0.0035 cost_new=0.0035
INFO | step #2: cost=0.0035 lambd=0.0001
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.30e-06 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=0.0000e+00, csupn=0.0000e+00, max_rho=4.0000e+01
INFO | step #3: cost=0.0035 lambd=0.0001
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.26e-06 cost_prev=0.0035 cost_new=0.0035
INFO | AL update: snorm=5.9605e-08, csupn=5.9605e-08, max_rho=1.6000e+02
INFO | Terminated @ iteration #4: cost=0.0035 criteria=[0 0 1], term_deltas=6.6e-08,1.6e-06,3.5e-08
INFO | Augmented Lagrangian: initial snorm=4.3689e-03, csupn=4.3689e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0035 lambd=0.0005
INFO | - variance_cost(1): 0.00353 (avg 0.00088)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00019 (avg 0.00019)
INFO | accepted=True ATb_norm=3.24e-03 cost_prev=0.0037 cost_new=0.0036
INFO | AL update: snorm=1.6532e-03, csupn=1.6532e-03, max_rho=4.0000e+01
INFO | step #1: cost=0.0036 lambd=0.0003
INFO | - variance_cost(1): 0.00357 (avg 0.00089)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00011 (avg 0.00011)
INFO | accepted=True ATb_norm=2.93e-02 cost_prev=0.0037 cost_new=0.0037
INFO | step #2: cost=0.0036 lambd=0.0001
INFO | - variance_cost(1): 0.00361 (avg 0.00090)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | accepted=True ATb_norm=4.11e-06 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=5.9253e-04, csupn=5.9253e-04, max_rho=4.0000e+01
INFO | step #3: cost=0.0036 lambd=0.0001
INFO | - variance_cost(1): 0.00361 (avg 0.00090)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00008 (avg 0.00008)
INFO | accepted=True ATb_norm=6.31e-04 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=2.2139e-04, csupn=2.2139e-04, max_rho=4.0000e+01
INFO | step #4: cost=0.0036 lambd=0.0000
INFO | - variance_cost(1): 0.00363 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=True ATb_norm=2.27e-04 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=8.2355e-05, csupn=8.2355e-05, max_rho=4.0000e+01
INFO | step #5: cost=0.0036 lambd=0.0000
INFO | - variance_cost(1): 0.00364 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=True ATb_norm=7.84e-05 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=3.0681e-05, csupn=3.0681e-05, max_rho=4.0000e+01
INFO | step #6: cost=0.0036 lambd=0.0000
INFO | - variance_cost(1): 0.00364 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=True ATb_norm=3.80e-05 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=1.1444e-05, csupn=1.1444e-05, max_rho=1.6000e+02
INFO | step #7: cost=0.0036 lambd=0.0000
INFO | - variance_cost(1): 0.00364 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=True ATb_norm=7.73e-05 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=4.2655e-06, csupn=4.2655e-06, max_rho=1.6000e+02
INFO | step #8: cost=0.0036 lambd=0.0000
INFO | - variance_cost(1): 0.00364 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | step #9: cost=0.0036 lambd=0.0000
INFO | - variance_cost(1): 0.00364 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=False ATb_norm=5.58e-05 cost_prev=0.0037 cost_new=0.0037
INFO | AL update: snorm=4.2655e-06, csupn=4.2655e-06, max_rho=6.4000e+02
INFO | Terminated @ iteration #10: cost=0.0036 criteria=[0 1 0], term_deltas=3.8e-05,3.0e-05,6.9e-05
INFO | Augmented Lagrangian: initial snorm=5.0043e-03, csupn=5.0043e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0036 lambd=0.0005
INFO | - variance_cost(1): 0.00364 (avg 0.00091)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=1.77e-03 cost_prev=0.0039 cost_new=0.0039
INFO | AL update: snorm=3.5075e-03, csupn=3.5075e-03, max_rho=4.0000e+01
INFO | step #1: cost=0.0037 lambd=0.0003
INFO | - variance_cost(1): 0.00373 (avg 0.00093)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00077 (avg 0.00077)
INFO | accepted=True ATb_norm=4.63e-02 cost_prev=0.0045 cost_new=0.0041
INFO | step #2: cost=0.0040 lambd=0.0001
INFO | - variance_cost(1): 0.00400 (avg 0.00100)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | step #3: cost=0.0040 lambd=0.0003
INFO | - variance_cost(1): 0.00400 (avg 0.00100)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=True ATb_norm=1.13e-05 cost_prev=0.0041 cost_new=0.0041
INFO | AL update: snorm=4.4919e-04, csupn=4.4919e-04, max_rho=4.0000e+01
INFO | step #4: cost=0.0040 lambd=0.0001
INFO | - variance_cost(1): 0.00400 (avg 0.00100)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00013 (avg 0.00013)
INFO | accepted=True ATb_norm=1.87e-03 cost_prev=0.0041 cost_new=0.0041
INFO | AL update: snorm=5.8517e-05, csupn=5.8517e-05, max_rho=4.0000e+01
INFO | step #5: cost=0.0040 lambd=0.0001
INFO | - variance_cost(1): 0.00404 (avg 0.00101)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00008 (avg 0.00008)
INFO | accepted=True ATb_norm=2.35e-04 cost_prev=0.0041 cost_new=0.0041
INFO | AL update: snorm=7.3835e-06, csupn=7.3835e-06, max_rho=4.0000e+01
INFO | step #6: cost=0.0041 lambd=0.0000
INFO | - variance_cost(1): 0.00405 (avg 0.00101)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00008 (avg 0.00008)
INFO | accepted=True ATb_norm=2.24e-05 cost_prev=0.0041 cost_new=0.0041
INFO | AL update: snorm=9.4622e-07, csupn=9.4622e-07, max_rho=4.0000e+01
INFO | step #7: cost=0.0041 lambd=0.0000
INFO | - variance_cost(1): 0.00405 (avg 0.00101)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00008 (avg 0.00008)
INFO | accepted=True ATb_norm=1.57e-05 cost_prev=0.0041 cost_new=0.0041
INFO | AL update: snorm=1.1921e-07, csupn=1.1921e-07, max_rho=4.0000e+01
INFO | step #8: cost=0.0041 lambd=0.0000
INFO | - variance_cost(1): 0.00405 (avg 0.00101)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00008 (avg 0.00008)
INFO | step #9: cost=0.0041 lambd=0.0000
INFO | - variance_cost(1): 0.00405 (avg 0.00101)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00008 (avg 0.00008)
INFO | accepted=False ATb_norm=3.21e-06 cost_prev=0.0041 cost_new=0.0041
INFO | AL update: snorm=1.1921e-07, csupn=1.1921e-07, max_rho=1.6000e+02
INFO | Terminated @ iteration #10: cost=0.0041 criteria=[0 1 0], term_deltas=2.9e-06,1.8e-06,2.9e-06
INFO | Augmented Lagrangian: initial snorm=5.0001e-03, csupn=5.0001e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0041 lambd=0.0005
INFO | - variance_cost(1): 0.00405 (avg 0.00101)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=2.31e-03 cost_prev=0.0043 cost_new=0.0043
INFO | AL update: snorm=5.3576e-03, csupn=5.3576e-03, max_rho=4.0000e+01
INFO | step #1: cost=0.0040 lambd=0.0003
INFO | - variance_cost(1): 0.00401 (avg 0.00100)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00179 (avg 0.00179)
INFO | accepted=True ATb_norm=5.30e-02 cost_prev=0.0058 cost_new=0.0048
INFO | step #2: cost=0.0046 lambd=0.0001
INFO | - variance_cost(1): 0.00464 (avg 0.00116)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00017 (avg 0.00017)
INFO | accepted=True ATb_norm=1.71e-05 cost_prev=0.0048 cost_new=0.0048
INFO | AL update: snorm=6.9366e-04, csupn=6.9366e-04, max_rho=1.6000e+02
INFO | step #3: cost=0.0046 lambd=0.0001
INFO | - variance_cost(1): 0.00464 (avg 0.00116)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00030 (avg 0.00030)
INFO | accepted=True ATb_norm=9.69e-03 cost_prev=0.0049 cost_new=0.0049
INFO | AL update: snorm=9.0286e-05, csupn=9.0286e-05, max_rho=1.6000e+02
INFO | step #4: cost=0.0047 lambd=0.0000
INFO | - variance_cost(1): 0.00474 (avg 0.00119)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00020 (avg 0.00020)
INFO | accepted=True ATb_norm=4.10e-04 cost_prev=0.0049 cost_new=0.0049
INFO | AL update: snorm=1.1481e-05, csupn=1.1481e-05, max_rho=1.6000e+02
INFO | step #5: cost=0.0048 lambd=0.0000
INFO | - variance_cost(1): 0.00476 (avg 0.00119)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00018 (avg 0.00018)
INFO | accepted=True ATb_norm=3.42e-05 cost_prev=0.0049 cost_new=0.0049
INFO | AL update: snorm=1.4752e-06, csupn=1.4752e-06, max_rho=1.6000e+02
INFO | step #6: cost=0.0048 lambd=0.0000
INFO | - variance_cost(1): 0.00476 (avg 0.00119)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00018 (avg 0.00018)
INFO | accepted=True ATb_norm=1.36e-05 cost_prev=0.0049 cost_new=0.0049
INFO | AL update: snorm=1.7881e-07, csupn=1.7881e-07, max_rho=1.6000e+02
INFO | step #7: cost=0.0048 lambd=0.0000
INFO | - variance_cost(1): 0.00476 (avg 0.00119)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00018 (avg 0.00018)
INFO | accepted=True ATb_norm=2.97e-06 cost_prev=0.0049 cost_new=0.0049
INFO | AL update: snorm=2.9802e-08, csupn=2.9802e-08, max_rho=1.6000e+02
INFO | step #8: cost=0.0048 lambd=0.0000
INFO | - variance_cost(1): 0.00476 (avg 0.00119)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00018 (avg 0.00018)
INFO | accepted=False ATb_norm=1.86e-06 cost_prev=0.0049 cost_new=0.0049
INFO | AL update: snorm=2.9802e-08, csupn=2.9802e-08, max_rho=1.6000e+02
INFO | Terminated @ iteration #9: cost=0.0048 criteria=[0 0 1], term_deltas=1.1e-06,9.8e-07,8.4e-07
INFO | Augmented Lagrangian: initial snorm=5.0000e-03, csupn=5.0000e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0048 lambd=0.0005
INFO | - variance_cost(1): 0.00476 (avg 0.00119)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=5.14e-03 cost_prev=0.0050 cost_new=0.0049
INFO | AL update: snorm=7.2076e-03, csupn=7.2076e-03, max_rho=4.0000e+01
INFO | step #1: cost=0.0044 lambd=0.0003
INFO | - variance_cost(1): 0.00441 (avg 0.00110)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00325 (avg 0.00325)
INFO | accepted=True ATb_norm=5.99e-02 cost_prev=0.0077 cost_new=0.0058
INFO | step #2: cost=0.0055 lambd=0.0001
INFO | - variance_cost(1): 0.00555 (avg 0.00139)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00030 (avg 0.00030)
INFO | accepted=True ATb_norm=2.31e-05 cost_prev=0.0058 cost_new=0.0058
INFO | AL update: snorm=9.3812e-04, csupn=9.3812e-04, max_rho=1.6000e+02
INFO | step #3: cost=0.0055 lambd=0.0001
INFO | - variance_cost(1): 0.00555 (avg 0.00139)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00054 (avg 0.00054)
INFO | accepted=True ATb_norm=1.32e-02 cost_prev=0.0061 cost_new=0.0061
INFO | step #4: cost=0.0057 lambd=0.0000
INFO | - variance_cost(1): 0.00573 (avg 0.00143)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00033 (avg 0.00033)
INFO | step #5: cost=0.0057 lambd=0.0001
INFO | - variance_cost(1): 0.00573 (avg 0.00143)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00033 (avg 0.00033)
INFO | step #6: cost=0.0057 lambd=0.0001
INFO | - variance_cost(1): 0.00573 (avg 0.00143)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00033 (avg 0.00033)
INFO | step #7: cost=0.0057 lambd=0.0003
INFO | - variance_cost(1): 0.00573 (avg 0.00143)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00033 (avg 0.00033)
INFO | step #8: cost=0.0057 lambd=0.0005
INFO | - variance_cost(1): 0.00573 (avg 0.00143)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00033 (avg 0.00033)
INFO | accepted=True ATb_norm=4.31e-06 cost_prev=0.0061 cost_new=0.0061
INFO | AL update: snorm=1.2180e-04, csupn=1.2180e-04, max_rho=1.6000e+02
INFO | step #9: cost=0.0057 lambd=0.0003
INFO | - variance_cost(1): 0.00573 (avg 0.00143)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00036 (avg 0.00036)
INFO | accepted=True ATb_norm=5.11e-04 cost_prev=0.0061 cost_new=0.0061
INFO | AL update: snorm=1.5900e-05, csupn=1.5900e-05, max_rho=1.6000e+02
INFO | step #10: cost=0.0058 lambd=0.0001
INFO | - variance_cost(1): 0.00576 (avg 0.00144)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00033 (avg 0.00033)
INFO | accepted=True ATb_norm=5.83e-05 cost_prev=0.0061 cost_new=0.0061
INFO | AL update: snorm=1.9372e-06, csupn=1.9372e-06, max_rho=1.6000e+02
INFO | Terminated @ iteration #11: cost=0.0058 criteria=[0 1 0], term_deltas=5.6e-04,4.5e-05,4.7e-04
INFO | Augmented Lagrangian: initial snorm=5.0019e-03, csupn=5.0019e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0058 lambd=0.0005
INFO | - variance_cost(1): 0.00576 (avg 0.00144)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=8.14e-03 cost_prev=0.0060 cost_new=0.0057
INFO | AL update: snorm=9.0577e-03, csupn=9.0577e-03, max_rho=4.0000e+01
INFO | step #1: cost=0.0049 lambd=0.0003
INFO | - variance_cost(1): 0.00492 (avg 0.00123)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00513 (avg 0.00513)
INFO | accepted=True ATb_norm=6.70e-02 cost_prev=0.0101 cost_new=0.0072
INFO | step #2: cost=0.0067 lambd=0.0001
INFO | - variance_cost(1): 0.00672 (avg 0.00168)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00048 (avg 0.00048)
INFO | accepted=True ATb_norm=2.90e-05 cost_prev=0.0072 cost_new=0.0072
INFO | AL update: snorm=1.1826e-03, csupn=1.1826e-03, max_rho=4.0000e+01
INFO | step #3: cost=0.0067 lambd=0.0001
INFO | - variance_cost(1): 0.00673 (avg 0.00168)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00086 (avg 0.00086)
INFO | accepted=True ATb_norm=4.86e-03 cost_prev=0.0076 cost_new=0.0075
INFO | AL update: snorm=1.5377e-04, csupn=1.5377e-04, max_rho=4.0000e+01
INFO | step #4: cost=0.0070 lambd=0.0000
INFO | - variance_cost(1): 0.00702 (avg 0.00175)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00056 (avg 0.00056)
INFO | accepted=True ATb_norm=6.32e-04 cost_prev=0.0076 cost_new=0.0076
INFO | AL update: snorm=1.9699e-05, csupn=1.9699e-05, max_rho=4.0000e+01
INFO | step #5: cost=0.0071 lambd=0.0000
INFO | - variance_cost(1): 0.00706 (avg 0.00176)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00053 (avg 0.00053)
INFO | accepted=True ATb_norm=8.24e-05 cost_prev=0.0076 cost_new=0.0076
INFO | AL update: snorm=2.5332e-06, csupn=2.5332e-06, max_rho=4.0000e+01
INFO | step #6: cost=0.0071 lambd=0.0000
INFO | - variance_cost(1): 0.00706 (avg 0.00177)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00053 (avg 0.00053)
INFO | accepted=True ATb_norm=1.30e-05 cost_prev=0.0076 cost_new=0.0076
INFO | AL update: snorm=3.2783e-07, csupn=3.2783e-07, max_rho=4.0000e+01
INFO | step #7: cost=0.0071 lambd=0.0000
INFO | - variance_cost(1): 0.00706 (avg 0.00177)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00053 (avg 0.00053)
INFO | step #8: cost=0.0071 lambd=0.0000
INFO | - variance_cost(1): 0.00706 (avg 0.00177)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00053 (avg 0.00053)
INFO | step #9: cost=0.0071 lambd=0.0000
INFO | - variance_cost(1): 0.00706 (avg 0.00177)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00053 (avg 0.00053)
INFO | accepted=False ATb_norm=5.79e-06 cost_prev=0.0076 cost_new=0.0076
INFO | AL update: snorm=3.2783e-07, csupn=3.2783e-07, max_rho=1.6000e+02
INFO | Terminated @ iteration #10: cost=0.0071 criteria=[0 1 0], term_deltas=1.2e-05,3.4e-06,9.8e-06
INFO | Augmented Lagrangian: initial snorm=5.0003e-03, csupn=5.0003e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0071 lambd=0.0005
INFO | - variance_cost(1): 0.00706 (avg 0.00177)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=1.12e-02 cost_prev=0.0073 cost_new=0.0067
INFO | step #1: cost=0.0056 lambd=0.0003
INFO | - variance_cost(1): 0.00556 (avg 0.00139)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00119 (avg 0.00119)
INFO | accepted=True ATb_norm=4.34e-05 cost_prev=0.0067 cost_new=0.0067
INFO | AL update: snorm=1.0948e-02, csupn=1.0948e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0055 lambd=0.0001
INFO | - variance_cost(1): 0.00555 (avg 0.00139)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00749 (avg 0.00749)
INFO | accepted=True ATb_norm=7.42e-02 cost_prev=0.0130 cost_new=0.0089
INFO | step #3: cost=0.0082 lambd=0.0001
INFO | - variance_cost(1): 0.00817 (avg 0.00204)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00069 (avg 0.00069)
INFO | accepted=True ATb_norm=1.76e-05 cost_prev=0.0089 cost_new=0.0089
INFO | AL update: snorm=1.4184e-03, csupn=1.4184e-03, max_rho=4.0000e+01
INFO | step #4: cost=0.0082 lambd=0.0000
INFO | - variance_cost(1): 0.00817 (avg 0.00204)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00124 (avg 0.00124)
INFO | accepted=True ATb_norm=5.86e-03 cost_prev=0.0094 cost_new=0.0093
INFO | AL update: snorm=1.8425e-04, csupn=1.8425e-04, max_rho=4.0000e+01
INFO | step #5: cost=0.0086 lambd=0.0000
INFO | - variance_cost(1): 0.00859 (avg 0.00215)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00082 (avg 0.00082)
INFO | accepted=True ATb_norm=7.61e-04 cost_prev=0.0094 cost_new=0.0094
INFO | AL update: snorm=2.3782e-05, csupn=2.3782e-05, max_rho=4.0000e+01
INFO | step #6: cost=0.0087 lambd=0.0000
INFO | - variance_cost(1): 0.00865 (avg 0.00216)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00077 (avg 0.00077)
INFO | accepted=True ATb_norm=8.45e-05 cost_prev=0.0094 cost_new=0.0094
INFO | AL update: snorm=3.0696e-06, csupn=3.0696e-06, max_rho=4.0000e+01
INFO | step #7: cost=0.0087 lambd=0.0000
INFO | - variance_cost(1): 0.00866 (avg 0.00216)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00076 (avg 0.00076)
INFO | step #8: cost=0.0087 lambd=0.0000
INFO | - variance_cost(1): 0.00866 (avg 0.00216)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00076 (avg 0.00076)
INFO | step #9: cost=0.0087 lambd=0.0000
INFO | - variance_cost(1): 0.00866 (avg 0.00216)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00076 (avg 0.00076)
INFO | accepted=False ATb_norm=3.31e-05 cost_prev=0.0094 cost_new=0.0094
INFO | AL update: snorm=3.0696e-06, csupn=3.0696e-06, max_rho=1.6000e+02
INFO | Terminated @ iteration #10: cost=0.0087 criteria=[0 1 0], term_deltas=1.1e-04,2.1e-05,9.6e-05
INFO | Augmented Lagrangian: initial snorm=5.0031e-03, csupn=5.0031e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0087 lambd=0.0005
INFO | - variance_cost(1): 0.00866 (avg 0.00216)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=1.42e-02 cost_prev=0.0089 cost_new=0.0079
INFO | step #1: cost=0.0063 lambd=0.0003
INFO | - variance_cost(1): 0.00630 (avg 0.00158)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00163 (avg 0.00163)
INFO | accepted=True ATb_norm=5.70e-05 cost_prev=0.0079 cost_new=0.0079
INFO | AL update: snorm=1.2810e-02, csupn=1.2810e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0063 lambd=0.0001
INFO | - variance_cost(1): 0.00629 (avg 0.00157)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01026 (avg 0.01026)
INFO | accepted=True ATb_norm=8.14e-02 cost_prev=0.0166 cost_new=0.0108
INFO | step #3: cost=0.0099 lambd=0.0001
INFO | - variance_cost(1): 0.00989 (avg 0.00247)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00095 (avg 0.00095)
INFO | accepted=True ATb_norm=2.24e-05 cost_prev=0.0108 cost_new=0.0108
INFO | AL update: snorm=1.6602e-03, csupn=1.6602e-03, max_rho=4.0000e+01
INFO | step #4: cost=0.0099 lambd=0.0000
INFO | - variance_cost(1): 0.00989 (avg 0.00247)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00170 (avg 0.00170)
INFO | accepted=True ATb_norm=6.85e-03 cost_prev=0.0116 cost_new=0.0115
INFO | AL update: snorm=2.1566e-04, csupn=2.1566e-04, max_rho=4.0000e+01
INFO | step #5: cost=0.0105 lambd=0.0000
INFO | - variance_cost(1): 0.01047 (avg 0.00262)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00112 (avg 0.00112)
INFO | accepted=True ATb_norm=8.85e-04 cost_prev=0.0116 cost_new=0.0116
INFO | AL update: snorm=2.7820e-05, csupn=2.7820e-05, max_rho=4.0000e+01
INFO | step #6: cost=0.0105 lambd=0.0000
INFO | - variance_cost(1): 0.01054 (avg 0.00264)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00105 (avg 0.00105)
INFO | accepted=True ATb_norm=1.28e-04 cost_prev=0.0116 cost_new=0.0116
INFO | AL update: snorm=3.5986e-06, csupn=3.5986e-06, max_rho=4.0000e+01
INFO | step #7: cost=0.0106 lambd=0.0000
INFO | - variance_cost(1): 0.01055 (avg 0.00264)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00105 (avg 0.00105)
INFO | accepted=True ATb_norm=1.36e-05 cost_prev=0.0116 cost_new=0.0116
INFO | AL update: snorm=4.6194e-07, csupn=4.6194e-07, max_rho=4.0000e+01
INFO | step #8: cost=0.0106 lambd=0.0000
INFO | - variance_cost(1): 0.01055 (avg 0.00264)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00104 (avg 0.00104)
INFO | step #9: cost=0.0106 lambd=0.0000
INFO | - variance_cost(1): 0.01055 (avg 0.00264)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00104 (avg 0.00104)
INFO | accepted=False ATb_norm=8.47e-06 cost_prev=0.0116 cost_new=0.0116
INFO | AL update: snorm=4.6194e-07, csupn=4.6194e-07, max_rho=1.6000e+02
INFO | Terminated @ iteration #10: cost=0.0106 criteria=[0 1 0], term_deltas=1.6e-05,5.0e-06,1.4e-05
INFO | Augmented Lagrangian: initial snorm=5.0005e-03, csupn=5.0005e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0106 lambd=0.0005
INFO | - variance_cost(1): 0.01055 (avg 0.00264)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=1.73e-02 cost_prev=0.0108 cost_new=0.0093
INFO | step #1: cost=0.0072 lambd=0.0003
INFO | - variance_cost(1): 0.00717 (avg 0.00179)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00213 (avg 0.00213)
INFO | accepted=True ATb_norm=7.06e-05 cost_prev=0.0093 cost_new=0.0093
INFO | AL update: snorm=1.4673e-02, csupn=1.4673e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0072 lambd=0.0001
INFO | - variance_cost(1): 0.00715 (avg 0.00179)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01346 (avg 0.01346)
INFO | accepted=True ATb_norm=8.88e-02 cost_prev=0.0206 cost_new=0.0131
INFO | step #3: cost=0.0119 lambd=0.0001
INFO | - variance_cost(1): 0.01187 (avg 0.00297)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00124 (avg 0.00124)
INFO | accepted=True ATb_norm=2.37e-05 cost_prev=0.0131 cost_new=0.0131
INFO | AL update: snorm=1.9019e-03, csupn=1.9019e-03, max_rho=4.0000e+01
INFO | step #4: cost=0.0119 lambd=0.0000
INFO | - variance_cost(1): 0.01188 (avg 0.00297)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00223 (avg 0.00223)
INFO | accepted=True ATb_norm=7.85e-03 cost_prev=0.0141 cost_new=0.0140
INFO | AL update: snorm=2.4708e-04, csupn=2.4708e-04, max_rho=4.0000e+01
INFO | step #5: cost=0.0126 lambd=0.0000
INFO | - variance_cost(1): 0.01263 (avg 0.00316)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00147 (avg 0.00147)
INFO | accepted=True ATb_norm=1.04e-03 cost_prev=0.0141 cost_new=0.0141
INFO | AL update: snorm=3.1859e-05, csupn=3.1859e-05, max_rho=4.0000e+01
INFO | step #6: cost=0.0127 lambd=0.0000
INFO | - variance_cost(1): 0.01273 (avg 0.00318)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00138 (avg 0.00138)
INFO | accepted=True ATb_norm=1.08e-04 cost_prev=0.0141 cost_new=0.0141
INFO | AL update: snorm=4.1276e-06, csupn=4.1276e-06, max_rho=4.0000e+01
INFO | step #7: cost=0.0127 lambd=0.0000
INFO | - variance_cost(1): 0.01275 (avg 0.00319)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00137 (avg 0.00137)
INFO | step #8: cost=0.0127 lambd=0.0000
INFO | - variance_cost(1): 0.01275 (avg 0.00319)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00137 (avg 0.00137)
INFO | step #9: cost=0.0127 lambd=0.0000
INFO | - variance_cost(1): 0.01275 (avg 0.00319)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00137 (avg 0.00137)
INFO | accepted=False ATb_norm=4.61e-05 cost_prev=0.0141 cost_new=0.0141
INFO | AL update: snorm=4.1276e-06, csupn=4.1276e-06, max_rho=1.6000e+02
INFO | Terminated @ iteration #10: cost=0.0127 criteria=[0 1 0], term_deltas=1.3e-04,2.9e-05,1.2e-04
INFO | Augmented Lagrangian: initial snorm=5.0041e-03, csupn=5.0041e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0127 lambd=0.0005
INFO | - variance_cost(1): 0.01275 (avg 0.00319)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=2.04e-02 cost_prev=0.0130 cost_new=0.0109
INFO | step #1: cost=0.0082 lambd=0.0003
INFO | - variance_cost(1): 0.00816 (avg 0.00204)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00271 (avg 0.00271)
INFO | accepted=True ATb_norm=8.42e-05 cost_prev=0.0109 cost_new=0.0109
INFO | AL update: snorm=1.6535e-02, csupn=1.6535e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0081 lambd=0.0001
INFO | - variance_cost(1): 0.00813 (avg 0.00203)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01709 (avg 0.01709)
INFO | accepted=True ATb_norm=9.62e-02 cost_prev=0.0252 cost_new=0.0157
INFO | step #3: cost=0.0141 lambd=0.0001
INFO | - variance_cost(1): 0.01413 (avg 0.00353)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00158 (avg 0.00158)
INFO | accepted=True ATb_norm=2.81e-05 cost_prev=0.0157 cost_new=0.0157
INFO | AL update: snorm=2.1437e-03, csupn=2.1437e-03, max_rho=4.0000e+01
INFO | step #4: cost=0.0141 lambd=0.0000
INFO | - variance_cost(1): 0.01413 (avg 0.00353)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00284 (avg 0.00284)
INFO | accepted=True ATb_norm=8.85e-03 cost_prev=0.0170 cost_new=0.0168
INFO | AL update: snorm=2.7848e-04, csupn=2.7848e-04, max_rho=4.0000e+01
INFO | step #5: cost=0.0151 lambd=0.0000
INFO | - variance_cost(1): 0.01509 (avg 0.00377)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00187 (avg 0.00187)
INFO | accepted=True ATb_norm=1.15e-03 cost_prev=0.0170 cost_new=0.0170
INFO | AL update: snorm=3.5904e-05, csupn=3.5904e-05, max_rho=4.0000e+01
INFO | step #6: cost=0.0152 lambd=0.0000
INFO | - variance_cost(1): 0.01522 (avg 0.00380)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00176 (avg 0.00176)
INFO | accepted=True ATb_norm=1.60e-04 cost_prev=0.0170 cost_new=0.0170
INFO | AL update: snorm=4.6566e-06, csupn=4.6566e-06, max_rho=4.0000e+01
INFO | step #7: cost=0.0152 lambd=0.0000
INFO | - variance_cost(1): 0.01524 (avg 0.00381)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00174 (avg 0.00174)
INFO | accepted=True ATb_norm=1.37e-05 cost_prev=0.0170 cost_new=0.0170
INFO | AL update: snorm=5.9605e-07, csupn=5.9605e-07, max_rho=4.0000e+01
INFO | step #8: cost=0.0152 lambd=0.0000
INFO | - variance_cost(1): 0.01524 (avg 0.00381)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00174 (avg 0.00174)
INFO | step #9: cost=0.0152 lambd=0.0000
INFO | - variance_cost(1): 0.01524 (avg 0.00381)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00174 (avg 0.00174)
INFO | accepted=False ATb_norm=5.35e-06 cost_prev=0.0170 cost_new=0.0170
INFO | AL update: snorm=5.9605e-07, csupn=5.9605e-07, max_rho=1.6000e+02
INFO | Terminated @ iteration #10: cost=0.0152 criteria=[0 1 0], term_deltas=1.8e-05,3.5e-06,1.7e-05
INFO | Augmented Lagrangian: initial snorm=5.0006e-03, csupn=5.0006e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0152 lambd=0.0005
INFO | - variance_cost(1): 0.01524 (avg 0.00381)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=2.34e-02 cost_prev=0.0155 cost_new=0.0126
INFO | step #1: cost=0.0093 lambd=0.0003
INFO | - variance_cost(1): 0.00926 (avg 0.00231)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00335 (avg 0.00335)
INFO | accepted=True ATb_norm=9.78e-05 cost_prev=0.0126 cost_new=0.0126
INFO | AL update: snorm=1.8398e-02, csupn=1.8398e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0092 lambd=0.0001
INFO | - variance_cost(1): 0.00923 (avg 0.00231)
INFO | - augmented_budget_constraint(1): 0.00003 (avg 0.00003)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02115 (avg 0.02115)
INFO | accepted=True ATb_norm=1.04e-01 cost_prev=0.0304 cost_new=0.0186
INFO | step #3: cost=0.0166 lambd=0.0001
INFO | - variance_cost(1): 0.01665 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00196 (avg 0.00196)
INFO | accepted=True ATb_norm=2.94e-05 cost_prev=0.0186 cost_new=0.0186
INFO | AL update: snorm=2.3855e-03, csupn=2.3855e-03, max_rho=4.0000e+01
INFO | step #4: cost=0.0167 lambd=0.0000
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #5: cost=0.0167 lambd=0.0001
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #6: cost=0.0167 lambd=0.0001
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #7: cost=0.0167 lambd=0.0003
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #8: cost=0.0167 lambd=0.0005
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #9: cost=0.0167 lambd=0.0010
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #10: cost=0.0167 lambd=0.0020
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #11: cost=0.0167 lambd=0.0040
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #12: cost=0.0167 lambd=0.0080
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #13: cost=0.0167 lambd=0.0160
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #14: cost=0.0167 lambd=0.0320
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #15: cost=0.0167 lambd=0.0640
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #16: cost=0.0167 lambd=0.1280
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | step #17: cost=0.0167 lambd=0.2560
INFO | - variance_cost(1): 0.01666 (avg 0.00416)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00351 (avg 0.00351)
INFO | accepted=True ATb_norm=9.86e-03 cost_prev=0.0202 cost_new=0.0201
INFO | AL update: snorm=3.3871e-03, csupn=3.3871e-03, max_rho=1.6000e+02
INFO | step #18: cost=0.0172 lambd=0.1280
INFO | - variance_cost(1): 0.01723 (avg 0.00431)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00072 (avg 0.00018)
INFO | - augmented_return_constraint(1): 0.00192 (avg 0.00192)
INFO | accepted=True ATb_norm=1.96e-01 cost_prev=0.0199 cost_new=0.0189
INFO | step #19: cost=0.0177 lambd=0.0640
INFO | - variance_cost(1): 0.01772 (avg 0.00443)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00115 (avg 0.00115)
INFO | accepted=True ATb_norm=3.78e-03 cost_prev=0.0189 cost_new=0.0188
INFO | AL update: snorm=7.4049e-04, csupn=1.6312e-04, max_rho=1.6000e+02
INFO | step #20: cost=0.0180 lambd=0.0320
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #21: cost=0.0180 lambd=0.0640
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #22: cost=0.0180 lambd=0.1280
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #23: cost=0.0180 lambd=0.2560
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #24: cost=0.0180 lambd=0.5120
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #25: cost=0.0180 lambd=1.0240
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #26: cost=0.0180 lambd=2.0480
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #27: cost=0.0180 lambd=4.0960
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | step #28: cost=0.0180 lambd=8.1920
INFO | - variance_cost(1): 0.01800 (avg 0.00450)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00093 (avg 0.00093)
INFO | accepted=True ATb_norm=8.19e-03 cost_prev=0.0189 cost_new=0.0189
INFO | AL update: snorm=1.1352e-04, csupn=1.1352e-04, max_rho=6.4000e+02
INFO | step #29: cost=0.0180 lambd=4.0960
INFO | - variance_cost(1): 0.01803 (avg 0.00451)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00032 (avg 0.00032)
INFO | accepted=True ATb_norm=1.90e-02 cost_prev=0.0184 cost_new=0.0183
INFO | step #30: cost=0.0181 lambd=2.0480
INFO | - variance_cost(1): 0.01805 (avg 0.00451)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00029 (avg 0.00029)
INFO | accepted=True ATb_norm=3.23e-03 cost_prev=0.0183 cost_new=0.0183
INFO | AL update: snorm=4.3578e-05, csupn=4.3578e-05, max_rho=6.4000e+02
INFO | step #31: cost=0.0181 lambd=1.0240
INFO | - variance_cost(1): 0.01808 (avg 0.00452)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00029 (avg 0.00029)
INFO | accepted=True ATb_norm=3.08e-02 cost_prev=0.0184 cost_new=0.0184
INFO | step #32: cost=0.0181 lambd=0.5120
INFO | - variance_cost(1): 0.01812 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00024 (avg 0.00024)
INFO | accepted=True ATb_norm=1.88e-03 cost_prev=0.0184 cost_new=0.0184
INFO | AL update: snorm=6.5833e-05, csupn=6.6757e-06, max_rho=6.4000e+02
INFO | step #33: cost=0.0182 lambd=0.2560
INFO | - variance_cost(1): 0.01816 (avg 0.00454)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00016 (avg 0.00016)
INFO | accepted=True ATb_norm=7.70e-04 cost_prev=0.0183 cost_new=0.0183
INFO | AL update: snorm=4.1537e-05, csupn=4.1723e-06, max_rho=2.5600e+03
INFO | step #34: cost=0.0181 lambd=0.1280
INFO | - variance_cost(1): 0.01814 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00015 (avg 0.00015)
INFO | accepted=True ATb_norm=2.00e-02 cost_prev=0.0183 cost_new=0.0183
INFO | step #35: cost=0.0181 lambd=0.0640
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00018 (avg 0.00018)
INFO | accepted=True ATb_norm=1.60e-04 cost_prev=0.0183 cost_new=0.0183
INFO | AL update: snorm=5.0589e-06, csupn=5.0589e-06, max_rho=2.5600e+03
INFO | step #36: cost=0.0181 lambd=0.0320
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | accepted=True ATb_norm=2.66e-03 cost_prev=0.0182 cost_new=0.0182
INFO | AL update: snorm=2.2352e-07, csupn=2.2352e-07, max_rho=1.0240e+04
INFO | step #37: cost=0.0181 lambd=0.0160
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | accepted=True ATb_norm=3.37e-03 cost_prev=0.0182 cost_new=0.0182
INFO | AL update: snorm=2.9802e-08, csupn=0.0000e+00, max_rho=1.0240e+04
INFO | step #38: cost=0.0181 lambd=0.0080
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #39: cost=0.0181 lambd=0.0160
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #40: cost=0.0181 lambd=0.0320
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | accepted=True ATb_norm=5.23e-04 cost_prev=0.0182 cost_new=0.0182
INFO | AL update: snorm=1.1921e-07, csupn=1.1921e-07, max_rho=4.0960e+04
INFO | step #41: cost=0.0181 lambd=0.0160
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | accepted=True ATb_norm=1.46e-02 cost_prev=0.0182 cost_new=0.0182
INFO | step #42: cost=0.0181 lambd=0.0080
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #43: cost=0.0181 lambd=0.0160
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #44: cost=0.0181 lambd=0.0320
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #45: cost=0.0181 lambd=0.0640
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #46: cost=0.0181 lambd=0.1280
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #47: cost=0.0181 lambd=0.2560
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | step #48: cost=0.0181 lambd=0.5120
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00005 (avg 0.00005)
INFO | accepted=False ATb_norm=2.51e-03 cost_prev=0.0182 cost_new=0.0182
INFO | AL update: snorm=7.4506e-09, csupn=0.0000e+00, max_rho=4.0960e+04
INFO | Terminated @ iteration #49: cost=0.0181 criteria=[0 0 1], term_deltas=3.1e-07,1.3e-03,6.4e-07
INFO | Augmented Lagrangian: initial snorm=5.0000e-03, csupn=5.0000e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0181 lambd=0.0005
INFO | - variance_cost(1): 0.01811 (avg 0.00453)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=2.68e-02 cost_prev=0.0184 cost_new=0.0145
INFO | step #1: cost=0.0105 lambd=0.0003
INFO | - variance_cost(1): 0.01047 (avg 0.00262)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00407 (avg 0.00407)
INFO | accepted=True ATb_norm=1.05e-04 cost_prev=0.0145 cost_new=0.0145
INFO | AL update: snorm=2.0260e-02, csupn=2.0260e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0104 lambd=0.0001
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #3: cost=0.0104 lambd=0.0003
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #4: cost=0.0104 lambd=0.0005
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #5: cost=0.0104 lambd=0.0010
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #6: cost=0.0104 lambd=0.0020
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #7: cost=0.0104 lambd=0.0040
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #8: cost=0.0104 lambd=0.0080
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | step #9: cost=0.0104 lambd=0.0160
INFO | - variance_cost(1): 0.01044 (avg 0.00261)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.02565 (avg 0.02565)
INFO | accepted=True ATb_norm=1.11e-01 cost_prev=0.0361 cost_new=0.0306
INFO | step #10: cost=0.0188 lambd=0.0080
INFO | - variance_cost(1): 0.01876 (avg 0.00469)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00873 (avg 0.00218)
INFO | - augmented_return_constraint(1): 0.00315 (avg 0.00315)
INFO | accepted=True ATb_norm=3.63e-01 cost_prev=0.0306 cost_new=0.0221
INFO | step #11: cost=0.0186 lambd=0.0040
INFO | - variance_cost(1): 0.01862 (avg 0.00466)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00347 (avg 0.00347)
INFO | accepted=True ATb_norm=3.65e-04 cost_prev=0.0221 cost_new=0.0221
INFO | AL update: snorm=4.1455e-03, csupn=4.1455e-03, max_rho=4.0000e+01
INFO | step #12: cost=0.0187 lambd=0.0020
INFO | - variance_cost(1): 0.01870 (avg 0.00468)
INFO | - augmented_budget_constraint(1): 0.00003 (avg 0.00003)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00714 (avg 0.00714)
INFO | accepted=True ATb_norm=1.81e-02 cost_prev=0.0259 cost_new=0.0255
INFO | step #13: cost=0.0205 lambd=0.0010
INFO | - variance_cost(1): 0.02050 (avg 0.00513)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00497 (avg 0.00497)
INFO | accepted=True ATb_norm=1.28e-04 cost_prev=0.0255 cost_new=0.0255
INFO | AL update: snorm=1.8864e-03, csupn=1.8864e-03, max_rho=1.6000e+02
INFO | step #14: cost=0.0205 lambd=0.0005
INFO | - variance_cost(1): 0.02055 (avg 0.00514)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00674 (avg 0.00674)
INFO | accepted=True ATb_norm=5.26e-02 cost_prev=0.0273 cost_new=0.0272
INFO | step #15: cost=0.0215 lambd=0.0003
INFO | - variance_cost(1): 0.02150 (avg 0.00537)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00573 (avg 0.00573)
INFO | accepted=False ATb_norm=2.00e-05 cost_prev=0.0272 cost_new=0.0272
INFO | AL update: snorm=8.6731e-04, csupn=8.6731e-04, max_rho=1.6000e+02
INFO | step #16: cost=0.0215 lambd=0.0003
INFO | - variance_cost(1): 0.02150 (avg 0.00537)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00659 (avg 0.00659)
INFO | accepted=True ATb_norm=1.30e-03 cost_prev=0.0281 cost_new=0.0281
INFO | AL update: snorm=3.8928e-04, csupn=3.8928e-04, max_rho=1.6000e+02
INFO | step #17: cost=0.0220 lambd=0.0001
INFO | - variance_cost(1): 0.02196 (avg 0.00549)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00650 (avg 0.00650)
INFO | accepted=True ATb_norm=5.69e-04 cost_prev=0.0285 cost_new=0.0285
INFO | AL update: snorm=1.7570e-04, csupn=1.7570e-04, max_rho=1.6000e+02
INFO | step #18: cost=0.0222 lambd=0.0001
INFO | - variance_cost(1): 0.02217 (avg 0.00554)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00646 (avg 0.00646)
INFO | accepted=True ATb_norm=2.55e-04 cost_prev=0.0286 cost_new=0.0286
INFO | AL update: snorm=7.9513e-05, csupn=7.9513e-05, max_rho=1.6000e+02
INFO | step #19: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02227 (avg 0.00557)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00644 (avg 0.00644)
INFO | accepted=True ATb_norm=1.36e-04 cost_prev=0.0287 cost_new=0.0287
INFO | AL update: snorm=3.6038e-05, csupn=3.6038e-05, max_rho=1.6000e+02
INFO | step #20: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02232 (avg 0.00558)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=6.28e-05 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=1.6354e-05, csupn=1.6354e-05, max_rho=6.4000e+02
INFO | step #21: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02234 (avg 0.00558)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=6.18e-04 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=7.4357e-06, csupn=7.4357e-06, max_rho=6.4000e+02
INFO | step #22: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02234 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=3.12e-04 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=3.3528e-06, csupn=3.3528e-06, max_rho=2.5600e+03
INFO | step #23: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #24: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #25: cost=0.0223 lambd=0.0000
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #26: cost=0.0223 lambd=0.0001
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #27: cost=0.0223 lambd=0.0002
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=1.68e-03 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=1.5348e-06, csupn=1.5348e-06, max_rho=2.5600e+03
INFO | step #28: cost=0.0224 lambd=0.0001
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=5.84e-04 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=6.9290e-07, csupn=6.9290e-07, max_rho=1.0240e+04
INFO | step #29: cost=0.0224 lambd=0.0000
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #30: cost=0.0224 lambd=0.0001
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #31: cost=0.0224 lambd=0.0002
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #32: cost=0.0224 lambd=0.0003
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #33: cost=0.0224 lambd=0.0006
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=2.09e-03 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=3.1292e-07, csupn=3.1292e-07, max_rho=1.0240e+04
INFO | step #34: cost=0.0224 lambd=0.0003
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #35: cost=0.0224 lambd=0.0006
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #36: cost=0.0224 lambd=0.0013
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | step #37: cost=0.0224 lambd=0.0026
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=5.81e-04 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=1.4156e-07, csupn=1.4156e-07, max_rho=4.0960e+04
INFO | step #38: cost=0.0224 lambd=0.0013
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=True ATb_norm=1.45e-02 cost_prev=0.0288 cost_new=0.0288
INFO | step #39: cost=0.0224 lambd=0.0006
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00643 (avg 0.00643)
INFO | accepted=False ATb_norm=3.61e-03 cost_prev=0.0288 cost_new=0.0288
INFO | AL update: snorm=6.7055e-08, csupn=6.7055e-08, max_rho=4.0960e+04
INFO | Terminated @ iteration #40: cost=0.0224 criteria=[0 0 1], term_deltas=8.3e-08,1.8e-03,1.9e-07
INFO | Augmented Lagrangian: initial snorm=5.0001e-03, csupn=5.0001e-03, max_rho=1.0000e+01, constraint_dim=6
INFO | step #0: cost=0.0224 lambd=0.0005
INFO | - variance_cost(1): 0.02235 (avg 0.00559)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00025 (avg 0.00025)
INFO | accepted=True ATb_norm=3.11e-02 cost_prev=0.0226 cost_new=0.0177
INFO | step #1: cost=0.0153 lambd=0.0003
INFO | - variance_cost(1): 0.01533 (avg 0.00383)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00241 (avg 0.00241)
INFO | accepted=True ATb_norm=7.94e-03 cost_prev=0.0177 cost_new=0.0167
INFO | AL update: snorm=2.2089e-02, csupn=2.2089e-02, max_rho=4.0000e+01
INFO | step #2: cost=0.0118 lambd=0.0001
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #3: cost=0.0118 lambd=0.0003
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #4: cost=0.0118 lambd=0.0005
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #5: cost=0.0118 lambd=0.0010
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #6: cost=0.0118 lambd=0.0020
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #7: cost=0.0118 lambd=0.0040
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #8: cost=0.0118 lambd=0.0080
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #9: cost=0.0118 lambd=0.0160
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #10: cost=0.0118 lambd=0.0320
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #11: cost=0.0118 lambd=0.0640
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | step #12: cost=0.0118 lambd=0.1280
INFO | - variance_cost(1): 0.01178 (avg 0.00294)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.03049 (avg 0.03049)
INFO | accepted=True ATb_norm=1.19e-01 cost_prev=0.0423 cost_new=0.0277
INFO | step #13: cost=0.0182 lambd=0.0640
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | step #14: cost=0.0182 lambd=0.1280
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | step #15: cost=0.0182 lambd=0.2560
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | step #16: cost=0.0182 lambd=0.5120
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | step #17: cost=0.0182 lambd=1.0240
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | step #18: cost=0.0182 lambd=2.0480
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | step #19: cost=0.0182 lambd=4.0960
INFO | - variance_cost(1): 0.01821 (avg 0.00455)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00940 (avg 0.00940)
INFO | accepted=True ATb_norm=2.21e-02 cost_prev=0.0277 cost_new=0.0276
INFO | step #20: cost=0.0184 lambd=2.0480
INFO | - variance_cost(1): 0.01843 (avg 0.00461)
INFO | - augmented_budget_constraint(1): 0.00004 (avg 0.00004)
INFO | - augmented_no_short_selling(1): 0.00020 (avg 0.00005)
INFO | - augmented_return_constraint(1): 0.00897 (avg 0.00897)
INFO | accepted=True ATb_norm=3.82e-02 cost_prev=0.0276 cost_new=0.0275
INFO | step #21: cost=0.0185 lambd=1.0240
INFO | - variance_cost(1): 0.01846 (avg 0.00462)
INFO | - augmented_budget_constraint(1): 0.00005 (avg 0.00005)
INFO | - augmented_no_short_selling(1): 0.00005 (avg 0.00001)
INFO | - augmented_return_constraint(1): 0.00894 (avg 0.00894)
INFO | accepted=True ATb_norm=9.58e-03 cost_prev=0.0275 cost_new=0.0274
INFO | AL update: snorm=9.1852e-03, csupn=9.1852e-03, max_rho=1.6000e+02
INFO | step #22: cost=0.0186 lambd=0.5120
INFO | - variance_cost(1): 0.01864 (avg 0.00466)
INFO | - augmented_budget_constraint(1): 0.00020 (avg 0.00020)
INFO | - augmented_no_short_selling(1): 0.00020 (avg 0.00005)
INFO | - augmented_return_constraint(1): 0.02283 (avg 0.02283)
INFO | accepted=True ATb_norm=2.19e-01 cost_prev=0.0419 cost_new=0.0404
INFO | step #23: cost=0.0195 lambd=0.2560
INFO | - variance_cost(1): 0.01951 (avg 0.00488)
INFO | - augmented_budget_constraint(1): 0.00003 (avg 0.00003)
INFO | - augmented_no_short_selling(1): 0.00004 (avg 0.00001)
INFO | - augmented_return_constraint(1): 0.02082 (avg 0.02082)
INFO | accepted=True ATb_norm=1.79e-02 cost_prev=0.0404 cost_new=0.0388
INFO | step #24: cost=0.0211 lambd=0.1280
INFO | - variance_cost(1): 0.02107 (avg 0.00527)
INFO | - augmented_budget_constraint(1): 0.00003 (avg 0.00003)
INFO | - augmented_no_short_selling(1): 0.00003 (avg 0.00001)
INFO | - augmented_return_constraint(1): 0.01766 (avg 0.01766)
INFO | accepted=True ATb_norm=1.32e-02 cost_prev=0.0388 cost_new=0.0375
INFO | step #25: cost=0.0232 lambd=0.0640
INFO | - variance_cost(1): 0.02319 (avg 0.00580)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00002 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01428 (avg 0.01428)
INFO | accepted=True ATb_norm=7.80e-03 cost_prev=0.0375 cost_new=0.0369
INFO | AL update: snorm=2.4174e-03, csupn=2.4174e-03, max_rho=1.6000e+02
INFO | step #26: cost=0.0252 lambd=0.0320
INFO | - variance_cost(1): 0.02518 (avg 0.00630)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00002 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01528 (avg 0.01528)
INFO | step #27: cost=0.0252 lambd=0.0640
INFO | - variance_cost(1): 0.02518 (avg 0.00630)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00002 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01528 (avg 0.01528)
INFO | step #28: cost=0.0252 lambd=0.1280
INFO | - variance_cost(1): 0.02518 (avg 0.00630)
INFO | - augmented_budget_constraint(1): 0.00002 (avg 0.00002)
INFO | - augmented_no_short_selling(1): 0.00002 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.01528 (avg 0.01528)
INFO | accepted=True ATb_norm=7.24e-03 cost_prev=0.0405 cost_new=0.0402
INFO | AL update: snorm=1.3512e-03, csupn=1.3512e-03, max_rho=6.4000e+02
INFO | step #29: cost=0.0265 lambd=0.0640
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #30: cost=0.0265 lambd=0.1280
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #31: cost=0.0265 lambd=0.2560
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #32: cost=0.0265 lambd=0.5120
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #33: cost=0.0265 lambd=1.0240
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #34: cost=0.0265 lambd=2.0480
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #35: cost=0.0265 lambd=4.0960
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | step #36: cost=0.0265 lambd=8.1920
INFO | - variance_cost(1): 0.02649 (avg 0.00662)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00570 (avg 0.00570)
INFO | accepted=True ATb_norm=1.16e-02 cost_prev=0.0322 cost_new=0.0322
INFO | step #37: cost=0.0265 lambd=4.0960
INFO | - variance_cost(1): 0.02655 (avg 0.00664)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00561 (avg 0.00561)
INFO | step #38: cost=0.0265 lambd=8.1920
INFO | - variance_cost(1): 0.02655 (avg 0.00664)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00561 (avg 0.00561)
INFO | step #39: cost=0.0265 lambd=16.3840
INFO | - variance_cost(1): 0.02655 (avg 0.00664)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00561 (avg 0.00561)
INFO | accepted=True ATb_norm=1.12e-02 cost_prev=0.0322 cost_new=0.0322
INFO | step #40: cost=0.0266 lambd=8.1920
INFO | - variance_cost(1): 0.02658 (avg 0.00664)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00557 (avg 0.00557)
INFO | accepted=True ATb_norm=2.72e-02 cost_prev=0.0322 cost_new=0.0322
INFO | step #41: cost=0.0266 lambd=4.0960
INFO | - variance_cost(1): 0.02659 (avg 0.00665)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00555 (avg 0.00555)
INFO | accepted=True ATb_norm=5.42e-03 cost_prev=0.0322 cost_new=0.0321
INFO | AL update: snorm=1.2476e-03, csupn=1.2476e-03, max_rho=2.5600e+03
INFO | step #42: cost=0.0266 lambd=2.0480
INFO | - variance_cost(1): 0.02662 (avg 0.00665)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00002 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00472 (avg 0.00472)
INFO | accepted=True ATb_norm=2.36e-02 cost_prev=0.0314 cost_new=0.0311
INFO | step #43: cost=0.0268 lambd=1.0240
INFO | - variance_cost(1): 0.02681 (avg 0.00670)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00422 (avg 0.00422)
INFO | accepted=True ATb_norm=1.73e-02 cost_prev=0.0311 cost_new=0.0306
INFO | step #44: cost=0.0271 lambd=0.5120
INFO | - variance_cost(1): 0.02713 (avg 0.00678)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00348 (avg 0.00348)
INFO | accepted=True ATb_norm=1.40e-02 cost_prev=0.0306 cost_new=0.0302
INFO | step #45: cost=0.0276 lambd=0.2560
INFO | - variance_cost(1): 0.02760 (avg 0.00690)
INFO | - augmented_budget_constraint(1): 0.00001 (avg 0.00001)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00258 (avg 0.00258)
INFO | accepted=True ATb_norm=9.58e-03 cost_prev=0.0302 cost_new=0.0299
INFO | AL update: snorm=2.0260e-04, csupn=2.0260e-04, max_rho=2.5600e+03
INFO | step #46: cost=0.0281 lambd=0.1280
INFO | - variance_cost(1): 0.02811 (avg 0.00703)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00001 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00224 (avg 0.00224)
INFO | accepted=True ATb_norm=7.32e-03 cost_prev=0.0304 cost_new=0.0302
INFO | AL update: snorm=1.4236e-04, csupn=4.1127e-06, max_rho=1.0240e+04
INFO | step #47: cost=0.0287 lambd=0.0640
INFO | - variance_cost(1): 0.02867 (avg 0.00717)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00123 (avg 0.00123)
INFO | accepted=True ATb_norm=7.91e-02 cost_prev=0.0299 cost_new=0.0299
INFO | step #48: cost=0.0288 lambd=0.0320
INFO | - variance_cost(1): 0.02876 (avg 0.00719)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00113 (avg 0.00113)
INFO | accepted=True ATb_norm=8.98e-04 cost_prev=0.0299 cost_new=0.0299
INFO | AL update: snorm=2.0987e-04, csupn=1.3709e-06, max_rho=1.0240e+04
INFO | step #49: cost=0.0288 lambd=0.0160
INFO | - variance_cost(1): 0.02879 (avg 0.00720)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00078 (avg 0.00078)
INFO | accepted=True ATb_norm=3.94e-03 cost_prev=0.0296 cost_new=0.0295
INFO | AL update: snorm=4.9837e-05, csupn=4.1723e-07, max_rho=1.0240e+04
INFO | step #50: cost=0.0285 lambd=0.0080
INFO | - variance_cost(1): 0.02852 (avg 0.00713)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00095 (avg 0.00095)
INFO | accepted=True ATb_norm=2.57e-03 cost_prev=0.0295 cost_new=0.0295
INFO | AL update: snorm=3.7998e-07, csupn=3.7998e-07, max_rho=1.0240e+04
INFO | step #51: cost=0.0284 lambd=0.0040
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #52: cost=0.0284 lambd=0.0080
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #53: cost=0.0284 lambd=0.0160
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #54: cost=0.0284 lambd=0.0320
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #55: cost=0.0284 lambd=0.0640
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #56: cost=0.0284 lambd=0.1280
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #57: cost=0.0284 lambd=0.2560
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | step #58: cost=0.0284 lambd=0.5120
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00026 (avg 0.00026)
INFO | accepted=True ATb_norm=3.88e-03 cost_prev=0.0287 cost_new=0.0287
INFO | AL update: snorm=4.2468e-07, csupn=4.2468e-07, max_rho=4.0960e+04
INFO | step #59: cost=0.0284 lambd=0.2560
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | step #60: cost=0.0284 lambd=0.5120
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | step #61: cost=0.0284 lambd=1.0240
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | step #62: cost=0.0284 lambd=2.0480
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00007 (avg 0.00007)
INFO | accepted=True ATb_norm=1.14e-02 cost_prev=0.0285 cost_new=0.0285
INFO | step #63: cost=0.0284 lambd=1.0240
INFO | - variance_cost(1): 0.02844 (avg 0.00711)
INFO | - augmented_budget_constraint(1): 0.00000 (avg 0.00000)
INFO | - augmented_no_short_selling(1): 0.00000 (avg 0.00000)
INFO | - augmented_return_constraint(1): 0.00006 (avg 0.00006)
INFO | accepted=False ATb_norm=7.36e-05 cost_prev=0.0285 cost_new=0.0285
INFO | AL update: snorm=1.7881e-07, csupn=1.7881e-07, max_rho=4.0960e+04
INFO | Terminated @ iteration #64: cost=0.0284 criteria=[0 1 0], term_deltas=7.0e-06,5.6e-05,1.1e-05
Computed 15 points on the efficient frontier
How it works: Augmented Lagrangian#
When constraints are present, jaxls uses an Augmented Lagrangian method.
The augmented Lagrangian#
The method adds both a linear term (Lagrange multiplier \(\lambda\)) and a quadratic penalty to the objective:
The multiplier \(\lambda\) handles steady-state enforcement while the penalty \(\rho\) accelerates convergence.
Multiplier updates#
As the solver runs, it updates the multipliers based on constraint violations:
For inequality constraints \(g(x) \leq 0\), this is projected to stay non-negative: \(\lambda_{\text{new}} = \max(0, \lambda + \rho \cdot g(x))\).
Updates occur when the cost stabilizes, indicating the current subproblem is solved. The penalty \(\rho\) increases if constraints aren’t improving fast enough.
Advanced: tuning the solver#
For difficult problems, you can tune the Augmented Lagrangian solver via AugmentedLagrangianConfig.
The parameters map to the concepts above. Max iterations can be controlled via TerminationConfig:
# Custom configuration example.
al_config = jaxls.AugmentedLagrangianConfig(
penalty_factor=10.0, # Multiply rho by this when constraints stagnate.
penalty_max=1e7, # Cap on rho to prevent ill-conditioning.
tolerance_absolute=1e-6, # Constraint violation tolerance for convergence.
)
# Use with solve(). Max iterations is controlled via TerminationConfig.
solution = problem.solve(
verbose=False,
linear_solver="dense_cholesky",
augmented_lagrangian=al_config,
termination=jaxls.TerminationConfig(max_iterations=150),
)
print("Solution with custom config:")
for name, w in zip(asset_names, solution[weights_var]):
print(f" {name}: {float(w):.1%}")
Solution with custom config:
Tech: 26.4%
Healthcare: 17.3%
Energy: 20.0%
Bonds: 36.4%
Alternative: manifold reformulation#
Some constraints can be reformulated as manifold optimization via custom retract_fn.
For example, the budget constraint (weights sum to 1) and no short-selling constraint
(weights non-negative) together define the probability simplex, which is a manifold that can
be optimized directly without explicit constraints.
This approach can be more numerically stable when the constraints define a smooth manifold with a simple retraction. See Manifold allocation for an example.
Summary#
Key points for constrained optimization in jaxls:
Use
kind="constraint_eq_zero"for equality constraints \(h(x) = 0\)Use
kind="constraint_leq_zero"for upper bounds \(g(x) \leq 0\)Use
kind="constraint_geq_zero"for lower bounds \(g(x) \geq 0\)Constraints are handled automatically via Augmented Lagrangian
Tune with
AugmentedLagrangianConfigif needed
For more constraint examples, see:
Mean-variance allocation: Portfolio optimization with return targets
Inverse kinematics: Position/orientation constraints with joint limits
Obstacle avoidance: Collision avoidance inequalities
Cart-pole (collocation): Dynamics as equality constraints