跳转至

核心模块 API

以下 API 文档由源码自动生成,内容始终与代码保持同步。


qka.Data

数据管理类

负责股票数据的获取、缓存和管理,支持多数据源、并发下载和自定义因子计算。 通过 indicators 参数统一处理技术指标和自定义因子,在数据加载时一次性预计算。

属性:

名称 类型 描述
symbols List[str]

股票代码列表

period str

数据周期,如 '1d'、'1m' 等

adjust str

复权方式,如 'qfq'、'hfq'、'bfq'

indicators dict | Callable

预计算指标/因子

source str

数据源,如 'baostock'(默认)、'akshare'、'qmt'

pool_size int

并发下载线程数

datadir Path

数据缓存目录

target_dir Path

目标存储目录

源代码位于: qka/core/data.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
class Data():
    """
    数据管理类

    负责股票数据的获取、缓存和管理,支持多数据源、并发下载和自定义因子计算。
    通过 `indicators` 参数统一处理技术指标和自定义因子,在数据加载时一次性预计算。

    Attributes:
        symbols (List[str]): 股票代码列表
        period (str): 数据周期,如 '1d'、'1m' 等
        adjust (str): 复权方式,如 'qfq'、'hfq'、'bfq'
        indicators (dict | Callable): 预计算指标/因子
        source (str): 数据源,如 'baostock'(默认)、'akshare'、'qmt'
        pool_size (int): 并发下载线程数
        datadir (Path): 数据缓存目录
        target_dir (Path): 目标存储目录
    """

    def __init__(
        self, 
        symbols: Optional[List[str]] = None,
        period: str = '1d',
        adjust: str = 'qfq',
        source: str = 'baostock',
        pool_size: int = 10,
        datadir: Optional[Path] = None,
        indicators: Optional[dict] = None,
    ):
        """
        初始化数据对象

        Args:
            symbols: 股票代码列表,如 ['000001.SZ', '600000.SH']
            period: 数据周期,如 '1d'(日线)、'1m'(分钟)
            adjust: 复权方式,'qfq'(前复权)、'hfq'(后复权)、'bfq'(不复权)
            source: 数据来源,'baostock'(默认)、'akshare'、'qmt'
            pool_size: 并发下载线程数
            datadir: 缓存目录路径
            indicators: 预计算指标/因子,支持三种格式:

                **1. 字典(混搭 TA 指标和自定义因子):**
                ```python
                {
                    'sma_20': ('sma', 20),           # TA 指标:列名 → (指标名, *参数)
                    'rsi_14': ('rsi', 14),
                    'macd': ('macd', 12, 26, 9),
                    'ma5': lambda df: df['close'].rolling(5).mean(),  # 自定义因子
                }
                ```
                支持的 TA 指标:sma、ema、wma、rsi、macd、bbands、atr
                factor 默认为 'close',可指定:`('sma', 'high', 20)`

                **2. 函数(自定义因子,替代旧版 factor 参数):**
                ```python
                indicators=lambda df: df.assign(ma5=df['close'].rolling(5).mean())
                ```
                函数接收单只股票的 DataFrame,返回添加了额外列的 DataFrame。
        """
        self.symbols = symbols or []
        self.period = period
        self.adjust = adjust
        self.source = source
        self.pool_size = pool_size

        # 统一处理 indicators 参数
        if callable(indicators):
            # 函数形式 → 保存为 callable
            self._indicators = indicators
        elif isinstance(indicators, dict):
            # 字典形式
            self._indicators = indicators
        elif indicators is None:
            self._indicators = {}
        else:
            raise TypeError(
                f"indicators 必须是 dict、callable 或 None,got {type(indicators)}"
            )

        # 初始化缓存目录
        if datadir is None:
            # 默认使用当前工作目录下的 datadir
            self.datadir = Path.cwd() / "datadir"
        else:
            self.datadir = Path(datadir)

        self.datadir.mkdir(parents=True, exist_ok=True)

        self.target_dir = self.datadir / self.source / self.period / (self.adjust or "bfq")
        self.target_dir.mkdir(parents=True, exist_ok=True)

    def _download(self, symbol: str) -> Path:
        """
        下载单个股票的数据

        Args:
            symbol (str): 股票代码

        Returns:
            Path: 数据文件路径
        """
        path = self.target_dir / f"{symbol}.parquet"

        if path.exists():
             return path

        if self.source == 'akshare':
            df = self._get_from_akshare(symbol)
        elif self.source == 'baostock':
            df = self._get_from_baostock(symbol)
        else:
            df = pd.DataFrame()

        if len(df) == 0:
            raise RuntimeError(f"{symbol}: baostock 返回空数据")

        table = pa.Table.from_pandas(df)
        pq.write_table(table, path)

        return path

    def get(self, lazy: bool = False):
        """
        获取历史数据。

        并发下载所有股票数据,应用因子计算,并返回合并后的数据。

        Args:
            lazy: 是否以懒加载模式返回 dask DataFrame(支持大规模数据分区迭代)。
                  默认 False,返回 compute() 后的 pandas DataFrame(向后兼容)。

        Returns:
            lazy=False: pd.DataFrame,列名格式 {symbol}|{factor}
            lazy=True: dd.DataFrame,列名格式 {symbol}|{factor}
            没有数据时抛出 RuntimeError
        """
        if not self.symbols:
            return pd.DataFrame()

        # 缓存
        if lazy:
            if hasattr(self, '_cached_dask') and self._cached_dask is not None:
                return self._cached_dask
        else:
            if hasattr(self, '_cached') and self._cached is not None:
                return self._cached

        # baostock 需要先登录,且其 C/S 架构不支持多线程并发
        bs_logged_in = False
        if self.source == 'baostock':
            lg = bs.login()
            if lg.error_code != '0':
                raise RuntimeError(f"baostock 登录失败: {lg.error_msg}")
            bs_logged_in = True

        errors = []
        try:
            if self.source == 'baostock':
                # baostock 串行下载(C/S 架构不支持并发)
                for symbol in tqdm(self.symbols, desc="下载数据"):
                    try:
                        self._download(symbol)
                    except Exception as e:
                        errors.append(f"{symbol}: {e}")
                        print(f"\n[警告] 下载 {symbol} 失败: {e}")
            else:
                # 其他数据源(akshare 等)并发下载
                with ThreadPoolExecutor(max_workers=self.pool_size) as executor:
                    futures = {
                        executor.submit(self._download, symbol): symbol
                        for symbol in self.symbols
                    }
                    with tqdm(total=len(self.symbols), desc="下载数据") as pbar:
                        for future in as_completed(futures):
                            symbol = futures[future]
                            try:
                                future.result()
                            except Exception as e:
                                errors.append(f"{symbol}: {e}")
                                print(f"\n[警告] 下载 {symbol} 失败: {e}")
                            pbar.update(1)
                            pbar.set_postfix_str(f"当前: {symbol}")
            if errors:
                raise RuntimeError(
                    f"共 {len(errors)} 只股票下载失败:\n" +
                    "\n".join(f"  - {e}" for e in errors)
                )
        finally:
            if bs_logged_in:
                bs.logout()

        if lazy:
            # 懒加载模式:返回 dask DataFrame,列名 {symbol}|{factor}
            dfs = []
            for symbol in self.symbols:
                parquet_path = self.target_dir / f"{symbol}.parquet"
                if not parquet_path.exists():
                    logger.warning(f"数据文件不存在,跳过: {parquet_path}")
                    continue
                ddf = dd.read_parquet(str(parquet_path))
                ddf = self._apply_indicators(ddf)
                column_mapping = {col: f'{symbol}|{col}' for col in ddf.columns}
                dfs.append(ddf.rename(columns=column_mapping))

            if not dfs:
                raise RuntimeError(
                    f"所有股票数据加载失败(共 {len(self.symbols)} 只),"
                    f"请检查网络连接和股票代码是否正确"
                )

            ddf = dd.concat(dfs, axis=1, join='outer')
            self._cached_dask = ddf
            return ddf

        else:
            # 全量模式(默认)
            dfs = []
            for symbol in self.symbols:
                parquet_path = self.target_dir / f"{symbol}.parquet"
                if not parquet_path.exists():
                    logger.warning(f"数据文件不存在,跳过: {parquet_path}")
                    continue
                df = dd.read_parquet(str(parquet_path))
                df = self._apply_indicators(df)
                column_mapping = {col: f'{symbol}|{col}' for col in df.columns}
                dfs.append(df.rename(columns=column_mapping))

            if not dfs:
                raise RuntimeError(
                    f"所有股票数据加载失败(共 {len(self.symbols)} 只),"
                    f"请检查网络连接和股票代码是否正确"
                )

            ddf = dd.concat(dfs, axis=1, join='outer')
            self._cached = ddf.compute()
            return self._cached

    def _apply_indicators(self, df: 'pd.DataFrame'):
        """
        对单只股票的数据计算预定义的技术指标。

        Args:
            df: 单只股票的 DataFrame

        Returns:
            DataFrame: 包含原始列和指标列
        """
        if not self.indicators:
            return df

    def _apply_indicators(self, df):
        """
        对单只股票的数据应用预定义的指标/因子。

        支持三种形式:
        - 空 dict → 跳过
        - callable → 旧版 factor 风格,接收 df 返回 df
        - dict → 混合 TA 指标和自定义 callable

        Args:
            df: 单只股票的 DataFrame

        Returns:
            DataFrame: 包含原始列和指标列
        """
        inds = self._indicators

        # 空 → 跳过
        if not inds:
            return df

        # 单函数形式(旧版 factor 的替代)
        if callable(inds):
            if isinstance(df, dd.DataFrame):
                return df.map_partitions(lambda p: inds(p.copy()))
            return inds(df.copy())

        # 字典形式
        if isinstance(df, dd.DataFrame):
            # 先用样本分区计算指标,获得准确的 meta(含新增的指标列)
            # 避免 dask 在迷你分区上推理 meta 时因窗口不足而崩溃
            sample = df.head(200)
            meta = self._compute_indicator_cols(sample.copy())
            return df.map_partitions(
                lambda partition: self._compute_indicator_cols(partition.copy()),
                meta=meta,
            )
        return self._compute_indicator_cols(df.copy())

    def _compute_indicator_cols(self, df):
        """在 pandas DataFrame 上计算指标/因子列(单只股票)。"""
        import ta as _ta_lib

        # 分区过小时跳过,避免 ta-lib(如 ATR window=14)在 dask meta 推断时崩溃
        min_rows = self._min_rows_for_indicators()
        if len(df) < min_rows:
            return df

        for col_name, spec in self._indicators.items():
            # 自定义因子(callable 值)
            if callable(spec):
                result = spec(df)
                if isinstance(result, pd.DataFrame):
                    # 多列返回 → 逐列添加
                    for c in result.columns:
                        df[c] = result[c]
                else:
                    # 单值返回 → 列名 = key
                    df[col_name] = result
                continue

            # TA 指标(tuple 值)
            if not isinstance(spec, (list, tuple)):
                logger.warning(f"指标 {col_name} 的规格必须为 tuple 或 callable,跳过")
                continue

            ind_type = spec[0]
            args = list(spec[1:])

            # 如果第一个参数是字符串,视为自定义 factor 列名;否则默认为 'close'
            factor = 'close'
            rest = args
            if args and isinstance(args[0], str):
                factor = args[0]
                rest = args[1:]

            if ind_type == 'sma':
                window = int(rest[0]) if rest else 20
                df[col_name] = _ta_lib.trend.sma_indicator(df[factor], window=window)

            elif ind_type == 'ema':
                window = int(rest[0]) if rest else 30
                df[col_name] = _ta_lib.trend.ema_indicator(df[factor], window=window)

            elif ind_type == 'wma':
                window = int(rest[0]) if rest else 30
                df[col_name] = _ta_lib.trend.wma_indicator(df[factor], window=window)

            elif ind_type == 'rsi':
                window = int(rest[0]) if rest else 14
                df[col_name] = _ta_lib.momentum.rsi(df[factor], window=window)

            elif ind_type == 'macd':
                fast = int(rest[0]) if len(rest) >= 1 else 12
                slow = int(rest[1]) if len(rest) >= 2 else 26
                signal = int(rest[2]) if len(rest) >= 3 else 9
                _macd = _ta_lib.trend.MACD(
                    df[factor], window_slow=slow, window_fast=fast, window_sign=signal,
                )
                df[col_name] = _macd.macd()
                df[f'{col_name}_signal'] = _macd.macd_signal()
                df[f'{col_name}_histogram'] = _macd.macd_diff()

            elif ind_type == 'bbands':
                window = int(rest[0]) if rest else 20
                std = int(rest[1]) if len(rest) >= 2 else 2
                _bb = _ta_lib.volatility.BollingerBands(
                    df[factor], window=window, window_dev=std,
                )
                df[f'{col_name}_upper'] = _bb.bollinger_hband()
                df[f'{col_name}_middle'] = _bb.bollinger_mavg()
                df[f'{col_name}_lower'] = _bb.bollinger_lband()

            elif ind_type == 'atr':
                window = int(rest[0]) if rest else 14
                df[col_name] = _ta_lib.volatility.average_true_range(
                    df['high'], df['low'], df['close'], window=window,
                )

            else:
                logger.warning(f"未知指标类型: {ind_type},跳过 {col_name}")

        return df

    def _min_rows_for_indicators(self):
        """计算所有指标所需的最小行数。"""
        if not self._indicators or not isinstance(self._indicators, dict):
            return 0
        max_window = 0
        for spec in self._indicators.values():
            if callable(spec) or not isinstance(spec, (list, tuple)):
                continue
            args = list(spec[1:])
            rest = args
            if args and isinstance(args[0], str):
                rest = args[1:]
            ind_type = spec[0]
            if ind_type in ('sma', 'ema', 'wma', 'rsi', 'atr'):
                w = int(rest[0]) if rest else (30 if ind_type in ('ema', 'wma') else (20 if ind_type == 'sma' else 14))
                max_window = max(max_window, w)
            elif ind_type == 'bbands':
                w = int(rest[0]) if rest else 20
                max_window = max(max_window, w)
            elif ind_type == 'macd':
                fast = int(rest[0]) if len(rest) >= 1 else 12
                slow = int(rest[1]) if len(rest) >= 2 else 26
                max_window = max(max_window, fast, slow)
        return max_window

    def _get_from_akshare(self, symbol: str) -> pd.DataFrame:
        """
        从 akshare 获取单个股票的数据。

        Args:
            symbol (str): 股票代码,支持带后缀如 000001.SZ 或不带后缀的 000001

        Returns:
            pd.DataFrame: 股票数据,以 date 为索引,包含 open, high, low, close, volume, amount 列
        """
        column_mapping = {
            "日期": "date",
            "开盘": "open",
            "收盘": "close",
            "最高": "high",
            "最低": "low",
            "成交量": "volume",
            "成交额": "amount",
        }

        # 下载数据
        # akshare 不支持带 .SZ/.SH 后缀,需去除
        clean_symbol = symbol.replace('.SZ', '').replace('.SH', '').replace('.BJ', '')
        df = ak.stock_zh_a_hist(symbol=clean_symbol, period='daily', adjust=self.adjust)

        # 数据标准化处理
        # 1. 标准化列名
        df = df.rename(columns=column_mapping)
        if "date" in df.columns:
            df["date"] = pd.to_datetime(df["date"])

        # 2. 确保数值列为数值类型
        numeric_cols = [c for c in ("open", "high", "low", "close", "volume", "amount") if c in df.columns]
        for col in numeric_cols:
            df[col] = pd.to_numeric(df[col], errors="coerce")

        # 3. 只保留需要的列
        mapped_columns = list(column_mapping.values())
        available_columns = [col for col in mapped_columns if col in df.columns]
        df = df[available_columns]

        df = df.set_index('date')
        # 设置索引
        return df

    def _get_from_baostock(self, symbol: str) -> pd.DataFrame:
        """
        从 baostock 获取单个股票的数据。

        Args:
            symbol (str): 股票代码,支持带后缀如 000001.SZ 或 600000.SH

        Returns:
            pd.DataFrame: 股票数据,以 date 为索引,包含 open, high, low, close, volume, amount 列
        """
        # baostock 代码格式:sz.000001 / sh.600000
        exchange = symbol[-2:].lower()  # 'sz', 'sh', 'bj'
        code = symbol.split('.')[0]  # '000001'
        bs_code = f"{exchange}.{code}"

        # adjustflag: 1=不复权, 2=前复权, 3=后复权
        adjust_map = {'bfq': '1', 'qfq': '2', 'hfq': '3'}
        adjustflag = adjust_map.get(self.adjust, '2')

        rs = bs.query_history_k_data_plus(
            bs_code,
            "date,open,high,low,close,volume,amount",
            start_date='1990-01-01',
            end_date='2050-12-31',
            frequency='d',
            adjustflag=adjustflag,
        )
        if rs.error_code != '0':
            raise RuntimeError(f"baostock 查询 {symbol}({bs_code}) 失败: {rs.error_msg}")
        df = rs.get_data()

        if len(df) == 0:
            return df

        # baostock 返回的数值列是字符串,转数值类型
        numeric_cols = ["open", "high", "low", "close", "volume", "amount"]
        for col in numeric_cols:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors="coerce")

        df["date"] = pd.to_datetime(df["date"])
        df = df.set_index("date")
        return df

