Files
insightflow/backend/performance_manager.py
OpenClaw Bot 8492e7a0d3 fix: auto-fix code issues (cron)
- 修复缺失导入: main.py 添加 AttributeTemplate 和 EntityAttribute 导入
- 修复裸异常捕获: 将 BaseException 改为具体异常类型
  - neo4j_manager.py: Exception
  - main.py: json.JSONDecodeError, ValueError, Exception
  - export_manager.py: AttributeError, TypeError, ValueError
  - localization_manager.py: ValueError, AttributeError
  - performance_manager.py: TypeError, ValueError
  - plugin_manager.py: OSError, IOError
- 修复部分行长度问题: security_manager.py 长行拆分
2026-02-28 21:14:59 +08:00

1766 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
InsightFlow - 性能优化与扩展模块
Phase 7 Task 8: Performance Optimization & Scaling
功能模块:
1. CacheManager - Redis 缓存层热点数据、TTL、LRU、缓存预热
2. DatabaseSharding - 数据库分片策略
3. TaskQueue - 异步任务队列Celery + Redis
4. PerformanceMonitor - 性能监控API响应、查询性能、缓存命中率
"""
import hashlib
import json
import os
import sqlite3
import threading
import time
import uuid
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from typing import Any
# 尝试导入 Redis
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
# 尝试导入 Celery
try:
from celery import Celery
from celery.result import AsyncResult
CELERY_AVAILABLE = True
except ImportError:
CELERY_AVAILABLE = False
# ==================== 数据模型 ====================
@dataclass
class CacheStats:
"""缓存统计数据模型"""
total_requests: int = 0
hits: int = 0
misses: int = 0
evictions: int = 0
expired: int = 0
hit_rate: float = 0.0
def update_hit_rate(self) -> None:
"""更新命中率"""
if self.total_requests > 0:
self.hit_rate = round(self.hits / self.total_requests, 4)
@dataclass
class CacheEntry:
"""缓存条目数据模型"""
key: str
value: Any
created_at: float
expires_at: float | None
access_count: int = 0
last_accessed: float = 0
size_bytes: int = 0
@dataclass
class PerformanceMetric:
"""性能指标数据模型"""
id: str
metric_type: str # api_response, db_query, cache_operation
endpoint: str | None
duration_ms: float
timestamp: str
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"id": self.id,
"metric_type": self.metric_type,
"endpoint": self.endpoint,
"duration_ms": self.duration_ms,
"timestamp": self.timestamp,
"metadata": self.metadata,
}
@dataclass
class TaskInfo:
"""任务信息数据模型"""
id: str
task_type: str
status: str # pending, running, success, failed, retrying
payload: dict
created_at: str
started_at: str | None = None
completed_at: str | None = None
result: Any | None = None
error_message: str | None = None
retry_count: int = 0
max_retries: int = 3
def to_dict(self) -> dict:
return {
"id": self.id,
"task_type": self.task_type,
"status": self.status,
"payload": self.payload,
"created_at": self.created_at,
"started_at": self.started_at,
"completed_at": self.completed_at,
"result": self.result,
"error_message": self.error_message,
"retry_count": self.retry_count,
"max_retries": self.max_retries,
}
@dataclass
class ShardInfo:
"""分片信息数据模型"""
shard_id: str
shard_key_range: tuple[str, str] # (start, end)
db_path: str
entity_count: int = 0
is_active: bool = True
created_at: str = ""
last_accessed: str = ""
# ==================== Redis 缓存层 ====================
class CacheManager:
"""
缓存管理器
功能:
- 热点数据缓存(实体、关系、转录)
- 缓存失效策略TTL、LRU
- 缓存预热机制
- 缓存统计和监控
支持两种模式:
1. Redis 模式(推荐生产环境)
2. 内存 LRU 模式(开发/测试环境)
"""
def __init__(
self,
redis_url: str | None = None,
max_memory_size: int = 100 * 1024 * 1024, # 100MB
default_ttl: int = 3600, # 1小时
db_path: str = "insightflow.db",
):
self.db_path = db_path
self.default_ttl = default_ttl
self.max_memory_size = max_memory_size
self.current_memory_size = 0
# Redis 客户端
self.redis_client = None
self.use_redis = False
if REDIS_AVAILABLE and redis_url:
try:
self.redis_client = redis.from_url(redis_url, decode_responses=True)
self.redis_client.ping()
self.use_redis = True
print(f"Redis 缓存已连接: {redis_url}")
except Exception as e:
print(f"Redis 连接失败,使用内存缓存: {e}")
# 内存缓存LRU
self.memory_cache: OrderedDict[str, CacheEntry] = OrderedDict()
self.cache_lock = threading.RLock()
# 统计
self.stats = CacheStats()
# 初始化缓存统计表
self._init_cache_tables()
def _init_cache_tables(self) -> None:
"""初始化缓存统计表"""
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS cache_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
total_requests INTEGER DEFAULT 0,
hits INTEGER DEFAULT 0,
misses INTEGER DEFAULT 0,
hit_rate REAL DEFAULT 0.0,
memory_usage INTEGER DEFAULT 0
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS performance_metrics (
id TEXT PRIMARY KEY,
metric_type TEXT NOT NULL,
endpoint TEXT,
duration_ms REAL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT
)
""")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_metrics_type ON performance_metrics(metric_type)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_metrics_time ON performance_metrics(timestamp)"
)
conn.commit()
conn.close()
def _get_entry_size(self, value: Any) -> int:
"""估算缓存条目大小"""
try:
return len(json.dumps(value, ensure_ascii=False).encode("utf-8"))
except (TypeError, ValueError):
return 1024 # 默认估算
def _evict_lru(self, required_space: int = 0) -> None:
"""LRU 淘汰策略"""
with self.cache_lock:
while (
self.current_memory_size + required_space > self.max_memory_size
and self.memory_cache
):
# 移除最久未访问的
oldest_key, oldest_entry = self.memory_cache.popitem(last=False)
self.current_memory_size -= oldest_entry.size_bytes
self.stats.evictions += 1
def get(self, key: str) -> Any | None:
"""
获取缓存值
Args:
key: 缓存键
Returns:
Optional[Any]: 缓存值,不存在返回 None
"""
self.stats.total_requests += 1
if self.use_redis:
try:
value = self.redis_client.get(key)
if value:
self.stats.hits += 1
return json.loads(value)
else:
self.stats.misses += 1
return None
except Exception as e:
print(f"Redis get 失败: {e}")
return None
else:
# 内存缓存
with self.cache_lock:
entry = self.memory_cache.get(key)
if entry:
# 检查是否过期
if entry.expires_at and time.time() > entry.expires_at:
del self.memory_cache[key]
self.current_memory_size -= entry.size_bytes
self.stats.expired += 1
self.stats.misses += 1
return None
# 更新访问信息
entry.access_count += 1
entry.last_accessed = time.time()
self.memory_cache.move_to_end(key)
self.stats.hits += 1
return entry.value
else:
self.stats.misses += 1
return None
def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
"""
设置缓存值
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间None 表示使用默认值
Returns:
bool: 是否成功
"""
ttl = ttl or self.default_ttl
if self.use_redis:
try:
serialized = json.dumps(value, ensure_ascii=False)
self.redis_client.setex(key, ttl, serialized)
return True
except Exception as e:
print(f"Redis set 失败: {e}")
return False
else:
# 内存缓存
with self.cache_lock:
size = self._get_entry_size(value)
# 检查是否需要淘汰
if self.current_memory_size + size > self.max_memory_size:
self._evict_lru(size)
now = time.time()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
expires_at=now + ttl if ttl > 0 else None,
size_bytes=size,
last_accessed=now,
)
# 如果已存在,更新大小
if key in self.memory_cache:
self.current_memory_size -= self.memory_cache[key].size_bytes
self.memory_cache[key] = entry
self.memory_cache.move_to_end(key)
self.current_memory_size += size
return True
def delete(self, key: str) -> bool:
"""删除缓存"""
if self.use_redis:
try:
return self.redis_client.delete(key) > 0
except Exception as e:
print(f"Redis delete 失败: {e}")
return False
else:
with self.cache_lock:
if key in self.memory_cache:
entry = self.memory_cache.pop(key)
self.current_memory_size -= entry.size_bytes
return True
return False
def clear(self) -> bool:
"""清空缓存"""
if self.use_redis:
try:
self.redis_client.flushdb()
return True
except Exception as e:
print(f"Redis clear 失败: {e}")
return False
else:
with self.cache_lock:
self.memory_cache.clear()
self.current_memory_size = 0
return True
def get_many(self, keys: list[str]) -> dict[str, Any]:
"""批量获取缓存"""
results = {}
if self.use_redis:
try:
values = self.redis_client.mget(keys)
for key, value in zip(keys, values):
if value:
results[key] = json.loads(value)
self.stats.hits += 1
else:
self.stats.misses += 1
self.stats.total_requests += 1
except Exception as e:
print(f"Redis mget 失败: {e}")
else:
for key in keys:
value = self.get(key)
if value is not None:
results[key] = value
return results
def set_many(self, mapping: dict[str, Any], ttl: int | None = None) -> bool:
"""批量设置缓存"""
ttl = ttl or self.default_ttl
if self.use_redis:
try:
pipe = self.redis_client.pipeline()
for key, value in mapping.items():
serialized = json.dumps(value, ensure_ascii=False)
pipe.setex(key, ttl, serialized)
pipe.execute()
return True
except Exception as e:
print(f"Redis mset 失败: {e}")
return False
else:
for key, value in mapping.items():
self.set(key, value, ttl)
return True
def get_stats(self) -> dict:
"""获取缓存统计"""
self.stats.update_hit_rate()
stats = {
"total_requests": self.stats.total_requests,
"hits": self.stats.hits,
"misses": self.stats.misses,
"hit_rate": self.stats.hit_rate,
"evictions": self.stats.evictions,
"expired": self.stats.expired,
"backend": "redis" if self.use_redis else "memory",
}
if not self.use_redis:
stats.update(
{
"memory_size_bytes": self.current_memory_size,
"max_memory_size_bytes": self.max_memory_size,
"memory_usage_percent": round(
self.current_memory_size / self.max_memory_size * 100, 2
),
"cache_entries": len(self.memory_cache),
}
)
return stats
def save_stats(self) -> None:
"""保存缓存统计到数据库"""
conn = sqlite3.connect(self.db_path)
self.stats.update_hit_rate()
conn.execute(
"""
INSERT INTO cache_stats
(timestamp, total_requests, hits, misses, hit_rate, memory_usage)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
datetime.now().isoformat(),
self.stats.total_requests,
self.stats.hits,
self.stats.misses,
self.stats.hit_rate,
self.current_memory_size,
),
)
conn.commit()
conn.close()
def warm_up(self, project_id: str) -> dict:
"""
缓存预热 - 加载项目的热点数据
Args:
project_id: 项目ID
Returns:
Dict: 预热统计
"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
stats = {"entities": 0, "relations": 0, "transcripts": 0}
# 预热实体数据
entities = conn.execute(
"""SELECT e.*,
(SELECT COUNT(*) FROM entity_mentions m WHERE m.entity_id = e.id) as mention_count
FROM entities e
WHERE e.project_id = ?
ORDER BY mention_count DESC
LIMIT 100""",
(project_id,),
).fetchall()
for entity in entities:
key = f"entity:{entity['id']}"
self.set(key, dict(entity), ttl=7200) # 2小时
stats["entities"] += 1
# 预热关系数据
relations = conn.execute(
"""SELECT r.*,
e1.name as source_name, e2.name as target_name
FROM entity_relations r
JOIN entities e1 ON r.source_entity_id = e1.id
JOIN entities e2 ON r.target_entity_id = e2.id
WHERE r.project_id = ?
LIMIT 200""",
(project_id,),
).fetchall()
for relation in relations:
key = f"relation:{relation['id']}"
self.set(key, dict(relation), ttl=3600)
stats["relations"] += 1
# 预热最近的转录
transcripts = conn.execute(
"""SELECT * FROM transcripts
WHERE project_id = ?
ORDER BY created_at DESC
LIMIT 10""",
(project_id,),
).fetchall()
for transcript in transcripts:
key = f"transcript:{transcript['id']}"
# 只缓存元数据,不缓存完整文本
meta = {
"id": transcript["id"],
"filename": transcript["filename"],
"type": transcript.get("type", "audio"),
"created_at": transcript["created_at"],
}
self.set(key, meta, ttl=1800) # 30分钟
stats["transcripts"] += 1
# 预热项目知识库摘要
entity_count = conn.execute(
"SELECT COUNT(*) FROM entities WHERE project_id = ?", (project_id,)
).fetchone()[0]
relation_count = conn.execute(
"SELECT COUNT(*) FROM entity_relations WHERE project_id = ?", (project_id,)
).fetchone()[0]
summary = {
"project_id": project_id,
"entity_count": entity_count,
"relation_count": relation_count,
"cached_at": datetime.now().isoformat(),
}
self.set(f"project_summary:{project_id}", summary, ttl=3600)
conn.close()
return stats
def invalidate_project(self, project_id: str) -> int:
"""
使项目的所有缓存失效
Args:
project_id: 项目ID
Returns:
int: 清除的缓存数量
"""
count = 0
if self.use_redis:
try:
# 使用 Redis 的 scan 查找相关 key
pattern = f"*:{project_id}:*"
for key in self.redis_client.scan_iter(match=pattern):
self.redis_client.delete(key)
count += 1
except Exception as e:
print(f"Redis 缓存失效失败: {e}")
else:
# 内存缓存 - 查找并删除相关 key
with self.cache_lock:
keys_to_delete = [key for key in self.memory_cache.keys() if project_id in key]
for key in keys_to_delete:
entry = self.memory_cache.pop(key)
self.current_memory_size -= entry.size_bytes
count += 1
return count
# ==================== 数据库分片 ====================
class DatabaseSharding:
"""
数据库分片管理器
功能:
- 项目数据分片策略
- 分片路由逻辑
- 跨分片查询支持
- 分片迁移工具
"""
def __init__(
self,
base_db_path: str = "insightflow.db",
shard_db_dir: str = "./shards",
shards_count: int = 4,
):
self.base_db_path = base_db_path
self.shard_db_dir = shard_db_dir
self.shards_count = shards_count
# 确保分片目录存在
os.makedirs(shard_db_dir, exist_ok=True)
# 分片映射
self.shard_map: dict[str, ShardInfo] = {}
# 初始化分片
self._init_shards()
def _init_shards(self) -> None:
"""初始化分片"""
# 计算每个分片的 key 范围
chars = "0123456789abcdef"
chars_per_shard = len(chars) // self.shards_count
for i in range(self.shards_count):
start_idx = i * chars_per_shard
end_idx = start_idx + chars_per_shard if i < self.shards_count - 1 else len(chars)
start_char = chars[start_idx]
end_char = chars[end_idx - 1]
shard_id = f"shard_{i}"
db_path = os.path.join(self.shard_db_dir, f"{shard_id}.db")
self.shard_map[shard_id] = ShardInfo(
shard_id=shard_id,
shard_key_range=(start_char, end_char),
db_path=db_path,
created_at=datetime.now().isoformat(),
)
# 确保分片数据库存在
if not os.path.exists(db_path):
self._create_shard_db(db_path)
def _create_shard_db(self, db_path: str) -> None:
"""创建分片数据库"""
conn = sqlite3.connect(db_path)
# 创建与主库相同的表结构(简化版)
conn.executescript("""
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
name TEXT NOT NULL,
type TEXT,
definition TEXT,
aliases TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS entity_relations (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
source_entity_id TEXT NOT NULL,
target_entity_id TEXT NOT NULL,
relation_type TEXT,
evidence TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_entities_project ON entities(project_id);
CREATE INDEX IF NOT EXISTS idx_relations_project ON entity_relations(project_id);
""")
conn.commit()
conn.close()
def _get_shard_id(self, project_id: str) -> str:
"""
根据项目ID计算分片ID
使用项目ID的第一个字符进行哈希
"""
if not project_id:
return "shard_0"
first_char = project_id[0].lower()
for shard_id, shard_info in self.shard_map.items():
start, end = shard_info.shard_key_range
if start <= first_char <= end:
return shard_id
return "shard_0"
def get_shard_connection(self, project_id: str) -> sqlite3.Connection:
"""获取项目对应的分片连接"""
shard_id = self._get_shard_id(project_id)
shard_info = self.shard_map[shard_id]
conn = sqlite3.connect(shard_info.db_path)
conn.row_factory = sqlite3.Row
# 更新访问时间
shard_info.last_accessed = datetime.now().isoformat()
return conn
def get_all_shards(self) -> list[ShardInfo]:
"""获取所有分片信息"""
return list(self.shard_map.values())
def migrate_project(self, project_id: str, target_shard_id: str) -> bool:
"""
迁移项目到指定分片
Args:
project_id: 项目ID
target_shard_id: 目标分片ID
Returns:
bool: 是否成功
"""
# 获取源分片
source_shard_id = self._get_shard_id(project_id)
if source_shard_id == target_shard_id:
return True # 已经在目标分片
source_info = self.shard_map.get(source_shard_id)
target_info = self.shard_map.get(target_shard_id)
if not source_info or not target_info:
return False
try:
# 从源分片读取数据
source_conn = sqlite3.connect(source_info.db_path)
source_conn.row_factory = sqlite3.Row
entities = source_conn.execute(
"SELECT * FROM entities WHERE project_id = ?", (project_id,)
).fetchall()
relations = source_conn.execute(
"SELECT * FROM entity_relations WHERE project_id = ?", (project_id,)
).fetchall()
source_conn.close()
# 写入目标分片
target_conn = sqlite3.connect(target_info.db_path)
for entity in entities:
target_conn.execute(
"""
INSERT OR REPLACE INTO entities
(id, project_id, name, type, definition, aliases, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
tuple(entity),
)
for relation in relations:
target_conn.execute(
"""
INSERT OR REPLACE INTO entity_relations
(id, project_id, source_entity_id, target_entity_id, relation_type, evidence, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
tuple(relation),
)
target_conn.commit()
target_conn.close()
# 从源分片删除数据
source_conn = sqlite3.connect(source_info.db_path)
source_conn.execute("DELETE FROM entities WHERE project_id = ?", (project_id,))
source_conn.execute("DELETE FROM entity_relations WHERE project_id = ?", (project_id,))
source_conn.commit()
source_conn.close()
# 更新分片统计
self._update_shard_stats(source_shard_id)
self._update_shard_stats(target_shard_id)
return True
except Exception as e:
print(f"迁移失败: {e}")
return False
def _update_shard_stats(self, shard_id: str) -> None:
"""更新分片统计"""
shard_info = self.shard_map.get(shard_id)
if not shard_info:
return
conn = sqlite3.connect(shard_info.db_path)
count = conn.execute("SELECT COUNT(DISTINCT project_id) FROM entities").fetchone()[0]
shard_info.entity_count = count
conn.close()
def cross_shard_query(self, query_func: Callable) -> list[dict]:
"""
跨分片查询
Args:
query_func: 查询函数,接收 connection 参数
Returns:
List[Dict]: 合并的查询结果
"""
results = []
for shard_info in self.shard_map.values():
conn = sqlite3.connect(shard_info.db_path)
conn.row_factory = sqlite3.Row
try:
shard_results = query_func(conn)
results.extend(shard_results)
except Exception as e:
print(f"分片 {shard_info.shard_id} 查询失败: {e}")
finally:
conn.close()
return results
def get_shard_stats(self) -> list[dict]:
"""获取所有分片的统计信息"""
stats = []
for shard_info in self.shard_map.values():
self._update_shard_stats(shard_info.shard_id)
stats.append(
{
"shard_id": shard_info.shard_id,
"key_range": shard_info.shard_key_range,
"db_path": shard_info.db_path,
"entity_count": shard_info.entity_count,
"is_active": shard_info.is_active,
"created_at": shard_info.created_at,
"last_accessed": shard_info.last_accessed,
}
)
return stats
def rebalance_shards(self) -> dict:
"""
重新平衡分片
将数据从过载的分片迁移到负载较轻的分片
Returns:
Dict: 重新平衡统计
"""
# 获取各分片的负载
stats = self.get_shard_stats()
if not stats:
return {"message": "No shards to rebalance"}
# 计算平均负载
avg_load = sum(s["entity_count"] for s in stats) / len(stats)
# 找出过载和欠载的分片
overloaded = [s for s in stats if s["entity_count"] > avg_load * 1.5]
underloaded = [s for s in stats if s["entity_count"] < avg_load * 0.5]
# 简化的重新平衡逻辑
# 实际生产环境需要更复杂的算法
return {
"average_load": avg_load,
"overloaded_shards": len(overloaded),
"underloaded_shards": len(underloaded),
"message": "Rebalancing analysis completed",
}
# ==================== 异步任务队列 ====================
class TaskQueue:
"""
异步任务队列管理器
功能:
- 基于 Celery + Redis 的任务队列
- 音频分析异步处理
- 报告生成异步处理
- 任务状态追踪和重试机制
"""
def __init__(self, redis_url: str | None = None, db_path: str = "insightflow.db"):
self.db_path = db_path
self.redis_url = redis_url
self.celery_app = None
self.use_celery = False
# 内存任务存储(非 Celery 模式)
self.tasks: dict[str, TaskInfo] = {}
self.task_handlers: dict[str, Callable] = {}
self.task_lock = threading.RLock()
# 初始化任务队列表
self._init_task_tables()
# 初始化 Celery
if CELERY_AVAILABLE and redis_url:
try:
self.celery_app = Celery("insightflow", broker=redis_url, backend=redis_url)
self.use_celery = True
print("Celery 任务队列已初始化")
except Exception as e:
print(f"Celery 初始化失败,使用内存任务队列: {e}")
def _init_task_tables(self) -> None:
"""初始化任务队列表"""
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS task_queue (
id TEXT PRIMARY KEY,
task_type TEXT NOT NULL,
status TEXT DEFAULT 'pending',
payload TEXT,
result TEXT,
error_message TEXT,
retry_count INTEGER DEFAULT 0,
max_retries INTEGER DEFAULT 3,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON task_queue(status)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_type ON task_queue(task_type)")
conn.commit()
conn.close()
def is_available(self) -> bool:
"""检查任务队列是否可用"""
return self.use_celery or True # 内存模式也可用
def register_handler(self, task_type: str, handler: Callable) -> None:
"""注册任务处理器"""
self.task_handlers[task_type] = handler
def submit(self, task_type: str, payload: dict, max_retries: int = 3) -> str:
"""
提交任务
Args:
task_type: 任务类型
payload: 任务数据
max_retries: 最大重试次数
Returns:
str: 任务ID
"""
task_id = str(uuid.uuid4())[:16]
task = TaskInfo(
id=task_id,
task_type=task_type,
status="pending",
payload=payload,
created_at=datetime.now().isoformat(),
max_retries=max_retries,
)
if self.use_celery:
# 使用 Celery
try:
# 这里简化处理,实际应该定义具体的 Celery 任务
result = self.celery_app.send_task(
f"insightflow.tasks.{task_type}",
args=[payload],
task_id=task_id,
retry=True,
retry_policy={
"max_retries": max_retries,
"interval_start": 10,
"interval_step": 10,
"interval_max": 60,
},
)
task.id = result.id
except Exception as e:
print(f"Celery 任务提交失败: {e}")
# 回退到内存模式
self.use_celery = False
if not self.use_celery:
# 内存模式
with self.task_lock:
self.tasks[task_id] = task
# 异步执行
threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start()
# 保存到数据库
self._save_task(task)
return task_id
def _execute_task(self, task_id: str) -> None:
"""执行任务(内存模式)"""
with self.task_lock:
task = self.tasks.get(task_id)
if not task:
return
task.status = "running"
task.started_at = datetime.now().isoformat()
self._update_task_status(task)
# 获取处理器
handler = self.task_handlers.get(task.task_type)
if not handler:
task.status = "failed"
task.error_message = f"No handler for task type: {task.task_type}"
else:
try:
result = handler(task.payload)
task.status = "success"
task.result = result
except Exception as e:
task.retry_count += 1
if task.retry_count <= task.max_retries:
task.status = "retrying"
# 延迟重试
threading.Timer(
10 * task.retry_count, self._execute_task, args=(task_id,)
).start()
else:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.now().isoformat()
with self.task_lock:
self.tasks[task_id] = task
self._update_task_status(task)
def _save_task(self, task: TaskInfo) -> None:
"""保存任务到数据库"""
conn = sqlite3.connect(self.db_path)
conn.execute(
"""
INSERT OR REPLACE INTO task_queue
(id, task_type, status, payload, result, error_message,
retry_count, max_retries, created_at, started_at, completed_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
task.id,
task.task_type,
task.status,
json.dumps(task.payload, ensure_ascii=False),
json.dumps(task.result, ensure_ascii=False) if task.result else None,
task.error_message,
task.retry_count,
task.max_retries,
task.created_at,
task.started_at,
task.completed_at,
),
)
conn.commit()
conn.close()
def _update_task_status(self, task: TaskInfo) -> None:
"""更新任务状态"""
conn = sqlite3.connect(self.db_path)
conn.execute(
"""
UPDATE task_queue SET
status = ?,
result = ?,
error_message = ?,
retry_count = ?,
started_at = ?,
completed_at = ?
WHERE id = ?
""",
(
task.status,
json.dumps(task.result, ensure_ascii=False) if task.result else None,
task.error_message,
task.retry_count,
task.started_at,
task.completed_at,
task.id,
),
)
conn.commit()
conn.close()
def get_status(self, task_id: str) -> TaskInfo | None:
"""获取任务状态"""
if self.use_celery:
try:
result = AsyncResult(task_id, app=self.celery_app)
status_map = {
"PENDING": "pending",
"STARTED": "running",
"SUCCESS": "success",
"FAILURE": "failed",
"RETRY": "retrying",
}
return TaskInfo(
id=task_id,
task_type="celery_task",
status=status_map.get(result.status, "unknown"),
payload={},
created_at="",
result=result.result if result.successful() else None,
error_message=str(result.result) if result.failed() else None,
)
except Exception as e:
print(f"获取 Celery 任务状态失败: {e}")
# 内存模式或回退
with self.task_lock:
return self.tasks.get(task_id)
def list_tasks(
self, status: str | None = None, task_type: str | None = None, limit: int = 100
) -> list[TaskInfo]:
"""列出任务"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
where_clauses = []
params = []
if status:
where_clauses.append("status = ?")
params.append(status)
if task_type:
where_clauses.append("task_type = ?")
params.append(task_type)
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
rows = conn.execute(
f"""
SELECT * FROM task_queue
WHERE {where_str}
ORDER BY created_at DESC
LIMIT ?
""",
params + [limit],
).fetchall()
conn.close()
tasks = []
for row in rows:
tasks.append(
TaskInfo(
id=row["id"],
task_type=row["task_type"],
status=row["status"],
payload=json.loads(row["payload"]) if row["payload"] else {},
created_at=row["created_at"],
started_at=row["started_at"],
completed_at=row["completed_at"],
result=json.loads(row["result"]) if row["result"] else None,
error_message=row["error_message"],
retry_count=row["retry_count"],
max_retries=row["max_retries"],
)
)
return tasks
def cancel(self, task_id: str) -> bool:
"""取消任务"""
if self.use_celery:
try:
self.celery_app.control.revoke(task_id, terminate=True)
return True
except Exception as e:
print(f"取消 Celery 任务失败: {e}")
with self.task_lock:
task = self.tasks.get(task_id)
if task and task.status in ["pending", "running"]:
task.status = "cancelled"
task.completed_at = datetime.now().isoformat()
self._update_task_status(task)
return True
return False
def retry(self, task_id: str) -> bool:
"""重试失败的任务"""
task = self.get_status(task_id)
if not task or task.status != "failed":
return False
task.status = "pending"
task.retry_count = 0
task.error_message = None
task.completed_at = None
if not self.use_celery:
with self.task_lock:
self.tasks[task_id] = task
threading.Thread(target=self._execute_task, args=(task_id,), daemon=True).start()
self._update_task_status(task)
return True
def get_stats(self) -> dict:
"""获取任务队列统计"""
conn = sqlite3.connect(self.db_path)
# 各状态任务数量
status_counts = conn.execute("""
SELECT status, COUNT(*) as count
FROM task_queue
GROUP BY status
""").fetchall()
# 各类型任务数量
type_counts = conn.execute("""
SELECT task_type, COUNT(*) as count
FROM task_queue
GROUP BY task_type
""").fetchall()
# 最近24小时任务数
recent_count = conn.execute("""
SELECT COUNT(*) as count
FROM task_queue
WHERE created_at > datetime('now', '-1 day')
""").fetchone()[0]
conn.close()
return {
"by_status": {r[0]: r[1] for r in status_counts},
"by_type": {r[0]: r[1] for r in type_counts},
"recent_24h": recent_count,
"backend": "celery" if self.use_celery else "memory",
}
# ==================== 性能监控 ====================
class PerformanceMonitor:
"""
性能监控器
功能:
- API 响应时间统计
- 数据库查询性能分析
- 缓存命中率监控
- 性能告警机制
"""
def __init__(
self,
db_path: str = "insightflow.db",
slow_query_threshold: int = 1000,
alert_threshold: int = 5000, # 毫秒
): # 毫秒
self.db_path = db_path
self.slow_query_threshold = slow_query_threshold
self.alert_threshold = alert_threshold
# 内存中的指标缓存
self.metrics_buffer: list[PerformanceMetric] = []
self.buffer_lock = threading.RLock()
self.buffer_size = 100
# 告警回调
self.alert_handlers: list[Callable] = []
def record_metric(
self,
metric_type: str,
duration_ms: float,
endpoint: str | None = None,
metadata: dict | None = None,
):
"""
记录性能指标
Args:
metric_type: 指标类型 (api_response, db_query, cache_operation)
duration_ms: 耗时(毫秒)
endpoint: 端点/查询标识
metadata: 额外元数据
"""
metric = PerformanceMetric(
id=str(uuid.uuid4())[:16],
metric_type=metric_type,
endpoint=endpoint,
duration_ms=duration_ms,
timestamp=datetime.now().isoformat(),
metadata=metadata or {},
)
# 添加到缓冲区
with self.buffer_lock:
self.metrics_buffer.append(metric)
if len(self.metrics_buffer) > self.buffer_size:
self._flush_metrics()
# 检查是否需要告警
if duration_ms > self.alert_threshold:
self._trigger_alert(metric)
# 慢查询记录
if metric_type == "db_query" and duration_ms > self.slow_query_threshold:
self._record_slow_query(metric)
def _flush_metrics(self) -> None:
"""将缓冲区指标写入数据库"""
if not self.metrics_buffer:
return
conn = sqlite3.connect(self.db_path)
for metric in self.metrics_buffer:
conn.execute(
"""
INSERT INTO performance_metrics
(id, metric_type, endpoint, duration_ms, timestamp, metadata)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
metric.id,
metric.metric_type,
metric.endpoint,
metric.duration_ms,
metric.timestamp,
json.dumps(metric.metadata, ensure_ascii=False),
),
)
conn.commit()
conn.close()
self.metrics_buffer = []
def _record_slow_query(self, metric: PerformanceMetric) -> None:
"""记录慢查询"""
# 可以发送到专门的慢查询日志或监控系统
print(f"[SLOW QUERY] {metric.endpoint}: {metric.duration_ms}ms")
def _trigger_alert(self, metric: PerformanceMetric) -> None:
"""触发告警"""
alert_data = {
"type": "performance_alert",
"metric": metric.to_dict(),
"threshold": self.alert_threshold,
"message": f"{metric.metric_type} exceeded threshold: {metric.duration_ms}ms > {self.alert_threshold}ms",
}
for handler in self.alert_handlers:
try:
handler(alert_data)
except Exception as e:
print(f"告警处理失败: {e}")
def register_alert_handler(self, handler: Callable) -> None:
"""注册告警处理器"""
self.alert_handlers.append(handler)
def get_stats(self, hours: int = 24) -> dict:
"""
获取性能统计
Args:
hours: 统计最近几小时的数据
Returns:
Dict: 性能统计
"""
# 先刷新缓冲区
self._flush_metrics()
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
# 总体统计
overall = conn.execute(
"""
SELECT
COUNT(*) as total,
AVG(duration_ms) as avg_duration,
MAX(duration_ms) as max_duration,
MIN(duration_ms) as min_duration
FROM performance_metrics
WHERE timestamp > datetime('now', ?)
""",
(f"-{hours} hours",),
).fetchone()
# 按类型统计
by_type = conn.execute(
"""
SELECT
metric_type,
COUNT(*) as count,
AVG(duration_ms) as avg_duration,
MAX(duration_ms) as max_duration
FROM performance_metrics
WHERE timestamp > datetime('now', ?)
GROUP BY metric_type
""",
(f"-{hours} hours",),
).fetchall()
# 按端点统计API
by_endpoint = conn.execute(
"""
SELECT
endpoint,
COUNT(*) as count,
AVG(duration_ms) as avg_duration,
MAX(duration_ms) as max_duration
FROM performance_metrics
WHERE timestamp > datetime('now', ?)
AND metric_type = 'api_response'
GROUP BY endpoint
ORDER BY avg_duration DESC
LIMIT 20
""",
(f"-{hours} hours",),
).fetchall()
# 慢查询统计
slow_queries = conn.execute(
"""
SELECT
metric_type,
endpoint,
duration_ms,
timestamp
FROM performance_metrics
WHERE timestamp > datetime('now', ?)
AND duration_ms > ?
ORDER BY duration_ms DESC
LIMIT 10
""",
(f"-{hours} hours", self.slow_query_threshold),
).fetchall()
conn.close()
return {
"period_hours": hours,
"overall": {
"total_requests": overall["total"] or 0,
"avg_duration_ms": round(overall["avg_duration"] or 0, 2),
"max_duration_ms": overall["max_duration"] or 0,
"min_duration_ms": overall["min_duration"] or 0,
},
"by_type": [
{
"type": r["metric_type"],
"count": r["count"],
"avg_duration_ms": round(r["avg_duration"], 2),
"max_duration_ms": r["max_duration"],
}
for r in by_type
],
"by_endpoint": [
{
"endpoint": r["endpoint"],
"count": r["count"],
"avg_duration_ms": round(r["avg_duration"], 2),
"max_duration_ms": r["max_duration"],
}
for r in by_endpoint
],
"slow_queries": [
{
"type": r["metric_type"],
"endpoint": r["endpoint"],
"duration_ms": r["duration_ms"],
"timestamp": r["timestamp"],
}
for r in slow_queries
],
}
def get_api_performance(self, endpoint: str | None = None, hours: int = 24) -> dict:
"""获取 API 性能详情"""
self._flush_metrics()
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
where_clause = "metric_type = 'api_response'"
params = [f"-{hours} hours"]
if endpoint:
where_clause += " AND endpoint = ?"
params.append(endpoint)
# 百分位数统计
percentiles = conn.execute(
f"""
SELECT
endpoint,
COUNT(*) as count,
AVG(duration_ms) as avg,
MIN(duration_ms) as min,
MAX(duration_ms) as max
FROM performance_metrics
WHERE {where_clause}
AND timestamp > datetime('now', ?)
GROUP BY endpoint
ORDER BY avg DESC
""",
params,
).fetchall()
conn.close()
return {
"endpoint": endpoint or "all",
"period_hours": hours,
"endpoints": [
{
"endpoint": r["endpoint"],
"count": r["count"],
"avg_ms": round(r["avg"], 2),
"min_ms": r["min"],
"max_ms": r["max"],
}
for r in percentiles
],
}
def cleanup_old_metrics(self, days: int = 30) -> int:
"""
清理旧的性能指标数据
Args:
days: 保留最近几天的数据
Returns:
int: 删除的记录数
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"""
DELETE FROM performance_metrics
WHERE timestamp < datetime('now', ?)
""",
(f"-{days} days",),
)
deleted = cursor.rowcount
conn.commit()
conn.close()
return deleted
# ==================== 性能装饰器 ====================
def cached(
cache_manager: CacheManager,
key_prefix: str = "",
ttl: int = 3600,
key_func: Callable | None = None,
) -> None:
"""
缓存装饰器
Args:
cache_manager: 缓存管理器实例
key_prefix: 缓存键前缀
ttl: 缓存过期时间
key_func: 自定义缓存键生成函数
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> None:
# 生成缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
else:
# 默认使用函数名和参数哈希
key_data = f"{func.__name__}:{str(args)}:{str(kwargs)}"
cache_key = f"{key_prefix}:{hashlib.md5(key_data.encode()).hexdigest()[:16]}"
# 尝试从缓存获取
cached_value = cache_manager.get(cache_key)
if cached_value is not None:
return cached_value
# 执行函数
result = func(*args, **kwargs)
# 写入缓存
cache_manager.set(cache_key, result, ttl)
return result
return wrapper
return decorator
def monitored(monitor: PerformanceMonitor, metric_type: str, endpoint: str | None = None) -> None:
"""
性能监控装饰器
Args:
monitor: 性能监控器实例
metric_type: 指标类型
endpoint: 端点标识
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
return result
finally:
duration_ms = (time.time() - start_time) * 1000
ep = endpoint or func.__name__
monitor.record_metric(metric_type, duration_ms, ep)
return wrapper
return decorator
# ==================== 性能管理器 ====================
class PerformanceManager:
"""
性能管理器 - 统一入口
整合缓存管理、数据库分片、任务队列和性能监控功能
"""
def __init__(
self,
db_path: str = "insightflow.db",
redis_url: str | None = None,
enable_sharding: bool = False,
):
self.db_path = db_path
# 初始化各模块
self.cache = CacheManager(redis_url=redis_url, db_path=db_path)
self.sharding = DatabaseSharding(base_db_path=db_path) if enable_sharding else None
self.task_queue = TaskQueue(redis_url=redis_url, db_path=db_path)
self.monitor = PerformanceMonitor(db_path=db_path)
def get_health_status(self) -> dict:
"""获取系统健康状态"""
return {
"cache": {
"available": True,
"backend": "redis" if self.cache.use_redis else "memory",
"stats": self.cache.get_stats(),
},
"sharding": {
"enabled": self.sharding is not None,
"shards_count": len(self.sharding.shard_map) if self.sharding else 0,
},
"task_queue": {
"available": self.task_queue.is_available(),
"backend": "celery" if self.task_queue.use_celery else "memory",
"stats": self.task_queue.get_stats(),
},
"monitor": {
"available": True,
"slow_query_threshold": self.monitor.slow_query_threshold,
"alert_threshold": self.monitor.alert_threshold,
},
}
def get_full_stats(self) -> dict:
"""获取完整统计信息"""
stats = {
"cache": self.cache.get_stats(),
"task_queue": self.task_queue.get_stats(),
"performance": self.monitor.get_stats(),
}
if self.sharding:
stats["sharding"] = self.sharding.get_shard_stats()
return stats
# 单例模式
_performance_manager = None
def get_performance_manager(
db_path: str = "insightflow.db", redis_url: str | None = None, enable_sharding: bool = False
) -> PerformanceManager:
"""获取性能管理器单例"""
global _performance_manager
if _performance_manager is None:
_performance_manager = PerformanceManager(
db_path=db_path, redis_url=redis_url, enable_sharding=enable_sharding
)
return _performance_manager