Browse Source

[ADD] Sync tdx concepts.

wingedox 2 months ago
commit
9f456a3043
10 changed files with 615 additions and 0 deletions
  1. 36 0
      config.py
  2. 15 0
      config.yaml
  3. 79 0
      main.py
  4. 62 0
      services/service.py
  5. 0 0
      tdx/__init__.py
  6. 259 0
      tdx/tdx_local.py
  7. 0 0
      tests/__init__.py
  8. 12 0
      tests/conftest.py
  9. 6 0
      tests/test_config.py
  10. 146 0
      tests/test_tdx.py

+ 36 - 0
config.py

@@ -0,0 +1,36 @@
+import configargparse
+import logging
+
+from chive.config import config_parse
+
+_logger = logging.getLogger(__name__)
+
+
+def config_argparse():
+    parser = configargparse.ArgumentParser(
+        # default_config_files=['config.yaml', 'config.ini'],
+        default_config_files=['config.yaml'],
+        config_file_parser_class=configargparse.YAMLConfigFileParser,  # 使用YAML解析器
+        auto_env_var_prefix='TRADERX_'  # 环境变量前缀
+    )
+
+    common_group = parser.add_argument_group('Common')
+    common_group.add_argument('--app_name', default='traderx', help='TraderXApp name')
+    common_group.add_argument('--log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='INFO')
+    common_group.add_argument('--zookeeper_url', default='192.168.3.3:2181', help='Zookeeper URL')
+    common_group.add_argument('--pulsar_url', default='pulsar://192.168.3.3:6650', help='Pulsar URL')
+    common_group.add_argument('--rpc_server', default='tcp://10.0.0.1:9483', help='RPC Server URL')
+
+    tdx_group = parser.add_argument_group('TRADER')
+    tdx_group.add_argument('--trader.account', help='Tdx Main Folder')
+
+    tdx_group = parser.add_argument_group('TDX')
+    tdx_group.add_argument('--tdx.main_folder', default='C:/new_tdx', help='Tdx Main Folder')
+    tdx_group.add_argument('--tdx.export_concept_file', default=u'T0002/export/概念板块.txt', help='Concepts file.')
+    tdx_group.add_argument('--tdx.watch_files', default=['tdx_block.yaml', 'pecat_block.yaml'], help='Files to watch.')
+    tdx_group.add_argument('--tdx.local_host', default='10.0.0.1:50051', help='Service host')
+
+    return config_parse(parser)
+
+
+config = config_argparse()

+ 15 - 0
config.yaml

@@ -0,0 +1,15 @@
+app_name: mpb_win
+redis_url: 192.168.3.3:6379
+zookeeper_url: 192.168.3.3:2181
+pulsar_url: pulsar://192.168.3.3:6650
+log_level: INFO
+rpc_server: tcp://10.0.0.2:9483
+
+trader:
+  account: guojin-gaojun
+
+tdx:
+  main_folder: C:/new_tdx
+  export_concept_file: T0002/export/概念板块.txt
+  block_files: tdx_block.yaml
+  local_host: 10.0.0.2:50051

+ 79 - 0
main.py

@@ -0,0 +1,79 @@
+
+import logging
+import sys
+import os
+import threading
+import pulsar
+
+from kazoo.client import KazooClient
+
+from chive.service_runner import ServiceRunner
+from config import config
+
+current = os.path.dirname(os.path.abspath(__file__))
+parent = os.path.dirname(current)
+if parent not in sys.path:
+    sys.path.append(parent)
+
+# logformat = '%(asctime)s %(pathname)s %(funcName)s%(lineno)d %(levelname)s: %(message)s'
+log_level = config['log_level']
+logging.basicConfig(
+    level=getattr(logging, log_level.upper())
+    , stream=sys.stdout
+    , format='%(asctime)s %(levelname)s: %(message)s')
+
+_logger = logging.getLogger(__name__)
+
+
+class TraderXApp(ServiceRunner):
+    def __init__(self, name, zk_client, pulsar_client, address):
+        super().__init__(name, zk_client, pulsar_client, address)
+        self.tick_listener = []
+
+    def start(self):
+        self.load_modules('traderx.services.service')
+        super().start()
+
+    def stop(self, block=True, timeout=15):
+        super().stop(block=block, timeout=timeout)
+
+
+if __name__ == '__main__':
+    threading.current_thread().testing = False
+
+    # log_server = config['log_server'] or "tcp://*:5555"
+    # start_log_server(server_address=log_server)
+    # client_id = config.pecat['name']
+    # configure_logging(client_id, server_address=log_server)
+
+    zk_client = KazooClient(hosts=config['zookeeper_url'])
+    zk_client.start()
+
+    _pulsar_logger = logging.getLogger('pulsar')
+    _pulsar_logger.setLevel(logging.WARNING)
+    pulsar_client = pulsar.Client(config['pulsar_url'], logger=_pulsar_logger)
+
+    server_address = config['rpc_server']
+    app = TraderXApp(f'TraderXApp-{config["app_name"]}', zk_client, pulsar_client, server_address)
+
+    def exit_app(*args):
+        _logger.info('User interrupt.')
+
+        try:
+            app.shutdown()
+            app.stop()
+        except Exception as e:
+            _logger.info(e)
+        finally:
+            _logger.info("Main thread done")
+            sys.exit(0)
+
+    import signal
+    signal.signal(signal.SIGINT, exit_app)
+
+    app.start()
+
+    import time
+    while True:
+        time.sleep(1)
+

+ 62 - 0
services/service.py

@@ -0,0 +1,62 @@
+import logging
+
+from pathlib import Path
+
+from chive.service import Service
+from chive.file_watcher import FileWatcher
+
+from traderx.config import config
+
+_logger = logging.getLogger(__name__)
+
+
+class TraderXService(Service):
+    _name = 'traderx.service'
+
+    def __init__(self, main_folder, pulsar_client):
+        super().__init__(pulsar_client=pulsar_client)
+        self.main_path = Path(main_folder)
+
+    def _watch_concept_export_file(self):
+        def hander(name, files):
+            file_name = files[0]
+            data = {}
+            with Path(file_name).open(encoding='gbk', errors='ignore') as file:
+                for line in file:
+                    try:
+                        concept_code, concept_name, stock_code, stock_name = line.strip().split(',')
+                        data.setdefault(concept_code, {'concept_name': concept_name, 'stocks': []})['stocks'].append([stock_code, stock_name])
+                    except Exception as e:
+                        ...
+
+            self.send_message(name, data, persistent=False, user='public', domain='default')
+
+        name = 'tdx_concepts_updated'
+        export_file = config['tdx_export_concept_file']
+        if Path(export_file).exists():
+            concept_file = Path(export_file)
+        else:
+            concept_file = self.main_path / export_file
+        if not concept_file.exists():
+            _logger.warning(f'Export concept file not found: {export_file}.')
+            return
+        concept_path = concept_file.parent
+
+        watcher_t0002_export = FileWatcher(hander, str(concept_path))
+        watcher_t0002_export.add_file_to_watch(name, str(concept_file))
+        watcher_t0002_export.start()
+
+    def start(self):
+        super().start()
+        self._watch_concept_export_file()
+
+
+service = None
+
+
+def start_service(runner):
+    global service
+
+    main_folder = config['tdx_main_folder']
+    service = TraderXService(main_folder, pulsar_client=runner.pulsar_client)
+    return service

+ 0 - 0
tdx/__init__.py


+ 259 - 0
tdx/tdx_local.py

@@ -0,0 +1,259 @@
+import os
+import struct
+import pandas as pd
+
+from datetime import datetime
+from pathlib import Path
+from struct import pack, unpack
+from collections import defaultdict
+
+from config import config
+from chive.service import MessageThread
+from chive.file_watcher import HandlerBase, FileWatcher, WatcherMessageHandler, WatcherZookeeperHandler
+
+
+BLOCK_FOLDER = r'T0002\blocknew'
+BLOCK_CFG_FILE = 'blocknew.cfg'
+
+CACHE_FOLDER = r"T0002\hq_cache"
+HY_CFG_FILE = r"tdxhy.cfg"
+SZM_FILE = r"szm.tnf"
+SHM_FILE = r"shm.tnf"
+
+
+class TDXConceptsHandler(HandlerBase):
+    _name = 'tdx.concepts.handler'
+
+    def dispatch(self, name, file_names):
+        for file_name in file_names:
+            file_path = Path(file_name)
+            content = file_path.read_text(encoding='gbk', errors='ignore')
+
+
+# 负责读取本地通达信数据
+class TDXClient(MessageThread):
+    def __init__(self, tdx_main_folder, pulsar_client=None, zk_client=None):
+        super().__init__(pulsar_client=pulsar_client, zk_client=zk_client)
+        self.main_path = Path(tdx_main_folder)
+        self.block_path = Path(self.resolve(BLOCK_FOLDER))
+        self.__base_all = None
+
+    def resolve(self, folder):
+        return str(self.main_path / folder)
+
+    def get_block_stocks(self, block_name=None):
+        now = datetime.now()
+        block_name = block_name or now.strftime('%y%m%d')
+        block_path = self.block_path / f'{block_name}.blk'
+        if block_path.exists():
+            block_stocks = [s for s in block_path.read_text().split('\n') if s.strip()]
+            return [f'{stock_code[1:]}.SZ' if stock_code[0] == '0' else f'{stock_code[1:]}.SH'
+                    for stock_code in block_stocks]
+
+    # region custom block
+    def get_cust_blocks(self):
+        blocks = {}
+
+        cfg_file = self.block_path / BLOCK_CFG_FILE
+        if not cfg_file.exists():
+            return blocks
+        buff = cfg_file.read_bytes()
+        count = int(len(buff) / 120)
+
+        buff_blocks = [buff[i * 120:i * 120 + 120] for i in range(count)]
+        for block in buff_blocks:
+            name = block[:50].decode('gbk', errors="ignore").replace('\x00', '')
+            display_name = block[50:].decode('gbk', errors="ignore").replace('\x00', '')
+            blocks[name] = display_name
+        return blocks
+
+    def save_cust_block(self, blocks: dict, replace=True):
+        file_blocks = self.get_cust_blocks()
+        for key, stocks in blocks.items():
+            name, display_name = key
+            block_path = (self.block_path / f'{name}.blk')
+            if not replace:
+                if block_path.exists():
+                    file_stocks = block_path.read_text().split('\n')
+                    stocks = file_stocks + [s for s in stocks if s and s not in file_stocks]
+            block_path.write_text('\n'.join(stocks))
+            file_blocks[name] = display_name
+
+        buff = b''
+        for block in file_blocks:
+            buff += pack(f'>{50}s', block.encode('gbk'))
+            buff += pack(f'>{70}s', file_blocks[block].encode('gbk'))
+
+        cfg_file = self.block_path / BLOCK_CFG_FILE
+        cfg_file.write_bytes(buff)
+    # endregion
+
+    # 获取行业数据
+    def _get_stock_industry(self):
+        tdxhy_path = self.resolve(HY_CFG_FILE)
+        names = "market symbol industry_code idontcare1 idontcare2 industry_second_code".split()
+        usecols = "market symbol industry_code industry_second_code".split()
+        tdxhy = pd.read_csv(
+            tdxhy_path,
+            sep="|",
+            names=names,
+            usecols=usecols,
+            dtype={"symbol": str},
+        )
+        return tdxhy
+
+    # 行业代码码值
+    def _get_industry(self):
+        incon_path = self.resolve(r"incon.dat")
+        incon = pd.read_csv(incon_path, encoding="gb2312", names=["index"])
+        fx0 = lambda x: x.split("|")[0] if "|" in x else ""
+        incon["industry_code"] = incon["index"].apply(fx0)
+        fx1 = lambda x: x.split("|")[1] if "|" in x else ""
+        incon["industry_name"] = incon["index"].apply(fx1)
+        usecols = ["industry_code", "industry_name"]
+        return incon[usecols]
+
+    # 股票代码对应拼音缩写
+    def _read_tnf(self, path):
+        market = path.split(".")[0][-3:]
+        with open(path, "rb") as f:
+            buff = f.read()
+
+        data = buff[50:]
+        l = len(data) // 314
+        fx = lambda x: str(x, encoding="gbk").strip("\x00")
+        sm = {"szm": ("00", "30"), "shm": ("60", "68")}
+
+        stocks = []
+        for x in [data[i * 314 : (i + 1) * 314] for i in range(l)]:
+            code = fx(x[:6])
+            if code.startswith(sm[market]):
+                name = fx(x[23:41])
+                shortcode = fx(x[285:293])
+
+                stocks += [[code, name, shortcode]]
+        return stocks
+
+    # 股票K线数据源文件
+    def _get_stock_names(self):
+        szm_path = self.resolve(SZM_FILE)
+        shm_path = self.resolve(SHM_FILE)
+
+        szm = self._read_tnf(szm_path)
+        shm = self._read_tnf(shm_path)
+
+        stocks = pd.DataFrame(szm + shm, columns=["symbol", "name", "shortcode"])
+        return stocks
+
+    # 整合基本数据
+    def _get_base_all(self):
+        stock_industry = self._get_stock_industry()
+        industry_name = self._get_industry()
+        stock_name = self._get_stock_names()
+
+        base = pd.merge(stock_name, stock_industry, how="left", on="symbol")
+        base = pd.merge(base, industry_name, how="left", on="industry_code")
+        base = pd.merge(
+            base,
+            industry_name,
+            how="left",
+            left_on="industry_second_code",
+            right_on="industry_code",
+        )
+
+        fx = lambda x: ".sh" if x else ".sz"
+        base["ts_code"] = base["symbol"] + base["market"].apply(fx)
+
+        base.rename(
+            columns={
+                "industry_name_x": "industry_name",
+                "industry_name_y": "industry_detail",
+                "industry_code_x": "industry_code",
+            },
+            inplace=True,
+        )
+        # base = base.drop(['industry_code_y'], axis=1)
+        usecols = "ts_code symbol name shortcode industry_name industry_detail".split()
+        return base[usecols]
+
+    def get_block_file(self, block='gn'):
+        file_name = f'T0002/hq_cache/block_{block}.dat'
+        file_path = self.resolve(file_name)
+        with open(file_path, 'rb') as f:
+            buff = f.read()
+
+        head = unpack('<384sh', buff[:386])
+        blk = buff[386:]
+        blocks = [blk[i * 2813:(i + 1) * 2813] for i in range(head[1])]
+        bk_list = []
+        for bk in blocks:
+            name = bk[:8].decode('gbk', 'ignore').strip('\x00')
+            num, t = unpack('<2h', bk[9:13])
+            stks = bk[13:(12 + 7 * num)].decode('gbk', 'ignore').split('\x00')
+            bk_list.append([name, block, num, stks])
+        # return pd.DataFrame(bk_list, columns=['name', 'tp', 'num', 'stocks'])
+        return bk_list
+
+    def _read_concepts(self, content):
+        concepts = {}
+        for line in content.split('\n'):
+            row = line.split(',')
+            if len(row) < 4:
+                continue
+            concept_code, concept_name, stock_code, stock_name = line.split(',')
+            concepts.setdefault(concept_code, {'name': concept_name, 'stocks': []})
+            concepts[concept_code]['stocks'].append([stock_code, stock_name])
+        return concepts
+
+    def get_export_concepts(self):
+        file_name = u'T0002/export/概念板块.txt'
+        file_path = Path(self.resolve(file_name))
+        if file_path.exists():
+            content = file_path.read_text(encoding='gbk', errors='ignore')
+            return self._read_concepts(content)
+
+    # 读取K线源文件
+    def _read_kline(self, filepath):
+        with open(filepath, "rb") as f:
+            usecols = "trade_date open high low close amount vol openinterest".split()
+            buffers = []
+            while True:
+                buffer = f.read(32)
+                if not buffer:
+                    break
+                buffer = struct.unpack("lllllfll", buffer)
+                buffers.append(buffer)
+            kline = pd.DataFrame(buffers, columns=usecols)
+
+        kline["trade_date"] = kline["trade_date"].astype(str)
+
+        price_columns = ["open", "high", "low", "close"]
+        kline[price_columns] = kline[price_columns].apply(lambda x: x / 100)
+        return kline
+
+    # 获取基本数据
+    def get_base_all(self):
+        if not self.__base_all:
+            self.__base_all = self._get_base_all()
+        return self.__base_all
+
+    # 获取日K线数据
+    def get_kline_daily(self, ts_code):
+        filename = ts_code.split(".")[1] + ts_code.split(".")[0] + ".day"
+        filepath = self.resolve('/'.join(["vipdoc", ts_code.split(".")[1], "lday", filename]))
+        kline = self._read_kline(filepath)
+        kline["ts_code"] = ts_code
+        kline.index = pd.to_datetime(kline["trade_date"])
+        kline.index.name = "index"
+        kline = kline.rename(columns={"vol": "volume"})
+        usecols = (
+            "ts_code trade_date open high low close amount volume openinterest".split()
+        )
+        return kline[usecols]
+
+    def start(self):
+        super().start()
+        if self.pulsar_client:
+            handlers = [WatcherMessageHandler(pulsar_client=self.pulsar_client)]
+            watcher = FileWatcher(self.main_path, handlers)
+            watcher.add_file_to_watch('tdx_concept_block', [HY_CFG_FILE, SZM_FILE, SHM_FILE])

+ 0 - 0
tests/__init__.py


+ 12 - 0
tests/conftest.py

@@ -0,0 +1,12 @@
+import pytest
+from tdx.tdx_local import TDXClient
+
+
+@pytest.fixture(scope='session')
+def tdx_main_folder():
+    return r'C:\new_tdx'
+
+
+@pytest.fixture(scope='session')
+def client(tdx_main_folder):
+    return TDXClient(tdx_main_folder)

+ 6 - 0
tests/test_config.py

@@ -0,0 +1,6 @@
+
+
+def test_config():
+    from config import config
+
+    assert config.tdx.main_folder

+ 146 - 0
tests/test_tdx.py

@@ -0,0 +1,146 @@
+
+from tdx.tdx_local import TDXClient
+
+
+def test_tdx(tdx_main_folder):
+    loader = TDXClient(tdx_main_folder)
+    # 获取所有股票基本数据
+    base_all = loader.get_base_all()
+    assert base_all
+
+    # 获取单股日K线数据
+    kline = loader.get_kline_daily("600645.sh")
+    assert kline
+
+
+def test_gn_export(client):
+    client.get_export_concepts()
+
+
+def test_gn(client):
+    data = client.get_block_file('gn')
+    [d for d in data if '雅' in d]
+    assert data
+
+
+def test_pytdx():
+    from pytdx.hq import TdxHq_API
+    api = TdxHq_API()
+    # 121.37.16.86:7615
+    # 43.136.49.71:7719
+    with api.connect(ip='43.136.49.71', port=7719):
+        # ret = api.get_block_info("block_zs.dat", 0, 100)
+        # print(len(ret))
+        # ret = api.get_and_parse_block_info("block_fg.dat")
+        # ret = api.get_and_parse_block_info("block_zs.dat")
+        ret = api.get_and_parse_block_info("block_gn.dat")
+        # ret = api.get_and_parse_block_info("block.dat")
+        # ret = api.get_and_parse_block_info("block.dat")
+        print(api.to_df(ret))
+
+
+def test_ts():
+    import tushare as ts
+
+    concepts = ts.get_concept_classified()
+    assert concepts
+
+
+def test_stock():
+    import struct
+    import pandas as pd
+
+    def get_stock_name_shm(mkt='sz'):
+        # 修正文件路径:通达信通常使用 .dat 文件
+        file_path = f'c:/new_tdx/T0002/hq_cache/{mkt}s.tnf'  # 常见的股票基本信息文件
+
+        try:
+            with open(file_path, 'rb') as f:
+                buff = f.read()
+
+            # 跳过文件头,从数据部分开始
+            data = buff[50:]
+            record_length = 314  # 每条记录的长度
+            num_records = len(data) // record_length
+
+            # 定义解码函数
+            def decode_string(byte_data):
+                return str(byte_data, encoding='gbk', errors='ignore').strip('\x00')
+
+            # 市场代码映射
+            market_codes = {'sz': ('00', '30'), 'sh': ('60', '68')}
+
+            stocks = []
+            for i in range(num_records):
+                start_idx = i * record_length
+                end_idx = (i + 1) * record_length
+                record = data[start_idx:end_idx]
+
+                # 解析各个字段
+                code = decode_string(record[:6])
+
+                # 只处理指定市场的股票
+                if 1 or code.startswith(market_codes[mkt]):
+                    name = decode_string(record[23:41])
+
+                    # 解析昨收价格
+                    try:
+                        lclose = round(struct.unpack('<f', record[276:280])[0], 2)
+                    except:
+                        lclose = 0.0
+
+                    attr = decode_string(record[285:293])
+
+                    # 正确添加到列表
+                    stocks.append([code, name, lclose, attr])
+
+            data_head = buff[:50]
+            return data_head, stocks
+
+        except FileNotFoundError:
+            print(f"文件未找到: {file_path}")
+            return None, []
+        except Exception as e:
+            print(f"读取文件时出错: {e}")
+            return None, []
+
+
+    # 获取深圳市场股票
+    header, sz_stocks = get_stock_name_shm('sz')
+
+    if sz_stocks:
+        # 转换为DataFrame方便查看
+        df = pd.DataFrame(sz_stocks, columns=['代码', '名称', '昨收', '属性'])
+        print(f"找到 {len(df)} 只深圳股票")
+        print(df.head())
+
+        # 保存到CSV
+        # df.to_csv('sz_stocks.csv', index=False, encoding='gbk')
+        print("数据已保存到 sz_stocks.csv")
+
+
+def test_mootdx():
+    from mootdx.quotes import Quotes
+
+    client = Quotes.factory(market='std', heartbeat=True)
+    block = client.block('block_gn.dat')
+    assert block
+
+
+def test_gn_block():
+    from struct import pack, unpack
+
+    # with open('c:\\new_tdx\\block_gn.dat', 'rb') as f:
+    with open(r'C:\new_tdx\T0002\hq_cache\block_gn.dat', 'rb') as f:
+        buff = f.read()
+
+    head = unpack('<384sh', buff[:386])
+    blk = buff[386:]
+    blocks = [blk[i * 2813:(i + 1) * 2813] for i in range(head[1])]
+    bk_list = []
+    for bk in blocks:
+        name = bk[:8].decode('gbk', 'ignore').strip('\x00')
+        num, t = unpack('<2h', bk[9:13])
+        stks = bk[13:(12 + 7 * num)].decode('gbk', 'ignore').split('\x00')
+        bk_list.append([name, num, stks])
+    assert bk_list