__init__(symbols=None, period='1d', adjust='qfq', source='baostock', pool_size=10, datadir=None, indicators=None)

初始化数据对象

参数:

名称 类型 描述 默认
symbols Optional[List[str]]

股票代码列表,如 ['000001.SZ', '600000.SH']

None
period str

数据周期,如 '1d'(日线)、'1m'(分钟)

'1d'
adjust str

复权方式,'qfq'(前复权)、'hfq'(后复权)、'bfq'(不复权)

'qfq'
source str

数据来源,'baostock'(默认)、'akshare'、'qmt'

'baostock'
pool_size int

并发下载线程数

10
datadir Optional[Path]

缓存目录路径

None
indicators Optional[dict]

预计算指标/因子,支持三种格式:

1. 字典(混搭 TA 指标和自定义因子):

{
    'sma_20': ('sma', 20),           # TA 指标:列名 → (指标名, *参数)
    'rsi_14': ('rsi', 14),
    'macd': ('macd', 12, 26, 9),
    'ma5': lambda df: df['close'].rolling(5).mean(),  # 自定义因子
}
支持的 TA 指标:sma、ema、wma、rsi、macd、bbands、atr factor 默认为 'close',可指定:('sma', 'high', 20)

2. 函数(自定义因子,替代旧版 factor 参数):

indicators=lambda df: df.assign(ma5=df['close'].rolling(5).mean())
函数接收单只股票的 DataFrame,返回添加了额外列的 DataFrame。

