208 lines
6.6 KiB
Python
208 lines
6.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
InsightFlow Backend - Phase 1 MVP with 阿里听悟
|
||
ASR: 阿里云听悟 (TingWu)
|
||
Speaker Diarization: 听悟内置
|
||
LLM: Kimi API for entity extraction
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import httpx
|
||
import time
|
||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
from datetime import datetime
|
||
from alibabacloud_tingwu20230930 import models as tingwu_models
|
||
from alibabacloud_tingwu20230930.client import Client as TingwuClient
|
||
from alibabacloud_tea_openapi import models as open_api_models
|
||
|
||
app = FastAPI(title="InsightFlow", version="0.1.0")
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Models
|
||
class Entity(BaseModel):
|
||
id: str
|
||
name: str
|
||
type: str
|
||
start: int
|
||
end: int
|
||
definition: Optional[str] = None
|
||
|
||
class TranscriptSegment(BaseModel):
|
||
start: float
|
||
end: float
|
||
text: str
|
||
speaker: Optional[str] = "Speaker A"
|
||
|
||
class AnalysisResult(BaseModel):
|
||
transcript_id: str
|
||
segments: List[TranscriptSegment]
|
||
entities: List[Entity]
|
||
full_text: str
|
||
created_at: str
|
||
|
||
storage = {}
|
||
|
||
# API Keys
|
||
ALI_ACCESS_KEY = os.getenv("ALI_ACCESS_KEY", "")
|
||
ALI_SECRET_KEY = os.getenv("ALI_SECRET_KEY", "")
|
||
KIMI_API_KEY = os.getenv("KIMI_API_KEY", "")
|
||
KIMI_BASE_URL = "https://api.kimi.com/coding"
|
||
|
||
def create_tingwu_client():
|
||
"""创建听悟客户端"""
|
||
config = open_api_models.Config(
|
||
access_key_id=ALI_ACCESS_KEY,
|
||
access_key_secret=ALI_SECRET_KEY
|
||
)
|
||
config.endpoint = "tingwu.cn-beijing.aliyuncs.com"
|
||
return TingwuClient(config)
|
||
|
||
def transcribe_with_tingwu(audio_data: bytes, filename: str) -> dict:
|
||
"""使用阿里听悟进行转录和说话人分离"""
|
||
if not ALI_ACCESS_KEY or not ALI_SECRET_KEY:
|
||
raise HTTPException(status_code=500, detail="Aliyun credentials not configured")
|
||
|
||
client = create_tingwu_client()
|
||
|
||
# 1. 创建任务
|
||
task_req = tingwu_models.CreateTaskRequest(
|
||
type="offline",
|
||
input=tingwu_models.Input(
|
||
source="oss", # 先上传到 OSS 或使用 URL
|
||
file_url="", # TODO: 需要 OSS 上传
|
||
),
|
||
parameters=tingwu_models.Parameters(
|
||
transcription=tingwu_models.Transcription(
|
||
diarization_enabled=True,
|
||
sentence_max_length=20
|
||
),
|
||
summarization=tingwu_models.Summarization(enabled=False)
|
||
)
|
||
)
|
||
|
||
# 简化:先用 HTTP 方式调用
|
||
# 实际生产需要 OSS 上传或 URL
|
||
|
||
# Mock 结果用于测试
|
||
return {
|
||
"full_text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。",
|
||
"segments": [
|
||
{"start": 0.0, "end": 5.0, "text": "这是一个示例转录文本,包含 Project Alpha 和 K8s 等术语。", "speaker": "Speaker A"}
|
||
]
|
||
}
|
||
|
||
def extract_entities_with_llm(text: str) -> List[Entity]:
|
||
"""使用 Kimi API 提取实体"""
|
||
if not KIMI_API_KEY or not text:
|
||
return []
|
||
|
||
prompt = f"""请从以下会议文本中提取关键实体(专有名词、项目名、技术术语、人名等),并以 JSON 格式返回:
|
||
|
||
文本:{text[:3000]}
|
||
|
||
要求:
|
||
1. 每个实体包含:name(名称), type(类型: PROJECT/TECH/PERSON/ORG/OTHER), start(起始字符位置), end(结束字符位置), definition(一句话定义)
|
||
2. 只返回 JSON 数组,不要其他内容
|
||
3. 确保 start/end 是字符在文本中的位置
|
||
|
||
示例输出:
|
||
[
|
||
{{"name": "Project Alpha", "type": "PROJECT", "start": 23, "end": 35, "definition": "Q3季度的核心项目"}},
|
||
{{"name": "K8s", "type": "TECH", "start": 37, "end": 40, "definition": "Kubernetes的缩写"}}
|
||
]
|
||
"""
|
||
|
||
try:
|
||
response = httpx.post(
|
||
f"{KIMI_BASE_URL}/v1/chat/completions",
|
||
headers={"Authorization": f"Bearer {KIMI_API_KEY}", "Content-Type": "application/json"},
|
||
json={
|
||
"model": "k2p5",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.1
|
||
},
|
||
timeout=60.0
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
content = result["choices"][0]["message"]["content"]
|
||
|
||
import re
|
||
json_match = re.search(r'\[.*?\]', content, re.DOTALL)
|
||
if json_match:
|
||
entities_data = json.loads(json_match.group())
|
||
entities = []
|
||
for i, e in enumerate(entities_data):
|
||
entities.append(Entity(
|
||
id=f"ent_{i+1}",
|
||
name=e["name"],
|
||
type=e.get("type", "OTHER"),
|
||
start=e["start"],
|
||
end=e["end"],
|
||
definition=e.get("definition", "")
|
||
))
|
||
return entities
|
||
except Exception as e:
|
||
print(f"LLM extraction failed: {e}")
|
||
|
||
return []
|
||
|
||
@app.post("/api/v1/upload", response_model=AnalysisResult)
|
||
async def upload_audio(file: UploadFile = File(...)):
|
||
"""上传音频并分析"""
|
||
content = await file.read()
|
||
|
||
# 听悟转录
|
||
print(f"Transcribing with Tingwu: {file.filename}")
|
||
tw_result = transcribe_with_tingwu(content, file.filename)
|
||
|
||
# 构建片段
|
||
segments = [
|
||
TranscriptSegment(**seg) for seg in tw_result["segments"]
|
||
] or [TranscriptSegment(start=0, end=0, text=tw_result["full_text"], speaker="Speaker A")]
|
||
|
||
# LLM 实体提取
|
||
print("Extracting entities with LLM...")
|
||
entities = extract_entities_with_llm(tw_result["full_text"])
|
||
|
||
analysis = AnalysisResult(
|
||
transcript_id=os.urandom(8).hex(),
|
||
segments=segments,
|
||
entities=entities,
|
||
full_text=tw_result["full_text"],
|
||
created_at=datetime.now().isoformat()
|
||
)
|
||
|
||
storage[analysis.transcript_id] = analysis
|
||
print(f"Analysis complete: {analysis.transcript_id}, {len(entities)} entities found")
|
||
return analysis
|
||
|
||
@app.get("/api/v1/transcripts/{transcript_id}", response_model=AnalysisResult)
|
||
async def get_transcript(transcript_id: str):
|
||
if transcript_id not in storage:
|
||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||
return storage[transcript_id]
|
||
|
||
@app.get("/api/v1/transcripts")
|
||
async def list_transcripts():
|
||
return list(storage.values())
|
||
|
||
# Serve frontend
|
||
app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|