extract_pdf_segments.py 4.47 KB
#!/usr/bin/env python3
"""
将培训文档 PDF 提取为结构化段落块,供模型蒸馏 doc_atoms 使用。

输出:
  build/<app_version>/doc_segments.jsonl

用法:
  python3 scripts/extract_pdf_segments.py
  python3 scripts/extract_pdf_segments.py 4.57.3
"""

from __future__ import annotations

import json
import re
import sys
from pathlib import Path


BASE_DIR = Path(__file__).parent.parent
PDF_DIR = BASE_DIR / "pdf"
BUILD_DIR = BASE_DIR / "build"


def extract_version_from_filename(name: str) -> str | None:
    match = re.search(r"(\d+\.\d+(?:\.\d+)?)", name)
    return f"v{match.group(1)}" if match else None


def clean_pdf_line(line: str) -> str:
    line = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", line)
    return line.strip()


def extract_text_from_pdf(pdf_path: Path) -> list[dict]:
    def read_pages(reader) -> list[dict]:
        pages = []
        total = len(reader.pages)
        for index, page in enumerate(reader.pages, start=1):
            raw = page.extract_text() or ""
            lines = [clean_pdf_line(line) for line in raw.split("\n")]
            lines = [line for line in lines if line]
            pages.append(
                {
                    "page": index,
                    "page_label": f"{index}/{total}",
                    "lines": lines,
                }
            )
        return pages

    try:
        from pypdf import PdfReader  # type: ignore
        return read_pages(PdfReader(str(pdf_path)))
    except ImportError:
        pass

    try:
        from PyPDF2 import PdfReader  # type: ignore
        return read_pages(PdfReader(str(pdf_path)))
    except ImportError:
        pass

    raise RuntimeError("未找到 PDF 解析库,请先安装 pypdf")


def is_title_line(line: str) -> bool:
    stripped = line.strip()
    if re.match(r"^\|.{2,40}\|?$", stripped):
        return True
    if re.match(r"^|.{2,40}|?$", stripped):
        return True
    return False


def normalize_title(line: str) -> str:
    return line.strip().strip("||").strip()


def chunk_page_lines(page: dict) -> list[dict]:
    blocks = []
    current_title = ""
    buffer: list[str] = []

    def flush() -> None:
        nonlocal buffer
        if not buffer:
            return
        text = " ".join(buffer).strip()
        if text:
            blocks.append(
                {
                    "page": page["page"],
                    "title": current_title or "未识别标题",
                    "text": text,
                    "line_count": len(buffer),
                }
            )
        buffer = []

    for line in page["lines"]:
        if is_title_line(line):
            flush()
            current_title = normalize_title(line)
            continue
        if re.match(r"^\d+[..、)]\s*", line) and buffer:
            flush()
        buffer.append(line)
    flush()
    return blocks


def process_pdf(pdf_path: Path) -> tuple[str, int]:
    app_version = extract_version_from_filename(pdf_path.name)
    if not app_version:
        return "", 0
    pages = extract_text_from_pdf(pdf_path)
    segments = []
    for page in pages:
        for index, block in enumerate(chunk_page_lines(page), start=1):
            segments.append(
                {
                    "candidate_type": "doc_segment",
                    "app_version": app_version,
                    "source_file": str(pdf_path.relative_to(BASE_DIR)),
                    "page": block["page"],
                    "segment_index": index,
                    "title": block["title"],
                    "text": block["text"][:4000],
                    "line_count": block["line_count"],
                }
            )

    out_dir = BUILD_DIR / app_version
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "doc_segments.jsonl"
    with out_path.open("w", encoding="utf-8") as handle:
        for row in segments:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")
    return app_version, len(segments)


def main() -> None:
    version_filter = next((arg for arg in sys.argv[1:] if re.match(r"\d+\.\d+", arg)), None)
    pdf_files = sorted(PDF_DIR.glob("*.pdf"))
    if version_filter:
        pdf_files = [path for path in pdf_files if version_filter in path.name]

    total = 0
    for pdf_file in pdf_files:
        version, count = process_pdf(pdf_file)
        if not version:
            continue
        total += count
        print(f"{version} segments={count} file={pdf_file.name}")
    print(f"total={total}")


if __name__ == "__main__":
    main()