None
源代码位于: qka/core/data.py
def __init__(
    self, 
    symbols: Optional[List[str]] = None,
    period: str = '1d',
    adjust: str = 'qfq',
    source: str = 'baostock',
    pool_size: int = 10,
    datadir: Optional[Path] = None,
    indicators: Optional[dict] = None,
):
    """
    初始化数据对象

    Args:
        symbols: 股票代码列表,如 ['000001.SZ', '600000.SH']
        period: 数据周期,如 '1d'(日线)、'1m'(分钟)
        adjust: 复权方式,'qfq'(前复权)、'hfq'(后复权)、'bfq'(不复权)
        source: 数据来源,'baostock'(默认)、'akshare'、'qmt'
        pool_size: 并发下载线程数
        datadir: 缓存目录路径
        indicators: 预计算指标/因子,支持三种格式:

            **1. 字典(混搭 TA 指标和自定义因子):**
            ```python
            {
                'sma_20': ('sma', 20),           # TA 指标:列名 → (指标名, *参数)
                'rsi_14': ('rsi', 14),
                'macd': ('macd', 12, 26, 9),
                'ma5': lambda df: df['close'].rolling(5).mean(),  # 自定义因子
            }
            ```
            支持的 TA 指标:sma、ema、wma、rsi、macd、bbands、atr
            factor 默认为 'close',可指定:`('sma', 'high', 20)`

            **2. 函数(自定义因子,替代旧版 factor 参数):**
            ```python
            indicators=lambda df: df.assign(ma5=df['close'].rolling(5).mean())
            ```
            函数接收单只股票的 DataFrame,返回添加了额外列的 DataFrame。
    """
    self.symbols = symbols or []
    self.period = period
    self.adjust = adjust
    self.source = source
    self.pool_size = pool_size

    # 统一处理 indicators 参数
    if callable(indicators):
        # 函数形式 → 保存为 callable
        self._indicators = indicators
    elif isinstance(indicators, dict):
        # 字典形式
        self._indicators = indicators
    elif indicators is None:
        self._indicators = {}
    else:
        raise TypeError(
            f"indicators 必须是 dict、callable 或 None,got {type(indicators)}"
        )

    # 初始化缓存目录
    if datadir is None:
        # 默认使用当前工作目录下的 datadir
        self.datadir = Path.cwd() / "datadir"
    else:
        self.datadir = Path(datadir)

    self.datadir.mkdir(parents=True, exist_ok=True)

    self.target_dir = self.datadir / self.source / self.period / (self.adjust or "bfq")
    self.target_dir.mkdir(parents=True, exist_ok=True)

get(lazy=False)

获取历史数据。

并发下载所有股票数据,应用因子计算,并返回合并后的数据。

参数:

名称 类型 描述 默认
lazy bool

是否以懒加载模式返回 dask DataFrame(支持大规模数据分区迭代)。 默认 False,返回 compute() 后的 pandas DataFrame(向后兼容)。

False

返回:

类型 描述

lazy=False: pd.DataFrame,列名格式 {symbol}|{factor}

lazy=True: dd.DataFrame,列名格式 {symbol}|{factor}

没有数据时抛出 RuntimeError

源代码位于: qka/core/data.py
def get(self, lazy: bool = False):
    """
    获取历史数据。

    并发下载所有股票数据,应用因子计算,并返回合并后的数据。

    Args:
        lazy: 是否以懒加载模式返回 dask DataFrame(支持大规模数据分区迭代)。
              默认 False,返回 compute() 后的 pandas DataFrame(向后兼容)。

    Returns:
        lazy=False: pd.DataFrame,列名格式 {symbol}|{factor}
        lazy=True: dd.DataFrame,列名格式 {symbol}|{factor}
        没有数据时抛出 RuntimeError
    """
    if not self.symbols:
        return pd.DataFrame()

    # 缓存
    if lazy:
        if hasattr(self, '_cached_dask') and self._cached_dask is not None:
            return self._cached_dask
    else:
        if hasattr(self, '_cached') and self._cached is not None:
            return self._cached

    # baostock 需要先登录,且其 C/S 架构不支持多线程并发
    bs_logged_in = False
    if self.source == 'baostock':
        lg = bs.login()
        if lg.error_code != '0':
            raise RuntimeError(f"baostock 登录失败: {lg.error_msg}")
        bs_logged_in = True

    errors = []
    try:
        if self.source == 'baostock':
            # baostock 串行下载(C/S 架构不支持并发)
            for symbol in tqdm(self.symbols, desc="下载数据"):
                try:
                    self._download(symbol)
                except Exception as e:
                    errors.append(f"{symbol}: {e}")
                    print(f"\n[警告] 下载 {symbol} 失败: {e}")
        else:
            # 其他数据源(akshare 等)并发下载
            with ThreadPoolExecutor(max_workers=self.pool_size) as executor:
                futures = {
                    executor.submit(self._download, symbol): symbol
                    for symbol in self.symbols
                }
                with tqdm(total=len(self.symbols), desc="下载数据") as pbar:
                    for future in as_completed(futures):
                        symbol = futures[future]
                        try:
                            future.result()
                        except Exception as e:
                            errors.append(f"{symbol}: {e}")
                            print(f"\n[警告] 下载 {symbol} 失败: {e}")
                        pbar.update(1)
                        pbar.set_postfix_str(f"当前: {symbol}")
        if errors:
            raise RuntimeError(
                f"共 {len(errors)} 只股票下载失败:\n" +
                "\n".join(f"  - {e}" for e in errors)
            )
    finally:
        if bs_logged_in:
            bs.logout()

    if lazy:
        # 懒加载模式:返回 dask DataFrame,列名 {symbol}|{factor}
        dfs = []
        for symbol in self.symbols:
            parquet_path = self.target_dir / f"{symbol}.parquet"
            if not parquet_path.exists():
                logger.warning(f"数据文件不存在,跳过: {parquet_path}")
                continue
            ddf = dd.read_parquet(str(parquet_path))
            ddf = self._apply_indicators(ddf)
            column_mapping = {col: f'{symbol}|{col}' for col in ddf.columns}
            dfs.append(ddf.rename(columns=column_mapping))

        if not dfs:
            raise RuntimeError(
                f"所有股票数据加载失败(共 {len(self.symbols)} 只),"
                f"请检查网络连接和股票代码是否正确"
            )

        ddf = dd.concat(dfs, axis=1, join='outer')
        self._cached_dask = ddf
        return ddf

    else:
        # 全量模式(默认)
        dfs = []
        for symbol in self.symbols:
            parquet_path = self.target_dir / f"{symbol}.parquet"
            if not parquet_path.exists():
                logger.warning(f"数据文件不存在,跳过: {parquet_path}")
                continue
            df = dd.read_parquet(str(parquet_path))
            df = self._apply_indicators(df)
            column_mapping = {col: f'{symbol}|{col}' for col in df.columns}
            dfs.append(df.rename(columns=column_mapping))

        if not dfs:
            raise RuntimeError(
                f"所有股票数据加载失败(共 {len(self.symbols)} 只),"
                f"请检查网络连接和股票代码是否正确"
            )

        ddf = dd.concat(dfs, axis=1, join='outer')
        self._cached = ddf.compute()
        return self._cached

qka.Backtest

QKA回测引擎类

提供基于时间序列的回测功能,支持多股票横截面数据处理, 以及绩效指标计算和可视化。

属性:

名称 类型 描述
data Data

数据对象实例

strategy Strategy

策略对象实例

results DataFrame

回测结果数据

initial_cash float

初始资金

