- 新增 search_manager.py 搜索管理模块
- FullTextSearch: 全文搜索引擎 (FTS5)
- SemanticSearch: 语义搜索引擎 (sentence-transformers)
- EntityPathDiscovery: 实体关系路径发现 (BFS/DFS)
- KnowledgeGapDetector: 知识缺口检测器
- 新增 performance_manager.py 性能管理模块
- CacheManager: Redis 缓存层 (支持内存回退)
- DatabaseSharding: 数据库分片管理
- TaskQueue: 异步任务队列 (Celery + Redis)
- PerformanceMonitor: 性能监控器
- 更新 schema.sql 添加新表
- search_indexes, embeddings, fts_transcripts
- cache_stats, task_queue, performance_metrics, shard_mappings
- 更新 main.py 添加 API 端点
- 搜索: /search/fulltext, /search/semantic, /entities/{id}/paths
- 性能: /cache/stats, /performance/metrics, /tasks, /health
- 更新 requirements.txt 添加依赖
- sentence-transformers==2.5.1
- redis==5.0.1
- celery==5.3.6
- 创建测试脚本和文档
- test_phase7_task6_8.py
- docs/PHASE7_TASK6_8_SUMMARY.md
Phase 7 全部完成!
1713 lines
54 KiB
Python
1713 lines
54 KiB
Python
"""
|
||
InsightFlow - 性能优化与扩展模块
|
||
Phase 7 Task 8: Performance Optimization & Scaling
|
||
|
||
功能模块:
|
||
1. CacheManager - Redis 缓存层(热点数据、TTL、LRU、缓存预热)
|
||
2. DatabaseSharding - 数据库分片策略
|
||
3. TaskQueue - 异步任务队列(Celery + Redis)
|
||
4. PerformanceMonitor - 性能监控(API响应、查询性能、缓存命中率)
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import hashlib
|
||
import sqlite3
|
||
import threading
|
||
from dataclasses import dataclass, field, asdict
|
||
from typing import Dict, List, Optional, Any, Callable, Tuple, Set
|
||
from datetime import datetime, timedelta
|
||
from collections import OrderedDict, defaultdict
|
||
from functools import wraps
|
||
import uuid
|
||
|
||
# 尝试导入 Redis
|
||
try:
|
||
import redis
|
||
REDIS_AVAILABLE = True
|
||
except ImportError:
|
||
REDIS_AVAILABLE = False
|
||
|
||
# 尝试导入 Celery
|
||
try:
|
||
from celery import Celery, Task
|
||
from celery.result import AsyncResult
|
||
CELERY_AVAILABLE = True
|
||
except ImportError:
|
||
CELERY_AVAILABLE = False
|
||
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
@dataclass
|
||
class CacheStats:
|
||
"""缓存统计数据模型"""
|
||
total_requests: int = 0
|
||
hits: int = 0
|
||
misses: int = 0
|
||
evictions: int = 0
|
||
expired: int = 0
|
||
hit_rate: float = 0.0
|
||
|
||
def update_hit_rate(self):
|
||
"""更新命中率"""
|
||
if self.total_requests > 0:
|
||
self.hit_rate = round(self.hits / self.total_requests, 4)
|
||
|
||
|
||
@dataclass
|
||
class CacheEntry:
|
||
"""缓存条目数据模型"""
|
||
key: str
|
||
value: Any
|
||
created_at: float
|
||
expires_at: Optional[float]
|
||
access_count: int = 0
|
||
last_accessed: float = 0
|
||
size_bytes: int = 0
|
||
|
||
|
||
@dataclass
|
||
class PerformanceMetric:
|
||
"""性能指标数据模型"""
|
||
id: str
|
||
metric_type: str # api_response, db_query, cache_operation
|
||
endpoint: Optional[str]
|
||
duration_ms: float
|
||
timestamp: str
|
||
metadata: Dict = field(default_factory=dict)
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
"id": self.id,
|
||
"metric_type": self.metric_type,
|
||
"endpoint": self.endpoint,
|
||
"duration_ms": self.duration_ms,
|
||
"timestamp": self.timestamp,
|
||
"metadata": self.metadata
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class TaskInfo:
|
||
"""任务信息数据模型"""
|
||
id: str
|
||
task_type: str
|
||
status: str # pending, running, success, failed, retrying
|
||
payload: Dict
|
||
created_at: str
|
||
started_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
result: Optional[Any] = None
|
||
error_message: Optional[str] = None
|
||
retry_count: int = 0
|
||
max_retries: int = 3
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
"id": self.id,
|
||
"task_type": self.task_type,
|
||
"status": self.status,
|
||
"payload": self.payload,
|
||
"created_at": self.created_at,
|
||
"started_at": self.started_at,
|
||
"completed_at": self.completed_at,
|
||
"result": self.result,
|
||
"error_message": self.error_message,
|
||
"retry_count": self.retry_count,
|
||
"max_retries": self.max_retries
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ShardInfo:
|
||
"""分片信息数据模型"""
|
||
shard_id: str
|
||
shard_key_range: Tuple[str, str] # (start, end)
|
||
db_path: str
|
||
entity_count: int = 0
|
||
is_active: bool = True
|
||
created_at: str = ""
|
||
last_accessed: str = ""
|
||
|
||
|
||
# ==================== Redis 缓存层 ====================
|
||
|
||
class CacheManager:
|
||
"""
|
||
缓存管理器
|
||
|
||
功能:
|
||
- 热点数据缓存(实体、关系、转录)
|
||
- 缓存失效策略(TTL、LRU)
|
||
- 缓存预热机制
|
||
- 缓存统计和监控
|
||
|
||
支持两种模式:
|
||
1. Redis 模式(推荐生产环境)
|
||
2. 内存 LRU 模式(开发/测试环境)
|
||
"""
|
||
|
||
def __init__(self,
|
||
redis_url: Optional[str] = None,
|
||
max_memory_size: int = 100 * 1024 * 1024, # 100MB
|
||
default_ttl: int = 3600, # 1小时
|
||
db_path: str = "insightflow.db"):
|
||
self.db_path = db_path
|
||
self.default_ttl = default_ttl
|
||
self.max_memory_size = max_memory_size
|
||
self.current_memory_size = 0
|
||
|
||
# Redis 客户端
|
||
self.redis_client = None
|
||
self.use_redis = False
|
||
|
||
if REDIS_AVAILABLE and redis_url:
|
||
try:
|
||
self.redis_client = redis.from_url(redis_url, decode_responses=True)
|
||
self.redis_client.ping()
|
||
self.use_redis = True
|
||
print(f"Redis 缓存已连接: {redis_url}")
|
||
except Exception as e:
|
||
print(f"Redis 连接失败,使用内存缓存: {e}")
|
||
|
||
# 内存缓存(LRU)
|
||
self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||
self.cache_lock = threading.RLock()
|
||
|
||
# 统计
|
||
self.stats = CacheStats()
|
||
|
||
# 初始化缓存统计表
|
||
self._init_cache_tables()
|
||
|
||
def _init_cache_tables(self):
|
||
"""初始化缓存统计表"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS cache_stats (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
total_requests INTEGER DEFAULT 0,
|
||
hits INTEGER DEFAULT 0,
|
||
misses INTEGER DEFAULT 0,
|
||
hit_rate REAL DEFAULT 0.0,
|
||
memory_usage INTEGER DEFAULT 0
|
||
)
|
||
""")
|
||
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS performance_metrics (
|
||
id TEXT PRIMARY KEY,
|
||
metric_type TEXT NOT NULL,
|
||
endpoint TEXT,
|
||
duration_ms REAL,
|
||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
metadata TEXT
|
||
)
|
||
""")
|
||
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_type ON performance_metrics(metric_type)")
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_time ON performance_metrics(timestamp)")
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def _get_entry_size(self, value: Any) -> int:
|
||
"""估算缓存条目大小"""
|
||
try:
|
||
return len(json.dumps(value, ensure_ascii=False).encode('utf-8'))
|
||
except:
|
||
return 1024 # 默认估算
|
||
|
||
def _evict_lru(self, required_space: int = 0):
|
||
"""LRU 淘汰策略"""
|
||
with self.cache_lock:
|
||
while (self.current_memory_size + required_space > self.max_memory_size
|
||
and self.memory_cache):
|
||
# 移除最久未访问的
|
||
oldest_key, oldest_entry = self.memory_cache.popitem(last=False)
|
||
self.current_memory_size -= oldest_entry.size_bytes
|
||
self.stats.evictions += 1
|
||
|
||
def get(self, key: str) -> Optional[Any]:
|
||
"""
|
||
获取缓存值
|
||
|
||
Args:
|
||
key: 缓存键
|
||
|
||
Returns:
|
||
Optional[Any]: 缓存值,不存在返回 None
|
||
"""
|
||
self.stats.total_requests += 1
|
||
|
||
if self.use_redis:
|
||
try:
|
||
value = self.redis_client.get(key)
|
||
if value:
|
||
self.stats.hits += 1
|
||
return json.loads(value)
|
||
else:
|
||
self.stats.misses += 1
|
||
return None
|
||
except Exception as e:
|
||
print(f"Redis get 失败: {e}")
|
||
return None
|
||
else:
|
||
# 内存缓存
|
||
with self.cache_lock:
|
||
entry = self.memory_cache.get(key)
|
||
|
||
if entry:
|
||
# 检查是否过期
|
||
if entry.expires_at and time.time() > entry.expires_at:
|
||
del self.memory_cache[key]
|
||
self.current_memory_size -= entry.size_bytes
|
||
self.stats.expired += 1
|
||
self.stats.misses += 1
|
||
return None
|
||
|
||
# 更新访问信息
|
||
entry.access_count += 1
|
||
entry.last_accessed = time.time()
|
||
self.memory_cache.move_to_end(key)
|
||
|
||
self.stats.hits += 1
|
||
return entry.value
|
||
else:
|
||
self.stats.misses += 1
|
||
return None
|
||
|
||
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||
"""
|
||
设置缓存值
|
||
|
||
Args:
|
||
key: 缓存键
|
||
value: 缓存值
|
||
ttl: 过期时间(秒),None 表示使用默认值
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
ttl = ttl or self.default_ttl
|
||
|
||
if self.use_redis:
|
||
try:
|
||
serialized = json.dumps(value, ensure_ascii=False)
|
||
self.redis_client.setex(key, ttl, serialized)
|
||
return True
|
||
except Exception as e:
|
||
print(f"Redis set 失败: {e}")
|
||
return False
|
||
else:
|
||
# 内存缓存
|
||
with self.cache_lock:
|
||
size = self._get_entry_size(value)
|
||
|
||
# 检查是否需要淘汰
|
||
if self.current_memory_size + size > self.max_memory_size:
|
||
self._evict_lru(size)
|
||
|
||
now = time.time()
|
||
entry = CacheEntry(
|
||
key=key,
|
||
value=value,
|
||
created_at=now,
|
||
expires_at=now + ttl if ttl > 0 else None,
|
||
size_bytes=size,
|
||
last_accessed=now
|
||
)
|
||
|
||
# 如果已存在,更新大小
|
||
if key in self.memory_cache:
|
||
self.current_memory_size -= self.memory_cache[key].size_bytes
|
||
|
||
self.memory_cache[key] = entry
|
||
self.memory_cache.move_to_end(key)
|
||
self.current_memory_size += size
|
||
|
||
return True
|
||
|
||
def delete(self, key: str) -> bool:
|
||
"""删除缓存"""
|
||
if self.use_redis:
|
||
try:
|
||
return self.redis_client.delete(key) > 0
|
||
except Exception as e:
|
||
print(f"Redis delete 失败: {e}")
|
||
return False
|
||
else:
|
||
with self.cache_lock:
|
||
if key in self.memory_cache:
|
||
entry = self.memory_cache.pop(key)
|
||
self.current_memory_size -= entry.size_bytes
|
||
return True
|
||
return False
|
||
|
||
def clear(self) -> bool:
|
||
"""清空缓存"""
|
||
if self.use_redis:
|
||
try:
|
||
self.redis_client.flushdb()
|
||
return True
|
||
except Exception as e:
|
||
print(f"Redis clear 失败: {e}")
|
||
return False
|
||
else:
|
||
with self.cache_lock:
|
||
self.memory_cache.clear()
|
||
self.current_memory_size = 0
|
||
return True
|
||
|
||
def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||
"""批量获取缓存"""
|
||
results = {}
|
||
|
||
if self.use_redis:
|
||
try:
|
||
values = self.redis_client.mget(keys)
|
||
for key, value in zip(keys, values):
|
||
if value:
|
||
results[key] = json.loads(value)
|
||
self.stats.hits += 1
|
||
else:
|
||
self.stats.misses += 1
|
||
self.stats.total_requests += 1
|
||
except Exception as e:
|
||
print(f"Redis mget 失败: {e}")
|
||
else:
|
||
for key in keys:
|
||
value = self.get(key)
|
||
if value is not None:
|
||
results[key] = value
|
||
|
||
return results
|
||
|
||
def set_many(self, mapping: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
||
"""批量设置缓存"""
|
||
ttl = ttl or self.default_ttl
|
||
|
||
if self.use_redis:
|
||
try:
|
||
pipe = self.redis_client.pipeline()
|
||
for key, value in mapping.items():
|
||
serialized = json.dumps(value, ensure_ascii=False)
|
||
pipe.setex(key, ttl, serialized)
|
||
pipe.execute()
|
||
return True
|
||
except Exception as e:
|
||
print(f"Redis mset 失败: {e}")
|
||
return False
|
||
else:
|
||
for key, value in mapping.items():
|
||
self.set(key, value, ttl)
|
||
return True
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取缓存统计"""
|
||
self.stats.update_hit_rate()
|
||
|
||
stats = {
|
||
"total_requests": self.stats.total_requests,
|
||
"hits": self.stats.hits,
|
||
"misses": self.stats.misses,
|
||
"hit_rate": self.stats.hit_rate,
|
||
"evictions": self.stats.evictions,
|
||
"expired": self.stats.expired,
|
||
"backend": "redis" if self.use_redis else "memory",
|
||
}
|
||
|
||
if not self.use_redis:
|
||
stats.update({
|
||
"memory_size_bytes": self.current_memory_size,
|
||
"max_memory_size_bytes": self.max_memory_size,
|
||
"memory_usage_percent": round(
|
||
self.current_memory_size / self.max_memory_size * 100, 2
|
||
),
|
||
"cache_entries": len(self.memory_cache)
|
||
})
|
||
|
||
return stats
|
||
|
||
def save_stats(self):
|
||
"""保存缓存统计到数据库"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
self.stats.update_hit_rate()
|
||
|
||
conn.execute("""
|
||
INSERT INTO cache_stats
|
||
(timestamp, total_requests, hits, misses, hit_rate, memory_usage)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""", (
|
||
datetime.now().isoformat(),
|
||
self.stats.total_requests,
|
||
self.stats.hits,
|
||
self.stats.misses,
|
||
self.stats.hit_rate,
|
||
self.current_memory_size
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def warm_up(self, project_id: str) -> Dict:
|
||
"""
|
||
缓存预热 - 加载项目的热点数据
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
|
||
Returns:
|
||
Dict: 预热统计
|
||
"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
stats = {"entities": 0, "relations": 0, "transcripts": 0}
|
||
|
||
# 预热实体数据
|
||
entities = conn.execute(
|
||
"""SELECT e.*,
|
||
(SELECT COUNT(*) FROM entity_mentions m WHERE m.entity_id = e.id) as mention_count
|
||
FROM entities e
|
||
WHERE e.project_id = ?
|
||
ORDER BY mention_count DESC
|
||
LIMIT 100""",
|
||
(project_id,)
|
||
).fetchall()
|
||
|
||
for entity in entities:
|
||
key = f"entity:{entity['id']}"
|
||
self.set(key, dict(entity), ttl=7200) # 2小时
|
||
stats["entities"] += 1
|
||
|
||
# 预热关系数据
|
||
relations = conn.execute(
|
||
"""SELECT r.*,
|
||
e1.name as source_name, e2.name as target_name
|
||
FROM entity_relations r
|
||
JOIN entities e1 ON r.source_entity_id = e1.id
|
||
JOIN entities e2 ON r.target_entity_id = e2.id
|
||
WHERE r.project_id = ?
|
||
LIMIT 200""",
|
||
(project_id,)
|
||
).fetchall()
|
||
|
||
for relation in relations:
|
||
key = f"relation:{relation['id']}"
|
||
self.set(key, dict(relation), ttl=3600)
|
||
stats["relations"] += 1
|
||
|
||
# 预热最近的转录
|
||
transcripts = conn.execute(
|
||
"""SELECT * FROM transcripts
|
||
WHERE project_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT 10""",
|
||
(project_id,)
|
||
).fetchall()
|
||
|
||
for transcript in transcripts:
|
||
key = f"transcript:{transcript['id']}"
|
||
# 只缓存元数据,不缓存完整文本
|
||
meta = {
|
||
"id": transcript['id'],
|
||
"filename": transcript['filename'],
|
||
"type": transcript.get('type', 'audio'),
|
||
"created_at": transcript['created_at']
|
||
}
|
||
self.set(key, meta, ttl=1800) # 30分钟
|
||
stats["transcripts"] += 1
|
||
|
||
# 预热项目知识库摘要
|
||
entity_count = conn.execute(
|
||
"SELECT COUNT(*) FROM entities WHERE project_id = ?",
|
||
(project_id,)
|
||
).fetchone()[0]
|
||
|
||
relation_count = conn.execute(
|
||
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?",
|
||
(project_id,)
|
||
).fetchone()[0]
|
||
|
||
summary = {
|
||
"project_id": project_id,
|
||
"entity_count": entity_count,
|
||
"relation_count": relation_count,
|
||
"cached_at": datetime.now().isoformat()
|
||
}
|
||
self.set(f"project_summary:{project_id}", summary, ttl=3600)
|
||
|
||
conn.close()
|
||
|
||
return stats
|
||
|
||
def invalidate_project(self, project_id: str) -> int:
|
||
"""
|
||
使项目的所有缓存失效
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
|
||
Returns:
|
||
int: 清除的缓存数量
|
||
"""
|
||
count = 0
|
||
|
||
if self.use_redis:
|
||
try:
|
||
# 使用 Redis 的 scan 查找相关 key
|
||
pattern = f"*:{project_id}:*"
|
||
for key in self.redis_client.scan_iter(match=pattern):
|
||
self.redis_client.delete(key)
|
||
count += 1
|
||
except Exception as e:
|
||
print(f"Redis 缓存失效失败: {e}")
|
||
else:
|
||
# 内存缓存 - 查找并删除相关 key
|
||
with self.cache_lock:
|
||
keys_to_delete = [
|
||
key for key in self.memory_cache.keys()
|
||
if project_id in key
|
||
]
|
||
for key in keys_to_delete:
|
||
entry = self.memory_cache.pop(key)
|
||
self.current_memory_size -= entry.size_bytes
|
||
count += 1
|
||
|
||
return count
|
||
|
||
|
||
# ==================== 数据库分片 ====================
|
||
|
||
class DatabaseSharding:
|
||
"""
|
||
数据库分片管理器
|
||
|
||
功能:
|
||
- 项目数据分片策略
|
||
- 分片路由逻辑
|
||
- 跨分片查询支持
|
||
- 分片迁移工具
|
||
"""
|
||
|
||
def __init__(self,
|
||
base_db_path: str = "insightflow.db",
|
||
shard_db_dir: str = "./shards",
|
||
shards_count: int = 4):
|
||
self.base_db_path = base_db_path
|
||
self.shard_db_dir = shard_db_dir
|
||
self.shards_count = shards_count
|
||
|
||
# 确保分片目录存在
|
||
os.makedirs(shard_db_dir, exist_ok=True)
|
||
|
||
# 分片映射
|
||
self.shard_map: Dict[str, ShardInfo] = {}
|
||
|
||
# 初始化分片
|
||
self._init_shards()
|
||
|
||
def _init_shards(self):
|
||
"""初始化分片"""
|
||
# 计算每个分片的 key 范围
|
||
chars = "0123456789abcdef"
|
||
chars_per_shard = len(chars) // self.shards_count
|
||
|
||
for i in range(self.shards_count):
|
||
start_idx = i * chars_per_shard
|
||
end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars)
|
||
|
||
start_char = chars[start_idx]
|
||
end_char = chars[end_idx - 1]
|
||
|
||
shard_id = f"shard_{i}"
|
||
db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db")
|
||
|
||
self.shard_map[shard_id] = ShardInfo(
|
||
shard_id=shard_id,
|
||
shard_key_range=(start_char, end_char),
|
||
db_path=db_path,
|
||
created_at=datetime.now().isoformat()
|
||
)
|
||
|
||
# 确保分片数据库存在
|
||
if not os.path.exists(db_path):
|
||
self._create_shard_db(db_path)
|
||
|
||
def _create_shard_db(self, db_path: str):
|
||
"""创建分片数据库"""
|
||
conn = sqlite3.connect(db_path)
|
||
|
||
# 创建与主库相同的表结构(简化版)
|
||
conn.executescript("""
|
||
CREATE TABLE IF NOT EXISTS entities (
|
||
id TEXT PRIMARY KEY,
|
||
project_id TEXT NOT NULL,
|
||
name TEXT NOT NULL,
|
||
type TEXT,
|
||
definition TEXT,
|
||
aliases TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS entity_relations (
|
||
id TEXT PRIMARY KEY,
|
||
project_id TEXT NOT NULL,
|
||
source_entity_id TEXT NOT NULL,
|
||
target_entity_id TEXT NOT NULL,
|
||
relation_type TEXT,
|
||
evidence TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_entities_project ON entities(project_id);
|
||
CREATE INDEX IF NOT EXISTS idx_relations_project ON entity_relations(project_id);
|
||
""")
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def _get_shard_id(self, project_id: str) -> str:
|
||
"""
|
||
根据项目ID计算分片ID
|
||
|
||
使用项目ID的第一个字符进行哈希
|
||
"""
|
||
if not project_id:
|
||
return "shard_0"
|
||
|
||
first_char = project_id[0].lower()
|
||
|
||
for shard_id, shard_info in self.shard_map.items():
|
||
start, end = shard_info.shard_key_range
|
||
if start <= first_char <= end:
|
||
return shard_id
|
||
|
||
return "shard_0"
|
||
|
||
def get_shard_connection(self, project_id: str) -> sqlite3.Connection:
|
||
"""获取项目对应的分片连接"""
|
||
shard_id = self._get_shard_id(project_id)
|
||
shard_info = self.shard_map[shard_id]
|
||
|
||
conn = sqlite3.connect(shard_info.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# 更新访问时间
|
||
shard_info.last_accessed = datetime.now().isoformat()
|
||
|
||
return conn
|
||
|
||
def get_all_shards(self) -> List[ShardInfo]:
|
||
"""获取所有分片信息"""
|
||
return list(self.shard_map.values())
|
||
|
||
def migrate_project(self, project_id: str, target_shard_id: str) -> bool:
|
||
"""
|
||
迁移项目到指定分片
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
target_shard_id: 目标分片ID
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
# 获取源分片
|
||
source_shard_id = self._get_shard_id(project_id)
|
||
|
||
if source_shard_id == target_shard_id:
|
||
return True # 已经在目标分片
|
||
|
||
source_info = self.shard_map.get(source_shard_id)
|
||
target_info = self.shard_map.get(target_shard_id)
|
||
|
||
if not source_info or not target_info:
|
||
return False
|
||
|
||
try:
|
||
# 从源分片读取数据
|
||
source_conn = sqlite3.connect(source_info.db_path)
|
||
source_conn.row_factory = sqlite3.Row
|
||
|
||
entities = source_conn.execute(
|
||
"SELECT * FROM entities WHERE project_id = ?",
|
||
(project_id,)
|
||
).fetchall()
|
||
|
||
relations = source_conn.execute(
|
||
"SELECT * FROM entity_relations WHERE project_id = ?",
|
||
(project_id,)
|
||
).fetchall()
|
||
|
||
source_conn.close()
|
||
|
||
# 写入目标分片
|
||
target_conn = sqlite3.connect(target_info.db_path)
|
||
|
||
for entity in entities:
|
||
target_conn.execute("""
|
||
INSERT OR REPLACE INTO entities
|
||
(id, project_id, name, type, definition, aliases, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""", tuple(entity))
|
||
|
||
for relation in relations:
|
||
target_conn.execute("""
|
||
INSERT OR REPLACE INTO entity_relations
|
||
(id, project_id, source_entity_id, target_entity_id, relation_type, evidence, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""", tuple(relation))
|
||
|
||
target_conn.commit()
|
||
target_conn.close()
|
||
|
||
# 从源分片删除数据
|
||
source_conn = sqlite3.connect(source_info.db_path)
|
||
source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id,))
|
||
source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id,))
|
||
source_conn.commit()
|
||
source_conn.close()
|
||
|
||
# 更新分片统计
|
||
self._update_shard_stats(source_shard_id)
|
||
self._update_shard_stats(target_shard_id)
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"迁移失败: {e}")
|
||
return False
|
||
|
||
def _update_shard_stats(self, shard_id: str):
|
||
"""更新分片统计"""
|
||
shard_info = self.shard_map.get(shard_id)
|
||
if not shard_info:
|
||
return
|
||
|
||
conn = sqlite3.connect(shard_info.db_path)
|
||
|
||
count = conn.execute(
|
||
"SELECT COUNT(DISTINCT project_id) FROM entities"
|
||
).fetchone()[0]
|
||
|
||
shard_info.entity_count = count
|
||
|
||
conn.close()
|
||
|
||
def cross_shard_query(self, query_func: Callable) -> List[Dict]:
|
||
"""
|
||
跨分片查询
|
||
|
||
Args:
|
||
query_func: 查询函数,接收 connection 参数
|
||
|
||
Returns:
|
||
List[Dict]: 合并的查询结果
|
||
"""
|
||
results = []
|
||
|
||
for shard_info in self.shard_map.values():
|
||
conn = sqlite3.connect(shard_info.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
try:
|
||
shard_results = query_func(conn)
|
||
results.extend(shard_results)
|
||
except Exception as e:
|
||
print(f"分片 {shard_info.shard_id} 查询失败: {e}")
|
||
finally:
|
||
conn.close()
|
||
|
||
return results
|
||
|
||
def get_shard_stats(self) -> List[Dict]:
|
||
"""获取所有分片的统计信息"""
|
||
stats = []
|
||
|
||
for shard_info in self.shard_map.values():
|
||
self._update_shard_stats(shard_info.shard_id)
|
||
|
||
stats.append({
|
||
"shard_id": shard_info.shard_id,
|
||
"key_range": shard_info.shard_key_range,
|
||
"db_path": shard_info.db_path,
|
||
"entity_count": shard_info.entity_count,
|
||
"is_active": shard_info.is_active,
|
||
"created_at": shard_info.created_at,
|
||
"last_accessed": shard_info.last_accessed
|
||
})
|
||
|
||
return stats
|
||
|
||
def rebalance_shards(self) -> Dict:
|
||
"""
|
||
重新平衡分片
|
||
|
||
将数据从过载的分片迁移到负载较轻的分片
|
||
|
||
Returns:
|
||
Dict: 重新平衡统计
|
||
"""
|
||
# 获取各分片的负载
|
||
stats = self.get_shard_stats()
|
||
|
||
if not stats:
|
||
return {"message": "No shards to rebalance"}
|
||
|
||
# 计算平均负载
|
||
avg_load = sum(s["entity_count"] for s in stats) / len(stats)
|
||
|
||
# 找出过载和欠载的分片
|
||
overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5]
|
||
underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5]
|
||
|
||
# 简化的重新平衡逻辑
|
||
# 实际生产环境需要更复杂的算法
|
||
|
||
return {
|
||
"average_load": avg_load,
|
||
"overloaded_shards": len(overloaded),
|
||
"underloaded_shards": len(underloaded),
|
||
"message": "Rebalancing analysis completed"
|
||
}
|
||
|
||
|
||
# ==================== 异步任务队列 ====================
|
||
|
||
class TaskQueue:
|
||
"""
|
||
异步任务队列管理器
|
||
|
||
功能:
|
||
- 基于 Celery + Redis 的任务队列
|
||
- 音频分析异步处理
|
||
- 报告生成异步处理
|
||
- 任务状态追踪和重试机制
|
||
"""
|
||
|
||
def __init__(self,
|
||
redis_url: Optional[str] = None,
|
||
db_path: str = "insightflow.db"):
|
||
self.db_path = db_path
|
||
self.redis_url = redis_url
|
||
self.celery_app = None
|
||
self.use_celery = False
|
||
|
||
# 内存任务存储(非 Celery 模式)
|
||
self.tasks: Dict[str, TaskInfo] = {}
|
||
self.task_handlers: Dict[str, Callable] = {}
|
||
self.task_lock = threading.RLock()
|
||
|
||
# 初始化任务队列表
|
||
self._init_task_tables()
|
||
|
||
# 初始化 Celery
|
||
if CELERY_AVAILABLE and redis_url:
|
||
try:
|
||
self.celery_app = Celery(
|
||
'insightflow',
|
||
broker=redis_url,
|
||
backend=redis_url
|
||
)
|
||
self.use_celery = True
|
||
print(f"Celery 任务队列已初始化")
|
||
except Exception as e:
|
||
print(f"Celery 初始化失败,使用内存任务队列: {e}")
|
||
|
||
def _init_task_tables(self):
|
||
"""初始化任务队列表"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS task_queue (
|
||
id TEXT PRIMARY KEY,
|
||
task_type TEXT NOT NULL,
|
||
status TEXT DEFAULT 'pending',
|
||
payload TEXT,
|
||
result TEXT,
|
||
error_message TEXT,
|
||
retry_count INTEGER DEFAULT 0,
|
||
max_retries INTEGER DEFAULT 3,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
started_at TIMESTAMP,
|
||
completed_at TIMESTAMP
|
||
)
|
||
""")
|
||
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON task_queue(status)")
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_type ON task_queue(task_type)")
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查任务队列是否可用"""
|
||
return self.use_celery or True # 内存模式也可用
|
||
|
||
def register_handler(self, task_type: str, handler: Callable):
|
||
"""注册任务处理器"""
|
||
self.task_handlers[task_type] = handler
|
||
|
||
def submit(self, task_type: str, payload: Dict,
|
||
max_retries: int = 3) -> str:
|
||
"""
|
||
提交任务
|
||
|
||
Args:
|
||
task_type: 任务类型
|
||
payload: 任务数据
|
||
max_retries: 最大重试次数
|
||
|
||
Returns:
|
||
str: 任务ID
|
||
"""
|
||
task_id = str(uuid.uuid4())[:16]
|
||
|
||
task = TaskInfo(
|
||
id=task_id,
|
||
task_type=task_type,
|
||
status="pending",
|
||
payload=payload,
|
||
created_at=datetime.now().isoformat(),
|
||
max_retries=max_retries
|
||
)
|
||
|
||
if self.use_celery:
|
||
# 使用 Celery
|
||
try:
|
||
# 这里简化处理,实际应该定义具体的 Celery 任务
|
||
result = self.celery_app.send_task(
|
||
f'insightflow.tasks.{task_type}',
|
||
args=[payload],
|
||
task_id=task_id,
|
||
retry=True,
|
||
retry_policy={
|
||
'max_retries': max_retries,
|
||
'interval_start': 10,
|
||
'interval_step': 10,
|
||
'interval_max': 60
|
||
}
|
||
)
|
||
task.id = result.id
|
||
except Exception as e:
|
||
print(f"Celery 任务提交失败: {e}")
|
||
# 回退到内存模式
|
||
self.use_celery = False
|
||
|
||
if not self.use_celery:
|
||
# 内存模式
|
||
with self.task_lock:
|
||
self.tasks[task_id] = task
|
||
# 异步执行
|
||
threading.Thread(
|
||
target=self._execute_task,
|
||
args=(task_id,),
|
||
daemon=True
|
||
).start()
|
||
|
||
# 保存到数据库
|
||
self._save_task(task)
|
||
|
||
return task_id
|
||
|
||
def _execute_task(self, task_id: str):
|
||
"""执行任务(内存模式)"""
|
||
with self.task_lock:
|
||
task = self.tasks.get(task_id)
|
||
if not task:
|
||
return
|
||
|
||
task.status = "running"
|
||
task.started_at = datetime.now().isoformat()
|
||
|
||
self._update_task_status(task)
|
||
|
||
# 获取处理器
|
||
handler = self.task_handlers.get(task.task_type)
|
||
|
||
if not handler:
|
||
task.status = "failed"
|
||
task.error_message = f"No handler for task type: {task.task_type}"
|
||
else:
|
||
try:
|
||
result = handler(task.payload)
|
||
task.status = "success"
|
||
task.result = result
|
||
except Exception as e:
|
||
task.retry_count += 1
|
||
|
||
if task.retry_count <= task.max_retries:
|
||
task.status = "retrying"
|
||
# 延迟重试
|
||
threading.Timer(
|
||
10 * task.retry_count,
|
||
self._execute_task,
|
||
args=(task_id,)
|
||
).start()
|
||
else:
|
||
task.status = "failed"
|
||
task.error_message = str(e)
|
||
|
||
task.completed_at = datetime.now().isoformat()
|
||
|
||
with self.task_lock:
|
||
self.tasks[task_id] = task
|
||
|
||
self._update_task_status(task)
|
||
|
||
def _save_task(self, task: TaskInfo):
|
||
"""保存任务到数据库"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute("""
|
||
INSERT OR REPLACE INTO task_queue
|
||
(id, task_type, status, payload, result, error_message,
|
||
retry_count, max_retries, created_at, started_at, completed_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""", (
|
||
task.id, task.task_type, task.status,
|
||
json.dumps(task.payload, ensure_ascii=False),
|
||
json.dumps(task.result, ensure_ascii=False) if task.result else None,
|
||
task.error_message,
|
||
task.retry_count, task.max_retries,
|
||
task.created_at, task.started_at, task.completed_at
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def _update_task_status(self, task: TaskInfo):
|
||
"""更新任务状态"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute("""
|
||
UPDATE task_queue SET
|
||
status = ?,
|
||
result = ?,
|
||
error_message = ?,
|
||
retry_count = ?,
|
||
started_at = ?,
|
||
completed_at = ?
|
||
WHERE id = ?
|
||
""", (
|
||
task.status,
|
||
json.dumps(task.result, ensure_ascii=False) if task.result else None,
|
||
task.error_message,
|
||
task.retry_count,
|
||
task.started_at,
|
||
task.completed_at,
|
||
task.id
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def get_status(self, task_id: str) -> Optional[TaskInfo]:
|
||
"""获取任务状态"""
|
||
if self.use_celery:
|
||
try:
|
||
result = AsyncResult(task_id, app=self.celery_app)
|
||
|
||
status_map = {
|
||
'PENDING': 'pending',
|
||
'STARTED': 'running',
|
||
'SUCCESS': 'success',
|
||
'FAILURE': 'failed',
|
||
'RETRY': 'retrying'
|
||
}
|
||
|
||
return TaskInfo(
|
||
id=task_id,
|
||
task_type="celery_task",
|
||
status=status_map.get(result.status, 'unknown'),
|
||
payload={},
|
||
created_at="",
|
||
result=result.result if result.successful() else None,
|
||
error_message=str(result.result) if result.failed() else None
|
||
)
|
||
except Exception as e:
|
||
print(f"获取 Celery 任务状态失败: {e}")
|
||
|
||
# 内存模式或回退
|
||
with self.task_lock:
|
||
return self.tasks.get(task_id)
|
||
|
||
def list_tasks(self, status: Optional[str] = None,
|
||
task_type: Optional[str] = None,
|
||
limit: int = 100) -> List[TaskInfo]:
|
||
"""列出任务"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
where_clauses = []
|
||
params = []
|
||
|
||
if status:
|
||
where_clauses.append("status = ?")
|
||
params.append(status)
|
||
|
||
if task_type:
|
||
where_clauses.append("task_type = ?")
|
||
params.append(task_type)
|
||
|
||
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||
|
||
rows = conn.execute(f"""
|
||
SELECT * FROM task_queue
|
||
WHERE {where_str}
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
""", params + [limit]).fetchall()
|
||
|
||
conn.close()
|
||
|
||
tasks = []
|
||
for row in rows:
|
||
tasks.append(TaskInfo(
|
||
id=row['id'],
|
||
task_type=row['task_type'],
|
||
status=row['status'],
|
||
payload=json.loads(row['payload']) if row['payload'] else {},
|
||
created_at=row['created_at'],
|
||
started_at=row['started_at'],
|
||
completed_at=row['completed_at'],
|
||
result=json.loads(row['result']) if row['result'] else None,
|
||
error_message=row['error_message'],
|
||
retry_count=row['retry_count'],
|
||
max_retries=row['max_retries']
|
||
))
|
||
|
||
return tasks
|
||
|
||
def cancel(self, task_id: str) -> bool:
|
||
"""取消任务"""
|
||
if self.use_celery:
|
||
try:
|
||
self.celery_app.control.revoke(task_id, terminate=True)
|
||
return True
|
||
except Exception as e:
|
||
print(f"取消 Celery 任务失败: {e}")
|
||
|
||
with self.task_lock:
|
||
task = self.tasks.get(task_id)
|
||
if task and task.status in ['pending', 'running']:
|
||
task.status = 'cancelled'
|
||
task.completed_at = datetime.now().isoformat()
|
||
self._update_task_status(task)
|
||
return True
|
||
|
||
return False
|
||
|
||
def retry(self, task_id: str) -> bool:
|
||
"""重试失败的任务"""
|
||
task = self.get_status(task_id)
|
||
|
||
if not task or task.status != 'failed':
|
||
return False
|
||
|
||
task.status = 'pending'
|
||
task.retry_count = 0
|
||
task.error_message = None
|
||
task.completed_at = None
|
||
|
||
if not self.use_celery:
|
||
with self.task_lock:
|
||
self.tasks[task_id] = task
|
||
threading.Thread(
|
||
target=self._execute_task,
|
||
args=(task_id,),
|
||
daemon=True
|
||
).start()
|
||
|
||
self._update_task_status(task)
|
||
return True
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取任务队列统计"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
# 各状态任务数量
|
||
status_counts = conn.execute("""
|
||
SELECT status, COUNT(*) as count
|
||
FROM task_queue
|
||
GROUP BY status
|
||
""").fetchall()
|
||
|
||
# 各类型任务数量
|
||
type_counts = conn.execute("""
|
||
SELECT task_type, COUNT(*) as count
|
||
FROM task_queue
|
||
GROUP BY task_type
|
||
""").fetchall()
|
||
|
||
# 最近24小时任务数
|
||
recent_count = conn.execute("""
|
||
SELECT COUNT(*) as count
|
||
FROM task_queue
|
||
WHERE created_at > datetime('now', '-1 day')
|
||
""").fetchone()[0]
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"by_status": {r[0]: r[1] for r in status_counts},
|
||
"by_type": {r[0]: r[1] for r in type_counts},
|
||
"recent_24h": recent_count,
|
||
"backend": "celery" if self.use_celery else "memory"
|
||
}
|
||
|
||
|
||
# ==================== 性能监控 ====================
|
||
|
||
class PerformanceMonitor:
|
||
"""
|
||
性能监控器
|
||
|
||
功能:
|
||
- API 响应时间统计
|
||
- 数据库查询性能分析
|
||
- 缓存命中率监控
|
||
- 性能告警机制
|
||
"""
|
||
|
||
def __init__(self, db_path: str = "insightflow.db",
|
||
slow_query_threshold: int = 1000, # 毫秒
|
||
alert_threshold: int = 5000): # 毫秒
|
||
self.db_path = db_path
|
||
self.slow_query_threshold = slow_query_threshold
|
||
self.alert_threshold = alert_threshold
|
||
|
||
# 内存中的指标缓存
|
||
self.metrics_buffer: List[PerformanceMetric] = []
|
||
self.buffer_lock = threading.RLock()
|
||
self.buffer_size = 100
|
||
|
||
# 告警回调
|
||
self.alert_handlers: List[Callable] = []
|
||
|
||
def record_metric(self, metric_type: str, duration_ms: float,
|
||
endpoint: Optional[str] = None,
|
||
metadata: Optional[Dict] = None):
|
||
"""
|
||
记录性能指标
|
||
|
||
Args:
|
||
metric_type: 指标类型 (api_response, db_query, cache_operation)
|
||
duration_ms: 耗时(毫秒)
|
||
endpoint: 端点/查询标识
|
||
metadata: 额外元数据
|
||
"""
|
||
metric = PerformanceMetric(
|
||
id=str(uuid.uuid4())[:16],
|
||
metric_type=metric_type,
|
||
endpoint=endpoint,
|
||
duration_ms=duration_ms,
|
||
timestamp=datetime.now().isoformat(),
|
||
metadata=metadata or {}
|
||
)
|
||
|
||
# 添加到缓冲区
|
||
with self.buffer_lock:
|
||
self.metrics_buffer.append(metric)
|
||
if len(self.metrics_buffer) > self.buffer_size:
|
||
self._flush_metrics()
|
||
|
||
# 检查是否需要告警
|
||
if duration_ms > self.alert_threshold:
|
||
self._trigger_alert(metric)
|
||
|
||
# 慢查询记录
|
||
if metric_type == 'db_query' and duration_ms > self.slow_query_threshold:
|
||
self._record_slow_query(metric)
|
||
|
||
def _flush_metrics(self):
|
||
"""将缓冲区指标写入数据库"""
|
||
if not self.metrics_buffer:
|
||
return
|
||
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
for metric in self.metrics_buffer:
|
||
conn.execute("""
|
||
INSERT INTO performance_metrics
|
||
(id, metric_type, endpoint, duration_ms, timestamp, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""", (
|
||
metric.id, metric.metric_type, metric.endpoint,
|
||
metric.duration_ms, metric.timestamp,
|
||
json.dumps(metric.metadata, ensure_ascii=False)
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
self.metrics_buffer = []
|
||
|
||
def _record_slow_query(self, metric: PerformanceMetric):
|
||
"""记录慢查询"""
|
||
# 可以发送到专门的慢查询日志或监控系统
|
||
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
|
||
|
||
def _trigger_alert(self, metric: PerformanceMetric):
|
||
"""触发告警"""
|
||
alert_data = {
|
||
"type": "performance_alert",
|
||
"metric": metric.to_dict(),
|
||
"threshold": self.alert_threshold,
|
||
"message": f"{metric.metric_type} exceeded threshold: {metric.duration_ms}ms > {self.alert_threshold}ms"
|
||
}
|
||
|
||
for handler in self.alert_handlers:
|
||
try:
|
||
handler(alert_data)
|
||
except Exception as e:
|
||
print(f"告警处理失败: {e}")
|
||
|
||
def register_alert_handler(self, handler: Callable):
|
||
"""注册告警处理器"""
|
||
self.alert_handlers.append(handler)
|
||
|
||
def get_stats(self, hours: int = 24) -> Dict:
|
||
"""
|
||
获取性能统计
|
||
|
||
Args:
|
||
hours: 统计最近几小时的数据
|
||
|
||
Returns:
|
||
Dict: 性能统计
|
||
"""
|
||
# 先刷新缓冲区
|
||
self._flush_metrics()
|
||
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# 总体统计
|
||
overall = conn.execute("""
|
||
SELECT
|
||
COUNT(*) as total,
|
||
AVG(duration_ms) as avg_duration,
|
||
MAX(duration_ms) as max_duration,
|
||
MIN(duration_ms) as min_duration
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
""", (f'-{hours} hours',)).fetchone()
|
||
|
||
# 按类型统计
|
||
by_type = conn.execute("""
|
||
SELECT
|
||
metric_type,
|
||
COUNT(*) as count,
|
||
AVG(duration_ms) as avg_duration,
|
||
MAX(duration_ms) as max_duration
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
GROUP BY metric_type
|
||
""", (f'-{hours} hours',)).fetchall()
|
||
|
||
# 按端点统计(API)
|
||
by_endpoint = conn.execute("""
|
||
SELECT
|
||
endpoint,
|
||
COUNT(*) as count,
|
||
AVG(duration_ms) as avg_duration,
|
||
MAX(duration_ms) as max_duration
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
AND metric_type = 'api_response'
|
||
GROUP BY endpoint
|
||
ORDER BY avg_duration DESC
|
||
LIMIT 20
|
||
""", (f'-{hours} hours',)).fetchall()
|
||
|
||
# 慢查询统计
|
||
slow_queries = conn.execute("""
|
||
SELECT
|
||
metric_type,
|
||
endpoint,
|
||
duration_ms,
|
||
timestamp
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
AND duration_ms > ?
|
||
ORDER BY duration_ms DESC
|
||
LIMIT 10
|
||
""", (f'-{hours} hours', self.slow_query_threshold)).fetchall()
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"period_hours": hours,
|
||
"overall": {
|
||
"total_requests": overall['total'] or 0,
|
||
"avg_duration_ms": round(overall['avg_duration'] or 0, 2),
|
||
"max_duration_ms": overall['max_duration'] or 0,
|
||
"min_duration_ms": overall['min_duration'] or 0
|
||
},
|
||
"by_type": [
|
||
{
|
||
"type": r['metric_type'],
|
||
"count": r['count'],
|
||
"avg_duration_ms": round(r['avg_duration'], 2),
|
||
"max_duration_ms": r['max_duration']
|
||
}
|
||
for r in by_type
|
||
],
|
||
"by_endpoint": [
|
||
{
|
||
"endpoint": r['endpoint'],
|
||
"count": r['count'],
|
||
"avg_duration_ms": round(r['avg_duration'], 2),
|
||
"max_duration_ms": r['max_duration']
|
||
}
|
||
for r in by_endpoint
|
||
],
|
||
"slow_queries": [
|
||
{
|
||
"type": r['metric_type'],
|
||
"endpoint": r['endpoint'],
|
||
"duration_ms": r['duration_ms'],
|
||
"timestamp": r['timestamp']
|
||
}
|
||
for r in slow_queries
|
||
]
|
||
}
|
||
|
||
def get_api_performance(self, endpoint: Optional[str] = None,
|
||
hours: int = 24) -> Dict:
|
||
"""获取 API 性能详情"""
|
||
self._flush_metrics()
|
||
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
where_clause = "metric_type = 'api_response'"
|
||
params = [f'-{hours} hours']
|
||
|
||
if endpoint:
|
||
where_clause += " AND endpoint = ?"
|
||
params.append(endpoint)
|
||
|
||
# 百分位数统计
|
||
percentiles = conn.execute(f"""
|
||
SELECT
|
||
endpoint,
|
||
COUNT(*) as count,
|
||
AVG(duration_ms) as avg,
|
||
MIN(duration_ms) as min,
|
||
MAX(duration_ms) as max
|
||
FROM performance_metrics
|
||
WHERE {where_clause}
|
||
AND timestamp > datetime('now', ?)
|
||
GROUP BY endpoint
|
||
ORDER BY avg DESC
|
||
""", params).fetchall()
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"endpoint": endpoint or "all",
|
||
"period_hours": hours,
|
||
"endpoints": [
|
||
{
|
||
"endpoint": r['endpoint'],
|
||
"count": r['count'],
|
||
"avg_ms": round(r['avg'], 2),
|
||
"min_ms": r['min'],
|
||
"max_ms": r['max']
|
||
}
|
||
for r in percentiles
|
||
]
|
||
}
|
||
|
||
def cleanup_old_metrics(self, days: int = 30) -> int:
|
||
"""
|
||
清理旧的性能指标数据
|
||
|
||
Args:
|
||
days: 保留最近几天的数据
|
||
|
||
Returns:
|
||
int: 删除的记录数
|
||
"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
cursor = conn.execute("""
|
||
DELETE FROM performance_metrics
|
||
WHERE timestamp < datetime('now', ?)
|
||
""", (f'-{days} days',))
|
||
|
||
deleted = cursor.rowcount
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
return deleted
|
||
|
||
|
||
# ==================== 性能装饰器 ====================
|
||
|
||
def cached(cache_manager: CacheManager, key_prefix: str = "",
|
||
ttl: int = 3600, key_func: Optional[Callable] = None):
|
||
"""
|
||
缓存装饰器
|
||
|
||
Args:
|
||
cache_manager: 缓存管理器实例
|
||
key_prefix: 缓存键前缀
|
||
ttl: 缓存过期时间
|
||
key_func: 自定义缓存键生成函数
|
||
"""
|
||
def decorator(func: Callable) -> Callable:
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
# 生成缓存键
|
||
if key_func:
|
||
cache_key = key_func(*args, **kwargs)
|
||
else:
|
||
# 默认使用函数名和参数哈希
|
||
key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}"
|
||
cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}"
|
||
|
||
# 尝试从缓存获取
|
||
cached_value = cache_manager.get(cache_key)
|
||
if cached_value is not None:
|
||
return cached_value
|
||
|
||
# 执行函数
|
||
result = func(*args, **kwargs)
|
||
|
||
# 写入缓存
|
||
cache_manager.set(cache_key, result, ttl)
|
||
|
||
return result
|
||
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
def monitored(monitor: PerformanceMonitor, metric_type: str,
|
||
endpoint: Optional[str] = None):
|
||
"""
|
||
性能监控装饰器
|
||
|
||
Args:
|
||
monitor: 性能监控器实例
|
||
metric_type: 指标类型
|
||
endpoint: 端点标识
|
||
"""
|
||
def decorator(func: Callable) -> Callable:
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
start_time = time.time()
|
||
|
||
try:
|
||
result = func(*args, **kwargs)
|
||
return result
|
||
finally:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
ep = endpoint or func.__name__
|
||
monitor.record_metric(metric_type, duration_ms, ep)
|
||
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
# ==================== 性能管理器 ====================
|
||
|
||
class PerformanceManager:
|
||
"""
|
||
性能管理器 - 统一入口
|
||
|
||
整合缓存管理、数据库分片、任务队列和性能监控功能
|
||
"""
|
||
|
||
def __init__(self,
|
||
db_path: str = "insightflow.db",
|
||
redis_url: Optional[str] = None,
|
||
enable_sharding: bool = False):
|
||
self.db_path = db_path
|
||
|
||
# 初始化各模块
|
||
self.cache = CacheManager(
|
||
redis_url=redis_url,
|
||
db_path=db_path
|
||
)
|
||
|
||
self.sharding = DatabaseSharding(
|
||
base_db_path=db_path
|
||
) if enable_sharding else None
|
||
|
||
self.task_queue = TaskQueue(
|
||
redis_url=redis_url,
|
||
db_path=db_path
|
||
)
|
||
|
||
self.monitor = PerformanceMonitor(
|
||
db_path=db_path
|
||
)
|
||
|
||
def get_health_status(self) -> Dict:
|
||
"""获取系统健康状态"""
|
||
return {
|
||
"cache": {
|
||
"available": True,
|
||
"backend": "redis" if self.cache.use_redis else "memory",
|
||
"stats": self.cache.get_stats()
|
||
},
|
||
"sharding": {
|
||
"enabled": self.sharding is not None,
|
||
"shards_count": len(self.sharding.shard_map) if self.sharding else 0
|
||
},
|
||
"task_queue": {
|
||
"available": self.task_queue.is_available(),
|
||
"backend": "celery" if self.task_queue.use_celery else "memory",
|
||
"stats": self.task_queue.get_stats()
|
||
},
|
||
"monitor": {
|
||
"available": True,
|
||
"slow_query_threshold": self.monitor.slow_query_threshold,
|
||
"alert_threshold": self.monitor.alert_threshold
|
||
}
|
||
}
|
||
|
||
def get_full_stats(self) -> Dict:
|
||
"""获取完整统计信息"""
|
||
stats = {
|
||
"cache": self.cache.get_stats(),
|
||
"task_queue": self.task_queue.get_stats(),
|
||
"performance": self.monitor.get_stats()
|
||
}
|
||
|
||
if self.sharding:
|
||
stats["sharding"] = self.sharding.get_shard_stats()
|
||
|
||
return stats
|
||
|
||
|
||
# 单例模式
|
||
_performance_manager = None
|
||
|
||
def get_performance_manager(
|
||
db_path: str = "insightflow.db",
|
||
redis_url: Optional[str] = None,
|
||
enable_sharding: bool = False
|
||
) -> PerformanceManager:
|
||
"""获取性能管理器单例"""
|
||
global _performance_manager
|
||
if _performance_manager is None:
|
||
_performance_manager = PerformanceManager(
|
||
db_path=db_path,
|
||
redis_url=redis_url,
|
||
enable_sharding=enable_sharding
|
||
)
|
||
return _performance_manager
|