Skip to content

Commit b7f5153

Browse files
Merge pull request #57 from Simple-Efficient/feature/tools-cache
Feature/tools cache
2 parents b04da93 + b7c7ba4 commit b7f5153

File tree

16 files changed

+836
-101
lines changed

16 files changed

+836
-101
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,8 @@ outputs
129129
tensorboard_log
130130
ckpt
131131

132-
.hopeignore
132+
.hopeignore
133+
134+
# cache
135+
cache_data/
136+
*.cache

envs/storage/CacheMe.md

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# RL Factory 本地缓存方案
2+
3+
## 目录
4+
- [缓存组件设计](#缓存组件设计)
5+
- [方案调研](#方案调研)
6+
- [缓存淘汰策略](#缓存淘汰策略)
7+
- [存储数据结构](#存储数据结构)
8+
- [持久化方案](#持久化方案)
9+
- [兜底场景](#兜底场景)
10+
- [使用方式](#使用方式)
11+
12+
## 缓存组件设计
13+
14+
### 核心需求
15+
- **易用性**:服务于训练过程中产生的大量Tools调用结果缓存,避免相同条件查询带来的IO和性能损耗
16+
- **高性能**:本地缓存组件降低远程调用延时,需要本地内存访问
17+
- **高并发**:训练过程需要支持cache多线程并发安全访问
18+
- **持久化**:训练可能存在多轮,需要具备缓存结果序列化反序列化能力
19+
- **内存限制**:实现合理的淘汰策略
20+
- **功能扩展**:自带丰富API或具备开发扩展能力
21+
22+
## 方案调研
23+
24+
### 缓存组件对比分析
25+
26+
| 缓存组件 | 优点 | 缺点 | 适用场景 |
27+
|---------|------|------|----------|
28+
| **cachebox** | • 满足本地高性能缓存需求<br>• 支持内存管理和并发访问<br>• 使用灵活,API丰富<br>• 易于集成和扩展<br>• 适合训练或高频场景 | • 受限于可用内存大小<br>• 无法持久化缓存<br>• 需要自行实现持久化<br>• 新开源组件,稳定性待验证 | 需要高性能内存缓存的场景 |
29+
| **functools** | • Python 3.9+ 标准库自带<br>• LRU策略实现<br>• 装饰器形式使用<br>• 线程安全 | • 功能相对简单<br>• 内存限制<br>• 过期策略有限 | 简单的函数结果缓存 |
30+
| **cachetools** | • 多种缓存策略支持<br>• 高度灵活性<br>• API简洁<br>• 支持自定义扩展 | • 仅内存存储<br>• 需额外安装<br>• 持久化需自行实现 | 需要多种缓存策略的场景 |
31+
| **diskcache** | • 磁盘持久化<br>• 功能丰富<br>• 类字典接口<br>• 线程和进程安全 | • I/O开销较大<br>• 需额外安装<br>• 配置管理复杂 | 需要持久化的场景 |
32+
| **joblib.Memory** | • 针对函数输出优化<br>• 透明缓存机制<br>• 磁盘持久化<br>• 大型数据友好 | • 特定场景限制<br>• 依赖Scipy生态<br>• 需额外安装 | 科学计算和机器学习场景 |
33+
34+
### 方案选择
35+
基于以上分析,我们选择采用 **cachebox + 自定义序列化/反序列化** 的方案,原因如下:
36+
1. 满足高性能和并发需求
37+
2. 提供丰富的API支持
38+
3. 易于扩展和定制
39+
4. 适合训练场景
40+
41+
## 缓存淘汰策略
42+
43+
### 策略类型
44+
1. **LRU (Least Recently Used)**
45+
- 基于访问时间
46+
- 适合大多数场景
47+
- 实现简单,效果稳定
48+
49+
2. **LFU (Least Frequently Used)**
50+
- 基于访问频率
51+
- 适合访问模式稳定的场景
52+
- 需要额外统计信息
53+
54+
3. **FIFO (First In First Out)**
55+
- 基于插入时间
56+
- 适合数据重要性随时间降低的场景
57+
- 实现最简单
58+
59+
4. **TTL (Time To Live)**
60+
- 基于过期时间
61+
- 适合数据时效性要求高的场景
62+
- 需要额外的时间管理
63+
64+
## 存储数据结构
65+
66+
### Key设计
67+
```python
68+
def generate_cache_key(method_name: str, input_params: dict) -> str:
69+
"""
70+
生成缓存键
71+
Args:
72+
method_name: 方法名
73+
input_params: 输入参数
74+
Returns:
75+
str: 缓存键
76+
"""
77+
key_str = f"{method_name}:{json.dumps(input_params, sort_keys=True)}"
78+
return hashlib.md5(key_str.encode()).hexdigest()
79+
```
80+
81+
### Value设计
82+
```python
83+
{
84+
"result": Any, # 实际结果
85+
"metadata": {
86+
"created_at": datetime,
87+
"expires_at": datetime,
88+
"access_count": int,
89+
"last_access": datetime
90+
}
91+
}
92+
```
93+
94+
## 持久化方案
95+
96+
### 文件组织
97+
- **命名格式**`{method_hash}_{date}_{partition}.cache`
98+
- **存储格式**`gzip(serialize({key: value}))`
99+
- **目录结构**
100+
```
101+
cache/
102+
├── 2024-03/
103+
│ ├── method1_20240301_0.cache
104+
│ └── method1_20240301_1.cache
105+
└── 2024-04/
106+
└── method1_20240401_0.cache
107+
```
108+
109+
### 加载机制
110+
1. 训练启动时通过 `-loadCache` 参数指定缓存文件
111+
2. 支持增量加载和全量加载
112+
3. 支持多节点数据同步
113+
114+
### 数据一致性
115+
- 采用定期回溯策略
116+
- 基于文件名归并缓存数据
117+
- 使用缓存key hash设计确保多节点数据一致
118+
119+
## 兜底场景
120+
121+
### 缓存未命中处理
122+
1. 允许缓存未命中
123+
2. 异步更新缓存
124+
3. 降级处理机制
125+
126+
### 分布式场景
127+
1. 单节点独立构建缓存
128+
2. 独立持久化
129+
3. 预留分布式扩展接口
130+
131+
## 使用方式
132+
133+
### 命令行参数
134+
```bash
135+
# 基础用法
136+
python storage_test.py
137+
138+
# 高级配置
139+
python storage_test.py \
140+
--cache single|multi \
141+
--persist \
142+
--eviction lru|lfu|fifo|ttl \
143+
--max-size 1000 \
144+
--persist-dir /path/to/cache \
145+
--load-cache cache_file.cache
146+
```
147+
148+
### 配置说明
149+
| 参数 | 说明 | 默认值 |
150+
|------|------|--------|
151+
| `--cache` | 缓存模式:single/multi | single |
152+
| `--persist` | 是否启用持久化 | false |
153+
| `--eviction` | 缓存淘汰策略 | lru |
154+
| `--max-size` | 最大缓存条目数 | 1000 |
155+
| `--persist-dir` | 持久化目录 | ./cache |
156+
| `--load-cache` | 加载缓存文件 | null |
157+
158+
### 使用示例
159+
```python
160+
from cache_manager import CacheManager
161+
162+
# 初始化缓存管理器
163+
cache = CacheManager(
164+
mode="single",
165+
eviction="lru",
166+
max_size=1000,
167+
persist=True,
168+
persist_dir="./cache"
169+
)
170+
171+
# 使用缓存
172+
@cache.cached
173+
def expensive_operation(param1, param2):
174+
# 实际计算逻辑
175+
return result
176+
```

envs/storage/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .cache import CacheBase
2+
from .persist import PersistBase
3+
4+
__all__ = ['CacheBase', 'PersistBase']
5+
6+
CACHE_REGISTRY = {
7+
'cache': CacheBase,
8+
'persist': PersistBase
9+
}
10+

envs/storage/cache/cache_base.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Optional, Dict
3+
from enum import Enum
4+
5+
class CacheMode(Enum):
6+
SINGLE = "single" # 单机模式
7+
MULTI = "multi" # 分布式模式
8+
9+
class EvictionPolicy(Enum):
10+
LRU = "lru" # 最近最少使用
11+
LFU = "lfu" # 最不经常使用
12+
FIFO = "fifo" # 先进先出
13+
TTL = "ttl" # 基于时间过期
14+
15+
class CacheBase(ABC):
16+
@abstractmethod
17+
def get(self, key: str) -> Optional[Any]:
18+
"""获取缓存值"""
19+
pass
20+
21+
@abstractmethod
22+
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
23+
"""设置缓存值
24+
25+
Args:
26+
key: 缓存键
27+
value: 缓存值
28+
ttl: 过期时间(秒),None表示永不过期
29+
"""
30+
pass
31+
32+
@abstractmethod
33+
def delete(self, key: str) -> None:
34+
"""删除缓存值"""
35+
pass
36+
37+
@abstractmethod
38+
def clear(self) -> None:
39+
"""清空缓存"""
40+
pass
41+
42+
@abstractmethod
43+
def has(self, key: str) -> bool:
44+
"""检查键是否存在"""
45+
pass
46+
47+
@abstractmethod
48+
def get_mode(self) -> CacheMode:
49+
"""获取缓存模式"""
50+
pass
51+
52+
@abstractmethod
53+
def get_stats(self) -> Dict[str, Any]:
54+
"""获取缓存统计信息"""
55+
pass
56+
57+
@abstractmethod
58+
def get_eviction_policy(self) -> EvictionPolicy:
59+
"""获取缓存淘汰策略"""
60+
pass
61+
62+
@abstractmethod
63+
def set_eviction_policy(self, policy: EvictionPolicy) -> None:
64+
"""设置缓存淘汰策略"""
65+
pass
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import hashlib
2+
import json
3+
import time
4+
from typing import Any, Optional, Dict
5+
from cachebox import CacheBox
6+
from .cache_base import CacheBase, CacheMode, EvictionPolicy
7+
8+
class CacheBoxCache(CacheBase):
9+
def __init__(self, max_size: int = 1000, mode: CacheMode = CacheMode.SINGLE,
10+
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
11+
"""初始化缓存
12+
13+
Args:
14+
max_size: 最大缓存条目数
15+
mode: 缓存模式
16+
eviction_policy: 缓存淘汰策略
17+
"""
18+
self.cache = CacheBox(max_size=max_size)
19+
self.mode = mode
20+
self.eviction_policy = eviction_policy
21+
self.stats = {
22+
"hits": 0,
23+
"misses": 0,
24+
"size": 0,
25+
"evictions": 0,
26+
"ttl_expirations": 0
27+
}
28+
self.ttl_map = {} # 用于存储TTL信息
29+
30+
def _hash_key(self, method_name: str, params: dict) -> str:
31+
"""生成缓存键
32+
33+
Args:
34+
method_name: 方法名
35+
params: 参数字典
36+
37+
Returns:
38+
str: 哈希后的键
39+
"""
40+
key_str = f"{method_name}:{json.dumps(params, sort_keys=True)}"
41+
return hashlib.md5(key_str.encode()).hexdigest()
42+
43+
def get(self, key: str) -> Optional[Any]:
44+
# 检查TTL
45+
if key in self.ttl_map:
46+
if time.time() > self.ttl_map[key]:
47+
del self.ttl_map[key]
48+
self.stats["ttl_expirations"] += 1
49+
return None
50+
51+
value = self.cache.get(key)
52+
if value is not None:
53+
self.stats["hits"] += 1
54+
else:
55+
self.stats["misses"] += 1
56+
return value
57+
58+
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
59+
self.cache.set(key, value)
60+
self.stats["size"] = len(self.cache)
61+
62+
# 设置TTL
63+
if ttl is not None:
64+
self.ttl_map[key] = time.time() + ttl
65+
66+
def delete(self, key: str) -> None:
67+
self.cache.delete(key)
68+
if key in self.ttl_map:
69+
del self.ttl_map[key]
70+
self.stats["size"] = len(self.cache)
71+
72+
def clear(self) -> None:
73+
self.cache.clear()
74+
self.ttl_map.clear()
75+
self.stats["size"] = 0
76+
77+
def has(self, key: str) -> bool:
78+
if key in self.ttl_map and time.time() > self.ttl_map[key]:
79+
del self.ttl_map[key]
80+
return False
81+
return key in self.cache
82+
83+
def get_mode(self) -> CacheMode:
84+
return self.mode
85+
86+
def get_stats(self) -> Dict[str, Any]:
87+
return self.stats
88+
89+
def get_eviction_policy(self) -> EvictionPolicy:
90+
return self.eviction_policy
91+
92+
def set_eviction_policy(self, policy: EvictionPolicy) -> None:
93+
self.eviction_policy = policy
94+
# TODO: 实现不同淘汰策略的具体逻辑
313 KB
Loading

0 commit comments

Comments
 (0)