源代码位于: qka/core/backtest.py
class Backtest:
    """
    QKA回测引擎类

    提供基于时间序列的回测功能,支持多股票横截面数据处理,
    以及绩效指标计算和可视化。

    Attributes:
        data (Data): 数据对象实例
        strategy (Strategy): 策略对象实例
        results (pd.DataFrame): 回测结果数据
        initial_cash (float): 初始资金
    """

    def __init__(self, data, strategy):
        """
        初始化回测引擎

        Args:
            data (Data): Data类的实例,包含股票数据
            strategy (Strategy): 策略对象,必须包含on_bar方法
        """
        self.data = data
        self.strategy = strategy
        self.results = None
        self.initial_cash = strategy.broker.cash
        self._benchmark_data = None

    @staticmethod
    def _parse_row(row):
        """
        解析一行 iterrows 数据为 per-factor 字典。

        列名格式: {symbol}|{factor}
        例: '000001.SZ|close' → factor='close', symbol='000001.SZ'
            '000001.SZ|sma_5' → factor='sma_5', symbol='000001.SZ'

        Args:
            row: pd.Series,列名格式 {symbol}|{factor}

        Returns:
            dict: {factor: {symbol: value}}
        """
        by_factor = defaultdict(dict)
        for col, val in row.items():
            if not isinstance(col, str) or '|' not in col:
                continue
            *sym_parts, factor = col.rsplit('|', 1)
            symbol = '|'.join(sym_parts)
            by_factor[factor][symbol] = val
        return dict(by_factor)

    def run(self, benchmark: Optional[str] = None):
        """
        执行回测

        遍历所有时间点,在每个时间点调用策略的on_bar方法进行交易决策,
        并记录交易后的状态。

        大规模回测(>500 bar)时自动使用分区迭代,分块加载数据到内存,
        避免一次性加载全量数据。

        Args:
            benchmark (str, optional): 基准代码,如 '000300.SH'(沪深300)。
                                       如果提供,会下载基准数据用于对比。

        Returns:
            None。回测结果保存在 self.results 中,可通过
            self.summary() 查看绩效指标,self.report() 生成报告。
        """
        # 获取数据(优先用 lazy 模式,由 Backtest 决定是否分区)
        raw = self.data.get(lazy=True)

        # 加载基准数据
        if benchmark:
            self._load_benchmark(benchmark)

        if isinstance(raw, dd.DataFrame):
            # ── dask 模式:分区迭代 ──
            ddf: dd.DataFrame = raw
            n_rows = len(ddf)

            if n_rows > 500:
                # 大规模:先算 index,再按日期分块读取
                index = ddf.index.compute()
                chunk_size = 500

                for start in range(0, n_rows, chunk_size):
                    end = min(start + chunk_size, n_rows)
                    date_start, date_end = index[start], index[end - 1]

                    chunk = ddf.loc[date_start:date_end].compute()
                    for dt, row in chunk.iterrows():
                        by_factor = self._parse_row(row)
                        for factor, data in by_factor.items():
                            self.strategy._data.push(dt, factor, data)
                        # dask 路径:策略使用 self.get() / self.history()
                        self.strategy.on_bar(dt)
                        self.strategy.broker.on_bar(
                            dt, self.strategy._data.get
                        )
            else:
                # 小规模:一次加载,零开销
                df = ddf.compute()
                for date, row in df.iterrows():
                    by_factor = self._parse_row(row)
                    for factor, data in by_factor.items():
                        self.strategy._data.push(date, factor, data)
                    self.strategy.on_bar(date)
                    self.strategy.broker.on_bar(
                        date, self.strategy._data.get
                    )
        else:
            # ── pandas 模式:向后兼容(测试 mock 数据等) ──
            df: pd.DataFrame = raw
            for date, row in df.iterrows():
                by_factor = self._parse_row(row)
                for factor, data in by_factor.items():
                    self.strategy._data.push(date, factor, data)
                self.strategy.on_bar(date)
                self.strategy.broker.on_bar(
                    date, self.strategy._data.get
                )

        # 保存回测结果
        self.results = self.strategy.broker.trades

    def _load_benchmark(self, benchmark_code: str):
        """
        加载基准指数数据

        Args:
            benchmark_code: 基准代码,如 '000300.SH'
        """
        try:
            import akshare as ak
            clean_code = benchmark_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '')
            bm_df = ak.stock_zh_index_daily(symbol=f"sh{clean_code}")
            if bm_df is not None and not bm_df.empty:
                bm_df['date'] = pd.to_datetime(bm_df['date'])
                bm_df = bm_df.set_index('date')
                bm_df = bm_df.sort_index()
                self._benchmark_data = bm_df['close']
                print(f"基准数据加载成功: {benchmark_code}{len(bm_df)} 个交易日")
        except Exception as e:
            print(f"基准数据加载失败: {e}")

    def summary(self) -> dict:
        """
        计算并打印回测绩效指标

        返回包含以下指标的字典:
        - 总收益率、年化收益率、年化波动率
        - 夏普比率、最大回撤、Calmar比率
        - 胜率、盈亏比、交易次数
        - 最终资产、总手续费

        Returns:
            dict: 绩效指标字典
        """
        if self.results is None or self.results.empty:
            print("请先运行回测 (backtest.run())")
            return {}

        totals = self.results['total']
        if len(totals) < 2:
            print("回测数据不足(至少需要2个交易周期)")
            return {}

        # 基本数据
        initial = self.initial_cash
        final = totals.iloc[-1]
        total_return = (final / initial - 1) * 100

        # 交易天数 / 年化
        n_days = len(totals)
        years = n_days / 252  # A股年均约252个交易日

        # 日收益率序列
        daily_returns = totals.pct_change().dropna()
        if len(daily_returns) == 0:
            print("没有足够的收益率数据")
            return {}

        # 年化收益率
        annual_return = (final / initial) ** (1 / years) - 1 if years > 0 else 0

        # 年化波动率
        annual_vol = daily_returns.std() * np.sqrt(252)

        # 夏普比率(无风险利率假设 3%)
        risk_free_rate = 0.03
        sharpe = (annual_return - risk_free_rate) / annual_vol if annual_vol > 0 else 0

        # 最大回撤
        cumulative = (1 + daily_returns).cumprod()
        running_max = cumulative.cummax()
        drawdown = (cumulative - running_max) / running_max
        max_drawdown = drawdown.min() * 100

        # Calmar 比率
        calmar = annual_return / abs(max_drawdown / 100) if max_drawdown != 0 else 0

        # 交易分析
        trades = self.strategy.broker.trade_history
        n_trades = len(trades)

        if n_trades > 0:
            # 统计每笔交易的盈亏
            trade_pnl = []
            buy_prices = {}
            for t in trades:
                if t['action'] == 'buy':
                    if t['symbol'] not in buy_prices:
                        buy_prices[t['symbol']] = []
                    buy_prices[t['symbol']].append((t['size'], t['exec_price'], t['total_cost']))
                elif t['action'] == 'sell':
                    symbol = t['symbol']
                    size = t['size']
                    net_proceeds = t['net_proceeds']
                    # 按先进先出匹配买入
                    if symbol in buy_prices and buy_prices[symbol]:
                        total_buy_cost = 0
                        remaining = size
                        while remaining > 0 and buy_prices[symbol]:
                            b_size, b_price, b_cost = buy_prices[symbol][0]
                            if b_size <= remaining:
                                total_buy_cost += b_cost
                                remaining -= b_size
                                buy_prices[symbol].pop(0)
                            else:
                                ratio = remaining / b_size
                                total_buy_cost += b_cost * ratio
                                buy_prices[symbol][0] = (b_size - remaining, b_price, b_cost * (1 - ratio))
                                remaining = 0
                        pnl = net_proceeds - total_buy_cost
                        trade_pnl.append(pnl)

            win_trades = sum(1 for p in trade_pnl if p > 0)
            win_rate = (win_trades / len(trade_pnl) * 100) if trade_pnl else 0
            avg_win = np.mean([p for p in trade_pnl if p > 0]) if any(p > 0 for p in trade_pnl) else 0
            avg_loss = abs(np.mean([p for p in trade_pnl if p <= 0])) if any(p <= 0 for p in trade_pnl) else 0
            profit_loss_ratio = avg_win / avg_loss if avg_loss > 0 else 0
        else:
            win_rate = 0
            profit_loss_ratio = 0
            trade_pnl = []

        # 总手续费
        total_commission = self.strategy.broker.total_commission

        # 打印报告
        print("=" * 55)
        print("           回测绩效报告")
        print("=" * 55)
        print(f"  初始资金:        RMB {initial:>10,.2f}")
        print(f"  最终资产:        RMB {final:>10,.2f}")
        print(f"  总收益率:         {total_return:>+8.2f}%")
        print(f"  年化收益率:       {annual_return * 100:>+8.2f}%")
        print(f"  年化波动率:       {annual_vol * 100:>8.2f}%")
        print(f"  夏普比率:         {sharpe:>8.2f}")
        print(f"  最大回撤:         {max_drawdown:>8.2f}%")
        print(f"  Calmar比率:       {calmar:>8.2f}")
        print(f"  交易次数:         {n_trades:>8}")
        print(f"  胜率:             {win_rate:>8.2f}%")
        print(f"  盈亏比:           {profit_loss_ratio:>8.2f}")
        print(f"  总手续费:         RMB {total_commission:>10,.2f}")
        print(f"  回测天数:         {n_days:>8} 天")
        print("=" * 55)

        return {
            'initial_cash': initial,
            'final_equity': final,
            'total_return_pct': total_return,
            'annual_return_pct': annual_return * 100,
            'annual_volatility_pct': annual_vol * 100,
            'sharpe_ratio': sharpe,
            'max_drawdown_pct': max_drawdown,
            'calmar_ratio': calmar,
            'total_trades': n_trades,
            'win_rate_pct': win_rate,
            'profit_loss_ratio': profit_loss_ratio,
            'total_commission': total_commission,
            'n_days': n_days,
        }

    def report(self, title: str = "未命名策略", output_path: Optional[str] = None) -> str:
        """
        生成自包含的 HTML 回测报告

        包含绩效指标卡片、净值曲线、回撤曲线、月度收益率热力图、
        交易明细表和回撤分析。可直接在浏览器中打开。

        Args:
            title: 策略名称(显示在报告标题中)
            output_path: 输出 HTML 文件路径。
                         None 则自动保存在 reports/ 目录下

        Returns:
            str: HTML 文件路径
        """
        from qka.core.report import generate_report

        if self.results is None or self.results.empty:
            print("错误:请先运行回测 (bt.run())")
            return ""

        bm = getattr(self, '_benchmark_data', None)
        return generate_report(
            results=self.results,
            broker=self.strategy.broker,
            initial_cash=self.initial_cash,
            benchmark_data=bm,
            strategy_name=title,
            output_path=output_path,
        )

__init__(data, strategy)

初始化回测引擎

参数:

名称 类型 描述 默认
data Data

Data类的实例,包含股票数据

必需
strategy Strategy

策略对象,必须包含on_bar方法

必需
源代码位于: qka/core/backtest.py
def __init__(self, data, strategy):
    """
    初始化回测引擎

    Args:
        data (Data): Data类的实例,包含股票数据
        strategy (Strategy): 策略对象,必须包含on_bar方法
    """
    self.data = data
    self.strategy = strategy
    self.results = None
    self.initial_cash = strategy.broker.cash
    self._benchmark_data = None

run(benchmark=None)

执行回测

遍历所有时间点,在每个时间点调用策略的on_bar方法进行交易决策, 并记录交易后的状态。

大规模回测(>500 bar)时自动使用分区迭代,分块加载数据到内存, 避免一次性加载全量数据。

参数:

名称 类型 描述 默认
benchmark str

基准代码,如 '000300.SH'(沪深300)。 如果提供,会下载基准数据用于对比。

None

返回:

类型 描述

None。回测结果保存在 self.results 中,可通过

self.summary() 查看绩效指标,self.report() 生成报告。

源代码位于: qka/core/backtest.py
def run(self, benchmark: Optional[str] = None):
    """
    执行回测

    遍历所有时间点,在每个时间点调用策略的on_bar方法进行交易决策,
    并记录交易后的状态。

    大规模回测(>500 bar)时自动使用分区迭代,分块加载数据到内存,
    避免一次性加载全量数据。

    Args:
        benchmark (str, optional): 基准代码,如 '000300.SH'(沪深300)。
                                   如果提供,会下载基准数据用于对比。

    Returns:
        None。回测结果保存在 self.results 中,可通过
        self.summary() 查看绩效指标,self.report() 生成报告。
    """
    # 获取数据(优先用 lazy 模式,由 Backtest 决定是否分区)
    raw = self.data.get(lazy=True)

    # 加载基准数据
    if benchmark:
        self._load_benchmark(benchmark)

    if isinstance(raw, dd.DataFrame):
        # ── dask 模式:分区迭代 ──
        ddf: dd.DataFrame = raw
        n_rows = len(ddf)

        if n_rows > 500:
            # 大规模:先算 index,再按日期分块读取
            index = ddf.index.compute()
            chunk_size = 500

            for start in range(0, n_rows, chunk_size):
                end = min(start + chunk_size, n_rows)
                date_start, date_end = index[start], index[end - 1]

                chunk = ddf.loc[date_start:date_end].compute()
                for dt, row in chunk.iterrows():
                    by_factor = self._parse_row(row)
                    for factor, data in by_factor.items():
                        self.strategy._data.push(dt, factor, data)
                    # dask 路径:策略使用 self.get() / self.history()
                    self.strategy.on_bar(dt)
                    self.strategy.broker.on_bar(
                        dt, self.strategy._data.get
                    )
        else:
            # 小规模:一次加载,零开销
            df = ddf.compute()
            for date, row in df.iterrows():
                by_factor = self._parse_row(row)
                for factor, data in by_factor.items():
                    self.strategy._data.push(date, factor, data)
                self.strategy.on_bar(date)
                self.strategy.broker.on_bar(
                    date, self.strategy._data.get
                )
    else:
        # ── pandas 模式:向后兼容(测试 mock 数据等) ──
        df: pd.DataFrame = raw
        for date, row in df.iterrows():
            by_factor = self._parse_row(row)
            for factor, data in by_factor.items():
                self.strategy._data.push(date, factor, data)
            self.strategy.on_bar(date)
            self.strategy.broker.on_bar(
                date, self.strategy._data.get
            )

    # 保存回测结果
    self.results = self.strategy.broker.trades

