|
1 | 1 | from enum import Enum |
2 | 2 | from operator import itemgetter |
3 | | -from typing import Dict, Mapping, NamedTuple, Optional, Union, cast, List, Tuple |
| 3 | +from typing import Dict, Mapping, NamedTuple, Optional, Tuple, Union, cast |
4 | 4 |
|
5 | 5 | from numpy import inf, sign |
6 | 6 | from numpy import round as np_round |
|
27 | 27 | from rqalpha.utils.arg_checker import assure_active_instrument |
28 | 28 | from rqalpha.utils.exception import RQApiNotSupportedError, RQInvalidArgument |
29 | 29 | from rqalpha.utils.functools import lru_cache |
30 | | -from rqalpha.utils.i18n import gettext as _, lazy_gettext |
| 30 | +from rqalpha.utils.i18n import gettext as _ |
| 31 | +from rqalpha.utils.i18n import lazy_gettext |
31 | 32 | from rqalpha.utils.price_limits import reaches_limit_down_vectorized, reaches_limit_up_vectorized |
32 | 33 |
|
33 | 34 |
|
@@ -62,7 +63,10 @@ def translation(self) -> str: |
62 | 63 |
|
63 | 64 |
|
64 | 65 | class DenialReason(CommentedEnum): |
65 | | - less_than_half = 'less_than_half', lazy_gettext('Order creation failed: quantity less than half of minimum order quantity') |
| 66 | + less_than_half = ( |
| 67 | + 'less_than_half', |
| 68 | + lazy_gettext('Order creation failed: quantity less than half of minimum order quantity'), |
| 69 | + ) |
66 | 70 | suspended_buy = 'suspended_buy', lazy_gettext('Order creation failed: cannot buy due to suspension') |
67 | 71 | suspended_sell = 'suspended_sell', lazy_gettext('Order creation failed: cannot sell due to suspension') |
68 | 72 | no_price = 'no_price', lazy_gettext('Order creation failed: no market data available') |
@@ -205,18 +209,26 @@ def _round_adjusting_odd_lots(self, adjusting: Series) -> Tuple[Series, Dict[Den |
205 | 209 | def _calc_adjusting( |
206 | 210 | self, target_quantities: Series, direction: POSITION_DIRECTION |
207 | 211 | ) -> Tuple[Series, Dict[DenialReason, Series]]: |
208 | | - # caller should ensure the index of diff, price_df and suspended are the same |
| 212 | + """计算调仓数量并应用各类约束。 |
| 213 | +
|
| 214 | + Returns: |
| 215 | + (diff, denials): 调整后的数量变化和各类拒绝原因 |
| 216 | + """ |
209 | 217 | diff, denials = self._round_adjusting_odd_lots(target_quantities.sub(self._current_quantities, fill_value=0)) |
210 | 218 | prices, limit_up, limit_down = itemgetter('last', 'limit_up', 'limit_down')(self._prices) |
| 219 | + |
| 220 | + # 构建完全不可调整的资产掩码(停牌、无行情) |
211 | 221 | adjusting_denied = ( |
212 | 222 | self._suspended # 停牌 |
213 | 223 | | prices.isna() # 无行情 |
214 | 224 | ) |
215 | 225 |
|
| 226 | + # 记录各类拒绝原因(用于向用户报告) |
216 | 227 | denials[DenialReason.suspended_buy] = (diff > 0) & self._suspended |
217 | 228 | denials[DenialReason.suspended_sell] = (diff < 0) & self._suspended |
218 | 229 | denials[DenialReason.no_price] = prices.isna() & (diff != 0) |
219 | 230 |
|
| 231 | + # 涨跌停限制(方向相关) |
220 | 232 | limit_up = reaches_limit_up_vectorized(prices, limit_up, self._tick_sizes) |
221 | 233 | limit_down = reaches_limit_down_vectorized(prices, limit_down, self._tick_sizes) |
222 | 234 | if direction == POSITION_DIRECTION.LONG: |
@@ -259,55 +271,111 @@ def _trans_cost_decider(self, market: MARKET) -> AbstractStockTransactionCostDec |
259 | 271 | ) |
260 | 272 | return decider |
261 | 273 |
|
262 | | - SAFETY: float = 1.2 |
| 274 | + def _estimate_transaction_costs(self, diff: Series, prices: Series) -> float: |
| 275 | + """估算交易成本(手续费 + 汇率成本)。""" |
| 276 | + delta_mv = diff * prices |
| 277 | + costs = 0.0 |
| 278 | + for market, group in self._market.groupby(by=self._market): |
| 279 | + # 税费等成本 |
| 280 | + costs += self._trans_cost_decider(market).batch_estimate(diff[group.index], prices[group.index]).sum() # type: ignore |
| 281 | + if market != MARKET.CN: |
| 282 | + # 汇率成本 |
| 283 | + exchange_rate = self._exchange_rates[market] # type: ignore |
| 284 | + buy_mask = (diff > 0) & (diff.index.isin(group.index)) |
| 285 | + sell_mask = (diff < 0) & (diff.index.isin(group.index)) |
| 286 | + costs += delta_mv[buy_mask].sum() * (exchange_rate.ask / exchange_rate.middle - 1) |
| 287 | + costs += delta_mv[sell_mask].sum() * (exchange_rate.middle / exchange_rate.bid - 1) |
| 288 | + return costs |
| 289 | + |
| 290 | + def _calc_min_adjustable(self, denials: Dict[DenialReason, Series], prices: Series) -> float: |
| 291 | + """计算最小可调精度,仅基于可调资产(排除所有拒绝的资产)。无可调资产时返回 inf 触发退出。""" |
| 292 | + adjusting_denied = Series(False, index=prices.index) |
| 293 | + for reason, mask in denials.items(): |
| 294 | + adjusting_denied |= mask |
| 295 | + can_adjust = ~adjusting_denied |
| 296 | + if can_adjust.any(): |
| 297 | + return (self._min_qty[can_adjust] * prices[can_adjust] / self._total_value).min() |
| 298 | + return inf |
| 299 | + |
| 300 | + MAX_ITERATIONS: int = 150 |
| 301 | + KP_INIT: float = 0.382 # 比例增益初始值(黄金分割比) |
| 302 | + KP_MIN: float = 0.01 # 比例增益下限 |
| 303 | + KP_DECAY: float = 0.382 # 振荡时 kp 衰减因子 |
| 304 | + PRECISION: float = 0.0001 # 硬性精度要求(万分之一) |
| 305 | + MAX_OSCILLATIONS: int = 10 # kp 达到下限后允许的最大振荡次数 |
263 | 306 |
|
264 | 307 | def __call__(self, direction: POSITION_DIRECTION = POSITION_DIRECTION.LONG) -> AdjustingResult: |
| 308 | + """使用 P 控制器迭代求解最优调仓方案。 |
| 309 | +
|
| 310 | + 算法核心思路: |
| 311 | + 1. 通过 safety 系数缩放目标持仓量,用比例控制器(P-controller)调节 safety 使实际持仓比例逼近目标权重 |
| 312 | + 2. 仅基于可调资产(排除停牌、涨跌停、不可平仓等)计算最小可调精度,无可调资产时立即退出 |
| 313 | + tips: |
| 314 | + 1. 误差定义:diff_proportion = 持仓市值 / (总资产 - 交易成本),将手续费、税费、汇率成本纳入控制目标 |
| 315 | + 2. 振荡抑制:检测误差符号翻转时降低比例增益 kp,防止在离散约束下来回震荡 |
| 316 | + """ |
265 | 317 | if self._current_quantities.empty and self._target_weights.empty: |
266 | 318 | return AdjustingResult(adjustments=Series(dtype='float64'), denials=dict()) |
267 | 319 |
|
268 | | - if self._target_weights.sum() > 0.95: |
269 | | - # 如果目标是满仓或者接近满仓,则使用一个较高的 safety 开始下降 |
270 | | - safety = self.SAFETY |
271 | | - else: |
272 | | - safety = 1. |
273 | | - last_proportion_diff = inf |
274 | | - last_diff = None |
275 | | - last_denials = None |
| 320 | + # 初始化 P 控制器状态 |
| 321 | + total_target_weight = self._target_weights.sum() |
| 322 | + safety = 1.0 # 目标持仓缩放系数,通过迭代调节逼近目标权重 |
| 323 | + kp = self.KP_INIT # 比例增益,控制每次 safety 调整幅度 |
| 324 | + prev_error = 0.0 # 上一次带符号误差,用于振荡检测 |
| 325 | + oscillation_count = 0 # 振荡次数,误差符号翻转时累加 |
| 326 | + |
| 327 | + # 历史最优可行解 |
| 328 | + best_error = inf |
| 329 | + best_diff = Series(dtype='float64') |
| 330 | + best_denials = None |
276 | 331 | prices = self._prices_settle_ccy |
277 | | - while True: |
278 | | - if safety < 0: |
279 | | - # 防止 bug 导致的死循环 |
280 | | - raise RuntimeError('safety < 0: {}'.format(safety)) |
| 332 | + |
| 333 | + for iteration in range(self.MAX_ITERATIONS): |
| 334 | + # 1. 根据当前 safety 计算目标持仓量,并应用各类约束 |
281 | 335 | target_quantities: Series = (self._total_value * safety * self._target_weights / prices).round(0) |
282 | 336 | diff, denials = self._calc_adjusting(target_quantities, direction) |
283 | 337 |
|
284 | | - delta_mv = diff * prices |
285 | | - cash_consumed = delta_mv.sum() |
286 | | - for market, group in self._market.groupby(by=self._market): |
287 | | - # 税费等成本 |
288 | | - cash_consumed += ( |
289 | | - self._trans_cost_decider(market).batch_estimate(diff[group.index], prices[group.index]).sum() |
290 | | - ) # type: ignore |
291 | | - # 汇率成本 |
292 | | - if market != MARKET.CN: |
293 | | - exchange_rate = self._exchange_rates[market] # type: ignore |
294 | | - cash_consumed += delta_mv[(diff > 0) & (diff.index.isin(group.index))].sum() * ( |
295 | | - exchange_rate.ask / exchange_rate.middle - 1 |
296 | | - ) |
297 | | - cash_consumed += delta_mv[(diff < 0) & (diff.index.isin(group.index))].sum() * ( |
298 | | - exchange_rate.middle / exchange_rate.bid - 1 |
299 | | - ) |
300 | | - |
301 | | - total_proportion = ((self._current_quantities.add(diff, fill_value=0)) * prices).sum() / self._total_value |
302 | | - proportion_diff = abs(total_proportion - self._target_weights.sum()) |
303 | | - if cash_consumed < self._cash_available: |
| 338 | + # 2. 计算交易成本和现金消耗 |
| 339 | + transaction_costs = self._estimate_transaction_costs(diff, prices) |
| 340 | + cash_consumed = (diff * prices).sum() + transaction_costs |
| 341 | + |
| 342 | + # 3. 计算成本感知误差 |
| 343 | + # signed_error > 0 表示实际比例低于目标,需增大 safety;反之需减小 |
| 344 | + total_market_value = ((self._current_quantities.add(diff, fill_value=0)) * prices).sum() |
| 345 | + diff_proportion = total_market_value / (self._total_value - transaction_costs) |
| 346 | + signed_error = total_target_weight - diff_proportion |
| 347 | + current_error = abs(signed_error) |
| 348 | + |
| 349 | + # 4. 更新最优可行解 |
| 350 | + if current_error < best_error: |
| 351 | + best_error = current_error |
| 352 | + best_diff = diff |
| 353 | + best_denials = denials |
| 354 | + |
| 355 | + # 5. 振荡检测与 kp 衰减 |
| 356 | + if signed_error * prev_error < 0: |
| 357 | + oscillation_count += 1 |
| 358 | + kp *= self.KP_DECAY |
| 359 | + kp = max(kp, self.KP_MIN) |
| 360 | + |
| 361 | + # 检查退出条件 |
| 362 | + min_adjustable = self._calc_min_adjustable(denials, prices) |
| 363 | + if ( |
| 364 | + cash_consumed < self._cash_available |
| 365 | + # 寻找基于当前可用现金下的最优解 |
304 | 366 | # TODO: 分别计算 A H 股的可用资金 |
305 | | - if proportion_diff > last_proportion_diff and last_diff is not None and last_denials is not None: |
306 | | - break |
307 | | - last_diff, last_denials = diff, denials |
308 | | - last_proportion_diff = proportion_diff |
309 | | - safety -= min(max(proportion_diff / 10, 0.0001), 0.002) |
310 | | - return AdjustingResult(adjustments=last_diff, denials=self._format_denials(last_denials)) |
| 367 | + and ( |
| 368 | + current_error < min_adjustable |
| 369 | + or current_error <= self.PRECISION |
| 370 | + ) |
| 371 | + ) or (kp <= self.KP_MIN and oscillation_count > self.MAX_OSCILLATIONS): |
| 372 | + break |
| 373 | + |
| 374 | + # safety 调整量 = kp × 误差 |
| 375 | + safety += kp * signed_error |
| 376 | + prev_error = signed_error |
| 377 | + |
| 378 | + return AdjustingResult(adjustments=best_diff, denials=self._format_denials(best_denials or dict())) |
311 | 379 |
|
312 | 380 |
|
313 | 381 | @export_as_api |
|
0 commit comments