Files
XCDesktop/tools/doubao/main.py
2026-03-08 01:34:54 +08:00

159 lines
5.3 KiB
Python

import os
import sys
import argparse
import time
from openai import OpenAI
# 任务定义
TASKS = {
"fix_markdown": {
"system": "你是一个 Markdown 格式专家。请修复以下 Markdown 文档片段的格式问题。不要修改文档的原始内容,只调整格式(如标题、列表、代码块等的规范化)。直接输出修复后的 Markdown 内容,不要包含任何解释或 ```markdown 标记。注意:这是长文档的一部分,请保持上下文连贯。",
}
}
# 最大块大小(字符数)
MAX_CHUNK_SIZE = 3000
def split_markdown(text, max_length=MAX_CHUNK_SIZE):
"""
将 Markdown 文本分割成较小的块,尽量保持段落和代码块完整。
"""
lines = text.split('\n')
chunks = []
current_chunk = []
current_length = 0
in_code_block = False
for line in lines:
# 检测代码块状态
if line.strip().startswith('```'):
in_code_block = not in_code_block
line_len = len(line) + 1 # +1 for newline
# 决定是否需要切分:
# 1. 当前长度超过最大限制
# 2. 且不在代码块内 (in_code_block == False)
if current_length + line_len > max_length and not in_code_block:
# 如果当前块不为空,则保存当前块
if current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 如果单行本身就超过了最大长度(极少见情况),也只能强行放入
current_chunk.append(line)
current_length += line_len
else:
current_chunk.append(line)
current_length += line_len
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def process_chunk(client, content, task_config, model="doubao-seed-1-8-251228"):
"""
处理单个文本块
"""
try:
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": task_config["system"]},
{"role": "user", "content": content},
],
max_tokens=4096, # 保持较大的输出 token 限制
)
return completion.choices[0].message.content
except Exception as e:
# 如果出错,打印错误到 stderr 但不中断整个流程(或者选择中断)
# 这里选择抛出异常以便外层捕获
raise e
def main():
parser = argparse.ArgumentParser(description="Doubao AI Task Executor")
parser.add_argument("--task", required=True, help="Task name", choices=TASKS.keys())
args = parser.parse_args()
# 优先从环境变量读取,如果没有则使用硬编码的 Key (仅供演示,实际应走环境变量)
api_key = os.getenv('ARK_API_KEY') or "a5ab502d-c9a9-49f3-a80b-9c80c6b5378b"
if not api_key:
print("Error: ARK_API_KEY environment variable is not set.", file=sys.stderr)
sys.exit(1)
client = OpenAI(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=api_key,
)
# Windows UTF-8 处理
if sys.platform == 'win32':
import io
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors='ignore')
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='ignore')
# 读取全部内容
content = sys.stdin.read()
# 清洗非法字符
content = content.encode('utf-8', 'ignore').decode('utf-8')
if not content:
print("Error: No input content provided via stdin.", file=sys.stderr)
sys.exit(1)
task_config = TASKS[args.task]
# 1. 分割文本
chunks = split_markdown(content)
# 2. 依次处理
results = []
total_chunks = len(chunks)
# 打印进度信息到 stderr (前端看不到,但方便调试)
print(f"Processing {total_chunks} chunks...", file=sys.stderr)
for i, chunk in enumerate(chunks):
try:
# 简单的重试机制
retry_count = 0
max_retries = 3
result = None
while retry_count < max_retries:
try:
result = process_chunk(client, chunk, task_config)
break
except Exception as e:
retry_count += 1
print(f"Chunk {i+1}/{total_chunks} failed (attempt {retry_count}): {e}", file=sys.stderr)
time.sleep(2) # 等待后重试
if result is None:
print(f"Error: Failed to process chunk {i+1} after {max_retries} attempts.", file=sys.stderr)
# 失败时保留原始内容,避免数据丢失
results.append(chunk)
else:
results.append(result)
# 避免触发速率限制
if i < total_chunks - 1:
time.sleep(0.5)
except Exception as e:
print(f"Critical error on chunk {i+1}: {e}", file=sys.stderr)
results.append(chunk)
# 3. 合并输出
final_output = '\n'.join(results)
# 4. 打印最终结果
print(final_output)
if __name__ == "__main__":
main()