summary()

计算并打印回测绩效指标

返回包含以下指标的字典: - 总收益率、年化收益率、年化波动率 - 夏普比率、最大回撤、Calmar比率 - 胜率、盈亏比、交易次数 - 最终资产、总手续费

返回:

名称 类型 描述
dict dict

绩效指标字典

源代码位于: qka/core/backtest.py
def summary(self) -> dict:
    """
    计算并打印回测绩效指标

    返回包含以下指标的字典:
    - 总收益率、年化收益率、年化波动率
    - 夏普比率、最大回撤、Calmar比率
    - 胜率、盈亏比、交易次数
    - 最终资产、总手续费

    Returns:
        dict: 绩效指标字典
    """
    if self.results is None or self.results.empty:
        print("请先运行回测 (backtest.run())")
        return {}

    totals = self.results['total']
    if len(totals) < 2:
        print("回测数据不足(至少需要2个交易周期)")
        return {}

    # 基本数据
    initial = self.initial_cash
    final = totals.iloc[-1]
    total_return = (final / initial - 1) * 100

    # 交易天数 / 年化
    n_days = len(totals)
    years = n_days / 252  # A股年均约252个交易日

    # 日收益率序列
    daily_returns = totals.pct_change().dropna()
    if len(daily_returns) == 0:
        print("没有足够的收益率数据")
        return {}

    # 年化收益率
    annual_return = (final / initial) ** (1 / years) - 1 if years > 0 else 0

    # 年化波动率
    annual_vol = daily_returns.std() * np.sqrt(252)

    # 夏普比率(无风险利率假设 3%)
    risk_free_rate = 0.03
    sharpe = (annual_return - risk_free_rate) / annual_vol if annual_vol > 0 else 0

    # 最大回撤
    cumulative = (1 + daily_returns).cumprod()
    running_max = cumulative.cummax()
    drawdown = (cumulative - running_max) / running_max
    max_drawdown = drawdown.min() * 100

    # Calmar 比率
    calmar = annual_return / abs(max_drawdown / 100) if max_drawdown != 0 else 0

    # 交易分析
    trades = self.strategy.broker.trade_history
    n_trades = len(trades)

    if n_trades > 0:
        # 统计每笔交易的盈亏
        trade_pnl = []
        buy_prices = {}
        for t in trades:
            if t['action'] == 'buy':
                if t['symbol'] not in buy_prices:
                    buy_prices[t['symbol']] = []
                buy_prices[t['symbol']].append((t['size'], t['exec_price'], t['total_cost']))
            elif t['action'] == 'sell':
                symbol = t['symbol']
                size = t['size']
                net_proceeds = t['net_proceeds']
                # 按先进先出匹配买入
                if symbol in buy_prices and buy_prices[symbol]:
                    total_buy_cost = 0
                    remaining = size
                    while remaining > 0 and buy_prices[symbol]:
                        b_size, b_price, b_cost = buy_prices[symbol][0]
                        if b_size <= remaining:
                            total_buy_cost += b_cost
                            remaining -= b_size
                            buy_prices[symbol].pop(0)
                        else:
                            ratio = remaining / b_size
                            total_buy_cost += b_cost * ratio
                            buy_prices[symbol][0] = (b_size - remaining, b_price, b_cost * (1 - ratio))
                            remaining = 0
                    pnl = net_proceeds - total_buy_cost
                    trade_pnl.append(pnl)

        win_trades = sum(1 for p in trade_pnl if p > 0)
        win_rate = (win_trades / len(trade_pnl) * 100) if trade_pnl else 0
        avg_win = np.mean([p for p in trade_pnl if p > 0]) if any(p > 0 for p in trade_pnl) else 0
        avg_loss = abs(np.mean([p for p in trade_pnl if p <= 0])) if any(p <= 0 for p in trade_pnl) else 0
        profit_loss_ratio = avg_win / avg_loss if avg_loss > 0 else 0
    else:
        win_rate = 0
        profit_loss_ratio = 0
        trade_pnl = []

    # 总手续费
    total_commission = self.strategy.broker.total_commission

    # 打印报告
    print("=" * 55)
    print("           回测绩效报告")
    print("=" * 55)
    print(f"  初始资金:        RMB {initial:>10,.2f}")
    print(f"  最终资产:        RMB {final:>10,.2f}")
    print(f"  总收益率:         {total_return:>+8.2f}%")
    print(f"  年化收益率:       {annual_return * 100:>+8.2f}%")
    print(f"  年化波动率:       {annual_vol * 100:>8.2f}%")
    print(f"  夏普比率:         {sharpe:>8.2f}")
    print(f"  最大回撤:         {max_drawdown:>8.2f}%")
    print(f"  Calmar比率:       {calmar:>8.2f}")
    print(f"  交易次数:         {n_trades:>8}")
    print(f"  胜率:             {win_rate:>8.2f}%")
    print(f"  盈亏比:           {profit_loss_ratio:>8.2f}")
    print(f"  总手续费:         RMB {total_commission:>10,.2f}")
    print(f"  回测天数:         {n_days:>8} 天")
    print("=" * 55)

    return {
        'initial_cash': initial,
        'final_equity': final,
        'total_return_pct': total_return,
        'annual_return_pct': annual_return * 100,
        'annual_volatility_pct': annual_vol * 100,
        'sharpe_ratio': sharpe,
        'max_drawdown_pct': max_drawdown,
        'calmar_ratio': calmar,
        'total_trades': n_trades,
        'win_rate_pct': win_rate,
        'profit_loss_ratio': profit_loss_ratio,
        'total_commission': total_commission,
        'n_days': n_days,
    }

report(title='未命名策略', output_path=None)

生成自包含的 HTML 回测报告

包含绩效指标卡片、净值曲线、回撤曲线、月度收益率热力图、 交易明细表和回撤分析。可直接在浏览器中打开。

参数:

名称 类型 描述 默认
title str

策略名称(显示在报告标题中)

'未命名策略'
output_path Optional[str]

输出 HTML 文件路径。 None 则自动保存在 reports/ 目录下

None

返回:

名称 类型 描述
str str

HTML 文件路径

源代码位于: qka/core/backtest.py
def report(self, title: str = "未命名策略", output_path: Optional[str] = None) -> str:
    """
    生成自包含的 HTML 回测报告

    包含绩效指标卡片、净值曲线、回撤曲线、月度收益率热力图、
    交易明细表和回撤分析。可直接在浏览器中打开。

    Args:
        title: 策略名称(显示在报告标题中)
        output_path: 输出 HTML 文件路径。
                     None 则自动保存在 reports/ 目录下

    Returns:
        str: HTML 文件路径
    """
    from qka.core.report import generate_report

    if self.results is None or self.results.empty:
        print("错误:请先运行回测 (bt.run())")
        return ""

    bm = getattr(self, '_benchmark_data', None)
    return generate_report(
        results=self.results,
        broker=self.strategy.broker,
        initial_cash=self.initial_cash,
        benchmark_data=bm,
        strategy_name=title,
        output_path=output_path,
    )

qka.Strategy

Bases: ABC

策略抽象基类

所有自定义策略都应该继承此类,并实现 on_bar 方法。

属性:

名称 类型 描述
broker Broker

交易经纪商实例,用于执行交易操作

sizing SizingAccessor

仓位计算工具,提供 self.sizing.percent() 等方法

_data DataAccessor

数据访问器,提供 self.get() 和 self.history() 接口

源代码位于: qka/core/strategy.py
class Strategy(ABC):
    """
    策略抽象基类

    所有自定义策略都应该继承此类,并实现 on_bar 方法。

    Attributes:
        broker (Broker): 交易经纪商实例,用于执行交易操作
        sizing (SizingAccessor): 仓位计算工具,提供 self.sizing.percent() 等方法
        _data (DataAccessor): 数据访问器,提供 self.get() 和 self.history() 接口
    """

    def __init__(self, cash: float = 100000.0):
        """
        初始化策略

        Args:
            cash: 初始资金,默认 10 万元
        """
        self.broker = Broker(initial_cash=cash)
        self.sizing = SizingAccessor(self.broker)
        self._data = DataAccessor(max_window=750)

    def get(self, factor: str):
        """
        获取当前 bar 的横截面数据。

        替代旧的 on_bar(date, get) 中的 get 参数。
        仅当 on_bar 通过 self._data 注入数据后才能使用。

        Args:
            factor: 因子名,如 'close', 'volume'

        Returns:
            pd.Series,index=股票代码,values=最新值
        """
        return self._data.get(factor)

    def history(self, factor: str, window: int = 20):
        """
        获取因子的历史窗口数据。

        Args:
            factor: 因子名
            window: 窗口大小

        Returns:
            pd.DataFrame,行=日期,列=股票代码
        """
        return self._data.history(factor, window)

    @abstractmethod
    def on_bar(self, date):
        """
        每个 bar 的处理逻辑,必须由子类实现。

        使用 self.get(factor) / self.history(factor, window) 获取数据。

        --- 用法 ---

        class MyStrategy(Strategy):
            def on_bar(self, date):
                # 横截面数据(当前 bar 所有股票)
                close = self.get('close')

                # 历史序列(过去 N 天)
                hist = self.history('close', 20)

                # 仓位管理
                size = self.sizing.percent(0.1, float(close['000001.SZ']))

                # 交易操作
                if not close.empty and '000001.SZ' in close.index:
                    price = float(close['000001.SZ'])
                    size = self.sizing.percent(0.1, price)
                    self.broker.buy('000001.SZ', price, size)

        Args:
            date: 当前日期(pd.Timestamp)
        """

