159 lines
5.3 KiB
Python
159 lines
5.3 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
import time
|
|
from openai import OpenAI
|
|
|
|
# 任务定义
|
|
TASKS = {
|
|
"fix_markdown": {
|
|
"system": "你是一个 Markdown 格式专家。请修复以下 Markdown 文档片段的格式问题。不要修改文档的原始内容,只调整格式(如标题、列表、代码块等的规范化)。直接输出修复后的 Markdown 内容,不要包含任何解释或 ```markdown 标记。注意:这是长文档的一部分,请保持上下文连贯。",
|
|
}
|
|
}
|
|
|
|
# 最大块大小(字符数)
|
|
MAX_CHUNK_SIZE = 3000
|
|
|
|
def split_markdown(text, max_length=MAX_CHUNK_SIZE):
|
|
"""
|
|
将 Markdown 文本分割成较小的块,尽量保持段落和代码块完整。
|
|
"""
|
|
lines = text.split('\n')
|
|
chunks = []
|
|
current_chunk = []
|
|
current_length = 0
|
|
in_code_block = False
|
|
|
|
for line in lines:
|
|
# 检测代码块状态
|
|
if line.strip().startswith('```'):
|
|
in_code_block = not in_code_block
|
|
|
|
line_len = len(line) + 1 # +1 for newline
|
|
|
|
# 决定是否需要切分:
|
|
# 1. 当前长度超过最大限制
|
|
# 2. 且不在代码块内 (in_code_block == False)
|
|
if current_length + line_len > max_length and not in_code_block:
|
|
# 如果当前块不为空,则保存当前块
|
|
if current_chunk:
|
|
chunks.append('\n'.join(current_chunk))
|
|
current_chunk = []
|
|
current_length = 0
|
|
|
|
# 如果单行本身就超过了最大长度(极少见情况),也只能强行放入
|
|
current_chunk.append(line)
|
|
current_length += line_len
|
|
else:
|
|
current_chunk.append(line)
|
|
current_length += line_len
|
|
|
|
if current_chunk:
|
|
chunks.append('\n'.join(current_chunk))
|
|
|
|
return chunks
|
|
|
|
def process_chunk(client, content, task_config, model="doubao-seed-1-8-251228"):
|
|
"""
|
|
处理单个文本块
|
|
"""
|
|
try:
|
|
completion = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": task_config["system"]},
|
|
{"role": "user", "content": content},
|
|
],
|
|
max_tokens=4096, # 保持较大的输出 token 限制
|
|
)
|
|
return completion.choices[0].message.content
|
|
except Exception as e:
|
|
# 如果出错,打印错误到 stderr 但不中断整个流程(或者选择中断)
|
|
# 这里选择抛出异常以便外层捕获
|
|
raise e
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Doubao AI Task Executor")
|
|
parser.add_argument("--task", required=True, help="Task name", choices=TASKS.keys())
|
|
args = parser.parse_args()
|
|
|
|
# 优先从环境变量读取,如果没有则使用硬编码的 Key (仅供演示,实际应走环境变量)
|
|
api_key = os.getenv('ARK_API_KEY') or "a5ab502d-c9a9-49f3-a80b-9c80c6b5378b"
|
|
|
|
if not api_key:
|
|
print("Error: ARK_API_KEY environment variable is not set.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
client = OpenAI(
|
|
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
|
api_key=api_key,
|
|
)
|
|
|
|
# Windows UTF-8 处理
|
|
if sys.platform == 'win32':
|
|
import io
|
|
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors='ignore')
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='ignore')
|
|
|
|
# 读取全部内容
|
|
content = sys.stdin.read()
|
|
|
|
# 清洗非法字符
|
|
content = content.encode('utf-8', 'ignore').decode('utf-8')
|
|
|
|
if not content:
|
|
print("Error: No input content provided via stdin.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
task_config = TASKS[args.task]
|
|
|
|
# 1. 分割文本
|
|
chunks = split_markdown(content)
|
|
|
|
# 2. 依次处理
|
|
results = []
|
|
total_chunks = len(chunks)
|
|
|
|
# 打印进度信息到 stderr (前端看不到,但方便调试)
|
|
print(f"Processing {total_chunks} chunks...", file=sys.stderr)
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
# 简单的重试机制
|
|
retry_count = 0
|
|
max_retries = 3
|
|
result = None
|
|
|
|
while retry_count < max_retries:
|
|
try:
|
|
result = process_chunk(client, chunk, task_config)
|
|
break
|
|
except Exception as e:
|
|
retry_count += 1
|
|
print(f"Chunk {i+1}/{total_chunks} failed (attempt {retry_count}): {e}", file=sys.stderr)
|
|
time.sleep(2) # 等待后重试
|
|
|
|
if result is None:
|
|
print(f"Error: Failed to process chunk {i+1} after {max_retries} attempts.", file=sys.stderr)
|
|
# 失败时保留原始内容,避免数据丢失
|
|
results.append(chunk)
|
|
else:
|
|
results.append(result)
|
|
|
|
# 避免触发速率限制
|
|
if i < total_chunks - 1:
|
|
time.sleep(0.5)
|
|
|
|
except Exception as e:
|
|
print(f"Critical error on chunk {i+1}: {e}", file=sys.stderr)
|
|
results.append(chunk)
|
|
|
|
# 3. 合并输出
|
|
final_output = '\n'.join(results)
|
|
|
|
# 4. 打印最终结果
|
|
print(final_output)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|