1766 lines
53 KiB
Python
1766 lines
53 KiB
Python
"""
|
||
InsightFlow - 性能优化与扩展模块
|
||
Phase 7 Task 8: Performance Optimization & Scaling
|
||
|
||
功能模块:
|
||
1. CacheManager - Redis 缓存层(热点数据、TTL、LRU、缓存预热)
|
||
2. DatabaseSharding - 数据库分片策略
|
||
3. TaskQueue - 异步任务队列(Celery + Redis)
|
||
4. PerformanceMonitor - 性能监控(API响应、查询性能、缓存命中率)
|
||
"""
|
||
|
||
import hashlib
|
||
import json
|
||
import os
|
||
import sqlite3
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from collections import OrderedDict
|
||
from collections.abc import Callable
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from functools import wraps
|
||
from typing import Any
|
||
|
||
# 尝试导入 Redis
|
||
try:
|
||
import redis
|
||
|
||
REDIS_AVAILABLE = True
|
||
except ImportError:
|
||
REDIS_AVAILABLE = False
|
||
|
||
# 尝试导入 Celery
|
||
try:
|
||
from celery import Celery
|
||
from celery.result import AsyncResult
|
||
|
||
CELERY_AVAILABLE = True
|
||
except ImportError:
|
||
CELERY_AVAILABLE = False
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
|
||
@dataclass
|
||
class CacheStats:
|
||
"""缓存统计数据模型"""
|
||
|
||
total_requests: int = 0
|
||
hits: int = 0
|
||
misses: int = 0
|
||
evictions: int = 0
|
||
expired: int = 0
|
||
hit_rate: float = 0.0
|
||
|
||
def update_hit_rate(self) -> None:
|
||
"""更新命中率"""
|
||
if self.total_requests > 0:
|
||
self.hit_rate = round(self.hits / self.total_requests, 4)
|
||
|
||
|
||
@dataclass
|
||
class CacheEntry:
|
||
"""缓存条目数据模型"""
|
||
|
||
key: str
|
||
value: Any
|
||
created_at: float
|
||
expires_at: float | None
|
||
access_count: int = 0
|
||
last_accessed: float = 0
|
||
size_bytes: int = 0
|
||
|
||
|
||
@dataclass
|
||
class PerformanceMetric:
|
||
"""性能指标数据模型"""
|
||
|
||
id: str
|
||
metric_type: str # api_response, db_query, cache_operation
|
||
endpoint: str | None
|
||
duration_ms: float
|
||
timestamp: str
|
||
metadata: dict = field(default_factory = dict)
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"id": self.id,
|
||
"metric_type": self.metric_type,
|
||
"endpoint": self.endpoint,
|
||
"duration_ms": self.duration_ms,
|
||
"timestamp": self.timestamp,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class TaskInfo:
|
||
"""任务信息数据模型"""
|
||
|
||
id: str
|
||
task_type: str
|
||
status: str # pending, running, success, failed, retrying
|
||
payload: dict
|
||
created_at: str
|
||
started_at: str | None = None
|
||
completed_at: str | None = None
|
||
result: Any | None = None
|
||
error_message: str | None = None
|
||
retry_count: int = 0
|
||
max_retries: int = 3
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"id": self.id,
|
||
"task_type": self.task_type,
|
||
"status": self.status,
|
||
"payload": self.payload,
|
||
"created_at": self.created_at,
|
||
"started_at": self.started_at,
|
||
"completed_at": self.completed_at,
|
||
"result": self.result,
|
||
"error_message": self.error_message,
|
||
"retry_count": self.retry_count,
|
||
"max_retries": self.max_retries,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ShardInfo:
|
||
"""分片信息数据模型"""
|
||
|
||
shard_id: str
|
||
shard_key_range: tuple[str, str] # (start, end)
|
||
db_path: str
|
||
entity_count: int = 0
|
||
is_active: bool = True
|
||
created_at: str = ""
|
||
last_accessed: str = ""
|
||
|
||
|
||
# ==================== Redis 缓存层 ====================
|
||
|
||
|
||
class CacheManager:
|
||
"""
|
||
缓存管理器
|
||
|
||
功能:
|
||
- 热点数据缓存(实体、关系、转录)
|
||
- 缓存失效策略(TTL、LRU)
|
||
- 缓存预热机制
|
||
- 缓存统计和监控
|
||
|
||
支持两种模式:
|
||
1. Redis 模式(推荐生产环境)
|
||
2. 内存 LRU 模式(开发/测试环境)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
redis_url: str | None = None,
|
||
max_memory_size: int = 100 * 1024 * 1024, # 100MB
|
||
default_ttl: int = 3600, # 1小时
|
||
db_path: str = "insightflow.db",
|
||
) -> None:
|
||
self.db_path = db_path
|
||
self.default_ttl = default_ttl
|
||
self.max_memory_size = max_memory_size
|
||
self.current_memory_size = 0
|
||
|
||
# Redis 客户端
|
||
self.redis_client = None
|
||
self.use_redis = False
|
||
|
||
if REDIS_AVAILABLE and redis_url:
|
||
try:
|
||
self.redis_client = redis.from_url(redis_url, decode_responses = True)
|
||
self.redis_client.ping()
|
||
self.use_redis = True
|
||
print(f"Redis 缓存已连接: {redis_url}")
|
||
except Exception as e:
|
||
print(f"Redis 连接失败,使用内存缓存: {e}")
|
||
|
||
# 内存缓存(LRU)
|
||
self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||
self.cache_lock = threading.RLock()
|
||
|
||
# 统计
|
||
self.stats = CacheStats()
|
||
|
||
# 初始化缓存统计表
|
||
self._init_cache_tables()
|
||
|
||
def _init_cache_tables(self) -> None:
|
||
"""初始化缓存统计表"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS cache_stats (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
total_requests INTEGER DEFAULT 0,
|
||
hits INTEGER DEFAULT 0,
|
||
misses INTEGER DEFAULT 0,
|
||
hit_rate REAL DEFAULT 0.0,
|
||
memory_usage INTEGER DEFAULT 0
|
||
)
|
||
""")
|
||
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS performance_metrics (
|
||
id TEXT PRIMARY KEY,
|
||
metric_type TEXT NOT NULL,
|
||
endpoint TEXT,
|
||
duration_ms REAL,
|
||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
metadata TEXT
|
||
)
|
||
""")
|
||
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_metrics_type ON performance_metrics(metric_type)"
|
||
)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_metrics_time ON performance_metrics(timestamp)"
|
||
)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def _get_entry_size(self, value: Any) -> int:
|
||
"""估算缓存条目大小"""
|
||
try:
|
||
return len(json.dumps(value, ensure_ascii = False).encode("utf-8"))
|
||
except (TypeError, ValueError):
|
||
return 1024 # 默认估算
|
||
|
||
def _evict_lru(self, required_space: int = 0) -> None:
|
||
"""LRU 淘汰策略"""
|
||
with self.cache_lock:
|
||
while (
|
||
self.current_memory_size + required_space > self.max_memory_size
|
||
and self.memory_cache
|
||
):
|
||
# 移除最久未访问的
|
||
oldest_key, oldest_entry = self.memory_cache.popitem(last = False)
|
||
self.current_memory_size -= oldest_entry.size_bytes
|
||
self.stats.evictions += 1
|
||
|
||
def get(self, key: str) -> Any | None:
|
||
"""
|
||
获取缓存值
|
||
|
||
Args:
|
||
key: 缓存键
|
||
|
||
Returns:
|
||
Optional[Any]: 缓存值,不存在返回 None
|
||
"""
|
||
self.stats.total_requests += 1
|
||
|
||
if self.use_redis:
|
||
try:
|
||
value = self.redis_client.get(key)
|
||
if value:
|
||
self.stats.hits += 1
|
||
return json.loads(value)
|
||
else:
|
||
self.stats.misses += 1
|
||
return None
|
||
except Exception as e:
|
||
print(f"Redis get 失败: {e}")
|
||
return None
|
||
else:
|
||
# 内存缓存
|
||
with self.cache_lock:
|
||
entry = self.memory_cache.get(key)
|
||
|
||
if entry:
|
||
# 检查是否过期
|
||
if entry.expires_at and time.time() > entry.expires_at:
|
||
del self.memory_cache[key]
|
||
self.current_memory_size -= entry.size_bytes
|
||
self.stats.expired += 1
|
||
self.stats.misses += 1
|
||
return None
|
||
|
||
# 更新访问信息
|
||
entry.access_count += 1
|
||
entry.last_accessed = time.time()
|
||
self.memory_cache.move_to_end(key)
|
||
|
||
self.stats.hits += 1
|
||
return entry.value
|
||
else:
|
||
self.stats.misses += 1
|
||
return None
|
||
|
||
def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
|
||
"""
|
||
设置缓存值
|
||
|
||
Args:
|
||
key: 缓存键
|
||
value: 缓存值
|
||
ttl: 过期时间(秒),None 表示使用默认值
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
ttl = ttl or self.default_ttl
|
||
|
||
if self.use_redis:
|
||
try:
|
||
serialized = json.dumps(value, ensure_ascii = False)
|
||
self.redis_client.setex(key, ttl, serialized)
|
||
return True
|
||
except Exception as e:
|
||
print(f"Redis set 失败: {e}")
|
||
return False
|
||
else:
|
||
# 内存缓存
|
||
with self.cache_lock:
|
||
size = self._get_entry_size(value)
|
||
|
||
# 检查是否需要淘汰
|
||
if self.current_memory_size + size > self.max_memory_size:
|
||
self._evict_lru(size)
|
||
|
||
now = time.time()
|
||
entry = CacheEntry(
|
||
key = key,
|
||
value = value,
|
||
created_at = now,
|
||
expires_at = now + ttl if ttl > 0 else None,
|
||
size_bytes = size,
|
||
last_accessed = now,
|
||
)
|
||
|
||
# 如果已存在,更新大小
|
||
if key in self.memory_cache:
|
||
self.current_memory_size -= self.memory_cache[key].size_bytes
|
||
|
||
self.memory_cache[key] = entry
|
||
self.memory_cache.move_to_end(key)
|
||
self.current_memory_size += size
|
||
|
||
return True
|
||
|
||
def delete(self, key: str) -> bool:
|
||
"""删除缓存"""
|
||
if self.use_redis:
|
||
try:
|
||
return self.redis_client.delete(key) > 0
|
||
except Exception as e:
|
||
print(f"Redis delete 失败: {e}")
|
||
return False
|
||
else:
|
||
with self.cache_lock:
|
||
if key in self.memory_cache:
|
||
entry = self.memory_cache.pop(key)
|
||
self.current_memory_size -= entry.size_bytes
|
||
return True
|
||
return False
|
||
|
||
def clear(self) -> bool:
|
||
"""清空缓存"""
|
||
if self.use_redis:
|
||
try:
|
||
self.redis_client.flushdb()
|
||
return True
|
||
except Exception as e:
|
||
print(f"Redis clear 失败: {e}")
|
||
return False
|
||
else:
|
||
with self.cache_lock:
|
||
self.memory_cache.clear()
|
||
self.current_memory_size = 0
|
||
return True
|
||
|
||
def get_many(self, keys: list[str]) -> dict[str, Any]:
|
||
"""批量获取缓存"""
|
||
results = {}
|
||
|
||
if self.use_redis:
|
||
try:
|
||
values = self.redis_client.mget(keys)
|
||
for key, value in zip(keys, values):
|
||
if value:
|
||
results[key] = json.loads(value)
|
||
self.stats.hits += 1
|
||
else:
|
||
self.stats.misses += 1
|
||
self.stats.total_requests += 1
|
||
except Exception as e:
|
||
print(f"Redis mget 失败: {e}")
|
||
else:
|
||
for key in keys:
|
||
value = self.get(key)
|
||
if value is not None:
|
||
results[key] = value
|
||
|
||
return results
|
||
|
||
def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool:
|
||
"""批量设置缓存"""
|
||
ttl = ttl or self.default_ttl
|
||
|
||
if self.use_redis:
|
||
try:
|
||
pipe = self.redis_client.pipeline()
|
||
for key, value in mapping.items():
|
||
serialized = json.dumps(value, ensure_ascii = False)
|
||
pipe.setex(key, ttl, serialized)
|
||
pipe.execute()
|
||
return True
|
||
except Exception as e:
|
||
print(f"Redis mset 失败: {e}")
|
||
return False
|
||
else:
|
||
for key, value in mapping.items():
|
||
self.set(key, value, ttl)
|
||
return True
|
||
|
||
def get_stats(self) -> dict:
|
||
"""获取缓存统计"""
|
||
self.stats.update_hit_rate()
|
||
|
||
stats = {
|
||
"total_requests": self.stats.total_requests,
|
||
"hits": self.stats.hits,
|
||
"misses": self.stats.misses,
|
||
"hit_rate": self.stats.hit_rate,
|
||
"evictions": self.stats.evictions,
|
||
"expired": self.stats.expired,
|
||
"backend": "redis" if self.use_redis else "memory",
|
||
}
|
||
|
||
if not self.use_redis:
|
||
stats.update(
|
||
{
|
||
"memory_size_bytes": self.current_memory_size,
|
||
"max_memory_size_bytes": self.max_memory_size,
|
||
"memory_usage_percent": round(
|
||
self.current_memory_size / self.max_memory_size * 100, 2
|
||
),
|
||
"cache_entries": len(self.memory_cache),
|
||
}
|
||
)
|
||
|
||
return stats
|
||
|
||
def save_stats(self) -> None:
|
||
"""保存缓存统计到数据库"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
self.stats.update_hit_rate()
|
||
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO cache_stats
|
||
(timestamp, total_requests, hits, misses, hit_rate, memory_usage)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
datetime.now().isoformat(),
|
||
self.stats.total_requests,
|
||
self.stats.hits,
|
||
self.stats.misses,
|
||
self.stats.hit_rate,
|
||
self.current_memory_size,
|
||
),
|
||
)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def warm_up(self, project_id: str) -> dict:
|
||
"""
|
||
缓存预热 - 加载项目的热点数据
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
|
||
Returns:
|
||
Dict: 预热统计
|
||
"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
stats = {"entities": 0, "relations": 0, "transcripts": 0}
|
||
|
||
# 预热实体数据
|
||
entities = conn.execute(
|
||
"""SELECT e.*,
|
||
(SELECT COUNT(*) FROM entity_mentions m WHERE m.entity_id = e.id) as mention_count
|
||
FROM entities e
|
||
WHERE e.project_id = ?
|
||
ORDER BY mention_count DESC
|
||
LIMIT 100""",
|
||
(project_id, ),
|
||
).fetchall()
|
||
|
||
for entity in entities:
|
||
key = f"entity:{entity['id']}"
|
||
self.set(key, dict(entity), ttl = 7200) # 2小时
|
||
stats["entities"] += 1
|
||
|
||
# 预热关系数据
|
||
relations = conn.execute(
|
||
"""SELECT r.*,
|
||
e1.name as source_name, e2.name as target_name
|
||
FROM entity_relations r
|
||
JOIN entities e1 ON r.source_entity_id = e1.id
|
||
JOIN entities e2 ON r.target_entity_id = e2.id
|
||
WHERE r.project_id = ?
|
||
LIMIT 200""",
|
||
(project_id, ),
|
||
).fetchall()
|
||
|
||
for relation in relations:
|
||
key = f"relation:{relation['id']}"
|
||
self.set(key, dict(relation), ttl = 3600)
|
||
stats["relations"] += 1
|
||
|
||
# 预热最近的转录
|
||
transcripts = conn.execute(
|
||
"""SELECT * FROM transcripts
|
||
WHERE project_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT 10""",
|
||
(project_id, ),
|
||
).fetchall()
|
||
|
||
for transcript in transcripts:
|
||
key = f"transcript:{transcript['id']}"
|
||
# 只缓存元数据,不缓存完整文本
|
||
meta = {
|
||
"id": transcript["id"],
|
||
"filename": transcript["filename"],
|
||
"type": transcript.get("type", "audio"),
|
||
"created_at": transcript["created_at"],
|
||
}
|
||
self.set(key, meta, ttl = 1800) # 30分钟
|
||
stats["transcripts"] += 1
|
||
|
||
# 预热项目知识库摘要
|
||
entity_count = conn.execute(
|
||
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id, )
|
||
).fetchone()[0]
|
||
|
||
relation_count = conn.execute(
|
||
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id, )
|
||
).fetchone()[0]
|
||
|
||
summary = {
|
||
"project_id": project_id,
|
||
"entity_count": entity_count,
|
||
"relation_count": relation_count,
|
||
"cached_at": datetime.now().isoformat(),
|
||
}
|
||
self.set(f"project_summary:{project_id}", summary, ttl = 3600)
|
||
|
||
conn.close()
|
||
|
||
return stats
|
||
|
||
def invalidate_project(self, project_id: str) -> int:
|
||
"""
|
||
使项目的所有缓存失效
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
|
||
Returns:
|
||
int: 清除的缓存数量
|
||
"""
|
||
count = 0
|
||
|
||
if self.use_redis:
|
||
try:
|
||
# 使用 Redis 的 scan 查找相关 key
|
||
pattern = f"*:{project_id}:*"
|
||
for key in self.redis_client.scan_iter(match = pattern):
|
||
self.redis_client.delete(key)
|
||
count += 1
|
||
except Exception as e:
|
||
print(f"Redis 缓存失效失败: {e}")
|
||
else:
|
||
# 内存缓存 - 查找并删除相关 key
|
||
with self.cache_lock:
|
||
keys_to_delete = [key for key in self.memory_cache.keys() if project_id in key]
|
||
for key in keys_to_delete:
|
||
entry = self.memory_cache.pop(key)
|
||
self.current_memory_size -= entry.size_bytes
|
||
count += 1
|
||
|
||
return count
|
||
|
||
|
||
# ==================== 数据库分片 ====================
|
||
|
||
|
||
class DatabaseSharding:
|
||
"""
|
||
数据库分片管理器
|
||
|
||
功能:
|
||
- 项目数据分片策略
|
||
- 分片路由逻辑
|
||
- 跨分片查询支持
|
||
- 分片迁移工具
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_db_path: str = "insightflow.db",
|
||
shard_db_dir: str = "./shards",
|
||
shards_count: int = 4,
|
||
) -> None:
|
||
self.base_db_path = base_db_path
|
||
self.shard_db_dir = shard_db_dir
|
||
self.shards_count = shards_count
|
||
|
||
# 确保分片目录存在
|
||
os.makedirs(shard_db_dir, exist_ok = True)
|
||
|
||
# 分片映射
|
||
self.shard_map: dict[str, ShardInfo] = {}
|
||
|
||
# 初始化分片
|
||
self._init_shards()
|
||
|
||
def _init_shards(self) -> None:
|
||
"""初始化分片"""
|
||
# 计算每个分片的 key 范围
|
||
chars = "0123456789abcdef"
|
||
chars_per_shard = len(chars) // self.shards_count
|
||
|
||
for i in range(self.shards_count):
|
||
start_idx = i * chars_per_shard
|
||
end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars)
|
||
|
||
start_char = chars[start_idx]
|
||
end_char = chars[end_idx - 1]
|
||
|
||
shard_id = f"shard_{i}"
|
||
db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db")
|
||
|
||
self.shard_map[shard_id] = ShardInfo(
|
||
shard_id = shard_id,
|
||
shard_key_range = (start_char, end_char),
|
||
db_path = db_path,
|
||
created_at = datetime.now().isoformat(),
|
||
)
|
||
|
||
# 确保分片数据库存在
|
||
if not os.path.exists(db_path):
|
||
self._create_shard_db(db_path)
|
||
|
||
def _create_shard_db(self, db_path: str) -> None:
|
||
"""创建分片数据库"""
|
||
conn = sqlite3.connect(db_path)
|
||
|
||
# 创建与主库相同的表结构(简化版)
|
||
conn.executescript("""
|
||
CREATE TABLE IF NOT EXISTS entities (
|
||
id TEXT PRIMARY KEY,
|
||
project_id TEXT NOT NULL,
|
||
name TEXT NOT NULL,
|
||
type TEXT,
|
||
definition TEXT,
|
||
aliases TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS entity_relations (
|
||
id TEXT PRIMARY KEY,
|
||
project_id TEXT NOT NULL,
|
||
source_entity_id TEXT NOT NULL,
|
||
target_entity_id TEXT NOT NULL,
|
||
relation_type TEXT,
|
||
evidence TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_entities_project ON entities(project_id);
|
||
CREATE INDEX IF NOT EXISTS idx_relations_project ON entity_relations(project_id);
|
||
""")
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def _get_shard_id(self, project_id: str) -> str:
|
||
"""
|
||
根据项目ID计算分片ID
|
||
|
||
使用项目ID的第一个字符进行哈希
|
||
"""
|
||
if not project_id:
|
||
return "shard_0"
|
||
|
||
first_char = project_id[0].lower()
|
||
|
||
for shard_id, shard_info in self.shard_map.items():
|
||
start, end = shard_info.shard_key_range
|
||
if start <= first_char <= end:
|
||
return shard_id
|
||
|
||
return "shard_0"
|
||
|
||
def get_shard_connection(self, project_id: str) -> sqlite3.Connection:
|
||
"""获取项目对应的分片连接"""
|
||
shard_id = self._get_shard_id(project_id)
|
||
shard_info = self.shard_map[shard_id]
|
||
|
||
conn = sqlite3.connect(shard_info.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# 更新访问时间
|
||
shard_info.last_accessed = datetime.now().isoformat()
|
||
|
||
return conn
|
||
|
||
def get_all_shards(self) -> list[ShardInfo]:
|
||
"""获取所有分片信息"""
|
||
return list(self.shard_map.values())
|
||
|
||
def migrate_project(self, project_id: str, target_shard_id: str) -> bool:
|
||
"""
|
||
迁移项目到指定分片
|
||
|
||
Args:
|
||
project_id: 项目ID
|
||
target_shard_id: 目标分片ID
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
# 获取源分片
|
||
source_shard_id = self._get_shard_id(project_id)
|
||
|
||
if source_shard_id == target_shard_id:
|
||
return True # 已经在目标分片
|
||
|
||
source_info = self.shard_map.get(source_shard_id)
|
||
target_info = self.shard_map.get(target_shard_id)
|
||
|
||
if not source_info or not target_info:
|
||
return False
|
||
|
||
try:
|
||
# 从源分片读取数据
|
||
source_conn = sqlite3.connect(source_info.db_path)
|
||
source_conn.row_factory = sqlite3.Row
|
||
|
||
entities = source_conn.execute(
|
||
"SELECT * FROM entities WHERE project_id = ?", (project_id, )
|
||
).fetchall()
|
||
|
||
relations = source_conn.execute(
|
||
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id, )
|
||
).fetchall()
|
||
|
||
source_conn.close()
|
||
|
||
# 写入目标分片
|
||
target_conn = sqlite3.connect(target_info.db_path)
|
||
|
||
for entity in entities:
|
||
target_conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO entities
|
||
(id, project_id, name, type, definition, aliases, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
tuple(entity),
|
||
)
|
||
|
||
for relation in relations:
|
||
target_conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO entity_relations
|
||
(id, project_id, source_entity_id, target_entity_id, relation_type, evidence, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
tuple(relation),
|
||
)
|
||
|
||
target_conn.commit()
|
||
target_conn.close()
|
||
|
||
# 从源分片删除数据
|
||
source_conn = sqlite3.connect(source_info.db_path)
|
||
source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id, ))
|
||
source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id, ))
|
||
source_conn.commit()
|
||
source_conn.close()
|
||
|
||
# 更新分片统计
|
||
self._update_shard_stats(source_shard_id)
|
||
self._update_shard_stats(target_shard_id)
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"迁移失败: {e}")
|
||
return False
|
||
|
||
def _update_shard_stats(self, shard_id: str) -> None:
|
||
"""更新分片统计"""
|
||
shard_info = self.shard_map.get(shard_id)
|
||
if not shard_info:
|
||
return
|
||
|
||
conn = sqlite3.connect(shard_info.db_path)
|
||
|
||
count = conn.execute("SELECT COUNT(DISTINCT project_id) FROM entities").fetchone()[0]
|
||
|
||
shard_info.entity_count = count
|
||
|
||
conn.close()
|
||
|
||
def cross_shard_query(self, query_func: Callable) -> list[dict]:
|
||
"""
|
||
跨分片查询
|
||
|
||
Args:
|
||
query_func: 查询函数,接收 connection 参数
|
||
|
||
Returns:
|
||
List[Dict]: 合并的查询结果
|
||
"""
|
||
results = []
|
||
|
||
for shard_info in self.shard_map.values():
|
||
conn = sqlite3.connect(shard_info.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
try:
|
||
shard_results = query_func(conn)
|
||
results.extend(shard_results)
|
||
except Exception as e:
|
||
print(f"分片 {shard_info.shard_id} 查询失败: {e}")
|
||
finally:
|
||
conn.close()
|
||
|
||
return results
|
||
|
||
def get_shard_stats(self) -> list[dict]:
|
||
"""获取所有分片的统计信息"""
|
||
stats = []
|
||
|
||
for shard_info in self.shard_map.values():
|
||
self._update_shard_stats(shard_info.shard_id)
|
||
|
||
stats.append(
|
||
{
|
||
"shard_id": shard_info.shard_id,
|
||
"key_range": shard_info.shard_key_range,
|
||
"db_path": shard_info.db_path,
|
||
"entity_count": shard_info.entity_count,
|
||
"is_active": shard_info.is_active,
|
||
"created_at": shard_info.created_at,
|
||
"last_accessed": shard_info.last_accessed,
|
||
}
|
||
)
|
||
|
||
return stats
|
||
|
||
def rebalance_shards(self) -> dict:
|
||
"""
|
||
重新平衡分片
|
||
|
||
将数据从过载的分片迁移到负载较轻的分片
|
||
|
||
Returns:
|
||
Dict: 重新平衡统计
|
||
"""
|
||
# 获取各分片的负载
|
||
stats = self.get_shard_stats()
|
||
|
||
if not stats:
|
||
return {"message": "No shards to rebalance"}
|
||
|
||
# 计算平均负载
|
||
avg_load = sum(s["entity_count"] for s in stats) / len(stats)
|
||
|
||
# 找出过载和欠载的分片
|
||
overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5]
|
||
underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5]
|
||
|
||
# 简化的重新平衡逻辑
|
||
# 实际生产环境需要更复杂的算法
|
||
|
||
return {
|
||
"average_load": avg_load,
|
||
"overloaded_shards": len(overloaded),
|
||
"underloaded_shards": len(underloaded),
|
||
"message": "Rebalancing analysis completed",
|
||
}
|
||
|
||
|
||
# ==================== 异步任务队列 ====================
|
||
|
||
|
||
class TaskQueue:
|
||
"""
|
||
异步任务队列管理器
|
||
|
||
功能:
|
||
- 基于 Celery + Redis 的任务队列
|
||
- 音频分析异步处理
|
||
- 报告生成异步处理
|
||
- 任务状态追踪和重试机制
|
||
"""
|
||
|
||
def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db") -> None:
|
||
self.db_path = db_path
|
||
self.redis_url = redis_url
|
||
self.celery_app = None
|
||
self.use_celery = False
|
||
|
||
# 内存任务存储(非 Celery 模式)
|
||
self.tasks: dict[str, TaskInfo] = {}
|
||
self.task_handlers: dict[str, Callable] = {}
|
||
self.task_lock = threading.RLock()
|
||
|
||
# 初始化任务队列表
|
||
self._init_task_tables()
|
||
|
||
# 初始化 Celery
|
||
if CELERY_AVAILABLE and redis_url:
|
||
try:
|
||
self.celery_app = Celery("insightflow", broker = redis_url, backend = redis_url)
|
||
self.use_celery = True
|
||
print("Celery 任务队列已初始化")
|
||
except Exception as e:
|
||
print(f"Celery 初始化失败,使用内存任务队列: {e}")
|
||
|
||
def _init_task_tables(self) -> None:
|
||
"""初始化任务队列表"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS task_queue (
|
||
id TEXT PRIMARY KEY,
|
||
task_type TEXT NOT NULL,
|
||
status TEXT DEFAULT 'pending',
|
||
payload TEXT,
|
||
result TEXT,
|
||
error_message TEXT,
|
||
retry_count INTEGER DEFAULT 0,
|
||
max_retries INTEGER DEFAULT 3,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
started_at TIMESTAMP,
|
||
completed_at TIMESTAMP
|
||
)
|
||
""")
|
||
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON task_queue(status)")
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_type ON task_queue(task_type)")
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查任务队列是否可用"""
|
||
return self.use_celery or True # 内存模式也可用
|
||
|
||
def register_handler(self, task_type: str, handler: Callable) -> None:
|
||
"""注册任务处理器"""
|
||
self.task_handlers[task_type] = handler
|
||
|
||
def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str:
|
||
"""
|
||
提交任务
|
||
|
||
Args:
|
||
task_type: 任务类型
|
||
payload: 任务数据
|
||
max_retries: 最大重试次数
|
||
|
||
Returns:
|
||
str: 任务ID
|
||
"""
|
||
task_id = str(uuid.uuid4())[:16]
|
||
|
||
task = TaskInfo(
|
||
id = task_id,
|
||
task_type = task_type,
|
||
status = "pending",
|
||
payload = payload,
|
||
created_at = datetime.now().isoformat(),
|
||
max_retries = max_retries,
|
||
)
|
||
|
||
if self.use_celery:
|
||
# 使用 Celery
|
||
try:
|
||
# 这里简化处理,实际应该定义具体的 Celery 任务
|
||
result = self.celery_app.send_task(
|
||
f"insightflow.tasks.{task_type}",
|
||
args = [payload],
|
||
task_id = task_id,
|
||
retry = True,
|
||
retry_policy = {
|
||
"max_retries": max_retries,
|
||
"interval_start": 10,
|
||
"interval_step": 10,
|
||
"interval_max": 60,
|
||
},
|
||
)
|
||
task.id = result.id
|
||
except Exception as e:
|
||
print(f"Celery 任务提交失败: {e}")
|
||
# 回退到内存模式
|
||
self.use_celery = False
|
||
|
||
if not self.use_celery:
|
||
# 内存模式
|
||
with self.task_lock:
|
||
self.tasks[task_id] = task
|
||
# 异步执行
|
||
threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start()
|
||
|
||
# 保存到数据库
|
||
self._save_task(task)
|
||
|
||
return task_id
|
||
|
||
def _execute_task(self, task_id: str) -> None:
|
||
"""执行任务(内存模式)"""
|
||
with self.task_lock:
|
||
task = self.tasks.get(task_id)
|
||
if not task:
|
||
return
|
||
|
||
task.status = "running"
|
||
task.started_at = datetime.now().isoformat()
|
||
|
||
self._update_task_status(task)
|
||
|
||
# 获取处理器
|
||
handler = self.task_handlers.get(task.task_type)
|
||
|
||
if not handler:
|
||
task.status = "failed"
|
||
task.error_message = f"No handler for task type: {task.task_type}"
|
||
else:
|
||
try:
|
||
result = handler(task.payload)
|
||
task.status = "success"
|
||
task.result = result
|
||
except Exception as e:
|
||
task.retry_count += 1
|
||
|
||
if task.retry_count <= task.max_retries:
|
||
task.status = "retrying"
|
||
# 延迟重试
|
||
threading.Timer(
|
||
10 * task.retry_count, self._execute_task, args = (task_id, )
|
||
).start()
|
||
else:
|
||
task.status = "failed"
|
||
task.error_message = str(e)
|
||
|
||
task.completed_at = datetime.now().isoformat()
|
||
|
||
with self.task_lock:
|
||
self.tasks[task_id] = task
|
||
|
||
self._update_task_status(task)
|
||
|
||
def _save_task(self, task: TaskInfo) -> None:
|
||
"""保存任务到数据库"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO task_queue
|
||
(id, task_type, status, payload, result, error_message,
|
||
retry_count, max_retries, created_at, started_at, completed_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
task.id,
|
||
task.task_type,
|
||
task.status,
|
||
json.dumps(task.payload, ensure_ascii = False),
|
||
json.dumps(task.result, ensure_ascii = False) if task.result else None,
|
||
task.error_message,
|
||
task.retry_count,
|
||
task.max_retries,
|
||
task.created_at,
|
||
task.started_at,
|
||
task.completed_at,
|
||
),
|
||
)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def _update_task_status(self, task: TaskInfo) -> None:
|
||
"""更新任务状态"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
conn.execute(
|
||
"""
|
||
UPDATE task_queue SET
|
||
status = ?,
|
||
result = ?,
|
||
error_message = ?,
|
||
retry_count = ?,
|
||
started_at = ?,
|
||
completed_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(
|
||
task.status,
|
||
json.dumps(task.result, ensure_ascii = False) if task.result else None,
|
||
task.error_message,
|
||
task.retry_count,
|
||
task.started_at,
|
||
task.completed_at,
|
||
task.id,
|
||
),
|
||
)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def get_status(self, task_id: str) -> TaskInfo | None:
|
||
"""获取任务状态"""
|
||
if self.use_celery:
|
||
try:
|
||
result = AsyncResult(task_id, app = self.celery_app)
|
||
|
||
status_map = {
|
||
"PENDING": "pending",
|
||
"STARTED": "running",
|
||
"SUCCESS": "success",
|
||
"FAILURE": "failed",
|
||
"RETRY": "retrying",
|
||
}
|
||
|
||
return TaskInfo(
|
||
id = task_id,
|
||
task_type = "celery_task",
|
||
status = status_map.get(result.status, "unknown"),
|
||
payload = {},
|
||
created_at = "",
|
||
result = result.result if result.successful() else None,
|
||
error_message = str(result.result) if result.failed() else None,
|
||
)
|
||
except Exception as e:
|
||
print(f"获取 Celery 任务状态失败: {e}")
|
||
|
||
# 内存模式或回退
|
||
with self.task_lock:
|
||
return self.tasks.get(task_id)
|
||
|
||
def list_tasks(
|
||
self, status: str | None = None, task_type: str | None = None, limit: int = 100
|
||
) -> list[TaskInfo]:
|
||
"""列出任务"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
where_clauses = []
|
||
params = []
|
||
|
||
if status:
|
||
where_clauses.append("status = ?")
|
||
params.append(status)
|
||
|
||
if task_type:
|
||
where_clauses.append("task_type = ?")
|
||
params.append(task_type)
|
||
|
||
where_str = " AND ".join(where_clauses) if where_clauses else "1 = 1"
|
||
|
||
rows = conn.execute(
|
||
f"""
|
||
SELECT * FROM task_queue
|
||
WHERE {where_str}
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
""",
|
||
params + [limit],
|
||
).fetchall()
|
||
|
||
conn.close()
|
||
|
||
tasks = []
|
||
for row in rows:
|
||
tasks.append(
|
||
TaskInfo(
|
||
id = row["id"],
|
||
task_type = row["task_type"],
|
||
status = row["status"],
|
||
payload = json.loads(row["payload"]) if row["payload"] else {},
|
||
created_at = row["created_at"],
|
||
started_at = row["started_at"],
|
||
completed_at = row["completed_at"],
|
||
result = json.loads(row["result"]) if row["result"] else None,
|
||
error_message = row["error_message"],
|
||
retry_count = row["retry_count"],
|
||
max_retries = row["max_retries"],
|
||
)
|
||
)
|
||
|
||
return tasks
|
||
|
||
def cancel(self, task_id: str) -> bool:
|
||
"""取消任务"""
|
||
if self.use_celery:
|
||
try:
|
||
self.celery_app.control.revoke(task_id, terminate = True)
|
||
return True
|
||
except Exception as e:
|
||
print(f"取消 Celery 任务失败: {e}")
|
||
|
||
with self.task_lock:
|
||
task = self.tasks.get(task_id)
|
||
if task and task.status in ["pending", "running"]:
|
||
task.status = "cancelled"
|
||
task.completed_at = datetime.now().isoformat()
|
||
self._update_task_status(task)
|
||
return True
|
||
|
||
return False
|
||
|
||
def retry(self, task_id: str) -> bool:
|
||
"""重试失败的任务"""
|
||
task = self.get_status(task_id)
|
||
|
||
if not task or task.status != "failed":
|
||
return False
|
||
|
||
task.status = "pending"
|
||
task.retry_count = 0
|
||
task.error_message = None
|
||
task.completed_at = None
|
||
|
||
if not self.use_celery:
|
||
with self.task_lock:
|
||
self.tasks[task_id] = task
|
||
threading.Thread(target = self._execute_task, args = (task_id, ), daemon = True).start()
|
||
|
||
self._update_task_status(task)
|
||
return True
|
||
|
||
def get_stats(self) -> dict:
|
||
"""获取任务队列统计"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
# 各状态任务数量
|
||
status_counts = conn.execute("""
|
||
SELECT status, COUNT(*) as count
|
||
FROM task_queue
|
||
GROUP BY status
|
||
""").fetchall()
|
||
|
||
# 各类型任务数量
|
||
type_counts = conn.execute("""
|
||
SELECT task_type, COUNT(*) as count
|
||
FROM task_queue
|
||
GROUP BY task_type
|
||
""").fetchall()
|
||
|
||
# 最近24小时任务数
|
||
recent_count = conn.execute("""
|
||
SELECT COUNT(*) as count
|
||
FROM task_queue
|
||
WHERE created_at > datetime('now', '-1 day')
|
||
""").fetchone()[0]
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"by_status": {r[0]: r[1] for r in status_counts},
|
||
"by_type": {r[0]: r[1] for r in type_counts},
|
||
"recent_24h": recent_count,
|
||
"backend": "celery" if self.use_celery else "memory",
|
||
}
|
||
|
||
|
||
# ==================== 性能监控 ====================
|
||
|
||
|
||
class PerformanceMonitor:
|
||
"""
|
||
性能监控器
|
||
|
||
功能:
|
||
- API 响应时间统计
|
||
- 数据库查询性能分析
|
||
- 缓存命中率监控
|
||
- 性能告警机制
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_path: str = "insightflow.db",
|
||
slow_query_threshold: int = 1000,
|
||
alert_threshold: int = 5000, # 毫秒
|
||
) -> None: # 毫秒
|
||
self.db_path = db_path
|
||
self.slow_query_threshold = slow_query_threshold
|
||
self.alert_threshold = alert_threshold
|
||
|
||
# 内存中的指标缓存
|
||
self.metrics_buffer: list[PerformanceMetric] = []
|
||
self.buffer_lock = threading.RLock()
|
||
self.buffer_size = 100
|
||
|
||
# 告警回调
|
||
self.alert_handlers: list[Callable] = []
|
||
|
||
def record_metric(
|
||
self,
|
||
metric_type: str,
|
||
duration_ms: float,
|
||
endpoint: str | None = None,
|
||
metadata: dict | None = None,
|
||
) -> None:
|
||
"""
|
||
记录性能指标
|
||
|
||
Args:
|
||
metric_type: 指标类型 (api_response, db_query, cache_operation)
|
||
duration_ms: 耗时(毫秒)
|
||
endpoint: 端点/查询标识
|
||
metadata: 额外元数据
|
||
"""
|
||
metric = PerformanceMetric(
|
||
id = str(uuid.uuid4())[:16],
|
||
metric_type = metric_type,
|
||
endpoint = endpoint,
|
||
duration_ms = duration_ms,
|
||
timestamp = datetime.now().isoformat(),
|
||
metadata = metadata or {},
|
||
)
|
||
|
||
# 添加到缓冲区
|
||
with self.buffer_lock:
|
||
self.metrics_buffer.append(metric)
|
||
if len(self.metrics_buffer) > self.buffer_size:
|
||
self._flush_metrics()
|
||
|
||
# 检查是否需要告警
|
||
if duration_ms > self.alert_threshold:
|
||
self._trigger_alert(metric)
|
||
|
||
# 慢查询记录
|
||
if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
|
||
self._record_slow_query(metric)
|
||
|
||
def _flush_metrics(self) -> None:
|
||
"""将缓冲区指标写入数据库"""
|
||
if not self.metrics_buffer:
|
||
return
|
||
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
for metric in self.metrics_buffer:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO performance_metrics
|
||
(id, metric_type, endpoint, duration_ms, timestamp, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
metric.id,
|
||
metric.metric_type,
|
||
metric.endpoint,
|
||
metric.duration_ms,
|
||
metric.timestamp,
|
||
json.dumps(metric.metadata, ensure_ascii = False),
|
||
),
|
||
)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
self.metrics_buffer = []
|
||
|
||
def _record_slow_query(self, metric: PerformanceMetric) -> None:
|
||
"""记录慢查询"""
|
||
# 可以发送到专门的慢查询日志或监控系统
|
||
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
|
||
|
||
def _trigger_alert(self, metric: PerformanceMetric) -> None:
|
||
"""触发告警"""
|
||
alert_data = {
|
||
"type": "performance_alert",
|
||
"metric": metric.to_dict(),
|
||
"threshold": self.alert_threshold,
|
||
"message": f"{metric.metric_type} exceeded threshold: {metric.duration_ms}ms > {self.alert_threshold}ms",
|
||
}
|
||
|
||
for handler in self.alert_handlers:
|
||
try:
|
||
handler(alert_data)
|
||
except Exception as e:
|
||
print(f"告警处理失败: {e}")
|
||
|
||
def register_alert_handler(self, handler: Callable) -> None:
|
||
"""注册告警处理器"""
|
||
self.alert_handlers.append(handler)
|
||
|
||
def get_stats(self, hours: int = 24) -> dict:
|
||
"""
|
||
获取性能统计
|
||
|
||
Args:
|
||
hours: 统计最近几小时的数据
|
||
|
||
Returns:
|
||
Dict: 性能统计
|
||
"""
|
||
# 先刷新缓冲区
|
||
self._flush_metrics()
|
||
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# 总体统计
|
||
overall = conn.execute(
|
||
"""
|
||
SELECT
|
||
COUNT(*) as total,
|
||
AVG(duration_ms) as avg_duration,
|
||
MAX(duration_ms) as max_duration,
|
||
MIN(duration_ms) as min_duration
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
""",
|
||
(f"-{hours} hours", ),
|
||
).fetchone()
|
||
|
||
# 按类型统计
|
||
by_type = conn.execute(
|
||
"""
|
||
SELECT
|
||
metric_type,
|
||
COUNT(*) as count,
|
||
AVG(duration_ms) as avg_duration,
|
||
MAX(duration_ms) as max_duration
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
GROUP BY metric_type
|
||
""",
|
||
(f"-{hours} hours", ),
|
||
).fetchall()
|
||
|
||
# 按端点统计(API)
|
||
by_endpoint = conn.execute(
|
||
"""
|
||
SELECT
|
||
endpoint,
|
||
COUNT(*) as count,
|
||
AVG(duration_ms) as avg_duration,
|
||
MAX(duration_ms) as max_duration
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
AND metric_type = 'api_response'
|
||
GROUP BY endpoint
|
||
ORDER BY avg_duration DESC
|
||
LIMIT 20
|
||
""",
|
||
(f"-{hours} hours", ),
|
||
).fetchall()
|
||
|
||
# 慢查询统计
|
||
slow_queries = conn.execute(
|
||
"""
|
||
SELECT
|
||
metric_type,
|
||
endpoint,
|
||
duration_ms,
|
||
timestamp
|
||
FROM performance_metrics
|
||
WHERE timestamp > datetime('now', ?)
|
||
AND duration_ms > ?
|
||
ORDER BY duration_ms DESC
|
||
LIMIT 10
|
||
""",
|
||
(f"-{hours} hours", self.slow_query_threshold),
|
||
).fetchall()
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"period_hours": hours,
|
||
"overall": {
|
||
"total_requests": overall["total"] or 0,
|
||
"avg_duration_ms": round(overall["avg_duration"] or 0, 2),
|
||
"max_duration_ms": overall["max_duration"] or 0,
|
||
"min_duration_ms": overall["min_duration"] or 0,
|
||
},
|
||
"by_type": [
|
||
{
|
||
"type": r["metric_type"],
|
||
"count": r["count"],
|
||
"avg_duration_ms": round(r["avg_duration"], 2),
|
||
"max_duration_ms": r["max_duration"],
|
||
}
|
||
for r in by_type
|
||
],
|
||
"by_endpoint": [
|
||
{
|
||
"endpoint": r["endpoint"],
|
||
"count": r["count"],
|
||
"avg_duration_ms": round(r["avg_duration"], 2),
|
||
"max_duration_ms": r["max_duration"],
|
||
}
|
||
for r in by_endpoint
|
||
],
|
||
"slow_queries": [
|
||
{
|
||
"type": r["metric_type"],
|
||
"endpoint": r["endpoint"],
|
||
"duration_ms": r["duration_ms"],
|
||
"timestamp": r["timestamp"],
|
||
}
|
||
for r in slow_queries
|
||
],
|
||
}
|
||
|
||
def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict:
|
||
"""获取 API 性能详情"""
|
||
self._flush_metrics()
|
||
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
where_clause = "metric_type = 'api_response'"
|
||
params = [f"-{hours} hours"]
|
||
|
||
if endpoint:
|
||
where_clause += " AND endpoint = ?"
|
||
params.append(endpoint)
|
||
|
||
# 百分位数统计
|
||
percentiles = conn.execute(
|
||
f"""
|
||
SELECT
|
||
endpoint,
|
||
COUNT(*) as count,
|
||
AVG(duration_ms) as avg,
|
||
MIN(duration_ms) as min,
|
||
MAX(duration_ms) as max
|
||
FROM performance_metrics
|
||
WHERE {where_clause}
|
||
AND timestamp > datetime('now', ?)
|
||
GROUP BY endpoint
|
||
ORDER BY avg DESC
|
||
""",
|
||
params,
|
||
).fetchall()
|
||
|
||
conn.close()
|
||
|
||
return {
|
||
"endpoint": endpoint or "all",
|
||
"period_hours": hours,
|
||
"endpoints": [
|
||
{
|
||
"endpoint": r["endpoint"],
|
||
"count": r["count"],
|
||
"avg_ms": round(r["avg"], 2),
|
||
"min_ms": r["min"],
|
||
"max_ms": r["max"],
|
||
}
|
||
for r in percentiles
|
||
],
|
||
}
|
||
|
||
def cleanup_old_metrics(self, days: int = 30) -> int:
|
||
"""
|
||
清理旧的性能指标数据
|
||
|
||
Args:
|
||
days: 保留最近几天的数据
|
||
|
||
Returns:
|
||
int: 删除的记录数
|
||
"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
|
||
cursor = conn.execute(
|
||
"""
|
||
DELETE FROM performance_metrics
|
||
WHERE timestamp < datetime('now', ?)
|
||
""",
|
||
(f"-{days} days", ),
|
||
)
|
||
|
||
deleted = cursor.rowcount
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
return deleted
|
||
|
||
|
||
# ==================== 性能装饰器 ====================
|
||
|
||
|
||
def cached(
|
||
cache_manager: CacheManager,
|
||
key_prefix: str = "",
|
||
ttl: int = 3600,
|
||
key_func: Callable | None = None,
|
||
) -> None:
|
||
"""
|
||
缓存装饰器
|
||
|
||
Args:
|
||
cache_manager: 缓存管理器实例
|
||
key_prefix: 缓存键前缀
|
||
ttl: 缓存过期时间
|
||
key_func: 自定义缓存键生成函数
|
||
"""
|
||
|
||
def decorator(func: Callable) -> Callable:
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs) -> None:
|
||
# 生成缓存键
|
||
if key_func:
|
||
cache_key = key_func(*args, **kwargs)
|
||
else:
|
||
# 默认使用函数名和参数哈希
|
||
key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}"
|
||
cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}"
|
||
|
||
# 尝试从缓存获取
|
||
cached_value = cache_manager.get(cache_key)
|
||
if cached_value is not None:
|
||
return cached_value
|
||
|
||
# 执行函数
|
||
result = func(*args, **kwargs)
|
||
|
||
# 写入缓存
|
||
cache_manager.set(cache_key, result, ttl)
|
||
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
|
||
"""
|
||
性能监控装饰器
|
||
|
||
Args:
|
||
monitor: 性能监控器实例
|
||
metric_type: 指标类型
|
||
endpoint: 端点标识
|
||
"""
|
||
|
||
def decorator(func: Callable) -> Callable:
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs) -> None:
|
||
start_time = time.time()
|
||
|
||
try:
|
||
result = func(*args, **kwargs)
|
||
return result
|
||
finally:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
ep = endpoint or func.__name__
|
||
monitor.record_metric(metric_type, duration_ms, ep)
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
# ==================== 性能管理器 ====================
|
||
|
||
|
||
class PerformanceManager:
|
||
"""
|
||
性能管理器 - 统一入口
|
||
|
||
整合缓存管理、数据库分片、任务队列和性能监控功能
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_path: str = "insightflow.db",
|
||
redis_url: str | None = None,
|
||
enable_sharding: bool = False,
|
||
) -> None:
|
||
self.db_path = db_path
|
||
|
||
# 初始化各模块
|
||
self.cache = CacheManager(redis_url = redis_url, db_path = db_path)
|
||
|
||
self.sharding = DatabaseSharding(base_db_path = db_path) if enable_sharding else None
|
||
|
||
self.task_queue = TaskQueue(redis_url = redis_url, db_path = db_path)
|
||
|
||
self.monitor = PerformanceMonitor(db_path = db_path)
|
||
|
||
def get_health_status(self) -> dict:
|
||
"""获取系统健康状态"""
|
||
return {
|
||
"cache": {
|
||
"available": True,
|
||
"backend": "redis" if self.cache.use_redis else "memory",
|
||
"stats": self.cache.get_stats(),
|
||
},
|
||
"sharding": {
|
||
"enabled": self.sharding is not None,
|
||
"shards_count": len(self.sharding.shard_map) if self.sharding else 0,
|
||
},
|
||
"task_queue": {
|
||
"available": self.task_queue.is_available(),
|
||
"backend": "celery" if self.task_queue.use_celery else "memory",
|
||
"stats": self.task_queue.get_stats(),
|
||
},
|
||
"monitor": {
|
||
"available": True,
|
||
"slow_query_threshold": self.monitor.slow_query_threshold,
|
||
"alert_threshold": self.monitor.alert_threshold,
|
||
},
|
||
}
|
||
|
||
def get_full_stats(self) -> dict:
|
||
"""获取完整统计信息"""
|
||
stats = {
|
||
"cache": self.cache.get_stats(),
|
||
"task_queue": self.task_queue.get_stats(),
|
||
"performance": self.monitor.get_stats(),
|
||
}
|
||
|
||
if self.sharding:
|
||
stats["sharding"] = self.sharding.get_shard_stats()
|
||
|
||
return stats
|
||
|
||
|
||
# 单例模式
|
||
_performance_manager = None
|
||
|
||
|
||
def get_performance_manager(
|
||
db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False
|
||
) -> PerformanceManager:
|
||
"""获取性能管理器单例"""
|
||
global _performance_manager
|
||
if _performance_manager is None:
|
||
_performance_manager = PerformanceManager(
|
||
db_path = db_path, redis_url = redis_url, enable_sharding = enable_sharding
|
||
)
|
||
return _performance_manager
|