diff --git a/bitnet_tools/cli.py b/bitnet_tools/cli.py index 7756a5f..760f563 100644 --- a/bitnet_tools/cli.py +++ b/bitnet_tools/cli.py @@ -8,6 +8,7 @@ from .analysis import DataSummary, build_analysis_payload, build_analysis_payload_from_request, build_markdown_report from .doctor import collect_environment +from .document_extract import extract_document_tables, table_to_analysis_request from .multi_csv import analyze_multiple_csv, build_multi_csv_markdown, result_to_json from .visualize import create_multi_charts from .web import serve @@ -32,7 +33,7 @@ def _build_parser() -> argparse.ArgumentParser: analyze_parser = subparsers.add_parser( "analyze", help="Build analysis payload from a CSV file" ) - analyze_parser.add_argument("csv", type=Path, help="Input CSV path") + analyze_parser.add_argument("csv", type=Path, help="Input data path (csv/pdf/docx/pptx)") analyze_parser.add_argument("--question", required=True, help="Analysis question") analyze_parser.add_argument( "--model", @@ -45,6 +46,17 @@ def _build_parser() -> argparse.ArgumentParser: default=Path("analysis_payload.json"), help="Where to store generated payload JSON", ) + analyze_parser.add_argument( + "--table-index", + type=int, + default=0, + help="Document table index to use when input is pdf/docx/pptx", + ) + analyze_parser.add_argument( + "--list-tables", + action="store_true", + help="List extracted document tables and exit", + ) ui_parser = subparsers.add_parser("ui", help="Run local web UI") ui_parser.add_argument("--host", default="127.0.0.1", help="Bind host") @@ -150,12 +162,23 @@ def main(argv: list[str] | None = None) -> int: return 0 if args.command == "analyze": - request_payload = { - "input_type": "csv", - "source_name": args.csv.name, - "normalized_csv_text": args.csv.read_text(encoding="utf-8"), - "meta": {"csv_path": str(args.csv)}, - } + suffix = args.csv.suffix.lower() + if suffix in {".pdf", ".docx", ".pptx"}: + extract_result = extract_document_tables(args.csv) + if args.list_tables: + print(json.dumps(extract_result.to_dict(), ensure_ascii=False, indent=2)) + return 0 + if not extract_result.tables: + raise ValueError(extract_result.failure_detail or extract_result.failure_reason or "표 추출 실패") + request_payload = table_to_analysis_request(extract_result, args.table_index) + request_payload["meta"] = {**request_payload.get("meta", {}), "document_path": str(args.csv)} + else: + request_payload = { + "input_type": "csv", + "source_name": args.csv.name, + "normalized_csv_text": args.csv.read_text(encoding="utf-8"), + "meta": {"csv_path": str(args.csv)}, + } payload = build_analysis_payload_from_request(request_payload, args.question, csv_path_override=str(args.csv)) args.out.write_text( json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8" diff --git a/bitnet_tools/document_extract.py b/bitnet_tools/document_extract.py new file mode 100644 index 0000000..a3b7e2f --- /dev/null +++ b/bitnet_tools/document_extract.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import base64 +import csv +import io +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any +import xml.etree.ElementTree as ET +import zipfile + + +SUPPORTED_DOCUMENT_EXTENSIONS = {".pdf", ".docx", ".pptx"} + + +@dataclass +class ExtractedTable: + table_id: str + source: str + rows: list[list[str]] + header_inferred: bool + missing_ratio: float + confidence: float + + @property + def row_count(self) -> int: + return len(self.rows) + + @property + def column_count(self) -> int: + return max((len(r) for r in self.rows), default=0) + + def to_csv(self) -> str: + if not self.rows: + return "" + max_len = self.column_count + output = io.StringIO() + writer = csv.writer(output) + for row in self.rows: + padded = row + [""] * (max_len - len(row)) + writer.writerow(padded) + return output.getvalue() + + def to_dict(self) -> dict[str, Any]: + return { + "table_id": self.table_id, + "source": self.source, + "row_count": self.row_count, + "column_count": self.column_count, + "header_inferred": self.header_inferred, + "missing_ratio": round(self.missing_ratio, 4), + "confidence": round(self.confidence, 4), + "preview": self.rows[:5], + } + + +@dataclass +class DocumentExtractResult: + input_type: str + source_name: str + tables: list[ExtractedTable] + failure_reason: str | None = None + failure_detail: str | None = None + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "input_type": self.input_type, + "source_name": self.source_name, + "tables": [t.to_dict() for t in self.tables], + } + if self.failure_reason: + payload["failure_reason"] = self.failure_reason + payload["failure_detail"] = self.failure_detail or self.failure_reason + return payload + + +def extract_document_tables_from_base64(file_base64: str, source_name: str) -> DocumentExtractResult: + try: + raw = base64.b64decode(file_base64) + except Exception as exc: + raise ValueError(f"invalid document base64: {exc}") from exc + return extract_document_tables_from_bytes(raw, source_name) + + +def extract_document_tables(path: str | Path) -> DocumentExtractResult: + file_path = Path(path) + return extract_document_tables_from_bytes(file_path.read_bytes(), file_path.name) + + +def extract_document_tables_from_bytes(raw: bytes, source_name: str) -> DocumentExtractResult: + ext = Path(source_name).suffix.lower() + if ext not in SUPPORTED_DOCUMENT_EXTENSIONS: + raise ValueError(f"unsupported document format: {ext or ''}") + + if ext == ".docx": + tables = _extract_docx_tables(raw) + elif ext == ".pptx": + tables = _extract_pptx_tables(raw) + else: + tables_or_failure = _extract_pdf_tables(raw) + if isinstance(tables_or_failure, tuple): + reason, detail = tables_or_failure + return DocumentExtractResult( + input_type="document", + source_name=source_name, + tables=[], + failure_reason=reason, + failure_detail=detail, + ) + tables = tables_or_failure + + if not tables: + return DocumentExtractResult( + input_type="document", + source_name=source_name, + tables=[], + failure_reason="표 없음", + failure_detail="문서에서 테이블 구조를 찾지 못했습니다.", + ) + return DocumentExtractResult(input_type="document", source_name=source_name, tables=tables) + + +def table_to_analysis_request(result: DocumentExtractResult, table_index: int) -> dict[str, Any]: + if not result.tables: + raise ValueError(result.failure_detail or result.failure_reason or "표 없음") + if table_index < 0 or table_index >= len(result.tables): + raise ValueError(f"table_index out of range: {table_index}") + table = result.tables[table_index] + return { + "input_type": "document", + "source_name": result.source_name, + "normalized_csv_text": table.to_csv(), + "meta": { + "table_id": table.table_id, + "table_index": table_index, + "row_count": table.row_count, + "column_count": table.column_count, + "header_inferred": table.header_inferred, + "missing_ratio": table.missing_ratio, + "confidence": table.confidence, + }, + } + + +def _extract_docx_tables(raw: bytes) -> list[ExtractedTable]: + ns = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"} + with zipfile.ZipFile(io.BytesIO(raw)) as zf: + if "word/document.xml" not in zf.namelist(): + return [] + root = ET.fromstring(zf.read("word/document.xml")) + + tables: list[ExtractedTable] = [] + for ti, tbl in enumerate(root.findall(".//w:tbl", ns), start=1): + rows: list[list[str]] = [] + for tr in tbl.findall("w:tr", ns): + row: list[str] = [] + for tc in tr.findall("w:tc", ns): + text = "".join((t.text or "") for t in tc.findall(".//w:t", ns)).strip() + row.append(text) + if any(c.strip() for c in row): + rows.append(row) + normalized = _normalize_rows(rows) + if normalized: + tables.append(_build_table(f"docx_table_{ti}", "docx", normalized)) + return tables + + +def _extract_pptx_tables(raw: bytes) -> list[ExtractedTable]: + with zipfile.ZipFile(io.BytesIO(raw)) as zf: + slide_paths = sorted( + p for p in zf.namelist() if p.startswith("ppt/slides/slide") and p.endswith(".xml") + ) + tables: list[ExtractedTable] = [] + for slide_idx, slide_path in enumerate(slide_paths, start=1): + root = ET.fromstring(zf.read(slide_path)) + tbl_nodes = root.findall('.//{http://schemas.openxmlformats.org/drawingml/2006/main}tbl') + for tbl_idx, tbl in enumerate(tbl_nodes, start=1): + rows: list[list[str]] = [] + for tr in tbl.findall('{http://schemas.openxmlformats.org/drawingml/2006/main}tr'): + row: list[str] = [] + for tc in tr.findall('{http://schemas.openxmlformats.org/drawingml/2006/main}tc'): + text = ''.join((t.text or '') for t in tc.iter('{http://schemas.openxmlformats.org/drawingml/2006/main}t')).strip() + row.append(text) + if any(c.strip() for c in row): + rows.append(row) + normalized = _normalize_rows(rows) + if normalized: + table_id = f"pptx_s{slide_idx}_t{tbl_idx}" + tables.append(_build_table(table_id, "pptx", normalized)) + return tables + + +def _extract_pdf_tables(raw: bytes) -> list[ExtractedTable] | tuple[str, str]: + if b"/Encrypt" in raw: + return ("암호화", "암호화된 PDF는 텍스트 추출이 제한됩니다.") + + if b"/Subtype /Image" in raw and b"BT" not in raw: + return ("스캔 이미지", "스캔 이미지 기반 PDF로 감지되어 OCR 없이는 표 추출이 어렵습니다.") + + text = raw.decode("latin-1", errors="ignore") + lines = [ln.strip() for ln in text.splitlines() if ln.strip()] + candidates: list[list[str]] = [] + for line in lines: + if "|" in line: + parts = [p.strip() for p in line.split("|")] + elif "\t" in line: + parts = [p.strip() for p in line.split("\t")] + elif line.count(",") >= 2: + parts = [p.strip() for p in line.split(",")] + else: + continue + if len(parts) >= 2: + candidates.append(parts) + + if not candidates: + return ("표 없음", "PDF에서 테이블 형태 텍스트를 찾지 못했습니다.") + + normalized = _normalize_rows(candidates) + if not normalized: + return ("표 없음", "PDF 테이블 후보를 정규화하지 못했습니다.") + return [_build_table("pdf_table_1", "pdf", normalized)] + + +def _normalize_rows(rows: list[list[str]]) -> list[list[str]]: + if not rows: + return [] + width = max(len(r) for r in rows) + if width == 0: + return [] + normalized = [r + [""] * (width - len(r)) for r in rows] + if not any(any(c.strip() for c in r) for r in normalized): + return [] + return normalized + + +def _estimate_header(row: list[str]) -> bool: + filled = [c for c in row if c.strip()] + if not filled: + return False + numeric_like = 0 + for cell in filled: + v = cell.strip().replace(",", "") + if re.fullmatch(r"[-+]?\d+(\.\d+)?", v): + numeric_like += 1 + unique_ratio = len(set(filled)) / len(filled) + numeric_ratio = numeric_like / len(filled) + return unique_ratio >= 0.8 and numeric_ratio <= 0.4 + + +def _calc_missing_ratio(rows: list[list[str]]) -> float: + if not rows: + return 1.0 + total = len(rows) * max(len(r) for r in rows) + if total == 0: + return 1.0 + missing = sum(1 for row in rows for cell in row if not str(cell).strip()) + return missing / total + + +def _calc_confidence(row_count: int, col_count: int, header_inferred: bool, missing_ratio: float) -> float: + row_factor = min(row_count / 8.0, 1.0) + col_factor = min(col_count / 6.0, 1.0) + header_bonus = 1.0 if header_inferred else 0.55 + missing_penalty = max(0.0, 1.0 - min(missing_ratio, 1.0)) + score = (0.3 * row_factor) + (0.25 * col_factor) + (0.25 * header_bonus) + (0.2 * missing_penalty) + return max(0.0, min(score, 1.0)) + + +def _build_table(table_id: str, source: str, rows: list[list[str]]) -> ExtractedTable: + header_inferred = _estimate_header(rows[0]) if rows else False + missing_ratio = _calc_missing_ratio(rows) + confidence = _calc_confidence(len(rows), max((len(r) for r in rows), default=0), header_inferred, missing_ratio) + return ExtractedTable( + table_id=table_id, + source=source, + rows=rows, + header_inferred=header_inferred, + missing_ratio=missing_ratio, + confidence=confidence, + ) diff --git a/bitnet_tools/ui/app.js b/bitnet_tools/ui/app.js index ac34e6d..159c0dd 100644 --- a/bitnet_tools/ui/app.js +++ b/bitnet_tools/ui/app.js @@ -76,6 +76,7 @@ function getInputTypeForFile(file) { if (selected !== 'auto') return selected; const name = String(file?.name || '').toLowerCase(); if (name.endsWith('.xlsx') || name.endsWith('.xls')) return 'excel'; + if (name.endsWith('.pdf') || name.endsWith('.docx') || name.endsWith('.pptx')) return 'document'; return 'csv'; } @@ -92,21 +93,41 @@ async function readFileAsBase64(file) { async function fetchSheetsForFile(file) { const inputType = getInputTypeForFile(file); - if (inputType !== 'excel') { - appState.detectedInputType = 'csv'; - if (UI.sheetSelect) UI.sheetSelect.innerHTML = ''; + const fileBase64 = await readFileAsBase64(file); + + if (inputType === 'excel') { + const res = await postJson('/api/sheets', { + input_type: 'excel', + source_name: file.name, + file_base64: fileBase64, + }, 'Excel 시트 목록 조회'); + appState.detectedInputType = 'excel'; + const names = Array.isArray(res.sheet_names) ? res.sheet_names : []; + const opts = ['', ...names.map((n) => ``)].join(''); + if (UI.sheetSelect) UI.sheetSelect.innerHTML = opts; return; } - const fileBase64 = await readFileAsBase64(file); - const res = await postJson('/api/sheets', { - input_type: 'excel', - source_name: file.name, - file_base64: fileBase64, - }, 'Excel 시트 목록 조회'); - appState.detectedInputType = 'excel'; - const names = Array.isArray(res.sheet_names) ? res.sheet_names : []; - const opts = ['', ...names.map((n) => ``)].join(''); - if (UI.sheetSelect) UI.sheetSelect.innerHTML = opts; + + if (inputType === 'document') { + const res = await postJson('/api/document/extract', { + input_type: 'document', + source_name: file.name, + file_base64: fileBase64, + }, '문서 표 추출'); + appState.detectedInputType = 'document'; + const tables = Array.isArray(res.tables) ? res.tables : []; + const opts = tables.length + ? tables.map((tb, idx) => ``).join('') + : ''; + if (UI.sheetSelect) UI.sheetSelect.innerHTML = opts; + if (!tables.length && res.failure_detail) { + showError('문서 표 추출 실패', res.failure_detail); + } + return; + } + + appState.detectedInputType = 'csv'; + if (UI.sheetSelect) UI.sheetSelect.innerHTML = ''; } async function buildAnalyzeRequest() { @@ -134,6 +155,17 @@ async function buildAnalyzeRequest() { }; } + if (inputType === 'document') { + const base64 = await readFileAsBase64(file); + return { + input_type: 'document', + source_name: file.name, + file_base64: base64, + table_index: Number(UI.sheetSelect?.value || 0), + question, + }; + } + return { input_type: 'csv', source_name: file.name, @@ -153,6 +185,13 @@ async function buildMultiPayloadFiles(files) { file_base64: await readFileAsBase64(f), sheet_name: UI.sheetSelect?.value || '', }); + } else if (inputType === 'document') { + payloadFiles.push({ + name: f.name, + input_type: 'document', + file_base64: await readFileAsBase64(f), + table_index: Number(UI.sheetSelect?.value || 0), + }); } else { payloadFiles.push({ name: f.name, diff --git a/bitnet_tools/ui/index.html b/bitnet_tools/ui/index.html index e9d9b3c..643c13b 100644 --- a/bitnet_tools/ui/index.html +++ b/bitnet_tools/ui/index.html @@ -28,16 +28,17 @@

2) 입력

+ - - + +
- +
@@ -97,7 +98,7 @@

BitNet 응답

고급: 멀티 CSV/Excel 분석

- +
diff --git a/bitnet_tools/web.py b/bitnet_tools/web.py index 9f7c5f0..d97acb9 100644 --- a/bitnet_tools/web.py +++ b/bitnet_tools/web.py @@ -20,6 +20,7 @@ from urllib.parse import urlparse from .analysis import build_analysis_payload_from_request +from .document_extract import extract_document_tables_from_base64, table_to_analysis_request from .multi_csv import analyze_multiple_csv from .visualize import create_multi_charts @@ -40,6 +41,17 @@ def _coerce_csv_text_from_file_payload(file_payload: dict[str, Any]) -> tuple[st source_name = str(file_payload.get('name', '')) meta: dict[str, Any] = {'source_name': source_name, 'input_type': input_type} + if input_type == 'document': + raw_b64 = str(file_payload.get('file_base64', '')).strip() + if not raw_b64: + raise ValueError('document file_base64 is required') + extract_result = extract_document_tables_from_base64(raw_b64, source_name) + table_index = int(file_payload.get('table_index', 0) or 0) + request_payload = table_to_analysis_request(extract_result, table_index) + normalized_text = str(request_payload.get('normalized_csv_text', '')) + meta.update(request_payload.get('meta', {})) + return source_name, normalized_text, meta + if input_type == 'excel': raw_b64 = str(file_payload.get('file_base64', '')).strip() if not raw_b64: @@ -321,6 +333,17 @@ def do_POST(self) -> None: sheet_names = _extract_sheet_names(file_base64) return self._send_json({'sheet_names': sheet_names}) + if route == '/api/document/extract': + input_type = str(payload.get('input_type', 'document') or 'document').strip().lower() + if input_type != 'document': + return self._send_json(self._error_payload('input_type must be document', input_type=input_type, preprocessing_stage='input_validation'), HTTPStatus.BAD_REQUEST) + file_base64 = str(payload.get('file_base64', '')).strip() + source_name = str(payload.get('source_name', 'document') or 'document') + if not file_base64: + return self._send_json(self._error_payload('document file is required', 'file_base64 is empty', input_type='document', preprocessing_stage='input_validation'), HTTPStatus.BAD_REQUEST) + result = extract_document_tables_from_base64(file_base64, source_name) + return self._send_json(result.to_dict()) + if route == "/api/analyze": question = str(payload.get("question", "")).strip() if not question: @@ -336,6 +359,25 @@ def do_POST(self) -> None: str(payload.get("sheet_name", "") or "").strip() or None, ) meta = {**meta, "sheet_name": str(payload.get("sheet_name", "") or "").strip() or ""} + elif input_type == "document": + extract_result = extract_document_tables_from_base64( + str(payload.get("file_base64", "") or ""), + source_name, + ) + if not extract_result.tables: + return self._send_json( + self._error_payload( + "document table extraction failed", + extract_result.failure_detail or extract_result.failure_reason or "표 추출 실패", + input_type="document", + preprocessing_stage="table_extraction", + ), + HTTPStatus.BAD_REQUEST, + ) + selected_index = int(payload.get("table_index", 0) or 0) + request_payload_for_table = table_to_analysis_request(extract_result, selected_index) + normalized_csv_text = request_payload_for_table["normalized_csv_text"] + meta = {**meta, **request_payload_for_table.get("meta", {})} request_payload = { "input_type": input_type, diff --git a/tests/test_cli.py b/tests/test_cli.py index 00027c6..dc0d11b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,6 @@ from pathlib import Path +import io +import zipfile from bitnet_tools import cli @@ -235,3 +237,39 @@ def fake_analyze(csv, question, group_column=None, target_column=None, use_cache assert code == 0 assert called["max_workers"] == 3 + + +def _write_docx(path: Path) -> None: + xml = """ + +c1c2 +a1 +""" + mem = io.BytesIO() + with zipfile.ZipFile(mem, 'w') as zf: + zf.writestr('word/document.xml', xml) + path.write_bytes(mem.getvalue()) + + +def test_cli_analyze_document_list_tables(tmp_path, capsys): + doc_path = tmp_path / 'sample.docx' + _write_docx(doc_path) + + code = cli.main(['analyze', str(doc_path), '--question', '요약', '--list-tables']) + + assert code == 0 + out = capsys.readouterr().out + assert 'docx_table_1' in out + + +def test_cli_analyze_document_to_payload(tmp_path): + doc_path = tmp_path / 'sample.docx' + out_path = tmp_path / 'out.json' + _write_docx(doc_path) + + code = cli.main(['analyze', str(doc_path), '--question', '요약', '--out', str(out_path)]) + + assert code == 0 + body = out_path.read_text(encoding='utf-8') + assert '"input_type": "document"' in body + assert '"table_id": "docx_table_1"' in body diff --git a/tests/test_document_extract.py b/tests/test_document_extract.py new file mode 100644 index 0000000..030036f --- /dev/null +++ b/tests/test_document_extract.py @@ -0,0 +1,55 @@ +import base64 +import io +import zipfile + +from bitnet_tools.document_extract import extract_document_tables_from_base64, table_to_analysis_request + + +def _make_docx_with_table() -> bytes: + document_xml = ''' + +namescore +A10 +''' + mem = io.BytesIO() + with zipfile.ZipFile(mem, 'w') as zf: + zf.writestr('word/document.xml', document_xml) + return mem.getvalue() + + +def test_extract_docx_tables_and_request_payload(): + raw = _make_docx_with_table() + b64 = base64.b64encode(raw).decode('ascii') + + result = extract_document_tables_from_base64(b64, 'sample.docx') + + assert len(result.tables) == 1 + table = result.tables[0] + assert table.row_count == 2 + assert table.column_count == 2 + assert 0.0 <= table.confidence <= 1.0 + + request = table_to_analysis_request(result, 0) + assert request['input_type'] == 'document' + assert 'name,score' in request['normalized_csv_text'] + assert request['meta']['table_id'] == 'docx_table_1' + + +def test_extract_pdf_failure_reason_encrypted(): + fake_pdf = b'%PDF-1.4\n1 0 obj\n<< /Encrypt 2 0 R >>\nendobj\n' + b64 = base64.b64encode(fake_pdf).decode('ascii') + + result = extract_document_tables_from_base64(b64, 'locked.pdf') + + assert result.tables == [] + assert result.failure_reason == '암호화' + + +def test_extract_pdf_failure_reason_scan_image(): + fake_pdf = b'%PDF-1.4\n<< /Subtype /Image >>\n' + b64 = base64.b64encode(fake_pdf).decode('ascii') + + result = extract_document_tables_from_base64(b64, 'scan.pdf') + + assert result.tables == [] + assert result.failure_reason == '스캔 이미지' diff --git a/tests/test_web.py b/tests/test_web.py index 6f4d860..6eb1f45 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -1,6 +1,10 @@ import time from pathlib import Path +import base64 +import io +import zipfile + import bitnet_tools.web as web @@ -34,3 +38,29 @@ def fake_create_multi_charts(csv_paths, out_dir): def test_get_chart_job_not_found(): result = web.get_chart_job("missing") assert result["status"] == "not_found" + + +def _make_docx_b64() -> str: + xml = """ + +h1h2 +v1v2 +""" + mem = io.BytesIO() + with zipfile.ZipFile(mem, 'w') as zf: + zf.writestr('word/document.xml', xml) + return base64.b64encode(mem.getvalue()).decode('ascii') + + +def test_coerce_document_payload_to_csv_text(): + b64 = _make_docx_b64() + source, csv_text, meta = web._coerce_csv_text_from_file_payload({ + 'input_type': 'document', + 'name': 'sample.docx', + 'file_base64': b64, + 'table_index': 0, + }) + + assert source == 'sample.docx' + assert 'h1,h2' in csv_text + assert meta['table_id'] == 'docx_table_1'