__init__(cash=100000.0)

初始化策略

参数:

名称 类型 描述 默认
cash float

初始资金,默认 10 万元

100000.0
源代码位于: qka/core/strategy.py
def __init__(self, cash: float = 100000.0):
    """
    初始化策略

    Args:
        cash: 初始资金,默认 10 万元
    """
    self.broker = Broker(initial_cash=cash)
    self.sizing = SizingAccessor(self.broker)
    self._data = DataAccessor(max_window=750)

get(factor)

获取当前 bar 的横截面数据。

替代旧的 on_bar(date, get) 中的 get 参数。 仅当 on_bar 通过 self._data 注入数据后才能使用。

参数:

名称 类型 描述 默认
factor str

因子名,如 'close', 'volume'

必需

返回:

类型 描述

pd.Series,index=股票代码,values=最新值

源代码位于: qka/core/strategy.py
def get(self, factor: str):
    """
    获取当前 bar 的横截面数据。

    替代旧的 on_bar(date, get) 中的 get 参数。
    仅当 on_bar 通过 self._data 注入数据后才能使用。

    Args:
        factor: 因子名,如 'close', 'volume'

    Returns:
        pd.Series,index=股票代码,values=最新值
    """
    return self._data.get(factor)

history(factor, window=20)

获取因子的历史窗口数据。

参数:

名称 类型 描述 默认
factor str

因子名

必需
window int

窗口大小

20

返回:

类型 描述

pd.DataFrame,行=日期,列=股票代码

源代码位于: qka/core/strategy.py
def history(self, factor: str, window: int = 20):
    """
    获取因子的历史窗口数据。

    Args:
        factor: 因子名
        window: 窗口大小

    Returns:
        pd.DataFrame,行=日期,列=股票代码
    """
    return self._data.history(factor, window)

on_bar(date) abstractmethod

每个 bar 的处理逻辑,必须由子类实现。

使用 self.get(factor) / self.history(factor, window) 获取数据。

--- 用法 ---

class MyStrategy(Strategy): def on_bar(self, date): # 横截面数据(当前 bar 所有股票) close = self.get('close')

    # 历史序列(过去 N 天)
    hist = self.history('close', 20)

    # 仓位管理
    size = self.sizing.percent(0.1, float(close['000001.SZ']))

    # 交易操作
    if not close.empty and '000001.SZ' in close.index:
        price = float(close['000001.SZ'])
        size = self.sizing.percent(0.1, price)
        self.broker.buy('000001.SZ', price, size)

参数:

名称 类型 描述 默认
date

当前日期(pd.Timestamp)

必需
源代码位于: qka/core/strategy.py
@abstractmethod
def on_bar(self, date):
    """
    每个 bar 的处理逻辑,必须由子类实现。

    使用 self.get(factor) / self.history(factor, window) 获取数据。

    --- 用法 ---

    class MyStrategy(Strategy):
        def on_bar(self, date):
            # 横截面数据(当前 bar 所有股票)
            close = self.get('close')

            # 历史序列(过去 N 天)
            hist = self.history('close', 20)

            # 仓位管理
            size = self.sizing.percent(0.1, float(close['000001.SZ']))

            # 交易操作
            if not close.empty and '000001.SZ' in close.index:
                price = float(close['000001.SZ'])
                size = self.sizing.percent(0.1, price)
                self.broker.buy('000001.SZ', price, size)

    Args:
        date: 当前日期(pd.Timestamp)
    """

qka.Broker

虚拟交易经纪商类

管理资金、持仓和交易记录,提供买入卖出操作接口。 支持佣金、印花税、滑点等真实交易成本模拟。

属性:

名称 类型 描述
cash float

可用现金

positions Dict

持仓记录

trade_history List

交易历史记录

commission_rate float

佣金费率

stamp_duty_rate float

印花税费率(仅卖出)

slippage float

滑点比率

total_commission float

累计佣金

total_stamp_duty float

累计印花税

total_slippage_cost float

累计滑点成本

trades DataFrame

逐日状态记录

源代码位于: qka/core/broker.py
class Broker:
    """
    虚拟交易经纪商类

    管理资金、持仓和交易记录,提供买入卖出操作接口。
    支持佣金、印花税、滑点等真实交易成本模拟。

    Attributes:
        cash (float): 可用现金
        positions (Dict): 持仓记录
        trade_history (List): 交易历史记录
        commission_rate (float): 佣金费率
        stamp_duty_rate (float): 印花税费率(仅卖出)
        slippage (float): 滑点比率
        total_commission (float): 累计佣金
        total_stamp_duty (float): 累计印花税
        total_slippage_cost (float): 累计滑点成本
        trades (pd.DataFrame): 逐日状态记录
    """

    def __init__(self, initial_cash=100000.0,
                 commission_rate=DEFAULT_COMMISSION_RATE,
                 stamp_duty_rate=DEFAULT_STAMP_DUTY_RATE,
                 slippage=DEFAULT_SLIPPAGE):
        """
        初始化Broker

        Args:
            initial_cash (float): 初始资金,默认10万元
            commission_rate (float): 佣金费率,默认万2.5
            stamp_duty_rate (float): 印花税费率(仅卖出),默认万5
            slippage (float): 滑点比率,默认0.1%
        """
        self.cash = initial_cash
        self.positions = {}
        self.trade_history = []
        self.timestamp = None

        self.commission_rate = commission_rate
        self.stamp_duty_rate = stamp_duty_rate
        self.slippage = slippage

        self.total_commission = 0.0
        self.total_stamp_duty = 0.0
        self.total_slippage_cost = 0.0

        self.trades = pd.DataFrame(columns=[
            'cash', 'value', 'total', 'positions', 'trades'
        ])

    def on_bar(self, date, get):
        """
        Bar结束时记录当前状态。

        Args:
            date: 当前时间戳
            get: 获取因子数据的函数
        """
        self.timestamp = date
        total_value = self.cash
        position_summary = {}
        for symbol, pos in self.positions.items():
            price = get('close')
            if symbol in price.index:
                current_price = price[symbol]
                market_value = pos['size'] * current_price
                total_value += market_value
                position_summary[symbol] = {
                    'size': pos['size'],
                    'avg_price': pos['avg_price'],
                    'current_price': current_price,
                    'market_value': market_value,
                    'profit_pct': (current_price / pos['avg_price'] - 1) * 100 if pos['avg_price'] > 0 else 0,
                }

        self.trades.loc[self.timestamp] = {
            'cash': self.cash,
            'value': total_value - self.cash,
            'total': total_value,
            'positions': position_summary,
            'trades': list(self.trade_history),
        }

    def buy(self, symbol: str, price: float, size: int) -> bool:
        """
        买入操作

        考虑滑点(买入价上移)和佣金(最低 5 元)。

        Args:
            symbol (str): 交易标的代码
            price (float): 市价
            size (int): 买入数量

        Returns:
            bool: 交易是否成功
        """
        if size <= 0:
            logger.warning(f"买入数量必须大于 0!当前: {size}")
            return False

        if price <= 0:
            logger.warning(f"价格 {price:.2f} 不合法(前复权可能导致早期价格为负),跳过买入 {symbol}")
            return False

        exec_price = price * (1 + self.slippage)
        amount = exec_price * size
        if self.commission_rate > 0:
            commission = max(amount * self.commission_rate, MIN_COMMISSION)
        else:
            commission = 0.0
        total_cost = amount + commission

        if self.cash < total_cost:
            logger.debug(f"资金不足!需要 {total_cost:.2f}(佣金 {commission:.2f}),当前可用 {self.cash:.2f}")
            return False

        # 执行买入
        self.cash -= total_cost
        self.total_commission += commission
        self.total_slippage_cost += amount - price * size

        # 更新持仓(按实际成交价记录成本)
        if symbol in self.positions:
            old = self.positions[symbol]
            new_total = old['size'] * old['avg_price'] + amount
            new_size = old['size'] + size
            self.positions[symbol] = {'size': new_size, 'avg_price': new_total / new_size}
        else:
            self.positions[symbol] = {'size': size, 'avg_price': exec_price}

        self.trade_history.append({
            'action': 'buy', 'symbol': symbol,
            'price': price, 'exec_price': exec_price,
            'size': size, 'amount': amount,
            'commission': commission, 'total_cost': total_cost,
            'timestamp': self.timestamp,
        })

        logger.debug(f"买入成功: {symbol} {size}股 @ {exec_price:.2f},花费 {total_cost:.2f}(佣金 {commission:.2f})")
        return True

    def sell(self, symbol: str, price: float, size: int) -> bool:
        """
        卖出操作

        考虑滑点(卖出价下移)、佣金(最低 5 元)和印花税。

        Args:
            symbol (str): 交易标的代码
            price (float): 市价
            size (int): 卖出数量

        Returns:
            bool: 交易是否成功
        """
        if size <= 0:
            logger.warning(f"卖出数量必须大于 0!当前: {size}")
            return False

        if price <= 0:
            logger.warning(f"价格 {price:.2f} 不合法,跳过卖出 {symbol}")
            return False

        if symbol not in self.positions:
            logger.warning(f"没有 {symbol} 的持仓!")
            return False

        position = self.positions[symbol]
        if position['size'] < size:
            logger.warning(f"持仓不足!当前持有 {position['size']},尝试卖出 {size}")
            return False

        exec_price = price * (1 - self.slippage)
        amount = exec_price * size
        if self.commission_rate > 0:
            commission = max(amount * self.commission_rate, MIN_COMMISSION)
        else:
            commission = 0.0
        stamp_duty = amount * self.stamp_duty_rate
        net_proceeds = amount - commission - stamp_duty

        # 执行卖出
        self.cash += net_proceeds
        self.total_commission += commission
        self.total_stamp_duty += stamp_duty
        self.total_slippage_cost += price * size - amount

        # 更新持仓
        if position['size'] == size:
            del self.positions[symbol]
        else:
            self.positions[symbol]['size'] -= size

        self.trade_history.append({
            'action': 'sell', 'symbol': symbol,
            'price': price, 'exec_price': exec_price,
            'size': size, 'amount': amount,
            'commission': commission, 'stamp_duty': stamp_duty,
            'net_proceeds': net_proceeds,
            'timestamp': self.timestamp,
        })

        logger.debug(f"卖出成功: {symbol} {size}股 @ {exec_price:.2f},获得 {net_proceeds:.2f}(佣金 {commission:.2f} + 印花税 {stamp_duty:.2f})")
        return True

    def get(self, factor: str, timestamp=None) -> Any:
        """
        从trades DataFrame中获取数据

        Args:
            factor (str): 列名,可选 'cash', 'value', 'total', 'positions', 'trades'
            timestamp: 时间戳,为None则使用当前时间戳

        Returns:
            Any: 对应列的数据,不存在则返回None
        """
        ts = timestamp if timestamp is not None else self.timestamp
        if ts is None or ts not in self.trades.index:
            return None
        if factor not in self.trades.columns:
            return None
        return self.trades.at[ts, factor]

