327 lines
8.8 KiB
Python
327 lines
8.8 KiB
Python
#!/usr/bin/env python
|
||
# coding=utf-8
|
||
|
||
import os
|
||
import time
|
||
import requests
|
||
import zipfile
|
||
import json
|
||
from config import (
|
||
MINERU_API_URL,
|
||
MINERU_TOKEN,
|
||
OSS_ACCESS_KEY_ID,
|
||
OSS_ACCESS_KEY_SECRET,
|
||
OSS_BUCKET_NAME,
|
||
OSS_ENDPOINT,
|
||
TEMP_DIR,
|
||
)
|
||
import sys
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from shared.oss_upload import upload_file_to_oss
|
||
|
||
|
||
def create_parse_task(url, model_version="vlm"):
|
||
"""
|
||
创建解析任务
|
||
:param url: 文件URL
|
||
:param model_version: 模型版本,默认为vlm
|
||
:return: 任务ID
|
||
"""
|
||
print(f"开始创建解析任务: {url}")
|
||
|
||
api_url = f"{MINERU_API_URL}/extract/task"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {MINERU_TOKEN}",
|
||
}
|
||
data = {"url": url, "model_version": model_version}
|
||
|
||
try:
|
||
response = requests.post(api_url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if result.get("code") == 0:
|
||
task_id = result.get("data", {}).get("task_id")
|
||
print(f"任务创建成功: {task_id}")
|
||
return task_id
|
||
else:
|
||
print(f"任务创建失败: {result.get('msg')}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"创建任务失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def get_task_status(task_id):
|
||
"""
|
||
获取任务状态
|
||
:param task_id: 任务ID
|
||
:return: 任务状态
|
||
"""
|
||
print(f"查询任务状态: {task_id}")
|
||
|
||
api_url = f"{MINERU_API_URL}/extract/task/{task_id}"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {MINERU_TOKEN}",
|
||
}
|
||
|
||
try:
|
||
response = requests.get(api_url, headers=headers)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if result.get("code") == 0:
|
||
status = result.get("data", {})
|
||
print(f"任务状态: {status.get('state')}")
|
||
return status
|
||
else:
|
||
print(f"查询状态失败: {result.get('msg')}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"查询状态失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def poll_task_status(task_id, max_retries=60, interval=5):
|
||
"""
|
||
轮询任务状态
|
||
:param task_id: 任务ID
|
||
:param max_retries: 最大重试次数
|
||
:param interval: 轮询间隔(秒)
|
||
:return: 任务完成状态
|
||
"""
|
||
print(f"开始轮询任务状态: {task_id}")
|
||
|
||
for i in range(max_retries):
|
||
status = get_task_status(task_id)
|
||
if status:
|
||
state = status.get("state")
|
||
if state == "done":
|
||
print("任务完成!")
|
||
return status
|
||
elif state == "failed":
|
||
print(f"任务失败: {status.get('err_msg')}")
|
||
return None
|
||
elif state in ["pending", "running", "converting"]:
|
||
print(f"任务正在进行中 ({state}),{interval}秒后重试...")
|
||
time.sleep(interval)
|
||
else:
|
||
print(f"未知状态: {state}")
|
||
time.sleep(interval)
|
||
else:
|
||
print(f"获取状态失败,{interval}秒后重试...")
|
||
time.sleep(interval)
|
||
|
||
print("轮询超时,任务可能仍在处理中")
|
||
return None
|
||
|
||
|
||
def download_and_extract_result(zip_url, output_dir):
|
||
"""
|
||
下载并提取解析结果
|
||
:param zip_url: 结果压缩包URL
|
||
:param output_dir: 输出目录
|
||
:return: 提取的文件列表
|
||
"""
|
||
print(f"开始下载解析结果: {zip_url}")
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 下载压缩包
|
||
zip_path = os.path.join(output_dir, "result.zip")
|
||
try:
|
||
response = requests.get(zip_url, stream=True)
|
||
response.raise_for_status()
|
||
|
||
with open(zip_path, "wb") as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
print(f"压缩包下载成功: {zip_path}")
|
||
|
||
# 提取压缩包
|
||
extracted_files = []
|
||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||
zip_ref.extractall(output_dir)
|
||
print(f"压缩包提取成功: {output_dir}")
|
||
extracted_files = zip_ref.namelist()
|
||
|
||
# 删除压缩包
|
||
os.remove(zip_path)
|
||
print(f"删除临时压缩包: {zip_path}")
|
||
|
||
return extracted_files
|
||
|
||
except Exception as e:
|
||
print(f"下载或提取失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def get_markdown_result(extracted_files, output_dir):
|
||
"""
|
||
获取Markdown格式的解析结果
|
||
:param extracted_files: 提取的文件列表
|
||
:param output_dir: 输出目录
|
||
:return: Markdown内容
|
||
"""
|
||
print("查找Markdown格式的解析结果")
|
||
|
||
for file_name in extracted_files:
|
||
if file_name.endswith(".md"):
|
||
md_path = os.path.join(output_dir, file_name)
|
||
print(f"找到Markdown文件: {md_path}")
|
||
|
||
try:
|
||
with open(md_path, "r", encoding="utf-8") as f:
|
||
md_content = f.read()
|
||
print(f"Markdown文件读取成功,长度: {len(md_content)} 字符")
|
||
return md_content
|
||
except Exception as e:
|
||
print(f"读取Markdown文件失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
print("未找到Markdown格式的解析结果")
|
||
return None
|
||
|
||
|
||
def parse_local_file(file_path, model_version="vlm"):
|
||
"""
|
||
解析本地文件
|
||
:param file_path: 本地文件路径
|
||
:param model_version: 模型版本,默认为vlm
|
||
:return: Markdown内容
|
||
"""
|
||
print(f"开始解析本地文件: {file_path}")
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(file_path):
|
||
print(f"文件不存在: {file_path}")
|
||
return None
|
||
|
||
# 检查文件大小
|
||
file_size = os.path.getsize(file_path)
|
||
if file_size > 200 * 1024 * 1024: # 200MB
|
||
print(f"文件大小超出限制: {file_size} bytes (最大200MB)")
|
||
return None
|
||
|
||
# 生成对象名称
|
||
timestamp = int(time.time())
|
||
file_name = os.path.basename(file_path)
|
||
object_name = f"mineru/{timestamp}_{file_name}"
|
||
|
||
# 上传文件到OSS
|
||
oss_url = upload_file_to_oss(
|
||
file_path,
|
||
OSS_BUCKET_NAME,
|
||
object_name,
|
||
OSS_ACCESS_KEY_ID,
|
||
OSS_ACCESS_KEY_SECRET,
|
||
OSS_ENDPOINT,
|
||
)
|
||
|
||
if not oss_url:
|
||
print("文件上传失败,无法继续解析")
|
||
return None
|
||
|
||
# 创建解析任务
|
||
task_id = create_parse_task(oss_url, model_version)
|
||
|
||
if not task_id:
|
||
print("任务创建失败,无法继续解析")
|
||
return None
|
||
|
||
# 轮询任务状态
|
||
task_status = poll_task_status(task_id)
|
||
|
||
if not task_status:
|
||
print("任务执行失败,无法获取解析结果")
|
||
return None
|
||
|
||
# 获取结果URL
|
||
zip_url = task_status.get("full_zip_url")
|
||
if not zip_url:
|
||
print("未找到解析结果URL")
|
||
return None
|
||
|
||
# 生成输出目录
|
||
output_dir = os.path.join(TEMP_DIR, task_id)
|
||
|
||
# 下载并提取结果
|
||
extracted_files = download_and_extract_result(zip_url, output_dir)
|
||
|
||
if not extracted_files:
|
||
print("下载或提取结果失败")
|
||
return None
|
||
|
||
# 获取Markdown结果
|
||
md_content = get_markdown_result(extracted_files, output_dir)
|
||
|
||
if not md_content:
|
||
print("未找到Markdown格式的解析结果")
|
||
return None
|
||
|
||
print("文件解析完成,成功获取Markdown格式的结果")
|
||
return {"content": md_content, "output_dir": output_dir}
|
||
|
||
|
||
def main():
|
||
"""
|
||
主函数
|
||
"""
|
||
import sys
|
||
|
||
if len(sys.argv) != 2:
|
||
print("使用方法: python mineru_parser.py <本地文件路径>")
|
||
print("示例: python mineru_parser.py ./example.pdf")
|
||
sys.exit(1)
|
||
|
||
file_path = sys.argv[1]
|
||
|
||
# 解析文件
|
||
result = parse_local_file(file_path)
|
||
|
||
if result:
|
||
md_content = result["content"]
|
||
output_dir = result["output_dir"]
|
||
|
||
# 保存Markdown结果
|
||
output_file = os.path.join(TEMP_DIR, f"{os.path.basename(file_path)}.md")
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
f.write(md_content)
|
||
|
||
# Print JSON result for caller to parse
|
||
print(
|
||
"JSON_RESULT:"
|
||
+ json.dumps(
|
||
{
|
||
"status": "success",
|
||
"markdown_file": os.path.abspath(output_file),
|
||
"output_dir": os.path.abspath(output_dir),
|
||
}
|
||
)
|
||
)
|
||
else:
|
||
print("文件解析失败")
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|