Skip to content

Commit 030d754

Browse files
authored
order_target_portfolio算法性能优化 (#1001)
* order_target_portfolio算法性能优化 1. 改用P 控制器算法缩减循环开支 2. 新增测试用例检测算法效果 具体效果概要: a. 对比测试的权重信号,平均耗时为原本的1/3左右 b. 相比原算法100、500、1000持仓下一个月平均差异均有小幅提升,以1000持仓举例:0.00079 -> 0.00081 c. 相比原算法1500持仓有十万分之1的误差增加 * order_target_portfolio算法性能优化_v1: 还原 test_order_target_portfolio_smart_all_denials 无需改动 * order_target_portfolio算法性能优化_v2: 1. benchmark 测试过重,暂且移除出标准测试,改为内部测试 2. fix testcase
1 parent a784c60 commit 030d754

4 files changed

Lines changed: 2985 additions & 2941 deletions

File tree

rqalpha/mod/rqalpha_mod_sys_accounts/api/order_target_portfolio.py

Lines changed: 111 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from operator import itemgetter
3-
from typing import Dict, Mapping, NamedTuple, Optional, Union, cast, List, Tuple
3+
from typing import Dict, Mapping, NamedTuple, Optional, Tuple, Union, cast
44

55
from numpy import inf, sign
66
from numpy import round as np_round
@@ -27,7 +27,8 @@
2727
from rqalpha.utils.arg_checker import assure_active_instrument
2828
from rqalpha.utils.exception import RQApiNotSupportedError, RQInvalidArgument
2929
from rqalpha.utils.functools import lru_cache
30-
from rqalpha.utils.i18n import gettext as _, lazy_gettext
30+
from rqalpha.utils.i18n import gettext as _
31+
from rqalpha.utils.i18n import lazy_gettext
3132
from rqalpha.utils.price_limits import reaches_limit_down_vectorized, reaches_limit_up_vectorized
3233

3334

@@ -62,7 +63,10 @@ def translation(self) -> str:
6263

6364

6465
class DenialReason(CommentedEnum):
65-
less_than_half = 'less_than_half', lazy_gettext('Order creation failed: quantity less than half of minimum order quantity')
66+
less_than_half = (
67+
'less_than_half',
68+
lazy_gettext('Order creation failed: quantity less than half of minimum order quantity'),
69+
)
6670
suspended_buy = 'suspended_buy', lazy_gettext('Order creation failed: cannot buy due to suspension')
6771
suspended_sell = 'suspended_sell', lazy_gettext('Order creation failed: cannot sell due to suspension')
6872
no_price = 'no_price', lazy_gettext('Order creation failed: no market data available')
@@ -205,18 +209,26 @@ def _round_adjusting_odd_lots(self, adjusting: Series) -> Tuple[Series, Dict[Den
205209
def _calc_adjusting(
206210
self, target_quantities: Series, direction: POSITION_DIRECTION
207211
) -> Tuple[Series, Dict[DenialReason, Series]]:
208-
# caller should ensure the index of diff, price_df and suspended are the same
212+
"""计算调仓数量并应用各类约束。
213+
214+
Returns:
215+
(diff, denials): 调整后的数量变化和各类拒绝原因
216+
"""
209217
diff, denials = self._round_adjusting_odd_lots(target_quantities.sub(self._current_quantities, fill_value=0))
210218
prices, limit_up, limit_down = itemgetter('last', 'limit_up', 'limit_down')(self._prices)
219+
220+
# 构建完全不可调整的资产掩码(停牌、无行情)
211221
adjusting_denied = (
212222
self._suspended # 停牌
213223
| prices.isna() # 无行情
214224
)
215225

226+
# 记录各类拒绝原因(用于向用户报告)
216227
denials[DenialReason.suspended_buy] = (diff > 0) & self._suspended
217228
denials[DenialReason.suspended_sell] = (diff < 0) & self._suspended
218229
denials[DenialReason.no_price] = prices.isna() & (diff != 0)
219230

231+
# 涨跌停限制(方向相关)
220232
limit_up = reaches_limit_up_vectorized(prices, limit_up, self._tick_sizes)
221233
limit_down = reaches_limit_down_vectorized(prices, limit_down, self._tick_sizes)
222234
if direction == POSITION_DIRECTION.LONG:
@@ -259,55 +271,111 @@ def _trans_cost_decider(self, market: MARKET) -> AbstractStockTransactionCostDec
259271
)
260272
return decider
261273

262-
SAFETY: float = 1.2
274+
def _estimate_transaction_costs(self, diff: Series, prices: Series) -> float:
275+
"""估算交易成本(手续费 + 汇率成本)。"""
276+
delta_mv = diff * prices
277+
costs = 0.0
278+
for market, group in self._market.groupby(by=self._market):
279+
# 税费等成本
280+
costs += self._trans_cost_decider(market).batch_estimate(diff[group.index], prices[group.index]).sum() # type: ignore
281+
if market != MARKET.CN:
282+
# 汇率成本
283+
exchange_rate = self._exchange_rates[market] # type: ignore
284+
buy_mask = (diff > 0) & (diff.index.isin(group.index))
285+
sell_mask = (diff < 0) & (diff.index.isin(group.index))
286+
costs += delta_mv[buy_mask].sum() * (exchange_rate.ask / exchange_rate.middle - 1)
287+
costs += delta_mv[sell_mask].sum() * (exchange_rate.middle / exchange_rate.bid - 1)
288+
return costs
289+
290+
def _calc_min_adjustable(self, denials: Dict[DenialReason, Series], prices: Series) -> float:
291+
"""计算最小可调精度,仅基于可调资产(排除所有拒绝的资产)。无可调资产时返回 inf 触发退出。"""
292+
adjusting_denied = Series(False, index=prices.index)
293+
for reason, mask in denials.items():
294+
adjusting_denied |= mask
295+
can_adjust = ~adjusting_denied
296+
if can_adjust.any():
297+
return (self._min_qty[can_adjust] * prices[can_adjust] / self._total_value).min()
298+
return inf
299+
300+
MAX_ITERATIONS: int = 150
301+
KP_INIT: float = 0.382 # 比例增益初始值(黄金分割比)
302+
KP_MIN: float = 0.01 # 比例增益下限
303+
KP_DECAY: float = 0.382 # 振荡时 kp 衰减因子
304+
PRECISION: float = 0.0001 # 硬性精度要求(万分之一)
305+
MAX_OSCILLATIONS: int = 10 # kp 达到下限后允许的最大振荡次数
263306

264307
def __call__(self, direction: POSITION_DIRECTION = POSITION_DIRECTION.LONG) -> AdjustingResult:
308+
"""使用 P 控制器迭代求解最优调仓方案。
309+
310+
算法核心思路:
311+
1. 通过 safety 系数缩放目标持仓量,用比例控制器(P-controller)调节 safety 使实际持仓比例逼近目标权重
312+
2. 仅基于可调资产(排除停牌、涨跌停、不可平仓等)计算最小可调精度,无可调资产时立即退出
313+
tips:
314+
1. 误差定义:diff_proportion = 持仓市值 / (总资产 - 交易成本),将手续费、税费、汇率成本纳入控制目标
315+
2. 振荡抑制:检测误差符号翻转时降低比例增益 kp,防止在离散约束下来回震荡
316+
"""
265317
if self._current_quantities.empty and self._target_weights.empty:
266318
return AdjustingResult(adjustments=Series(dtype='float64'), denials=dict())
267319

268-
if self._target_weights.sum() > 0.95:
269-
# 如果目标是满仓或者接近满仓,则使用一个较高的 safety 开始下降
270-
safety = self.SAFETY
271-
else:
272-
safety = 1.
273-
last_proportion_diff = inf
274-
last_diff = None
275-
last_denials = None
320+
# 初始化 P 控制器状态
321+
total_target_weight = self._target_weights.sum()
322+
safety = 1.0 # 目标持仓缩放系数,通过迭代调节逼近目标权重
323+
kp = self.KP_INIT # 比例增益,控制每次 safety 调整幅度
324+
prev_error = 0.0 # 上一次带符号误差,用于振荡检测
325+
oscillation_count = 0 # 振荡次数,误差符号翻转时累加
326+
327+
# 历史最优可行解
328+
best_error = inf
329+
best_diff = Series(dtype='float64')
330+
best_denials = None
276331
prices = self._prices_settle_ccy
277-
while True:
278-
if safety < 0:
279-
# 防止 bug 导致的死循环
280-
raise RuntimeError('safety < 0: {}'.format(safety))
332+
333+
for iteration in range(self.MAX_ITERATIONS):
334+
# 1. 根据当前 safety 计算目标持仓量,并应用各类约束
281335
target_quantities: Series = (self._total_value * safety * self._target_weights / prices).round(0)
282336
diff, denials = self._calc_adjusting(target_quantities, direction)
283337

284-
delta_mv = diff * prices
285-
cash_consumed = delta_mv.sum()
286-
for market, group in self._market.groupby(by=self._market):
287-
# 税费等成本
288-
cash_consumed += (
289-
self._trans_cost_decider(market).batch_estimate(diff[group.index], prices[group.index]).sum()
290-
) # type: ignore
291-
# 汇率成本
292-
if market != MARKET.CN:
293-
exchange_rate = self._exchange_rates[market] # type: ignore
294-
cash_consumed += delta_mv[(diff > 0) & (diff.index.isin(group.index))].sum() * (
295-
exchange_rate.ask / exchange_rate.middle - 1
296-
)
297-
cash_consumed += delta_mv[(diff < 0) & (diff.index.isin(group.index))].sum() * (
298-
exchange_rate.middle / exchange_rate.bid - 1
299-
)
300-
301-
total_proportion = ((self._current_quantities.add(diff, fill_value=0)) * prices).sum() / self._total_value
302-
proportion_diff = abs(total_proportion - self._target_weights.sum())
303-
if cash_consumed < self._cash_available:
338+
# 2. 计算交易成本和现金消耗
339+
transaction_costs = self._estimate_transaction_costs(diff, prices)
340+
cash_consumed = (diff * prices).sum() + transaction_costs
341+
342+
# 3. 计算成本感知误差
343+
# signed_error > 0 表示实际比例低于目标,需增大 safety;反之需减小
344+
total_market_value = ((self._current_quantities.add(diff, fill_value=0)) * prices).sum()
345+
diff_proportion = total_market_value / (self._total_value - transaction_costs)
346+
signed_error = total_target_weight - diff_proportion
347+
current_error = abs(signed_error)
348+
349+
# 4. 更新最优可行解
350+
if current_error < best_error:
351+
best_error = current_error
352+
best_diff = diff
353+
best_denials = denials
354+
355+
# 5. 振荡检测与 kp 衰减
356+
if signed_error * prev_error < 0:
357+
oscillation_count += 1
358+
kp *= self.KP_DECAY
359+
kp = max(kp, self.KP_MIN)
360+
361+
# 检查退出条件
362+
min_adjustable = self._calc_min_adjustable(denials, prices)
363+
if (
364+
cash_consumed < self._cash_available
365+
# 寻找基于当前可用现金下的最优解
304366
# TODO: 分别计算 A H 股的可用资金
305-
if proportion_diff > last_proportion_diff and last_diff is not None and last_denials is not None:
306-
break
307-
last_diff, last_denials = diff, denials
308-
last_proportion_diff = proportion_diff
309-
safety -= min(max(proportion_diff / 10, 0.0001), 0.002)
310-
return AdjustingResult(adjustments=last_diff, denials=self._format_denials(last_denials))
367+
and (
368+
current_error < min_adjustable
369+
or current_error <= self.PRECISION
370+
)
371+
) or (kp <= self.KP_MIN and oscillation_count > self.MAX_OSCILLATIONS):
372+
break
373+
374+
# safety 调整量 = kp × 误差
375+
safety += kp * signed_error
376+
prev_error = signed_error
377+
378+
return AdjustingResult(adjustments=best_diff, denials=self._format_denials(best_denials or dict()))
311379

312380

313381
@export_as_api

tests/integration_tests/test_api/test_order_target_portfolio_smart_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def handle_bar(context, bar_dict):
7070
},
7171
)
7272
assert get_position('000001.XSHE').quantity == 0 # 清仓
73-
assert get_position('000004.XSHE').quantity == 5500 # (993695.7496 * 0.1) / 18 = 5520.53
74-
assert get_position('000005.XSHE').quantity == 67600 # (993695.7496 * 0.2) / 2.92 = 68061.35
73+
assert get_position('000004.XSHE').quantity == 5700 # (993695.7496 * 0.1) / 18 = 5520.53
74+
assert get_position('000005.XSHE').quantity == 69700 # (993695.7496 * 0.2) / 2.92 = 68061.35
7575
assert get_position('600519.XSHG').quantity == 0 # 970 低于 收盘价 无法买进
7676

7777
run_func(config=config, init=init, handle_bar=handle_bar)

0 commit comments

Comments
 (0)