__init__(initial_cash=100000.0, commission_rate=DEFAULT_COMMISSION_RATE, stamp_duty_rate=DEFAULT_STAMP_DUTY_RATE, slippage=DEFAULT_SLIPPAGE)

初始化Broker

参数:

名称 类型 描述 默认
initial_cash float

初始资金,默认10万元

100000.0
commission_rate float

佣金费率,默认万2.5

DEFAULT_COMMISSION_RATE
stamp_duty_rate float

印花税费率(仅卖出),默认万5

DEFAULT_STAMP_DUTY_RATE
slippage float

滑点比率,默认0.1%

DEFAULT_SLIPPAGE
源代码位于: qka/core/broker.py
def __init__(self, initial_cash=100000.0,
             commission_rate=DEFAULT_COMMISSION_RATE,
             stamp_duty_rate=DEFAULT_STAMP_DUTY_RATE,
             slippage=DEFAULT_SLIPPAGE):
    """
    初始化Broker

    Args:
        initial_cash (float): 初始资金,默认10万元
        commission_rate (float): 佣金费率,默认万2.5
        stamp_duty_rate (float): 印花税费率(仅卖出),默认万5
        slippage (float): 滑点比率,默认0.1%
    """
    self.cash = initial_cash
    self.positions = {}
    self.trade_history = []
    self.timestamp = None

    self.commission_rate = commission_rate
    self.stamp_duty_rate = stamp_duty_rate
    self.slippage = slippage

    self.total_commission = 0.0
    self.total_stamp_duty = 0.0
    self.total_slippage_cost = 0.0

    self.trades = pd.DataFrame(columns=[
        'cash', 'value', 'total', 'positions', 'trades'
    ])

on_bar(date, get)

Bar结束时记录当前状态。

参数:

名称 类型 描述 默认
date

当前时间戳

必需
get

获取因子数据的函数

必需
源代码位于: qka/core/broker.py
def on_bar(self, date, get):
    """
    Bar结束时记录当前状态。

    Args:
        date: 当前时间戳
        get: 获取因子数据的函数
    """
    self.timestamp = date
    total_value = self.cash
    position_summary = {}
    for symbol, pos in self.positions.items():
        price = get('close')
        if symbol in price.index:
            current_price = price[symbol]
            market_value = pos['size'] * current_price
            total_value += market_value
            position_summary[symbol] = {
                'size': pos['size'],
                'avg_price': pos['avg_price'],
                'current_price': current_price,
                'market_value': market_value,
                'profit_pct': (current_price / pos['avg_price'] - 1) * 100 if pos['avg_price'] > 0 else 0,
            }

    self.trades.loc[self.timestamp] = {
        'cash': self.cash,
        'value': total_value - self.cash,
        'total': total_value,
        'positions': position_summary,
        'trades': list(self.trade_history),
    }

buy(symbol, price, size)

买入操作

考虑滑点(买入价上移)和佣金(最低 5 元)。

参数:

名称 类型 描述 默认
symbol str

交易标的代码

必需
price float

市价

必需
size int

买入数量

必需

返回:

名称 类型 描述
bool bool

交易是否成功

源代码位于: qka/core/broker.py
def buy(self, symbol: str, price: float, size: int) -> bool:
    """
    买入操作

    考虑滑点(买入价上移)和佣金(最低 5 元)。

    Args:
        symbol (str): 交易标的代码
        price (float): 市价
        size (int): 买入数量

    Returns:
        bool: 交易是否成功
    """
    if size <= 0:
        logger.warning(f"买入数量必须大于 0!当前: {size}")
        return False

    if price <= 0:
        logger.warning(f"价格 {price:.2f} 不合法(前复权可能导致早期价格为负),跳过买入 {symbol}")
        return False

    exec_price = price * (1 + self.slippage)
    amount = exec_price * size
    if self.commission_rate > 0:
        commission = max(amount * self.commission_rate, MIN_COMMISSION)
    else:
        commission = 0.0
    total_cost = amount + commission

    if self.cash < total_cost:
        logger.debug(f"资金不足!需要 {total_cost:.2f}(佣金 {commission:.2f}),当前可用 {self.cash:.2f}")
        return False

    # 执行买入
    self.cash -= total_cost
    self.total_commission += commission
    self.total_slippage_cost += amount - price * size

    # 更新持仓(按实际成交价记录成本)
    if symbol in self.positions:
        old = self.positions[symbol]
        new_total = old['size'] * old['avg_price'] + amount
        new_size = old['size'] + size
        self.positions[symbol] = {'size': new_size, 'avg_price': new_total / new_size}
    else:
        self.positions[symbol] = {'size': size, 'avg_price': exec_price}

    self.trade_history.append({
        'action': 'buy', 'symbol': symbol,
        'price': price, 'exec_price': exec_price,
        'size': size, 'amount': amount,
        'commission': commission, 'total_cost': total_cost,
        'timestamp': self.timestamp,
    })

    logger.debug(f"买入成功: {symbol} {size}股 @ {exec_price:.2f},花费 {total_cost:.2f}(佣金 {commission:.2f})")
    return True

sell(symbol, price, size)

卖出操作

考虑滑点(卖出价下移)、佣金(最低 5 元)和印花税。

参数:

名称 类型 描述 默认
symbol str

交易标的代码

必需
price float

市价

必需
size int

卖出数量

必需

返回:

名称 类型 描述
bool bool

交易是否成功

源代码位于: qka/core/broker.py
def sell(self, symbol: str, price: float, size: int) -> bool:
    """
    卖出操作

    考虑滑点(卖出价下移)、佣金(最低 5 元)和印花税。

    Args:
        symbol (str): 交易标的代码
        price (float): 市价
        size (int): 卖出数量

    Returns:
        bool: 交易是否成功
    """
    if size <= 0:
        logger.warning(f"卖出数量必须大于 0!当前: {size}")
        return False

    if price <= 0:
        logger.warning(f"价格 {price:.2f} 不合法,跳过卖出 {symbol}")
        return False

    if symbol not in self.positions:
        logger.warning(f"没有 {symbol} 的持仓!")
        return False

    position = self.positions[symbol]
    if position['size'] < size:
        logger.warning(f"持仓不足!当前持有 {position['size']},尝试卖出 {size}")
        return False

    exec_price = price * (1 - self.slippage)
    amount = exec_price * size
    if self.commission_rate > 0:
        commission = max(amount * self.commission_rate, MIN_COMMISSION)
    else:
        commission = 0.0
    stamp_duty = amount * self.stamp_duty_rate
    net_proceeds = amount - commission - stamp_duty

    # 执行卖出
    self.cash += net_proceeds
    self.total_commission += commission
    self.total_stamp_duty += stamp_duty
    self.total_slippage_cost += price * size - amount

    # 更新持仓
    if position['size'] == size:
        del self.positions[symbol]
    else:
        self.positions[symbol]['size'] -= size

    self.trade_history.append({
        'action': 'sell', 'symbol': symbol,
        'price': price, 'exec_price': exec_price,
        'size': size, 'amount': amount,
        'commission': commission, 'stamp_duty': stamp_duty,
        'net_proceeds': net_proceeds,
        'timestamp': self.timestamp,
    })

    logger.debug(f"卖出成功: {symbol} {size}股 @ {exec_price:.2f},获得 {net_proceeds:.2f}(佣金 {commission:.2f} + 印花税 {stamp_duty:.2f})")
    return True

get(factor, timestamp=None)

从trades DataFrame中获取数据

参数:

名称 类型 描述 默认
factor str

列名,可选 'cash', 'value', 'total', 'positions', 'trades'

必需
timestamp

时间戳,为None则使用当前时间戳

None

返回:

名称 类型 描述
Any Any

对应列的数据,不存在则返回None

源代码位于: qka/core/broker.py
def get(self, factor: str, timestamp=None) -> Any:
    """
    从trades DataFrame中获取数据

    Args:
        factor (str): 列名,可选 'cash', 'value', 'total', 'positions', 'trades'
        timestamp: 时间戳,为None则使用当前时间戳

    Returns:
        Any: 对应列的数据,不存在则返回None
    """
    ts = timestamp if timestamp is not None else self.timestamp
    if ts is None or ts not in self.trades.index:
        return None
    if factor not in self.trades.columns:
        return None
    return self.trades.at[ts, factor]

qka.SizingAccessor

仓位计算工具

挂载在 Strategy.sizing 下,用于计算每次交易的股数。 自动获取 Broker 的资金信息(cash、positions、total_value)。

A 股最小交易单位 100 股,所有方法自动按手向下取整。

源代码位于: qka/core/sizing.py
class SizingAccessor:
    """
    仓位计算工具

    挂载在 Strategy.sizing 下,用于计算每次交易的股数。
    自动获取 Broker 的资金信息(cash、positions、total_value)。

    A 股最小交易单位 100 股,所有方法自动按手向下取整。
    """

    # A 股最小手数
    MIN_LOT = 100

    def __init__(self, broker):
        self._broker = broker

    def _round_lot(self, shares: float) -> int:
        """按手取整(向下),不足一手返回 0"""
        if shares < self.MIN_LOT:
            return 0
        return int(floor(shares / self.MIN_LOT) * self.MIN_LOT)

    def _validate_positive(self, value, name: str):
        if value <= 0:
            raise ValueError(f"{name} 必须为正数,got {value}")

    def _validate_non_negative(self, value, name: str):
        if value < 0:
            raise ValueError(f"{name} 不能为负数,got {value}")

    def _validate_range(self, value, lo, hi, name: str):
        if not (lo <= value <= hi):
            raise ValueError(f"{name} 必须在 [{lo}, {hi}] 范围内,got {value}")

    # ── 核心方法 ──

    def fixed_shares(self, n: int) -> int:
        """
        固定股数。

        如果 n 不足一手(100股),返回 0。

        Args:
            n: 股数

        Returns:
            int: 按手取整后的股数
        """
        self._validate_positive(n, 'n')
        return self._round_lot(n)

    def fixed_amount(self, amount: float, price: float) -> int:
        """
        固定金额。

        计算 amount 能买多少股,向下按手取整。

        Args:
            amount: 投入金额
            price: 每股价格

        Returns:
            int: 可买入股数(按手取整)
        """
        self._validate_positive(amount, 'amount')
        self._validate_positive(price, 'price')
        return self._round_lot(amount / price)

    def percent(self, ratio: float, price: float) -> int:
        """
        资金百分比。

        使用可用现金的 ratio 比例买入,按手取整。

        Args:
            ratio: 0~1 之间的比例,如 0.1 表示 10%
            price: 每股价格

        Returns:
            int: 可买入股数
        """
        self._validate_range(ratio, 0, 1, 'ratio')
        self._validate_positive(price, 'price')
        if self._broker.cash <= 0:
            return 0
        return self._round_lot(self._broker.cash * ratio / price)

    def atr_risk(self, risk_ratio: float, price: float, atr_value: float,
                 multiplier: float = 2.0) -> int:
        """
        ATR 风险仓位。

        基于 ATR(平均真实波幅)计算仓位,确保单笔亏损不超过 risk_ratio 比例。

        公式:股数 = (cash * risk_ratio) / (atr_value * multiplier)

        Args:
            risk_ratio: 0~1,单笔最大亏损占资金比例
            price: 每股价格
            atr_value: ATR 当前值
            multiplier: 止损倍数,默认 2(即止损位为入场价 ± 2 * ATR)

        Returns:
            int: 可买入股数
        """
        self._validate_range(risk_ratio, 0, 1, 'risk_ratio')
        self._validate_positive(price, 'price')
        self._validate_non_negative(atr_value, 'atr_value')
        self._validate_positive(multiplier, 'multiplier')
        if self._broker.cash <= 0 or atr_value <= 0:
            return 0
        risk_per_share = atr_value * multiplier
        if risk_per_share <= 0:
            return 0
        return self._round_lot(self._broker.cash * risk_ratio / risk_per_share)

    def kelly(self, win_rate: float, win_loss_ratio: float, price: float) -> int:
        """
        凯利公式。

        f* = (p * b - q) / b

        其中:
        - p = 胜率
        - b = 赔率(盈利/亏损)
        - q = 1 - p(败率)

        当 f* ≤ 0 时返回 0(不建议下注)。

        Args:
            win_rate: 胜率,0~1
            win_loss_ratio: 赔率(平均盈利 / 平均亏损)
            price: 每股价格

        Returns:
            int: 可买入股数
        """
        self._validate_range(win_rate, 0, 1, 'win_rate')
        self._validate_positive(win_loss_ratio, 'win_loss_ratio')
        self._validate_positive(price, 'price')
        if self._broker.cash <= 0:
            return 0
        p = win_rate
        b = win_loss_ratio
        q = 1 - p
        fraction = (p * b - q) / b
        if fraction <= 0:
            return 0
        return self._round_lot(self._broker.cash * fraction / price)

fixed_shares(n)

固定股数。

如果 n 不足一手(100股),返回 0。

参数:

名称 类型 描述 默认
n int

股数

必需

返回:

名称 类型 描述
int int

按手取整后的股数

源代码位于: qka/core/sizing.py
def fixed_shares(self, n: int) -> int:
    """
    固定股数。

    如果 n 不足一手(100股),返回 0。

    Args:
        n: 股数

    Returns:
        int: 按手取整后的股数
    """
    self._validate_positive(n, 'n')
    return self._round_lot(n)

fixed_amount(amount, price)

固定金额。

计算 amount 能买多少股,向下按手取整。

参数:

名称 类型 描述 默认
amount float

投入金额

必需
price float

每股价格

必需

返回:

名称 类型 描述
int int

可买入股数(按手取整)

源代码位于: qka/core/sizing.py
def fixed_amount(self, amount: float, price: float) -> int:
    """
    固定金额。

    计算 amount 能买多少股,向下按手取整。

    Args:
        amount: 投入金额
        price: 每股价格

    Returns:
        int: 可买入股数(按手取整)
    """
    self._validate_positive(amount, 'amount')
    self._validate_positive(price, 'price')
    return self._round_lot(amount / price)

percent(ratio, price)

资金百分比。

使用可用现金的 ratio 比例买入,按手取整。

参数:

名称 类型 描述 默认
ratio float

0~1 之间的比例,如 0.1 表示 10%

必需
price float

每股价格

必需

返回:

名称 类型 描述
int int

可买入股数

源代码位于: qka/core/sizing.py
def percent(self, ratio: float, price: float) -> int:
    """
    资金百分比。

    使用可用现金的 ratio 比例买入,按手取整。

    Args:
        ratio: 0~1 之间的比例,如 0.1 表示 10%
        price: 每股价格

    Returns:
        int: 可买入股数
    """
    self._validate_range(ratio, 0, 1, 'ratio')
    self._validate_positive(price, 'price')
    if self._broker.cash <= 0:
        return 0
    return self._round_lot(self._broker.cash * ratio / price)

atr_risk(risk_ratio, price, atr_value, multiplier=2.0)

ATR 风险仓位。

基于 ATR(平均真实波幅)计算仓位,确保单笔亏损不超过 risk_ratio 比例。

公式:股数 = (cash * risk_ratio) / (atr_value * multiplier)

参数:

名称 类型 描述 默认
risk_ratio float

0~1,单笔最大亏损占资金比例

必需
price float

每股价格

必需
atr_value float

ATR 当前值

必需
multiplier float

止损倍数,默认 2(即止损位为入场价 ± 2 * ATR)

2.0

返回:

名称 类型 描述
int int

可买入股数

源代码位于: qka/core/sizing.py
def atr_risk(self, risk_ratio: float, price: float, atr_value: float,
             multiplier: float = 2.0) -> int:
    """
    ATR 风险仓位。

    基于 ATR(平均真实波幅)计算仓位,确保单笔亏损不超过 risk_ratio 比例。

    公式:股数 = (cash * risk_ratio) / (atr_value * multiplier)

    Args:
        risk_ratio: 0~1,单笔最大亏损占资金比例
        price: 每股价格
        atr_value: ATR 当前值
        multiplier: 止损倍数,默认 2(即止损位为入场价 ± 2 * ATR)

    Returns:
        int: 可买入股数
    """
    self._validate_range(risk_ratio, 0, 1, 'risk_ratio')
    self._validate_positive(price, 'price')
    self._validate_non_negative(atr_value, 'atr_value')
    self._validate_positive(multiplier, 'multiplier')
    if self._broker.cash <= 0 or atr_value <= 0:
        return 0
    risk_per_share = atr_value * multiplier
    if risk_per_share <= 0:
        return 0
    return self._round_lot(self._broker.cash * risk_ratio / risk_per_share)

kelly(win_rate, win_loss_ratio, price)

凯利公式。

f* = (p * b - q) / b

其中: - p = 胜率 - b = 赔率(盈利/亏损) - q = 1 - p(败率)

当 f* ≤ 0 时返回 0(不建议下注)。

参数:

名称 类型 描述 默认
win_rate float

胜率,0~1

必需
win_loss_ratio float

赔率(平均盈利 / 平均亏损)

必需
price float

每股价格

必需

返回:

名称 类型 描述
int int

可买入股数

源代码位于: qka/core/sizing.py
def kelly(self, win_rate: float, win_loss_ratio: float, price: float) -> int:
    """
    凯利公式。

    f* = (p * b - q) / b

    其中:
    - p = 胜率
    - b = 赔率(盈利/亏损)
    - q = 1 - p(败率)

    当 f* ≤ 0 时返回 0(不建议下注)。

    Args:
        win_rate: 胜率,0~1
        win_loss_ratio: 赔率(平均盈利 / 平均亏损)
        price: 每股价格

    Returns:
        int: 可买入股数
    """
    self._validate_range(win_rate, 0, 1, 'win_rate')
    self._validate_positive(win_loss_ratio, 'win_loss_ratio')
    self._validate_positive(price, 'price')
    if self._broker.cash <= 0:
        return 0
    p = win_rate
    b = win_loss_ratio
    q = 1 - p
    fraction = (p * b - q) / b
    if fraction <= 0:
        return 0
    return self._round_lot(self._broker.cash * fraction / price)