commit 1a7a7609cf28f14e986a2ef94577cee1b06b5c7f Author: Yun Date: Sat May 16 19:54:12 2026 +0800 初始化缓存管理器 diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..90d9d6e --- /dev/null +++ b/cache.go @@ -0,0 +1,51 @@ +package cache_manager + +import ( + "fmt" + "time" + + "code.yun.ink/pkg/cache_manager/manager" + "gorm.io/gorm" +) + +// T必须实现IModel接口,提供TableName方法来获取表名,用于缓存键的生成 +type commonTableCache[T any] struct { + tableName string + cache manager.ICache + ttl time.Duration + *manager.EventManager[T] // 对外提供调用(自动缓存) + cacheManager *manager.CacheManager[T] // 操作缓存管理器 +} + +type IModel interface { + TableName() string +} + +// newCommonCache 创建通用缓存结构 +// T 是实现了 IModel 接口的具体数据类型,实现了 IModel 接口的 TableName 方法来获取表名 +// cache 是缓存管理器实例,ttl 是缓存的过期时间,getId 是一个函数,用于从数据实例中提取唯一标识符(ID) +// 适合于监听表数据变化并更新缓存的场景,提供了一个事件管理器来处理数据更新事件 +func NewCommonTableCache[T any](tx *gorm.DB, cache manager.ICache, ttl time.Duration, getId func(data *T) int64) (*commonTableCache[T], error) { + var zero T + + _, ok := any(&zero).(IModel) // 断言,确保 T 实现了 IModel 接口 + if !ok { + return nil, fmt.Errorf("type %T does not implement IModel", zero) + } + + tableName := any(&zero).(IModel).TableName() + + o := &commonTableCache[T]{ + tableName: tableName, + // cache: cache, + ttl: ttl, + } + o.cacheManager = manager.NewCacheManager(tx, cache, tableName, getId) + + o.EventManager = manager.NewEventManager(tx, o.ttl, o.cacheManager, getId) + + // 注册事件监听器,用于处理数据更新和删除操作 + // global.EventRegister(o.tableName, o.cacheManager.Remove) + + return o, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3c1d827 --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module code.yun.ink/pkg/cache_manager + +go 1.25.1 + +require ( + github.com/bytedance/sonic v1.15.1 + github.com/redis/go-redis/v9 v9.19.0 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.5.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.20.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7702037 --- /dev/null +++ b/go.sum @@ -0,0 +1,54 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.1 h1:nJD5PmM0vY7J8CT6MxoqbVAAMhkSmV2HgRAUrrpLoOw= +github.com/bytedance/sonic v1.15.1/go.mod h1:mT2NbXunuaEbnZ+mRIX/vYqKISmgEuHFDI4UzmKx2SA= +github.com/bytedance/sonic/loader v0.5.1 h1:Ygpfa9zwRCCKSlrp5bBP/b/Xzc3VxsAW+5NIYXrOOpI= +github.com/bytedance/sonic/loader v0.5.1/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/manager/cache.go b/manager/cache.go new file mode 100644 index 0000000..3bca772 --- /dev/null +++ b/manager/cache.go @@ -0,0 +1,710 @@ +package manager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "slices" + "sync" + "time" + + "github.com/bytedance/sonic" + "github.com/redis/go-redis/v9" +) + +// 缓存适配层 + +// ICache 定义缓存的标准行为接口 +// Get 方法的 value 参数必须为指针类型,用于反序列化结果。 +// Set/SetLocal/SetRedis 方法 ttl <= 0 时,默认过期时间为 24 小时。 +type ICache interface { + Set(ctx context.Context, key string, value any, ttl time.Duration, ids []int64) error + Get(ctx context.Context, key string, value any) error // value 必须为指针类型 + Del(ctx context.Context, key string) error + Remove(ctx context.Context, ids []int64) error + BatchDel(ctx context.Context, keys []string) error +} + +var ( + ErrCacheNil = errors.New("cache value is nil") +) + +// CacheRedisHash 以 Redis Hash 结构实现的缓存方案 +// 实现原理: +// - 所有缓存数据存储在同一个 Redis Hash(keyName)下,不同业务 key 作为 field。 +// - 只支持整个 hash 的过期时间(通过 Expire 设置),不支持 field 级别的过期。 +// - 适合 field 数量有限、生命周期一致的场景。 +// - 若 field 很多且生命周期不一致,建议使用 CacheRedis。 +// - 频繁访问某个 field 会导致整个 hash 过期时间被刷新,其他 field 也不会过期,需注意内存风险。 +// +// 适用场景: +// - 业务 key 数量有限,生命周期一致,且需要批量操作 hash 的场景。 +type CacheRedisHash struct { + options *managerOptions + redisClient redis.UniversalClient + keyName string + hashTableName string // 记录Key=>ids 映射关系的Hash表,field为key,value为ids的json字符串 + ttl time.Duration +} + +// NewCacheRedisHash 创建新的 Redis Hash 缓存实例 +// ttl <= 0 时,默认过期时间为 24 小时 +// keyName 是 Redis Hash 的 key,实际缓存项存储在该 Hash 中,field 为具体的缓存 key +// 适用于需要在 Redis 中存储大量相关缓存项的场景,避免过多的 Redis key 导致性能问题 +// 这种 hash 结构适合 field 数量有限、生命周期一致的场景。如果 field 很多且生命周期不一致,建议直接用 string 类型的 key-value。 +// 例如,可以存储配置项等,一直不需要过期的数据。 +func NewCacheRedisHash(redis redis.UniversalClient, keyName string, ttl time.Duration, ops ...OptionFunc) *CacheRedisHash { + + options := defaultManagerOptions() + for _, op := range ops { + op(options) + } + + if ttl <= 0 { + ttl = 24 * time.Hour + } + + return &CacheRedisHash{ + options: options, + redisClient: redis, + keyName: keyName, + hashTableName: keyName + ":hashTable", + ttl: ttl, + } +} + +func (l *CacheRedisHash) Get(ctx context.Context, key string, value any) error { + + val, err := l.redisClient.HGet(ctx, l.keyName, key).Result() + if err != nil { + if err == redis.Nil { + return ErrCacheNil // 明确区分缓存未命中和其他错误 + } + l.options.logger.Errorf(ctx, "cache redis get error: %v", err) + return err + } + + err = sonic.Unmarshal([]byte(val), value) + if err != nil { + l.options.logger.Errorf(ctx, "cache redis unmarshal error: %v", err) + return err + } + + l.redisClient.Expire(ctx, l.keyName, l.ttl) + + return nil +} + +func (r *CacheRedisHash) Set(ctx context.Context, key string, value any, ttl time.Duration, ids []int64) error { + + jsonBy, err := sonic.Marshal(value) + if err != nil { + r.options.logger.Errorf(ctx, "cache redis marshal error: %v", err) + return err + } + if ttl <= 0 { + ttl = r.ttl // 默认过期时间 + } + + idsBy, err := sonic.Marshal(ids) + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Set redis marshal ids key:%s error: %v", key, err) + return err + } + p := r.redisClient.Pipeline() + p.HSet(ctx, r.keyName, key, jsonBy) + p.Expire(ctx, r.keyName, ttl) // 设置整个 hash 的过期时间,注意这会影响到 hash 中的所有 field + p.HSet(ctx, r.hashTableName, key, idsBy) + p.Expire(ctx, r.hashTableName, ttl) + + cmders, err := p.Exec(ctx) + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Set redis pipeline exec key:%s error: %v", key, err) + return err + } + + // 检查 pipeline 中每个命令的错误 + for i, cmd := range cmders { + if err := cmd.Err(); err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Set redis pipeline cmd[%d] error: %v", i, err) + return err + } + } + + return nil +} + +func (r *CacheRedisHash) Del(ctx context.Context, key string) error { + + p := r.redisClient.Pipeline() + + p.HDel(ctx, r.keyName, key) + p.HDel(ctx, r.hashTableName, key) + cmders, err := p.Exec(ctx) + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Del redis pipeline exec key:%s error: %v", key, err) + return err + } + + // 检查 pipeline 中每个命令的错误 + for i, cmd := range cmders { + if err := cmd.Err(); err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Del redis pipeline cmd[%d] error: %v", i, err) + return err + } + } + + return nil +} + +func (r *CacheRedisHash) Remove(ctx context.Context, ids []int64) error { + // 获取所有 field=>ids 映射关系 + hashTable, err := r.redisClient.HGetAll(ctx, r.hashTableName).Result() + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Remove redis get keyName:%s error: %v", r.hashTableName, err) + return err + } + + for k, v := range hashTable { + var inIds []int64 + err = sonic.Unmarshal([]byte(v), &inIds) + if err != nil { + continue + } + for _, id := range ids { + if slices.Contains(inIds, id) { + if err := r.Del(ctx, k); err != nil { + r.options.logger.Errorf(ctx, "CacheRedisHash Remove Del key:%s error: %v", k, err) + } + break + } + } + } + return nil +} + +func (r *CacheRedisHash) BatchDel(ctx context.Context, keys []string) error { + if len(keys) == 0 { + return nil + } + err := r.redisClient.HDel(ctx, r.keyName, keys...).Err() + if err != nil { + r.options.logger.Errorf(ctx, "cache redis batch del error: %v", err) + return err + } + return nil +} + +func BatchGetRedisHash[T any](r *CacheRedisHash, ctx context.Context, keys []string) (map[string]*T, error) { + if len(keys) == 0 { // 批量获取的key为空时,返回空map + return map[string]*T{}, nil + } + + res, err := r.redisClient.HMGet(ctx, r.keyName, keys...).Result() + if err != nil && !errors.Is(err, redis.Nil) { + r.options.logger.Errorf(ctx, "cache redis get error: %v", err) + return nil, err + } + + r.options.logger.Infof(ctx, "cache redis get keyName:%s keys: %v len:%d", r.keyName, keys, len(res)) + + result := make(map[string]*T) + for i, v := range res { + if v == nil { + continue + } + var t T + err := sonic.Unmarshal([]byte(v.(string)), &t) + if err != nil { + r.options.logger.Errorf(ctx, "cache redis unmarshal error: %v", err) + return nil, err + } + result[keys[i]] = &t + } + return result, nil +} + +// CacheRedis 以 Redis String 结构实现的缓存方案 +// 实现原理: +// - 每个缓存 key 独立存储为一个 Redis String,支持单独设置过期时间。 +// - 通过 prefix 区分不同业务,避免 key 冲突。 +// - 支持泛型批量获取(BatchGetRedis),底层用 pipeline 提高性能。 +// - 适合 key 数量较多、生命周期不一致的场景。 +// +// 适用场景: +// - 业务 key 数量较多,生命周期不一致,或需要单独控制每个 key 过期时间。 +type CacheRedis struct { + options *managerOptions + redisClient redis.UniversalClient + prefix string // key前缀,避免不同业务之间的key冲突 + hashTableName string // 记录Key=>ids 映射关系的Hash表,field为key,value为ids的json字符串 + sortSetName string // 记录key的过期时间的有序集合,number为过期的毫秒时间戳,value为key +} + +// NewCacheRedis 创建新的redis缓存实例 +func NewCacheRedis(ctx context.Context, redis redis.UniversalClient, prefix string, ops ...OptionFunc) *CacheRedis { + + options := defaultManagerOptions() + for _, op := range ops { + op(options) + } + + r := &CacheRedis{ + redisClient: redis, + prefix: prefix, + hashTableName: fmt.Sprintf("%s:hashTable", prefix), + sortSetName: fmt.Sprintf("%s:sortSet", prefix), + options: options, + } + + go func() { + ticker := time.NewTicker(time.Hour) // 定时检查过期key + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + r.clean(ctx) + } + } + }() + + return r +} + +func (r *CacheRedis) clean(ctx context.Context) error { + // 从sortSet查询已过期的key,并删除 + now := time.Now().UnixMilli() + expiredKeys, err := r.redisClient.ZRangeByScore(ctx, r.sortSetName, &redis.ZRangeBy{ + Min: "0", + Max: fmt.Sprintf("%d", now), + }).Result() + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedis clean redis ZRangeByScore keyName:%s error: %v", r.sortSetName, err) + return err + } + for _, key := range expiredKeys { + r.Del(ctx, key) + } + return nil + +} + +func (r *CacheRedis) Set(ctx context.Context, key string, value any, ttl time.Duration, ids []int64) error { + + key = fmt.Sprintf("%s:%s", r.prefix, key) + // global.Logger.Infof(ctx, "cache redis set key: %s", key) + + jsonBy, err := sonic.Marshal(value) + if err != nil { + r.options.logger.Errorf(ctx, "cache redis marshal error: %v", err) + return err + } + + if ttl <= 0 { + ttl = time.Hour * 24 // 默认过期时间 + } + + idsBy, err := sonic.Marshal(ids) + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedis Set redis marshal ids key:%s error: %v", key, err) + return err + } + + p := r.redisClient.Pipeline() + p.Set(ctx, key, jsonBy, ttl) + p.ZAdd(ctx, r.sortSetName, redis.Z{ + Score: float64(time.Now().Add(ttl).UnixMilli()), + Member: key, + }) + p.HSet(ctx, r.hashTableName, key, idsBy) + p.Expire(ctx, r.hashTableName, time.Hour*24*30) // 一个月 + cmders, err := p.Exec(ctx) + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedis Set redis pipeline exec key:%s error: %v", key, err) + return err + } + + // 检查 pipeline 中每个命令的错误 + for i, cmd := range cmders { + if err := cmd.Err(); err != nil { + r.options.logger.Errorf(ctx, "CacheRedis Set redis pipeline cmd[%d] error: %v", i, err) + return err + } + } + + return nil +} + +func (r *CacheRedis) Get(ctx context.Context, key string, value any) error { + + key = fmt.Sprintf("%s:%s", r.prefix, key) + + val, err := r.redisClient.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return ErrCacheNil // 明确区分缓存未命中和其他错误 + } + r.options.logger.Errorf(ctx, "cache redis get error: %v", err) + return err + } + + err = sonic.Unmarshal([]byte(val), value) + if err != nil { + r.options.logger.Errorf(ctx, "cache redis unmarshal error: %v", err) + return err + } + + return nil +} + +func (l *CacheRedis) Remove(ctx context.Context, ids []int64) error { + // 获取所有 field=>ids 映射关系 + hashTable, err := l.redisClient.HGetAll(ctx, l.hashTableName).Result() + if err != nil { + l.options.logger.Errorf(ctx, "CacheRedis Remove redis get keyName:%s error: %v", l.hashTableName, err) + return err + } + + for k, v := range hashTable { + var inIds []int64 + err = json.Unmarshal([]byte(v), &inIds) + if err != nil { + continue + } + for _, id := range ids { + if slices.Contains(inIds, id) { + if err := l.Del(ctx, k); err != nil { + l.options.logger.Errorf(ctx, "CacheRedis Remove Del key:%s error: %v", k, err) + } + break + } + } + } + return nil +} + +// BatchGetRedis 批量获取Redis缓存,支持泛型 +func BatchGetRedis[T any](r *CacheRedis, ctx context.Context, keys []string) (map[string]*T, error) { + + result := make(map[string]*T) + if len(keys) == 0 { + return result, nil + } + + keyMap := make(map[string]string, len(keys)) + redisKeys := make([]string, 0, len(keys)) + cmder, err := r.redisClient.Pipeline().Pipelined(ctx, func(pipe redis.Pipeliner) error { + for _, key := range keys { + rediskey := fmt.Sprintf("%s:%s", r.prefix, key) + redisKeys = append(redisKeys, rediskey) + keyMap[rediskey] = key + // global.Logger.Infof(ctx, "cache redis get key: %s", key) + pipe.Get(ctx, rediskey) + } + return nil + }) + if err != nil && !errors.Is(err, redis.Nil) { + r.options.logger.Errorf(ctx, "BatchGetRedis pipeline exec error: %v", err) + return nil, err + } + + for i, cmd := range cmder { + + stringCom, ok := cmd.(*redis.StringCmd) + // r.options.logger.Infof(ctx, "cache redis get key: %s ok:%+v val:%+v", redisKeys[i], ok, stringCom) + if !ok { + continue + } + if stringCom.Err() != nil { + continue + } + val := stringCom.Val() + key := keyMap[redisKeys[i]] + var t T + err := sonic.Unmarshal([]byte(val), &t) + if err != nil { + // 反序列化失败时跳过该项,整体不报错 + continue + } + result[key] = &t + } + return result, nil +} + +// CacheRedis 删除缓存 保证Key删除的幂等性和一致性 +func (r *CacheRedis) Del(ctx context.Context, key string) error { + key = fmt.Sprintf("%s:%s", r.prefix, key) + + p := r.redisClient.Pipeline() + // 删除key + p.Del(ctx, key) + // 从过期时间的有序集合中删除 + p.ZRem(ctx, r.sortSetName, key) + // 从hashTable中删除 + p.HDel(ctx, r.hashTableName, key) + cmders, err := p.Exec(ctx) + if err != nil { + r.options.logger.Errorf(ctx, "CacheRedis Del redis pipeline exec key:%s error: %v", key, err) + return err + } + + // 检查 pipeline 中每个命令的错误 + for i, cmd := range cmders { + if err := cmd.Err(); err != nil { + r.options.logger.Errorf(ctx, "CacheRedis Del redis pipeline cmd[%d] error: %v", i, err) + return err + } + } + + r.options.logger.Infof(ctx, "CacheRedis Del redis del key:%s", key) + + return nil +} + +func (r *CacheRedis) BatchDel(ctx context.Context, keys []string) error { + if len(keys) == 0 { + return nil + } + redisKeys := make([]string, 0, len(keys)) + for _, key := range keys { + redisKeys = append(redisKeys, fmt.Sprintf("%s:%s", r.prefix, key)) + } + //global.Logger.Infof(ctx, "CacheRedis BatchDel redis del keys:%+v", redisKeys) + err := r.redisClient.Del(ctx, redisKeys...).Err() + if err != nil { + r.options.logger.Errorf(ctx, "cache redis batch del error: %v", err) + return err + } + return nil +} + +// CacheLocal 本地内存缓存实现 +// 实现原理: +// - 使用 map[string]item 存储所有缓存,item 结构体包含值和过期时间。 +// - 通过互斥锁保证并发安全。 +// - 启动后台协程定期清理过期 key。 +// - 适合单机场景,缓存容量受限于本地内存。 +// +// 适用场景: +// - 适用于高性能、低延迟、单机进程内缓存需求,如热点数据、本地会话等。 +// - 不适合分布式场景或大容量缓存。 +type CacheLocal struct { + options *managerOptions + cacheMap sync.Map + keyTableMap sync.Map // 记录Key=>ids 映射关系的Map,field为key,value为ids的数组 + quit chan struct{} // 用于停止后台清理协程 +} + +// item 内部存储单元,包含值和过期时间 +type item struct { + value any + expiry int64 // 过期时间戳 (Unix Nano) +} + +func NewCacheLocal(cleanupInterval time.Duration, ops ...OptionFunc) *CacheLocal { + + options := defaultManagerOptions() + for _, op := range ops { + op(options) + } + + cache := &CacheLocal{ + options: options, + cacheMap: sync.Map{}, + keyTableMap: sync.Map{}, + quit: make(chan struct{}), + } + + // 启动后台清理协程 + go func() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + cache.cleanup(context.Background()) + case <-cache.quit: + return + } + } + }() + + return cache +} + +// cleanup 定期清理过期的 Key +func (c *CacheLocal) cleanup(ctx context.Context) { + + now := time.Now().UnixNano() + + c.cacheMap.Range(func(key, value any) bool { + it, ok := value.(item) + if !ok { + return true + } + if it.expiry > 0 && now > it.expiry { + c.cacheMap.Delete(key) // 删除过时的缓存项 + } + return true + }) + c.options.logger.Infof(ctx, "CacheLocal cleanup") + return +} + +// Stop 停止后台清理协程(通常在程序退出时调用) +func (c *CacheLocal) Stop() { + select { + case <-c.quit: + // 已关闭 + return + default: + close(c.quit) + } +} + +// Set 将 value 存储为 interface{},建议 value 为可安全复制的类型 +func (l *CacheLocal) Set(ctx context.Context, key string, value any, ttl time.Duration, ids []int64) error { + + if ttl <= 0 { + ttl = time.Hour * 24 // 默认过期时间 + } + + // 存储深拷贝,防止外部修改 + var v any + switch vv := value.(type) { + case nil: + v = nil + case string, int, int64, float64, bool: + v = vv + default: + // 对于结构体、切片、map等,序列化再反序列化实现深拷贝 + by, err := sonic.Marshal(vv) + if err != nil { + return err + } + err = sonic.Unmarshal(by, &v) + if err != nil { + return err + } + } + l.cacheMap.Store(key, item{v, time.Now().Add(ttl).UnixNano()}) + l.keyTableMap.Store(key, ids) + return nil +} + +// Get 直接类型断言赋值,提升类型安全 +func (l *CacheLocal) Get(ctx context.Context, key string, value any) error { + + itemData, exists := l.cacheMap.Load(key) + if !exists { + return ErrCacheNil // 不存在,返回nil表示缓存未命中 + } + + item, ok := itemData.(item) + if !ok { + return ErrCacheNil // 类型断言失败,返回nil表示缓存未命中 + } + + if time.Now().UnixNano() > item.expiry { + l.cacheMap.Delete(key) // 删除过期的缓存 + return ErrCacheNil + } + + // 直接类型断言赋值 + switch v := value.(type) { + case *string: + vv, ok := item.value.(string) + if !ok { + return errors.New("cache type mismatch: expect string") + } + *v = vv + case *int: + vv, ok := item.value.(int) + if !ok { + return errors.New("cache type mismatch: expect int") + } + *v = vv + case *int64: + vv, ok := item.value.(int64) + if !ok { + return errors.New("cache type mismatch: expect int64") + } + *v = vv + case *float64: + vv, ok := item.value.(float64) + if !ok { + return errors.New("cache type mismatch: expect float64") + } + *v = vv + case *bool: + vv, ok := item.value.(bool) + if !ok { + return errors.New("cache type mismatch: expect bool") + } + *v = vv + default: + // 结构体、切片、map等,序列化再反序列化 + by, err := sonic.Marshal(item.value) + if err != nil { + return err + } + err = sonic.Unmarshal(by, value) + if err != nil { + return err + } + } + return nil +} + +// BatchGetLocal 批量获取本地缓存,支持泛型 +func BatchGetLocal[T any](c *CacheLocal, ctx context.Context, keys []string) (map[string]*T, error) { + result := make(map[string]*T) + for _, key := range keys { + var t T + err := c.Get(ctx, key, &t) + if err != nil && !errors.Is(err, ErrCacheNil) { + return nil, err + } + if errors.Is(err, ErrCacheNil) { + continue + } + result[key] = &t + } + return result, nil +} + +func (l *CacheLocal) BatchDel(ctx context.Context, keys []string) error { + return nil +} + +func (l *CacheLocal) Del(ctx context.Context, key string) error { + l.cacheMap.Delete(key) // 删除指定的缓存 + l.keyTableMap.Delete(key) + return nil +} + +func (l *CacheLocal) Remove(ctx context.Context, ids []int64) error { + l.keyTableMap.Range(func(key, value any) bool { + idsArr, ok := value.([]int64) + if !ok { + return true + } + keyStr, ok := key.(string) + if !ok { + return true + } + for _, id := range idsArr { + for _, id2 := range ids { + if id == id2 { + l.Del(ctx, keyStr) + } + } + } + return true + }) + + return nil +} diff --git a/manager/cache_benchmark_test.go b/manager/cache_benchmark_test.go new file mode 100644 index 0000000..017346b --- /dev/null +++ b/manager/cache_benchmark_test.go @@ -0,0 +1,129 @@ +package manager_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "code.yun.ink/pkg/cache_manager/manager" + + "github.com/redis/go-redis/v9" +) + +// 本地缓存压测 +type testStruct struct { + A int + B string +} + +func getRedisClient() *redis.Client { + // 支持通过环境变量覆盖 redis 地址 + addr := os.Getenv("TEST_REDIS_ADDR") + if addr == "" { + addr = "10.40.92.54:30379" + } + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: "123000", + DB: 0, + }) + return client +} + +func BenchmarkCacheLocal_SetGet(b *testing.B) { + cache := manager.NewCacheLocal(time.Minute) + + ctx := context.Background() + defer cache.Stop() + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + fmt.Sprintf("%d", i) + val := testStruct{A: i, B: "val"} + _ = cache.Set(ctx, key, val, time.Minute, []int64{int64(i)}) + var out testStruct + _ = cache.Get(ctx, key, &out) + } +} + +func BenchmarkCacheLocal_Parallel(b *testing.B) { + cache := manager.NewCacheLocal(time.Minute) + ctx := context.Background() + defer cache.Stop() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := "key" + fmt.Sprintf("%d", i) + val := testStruct{A: i, B: "val"} + _ = cache.Set(ctx, key, val, time.Minute, []int64{int64(i)}) + var out testStruct + _ = cache.Get(ctx, key, &out) + i++ + } + }) +} + +// Redis Hash 缓存压测 +func BenchmarkCacheRedisHash_SetGet(b *testing.B) { + client := getRedisClient() + cache := manager.NewCacheRedisHash(client, "bench_hash", time.Hour) + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + fmt.Sprintf("%d", i) + val := testStruct{A: i, B: "val"} + _ = cache.Set(ctx, key, val, time.Hour, []int64{int64(i)}) + var out testStruct + _ = cache.Get(ctx, key, &out) + } +} + +func BenchmarkCacheRedisHash_Parallel(b *testing.B) { + client := getRedisClient() + cache := manager.NewCacheRedisHash(client, "bench_hash", time.Hour) + ctx := context.Background() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := "key" + fmt.Sprintf("%d", i) + val := testStruct{A: i, B: "val"} + _ = cache.Set(ctx, key, val, time.Hour, []int64{int64(i)}) + var out testStruct + _ = cache.Get(ctx, key, &out) + i++ + } + }) +} + +// Redis String 缓存压测 +func BenchmarkCacheRedis_SetGet(b *testing.B) { + client := getRedisClient() + ctx := context.Background() + cache := manager.NewCacheRedis(ctx, client, "bench_str") + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + fmt.Sprintf("%d", i) + val := testStruct{A: i, B: "val"} + _ = cache.Set(ctx, key, val, time.Hour, []int64{int64(i)}) + var out testStruct + _ = cache.Get(ctx, key, &out) + } +} + +func BenchmarkCacheRedis_Parallel(b *testing.B) { + client := getRedisClient() + ctx := context.Background() + cache := manager.NewCacheRedis(ctx, client, "bench_str") + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := "key" + fmt.Sprintf("%d", i) + val := testStruct{A: i, B: "val"} + _ = cache.Set(ctx, key, val, time.Hour, []int64{int64(i)}) + var out testStruct + _ = cache.Get(ctx, key, &out) + i++ + } + }) +} diff --git a/manager/cache_manager.go b/manager/cache_manager.go new file mode 100644 index 0000000..60480ef --- /dev/null +++ b/manager/cache_manager.go @@ -0,0 +1,201 @@ +package manager + +import ( + "bytes" + "context" + "errors" + "fmt" + "sort" + "time" + + "gorm.io/gorm" +) + +// 缓存抽象管理器 +// +// 作用 +// 1. 组装好缓存上下Key,简化调用负担 +// 2. 封装了批量操作的逻辑,提升性能 + +// CacheManager 针对特定类型的缓存管理器 +type CacheManager[T any] struct { + tx *gorm.DB + options *managerOptions + cache ICache + prefix string + idGetter func(*T) int64 +} + +// NewCacheManager 创建特定类型的管理器 +func NewCacheManager[T any](tx *gorm.DB, c ICache, prefix string, idGetter func(*T) int64, ops ...OptionFunc) *CacheManager[T] { + options := defaultManagerOptions() + for _, op := range ops { + op(options) + } + + return &CacheManager[T]{ + options: options, + cache: c, + prefix: prefix, + idGetter: idGetter, + tx: tx, + } +} + +// Get 获取数据(带类型断言) +func (m *CacheManager[T]) Get(ctx context.Context, key string) (*T, error) { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + t := new(T) + err := m.cache.Get(ctx, fullKey, t) + if err != nil { + return nil, err + } + return t, nil +} + +func (m *CacheManager[T]) GetList(ctx context.Context, key string) ([]T, error) { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + t := make([]T, 0) + err := m.cache.Get(ctx, fullKey, &t) + if err != nil { + return nil, err + } + return t, nil +} + +func (m *CacheManager[T]) SetList(ctx context.Context, key string, data []T, ttl time.Duration) error { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + return m.cache.Set(ctx, fullKey, data, ttl, []int64{}) +} + +// Set 设置数据 +func (m *CacheManager[T]) Set(ctx context.Context, key string, data *T, ttl time.Duration, ids []int64) error { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + return m.cache.Set(ctx, fullKey, data, ttl, ids) +} + +// GetByAny data注意必须指针 +func (m *CacheManager[T]) GetByAny(ctx context.Context, key string, data any) error { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + return m.cache.Get(ctx, fullKey, data) +} + +// SetByAny 设置数据 +func (m *CacheManager[T]) SetByAny(ctx context.Context, key string, data any, ttl time.Duration, ids []int64) error { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + return m.cache.Set(ctx, fullKey, data, ttl, ids) +} + +// Del 删除数据 +func (m *CacheManager[T]) Del(ctx context.Context, key string) error { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + return m.cache.Del(ctx, fullKey) +} + +func (m *CacheManager[T]) Remove(ctx context.Context, ids []int64) error { + return m.cache.Remove(ctx, ids) +} + +func (m *CacheManager[T]) BatchDel(ctx context.Context, keys []string) error { + fullKeys := make([]string, 0, len(keys)) + for _, key := range keys { + fullKey := fmt.Sprintf("%s:%s", m.prefix, key) + m.options.logger.Infof(ctx, "delete redis key:%s", fullKey) + fullKeys = append(fullKeys, fullKey) + } + return m.cache.BatchDel(ctx, fullKeys) +} + +// BatchGet 批量获取,只会获取到存在的key,返回的map中不存在的key表示缓存未命中 +func (m *CacheManager[T]) BatchGet(ctx context.Context, keys []string) (map[string]*T, error) { + // key转换 + keyMap := make(map[string]string) + newKeys := make([]string, 0, len(keys)) + for _, key := range keys { + newKey := fmt.Sprintf("%s:%s", m.prefix, key) + keyMap[newKey] = key + newKeys = append(newKeys, newKey) + } + m.options.logger.Infof(ctx, "BatchGet keyMap:%+v", keyMap) + var resp map[string]*T + var err error + switch cache := m.cache.(type) { + case *CacheRedis: + resp, err = BatchGetRedis[T](cache, ctx, newKeys) + case *CacheLocal: + resp, err = BatchGetLocal[T](cache, ctx, newKeys) + case *CacheRedisHash: + resp, err = BatchGetRedisHash[T](cache, ctx, newKeys) + default: + return nil, errors.New("unsupported cache type for batch get") + } + if err != nil { + m.options.logger.Errorf(ctx, "BatchGet error: %v", err) + return nil, err + } + + // m.options.logger.Infof(ctx, "BatchGet resp:%+v", resp) + + var result = make(map[string]*T) + for k, v := range resp { + // m.options.logger.Infof(ctx, "BatchGet k:%v,v:%v", k, keyMap[k]) + result[keyMap[k]] = v + } + return result, err +} + +func (l *CacheManager[T]) GetByFields(ctx context.Context, fields map[string]interface{}, ttl time.Duration) ([]*T, error) { + + key := l.fieldsToKey(fields) + resp := make([]*T, 0) + err := l.cache.Get(ctx, key, &resp) + if err != nil { + return nil, err + } + return resp, nil +} + +func (l *CacheManager[T]) UpdateByFields(ctx context.Context, fields map[string]interface{}, ttl time.Duration) ([]*T, error) { + var data []*T + err := l.tx.WithContext(ctx).Where(fields).Find(&data).Error + if err != nil { + l.options.logger.Errorf(ctx, "db query error: %v", err) + return nil, err + } + + ids := make([]int64, 0, len(data)) + for _, v := range data { + ids = append(ids, l.idGetter(v)) + } + + key := l.fieldsToKey(fields) + // l.options.logger.Infof(ctx, "UpdateByFields key:%v", key) + l.cache.Set(ctx, key, data, ttl, ids) + return data, nil +} + +func (l *CacheManager[T]) DelByFields(ctx context.Context, fields map[string]interface{}) error { + return l.cache.Del(ctx, l.fieldsToKey(fields)) +} + +func (l *CacheManager[T]) fieldsToKey(fields map[string]interface{}) string { + // 需要取出来,排序,然后组装成字符串 + keys := make([]string, 0, len(fields)) + for k := range fields { + keys = append(keys, k) + } + sort.Strings(keys) + var buf bytes.Buffer + + buf.WriteString(l.prefix + ":{") + for i, key := range keys { + if i != 0 { + buf.WriteString(",") + } + buf.WriteString(key) + buf.WriteString(":") + buf.WriteString(fmt.Sprintf("%v", fields[key])) + } + buf.WriteString("}") + return buf.String() +} diff --git a/manager/cache_manager_test.go b/manager/cache_manager_test.go new file mode 100644 index 0000000..5b3e34a --- /dev/null +++ b/manager/cache_manager_test.go @@ -0,0 +1,444 @@ +package manager + +import ( + "context" + "encoding/json" + "fmt" + "os" + "testing" + "time" + "wallet-pay-api/common/global" + "wallet-pay-api/common/model" + "wallet-pay-api/pkg/loggerx" + + "github.com/redis/go-redis/v9" +) + +func newTestRedisCache() *CacheRedis { + // 支持通过环境变量覆盖 redis 地址 + addr := os.Getenv("TEST_REDIS_ADDR") + if addr == "" { + addr = "10.40.92.54:30379" + } + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: "123000", + DB: 0, + }) + return NewCacheRedis(context.Background(), client, "test") +} +func TestBatchGetRedis(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + keys := []string{"k1", "k2", "k3"} + resp, err := BatchGetRedis[model.YuebaoUserAsset](cache, ctx, keys) + if err != nil { + t.Fatal(err) + } + t.Logf("%+v", resp) + +} + +func TestCacheRedis_Types(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + // int + err := cache.Set(ctx, "int", 123, time.Second*10, []int64{123}) + if err != nil { + t.Fatal(err) + } + var i int + err = cache.Get(ctx, "int", &i) + if err != nil || i != 123 { + t.Errorf("int get failed: %v, %d", err, i) + } + // string + err = cache.Set(ctx, "str", "abc", time.Second*10, []int64{124}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "str", &s) + if err != nil || s != "abc" { + t.Errorf("string get failed: %v, %s", err, s) + } + // struct + st := TestStruct{Name: "foo"} + err = cache.Set(ctx, "struct", st, time.Second*10, []int64{125}) + if err != nil { + t.Fatal(err) + } + var st2 TestStruct + err = cache.Get(ctx, "struct", &st2) + if err != nil || st2.Name != "foo" { + t.Errorf("struct get failed: %v, %+v", err, st2) + } + // map + m := map[string]int{"a": 1} + err = cache.Set(ctx, "map", m, time.Second*10, []int64{126}) + if err != nil { + t.Fatal(err) + } + var m2 map[string]int + err = cache.Get(ctx, "map", &m2) + if err != nil || m2["a"] != 1 { + t.Errorf("map get failed: %v, %+v", err, m2) + } + // slice + sl := []string{"x", "y"} + err = cache.Set(ctx, "slice", sl, time.Second*10, []int64{127}) + if err != nil { + t.Fatal(err) + } + var sl2 []string + err = cache.Get(ctx, "slice", &sl2) + if err != nil || len(sl2) != 2 || sl2[0] != "x" { + t.Errorf("slice get failed: %v, %+v", err, sl2) + } +} + +func TestCacheRedis_Expire(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + err := cache.Set(ctx, "expire", "v", time.Millisecond*200, []int64{128}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "expire", &s) + if err != nil || s != "v" { + t.Errorf("expire get failed: %v, %s", err, s) + } + time.Sleep(time.Millisecond * 250) + err = cache.Get(ctx, "expire", &s) + if err == nil { + t.Errorf("should expire, got: %v", s) + } +} + +func TestCacheRedis_Del(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + err := cache.Set(ctx, "del", 1, time.Second*10, []int64{129}) + if err != nil { + t.Fatal(err) + } + err = cache.Del(ctx, "del") + if err != nil { + t.Fatal(err) + } + var i int + err = cache.Get(ctx, "del", &i) + if err == nil { + t.Errorf("should not get deleted key") + } +} + +func TestCacheRedis_TypeMismatch(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + err := cache.Set(ctx, "mismatch", 123, time.Second*10, []int64{123}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "mismatch", &s) + if err == nil { + t.Errorf("should type mismatch") + } +} + +func TestCacheRedis_BatchGet(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + cache.Set(ctx, "ba", 1, time.Second*10, []int64{1}) + cache.Set(ctx, "bb", 2, time.Second*10, []int64{2}) + cache.Set(ctx, "bc", 3, time.Second*10, []int64{3}) + keys := []string{"bj", "ba", "bb", "bx"} + m, err := BatchGetRedis[int](cache, ctx, keys) + if err != nil { + t.Fatal(err) + } + b, _ := json.Marshal(m) + fmt.Println(string(b)) + if len(m) != 2 || *m["ba"] != 1 || *m["bb"] != 2 { + t.Errorf("batch get failed: %+v", m) + } +} + +func TestCacheManager_Redis_Generic(t *testing.T) { + ctx := context.Background() + cache := newTestRedisCache() + m := NewCacheManager[TestStruct](cache, "pref", func(t *TestStruct) int64 { return 0 }) + data := TestStruct{Name: "bar"} + err := m.Set(ctx, "k1", &data, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + got, err := m.Get(ctx, "k1") + if err != nil || got.Name != "bar" { + t.Errorf("manager redis get failed: %v, %+v", err, got) + } + err = m.Del(ctx, "k1") + if err != nil { + t.Fatal(err) + } + got, err = m.Get(ctx, "k1") + if err == nil { + t.Errorf("manager redis del failed") + } +} + +func TestICacheInterface(t *testing.T) { + ctx := context.Background() + var cache ICache = NewCacheLocal(time.Second * 10) + err := cache.Set(ctx, "iface", "val", time.Second*10, []int64{130}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "iface", &s) + if err != nil || s != "val" { + t.Errorf("iface get failed: %v, %s", err, s) + } + err = cache.Del(ctx, "iface") + if err != nil { + t.Fatal(err) + } + err = cache.Get(ctx, "iface", &s) + if err == nil { + t.Errorf("iface del failed") + } +} + +func TestCacheManager_Generic(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Second * 10) + m := NewCacheManager[TestStruct](cache, "pref", func(t *TestStruct) int64 { return 0 }) + data := TestStruct{Name: "bar"} + err := m.Set(ctx, "k1", &data, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + got, err := m.Get(ctx, "k1") + if err != nil || got.Name != "bar" { + t.Errorf("manager get failed: %v, %+v", err, got) + } + err = m.Del(ctx, "k1") + if err != nil { + t.Fatal(err) + } + got, err = m.Get(ctx, "k1") + if err == nil { + t.Errorf("manager del failed") + } +} + +func TestCacheLocal_Boundary(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Second * 10) + + // 空 key + err := cache.Set(ctx, "", "v", time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "", &s) + if err != nil || s != "v" { + t.Errorf("empty key failed: %v, %s", err, s) + } + // nil value + err = cache.Set(ctx, "nil", nil, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var x any + err = cache.Get(ctx, "nil", &x) + if err != nil || x != nil { + t.Errorf("nil value failed: %v, %v", err, x) + } + // 极短 ttl + err = cache.Set(ctx, "short", "v", time.Millisecond*1, []int64{0}) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 5) + err = cache.Get(ctx, "short", &s) + if err == nil { + t.Errorf("short ttl should expire") + } + // 重复 set + err = cache.Set(ctx, "dup", "v1", time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + err = cache.Set(ctx, "dup", "v2", time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + err = cache.Get(ctx, "dup", &s) + if err != nil || s != "v2" { + t.Errorf("dup set failed: %v, %s", err, s) + } + // Stop 后行为 + cache.Stop() + // Stop 后再次 Stop 不应 panic + cache.Stop() +} + +func TestCacheLocal_Concurrent(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Second * 10) + + n := 100 + done := make(chan struct{}, n*2) + for i := 0; i < n; i++ { + go func(i int) { + key := "k" + string(rune(i)) + cache.Set(ctx, key, i, time.Second*10, []int64{int64(i)}) + done <- struct{}{} + }(i) + go func(i int) { + key := "k" + string(rune(i)) + var v int + cache.Get(ctx, key, &v) + done <- struct{}{} + }(i) + } + for i := 0; i < n*2; i++ { + <-done + } +} + +func TestCacheLocal_Types(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Second * 10) + // int + err := cache.Set(ctx, "int", 123, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var i int + err = cache.Get(ctx, "int", &i) + if err != nil || i != 123 { + t.Errorf("int get failed: %v, %d", err, i) + } + // string + err = cache.Set(ctx, "str", "abc", time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "str", &s) + if err != nil || s != "abc" { + t.Errorf("string get failed: %v, %s", err, s) + } + // struct + st := TestStruct{Name: "foo"} + err = cache.Set(ctx, "struct", st, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var st2 TestStruct + err = cache.Get(ctx, "struct", &st2) + if err != nil || st2.Name != "foo" { + t.Errorf("struct get failed: %v, %+v", err, st2) + } + // map + m := map[string]int{"a": 1} + err = cache.Set(ctx, "map", m, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var m2 map[string]int + err = cache.Get(ctx, "map", &m2) + if err != nil || m2["a"] != 1 { + t.Errorf("map get failed: %v, %+v", err, m2) + } + // slice + sl := []string{"x", "y"} + err = cache.Set(ctx, "slice", sl, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var sl2 []string + err = cache.Get(ctx, "slice", &sl2) + if err != nil || len(sl2) != 2 || sl2[0] != "x" { + t.Errorf("slice get failed: %v, %+v", err, sl2) + } +} + +func TestCacheLocal_Expire(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Millisecond * 50) + + err := cache.Set(ctx, "expire", "v", time.Millisecond*100, []int64{0}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "expire", &s) + if err != nil || s != "v" { + t.Errorf("expire get failed: %v, %s", err, s) + } + time.Sleep(time.Millisecond * 120) + err = cache.Get(ctx, "expire", &s) + if err == nil { + t.Errorf("should expire, got: %v", s) + } +} + +func TestCacheLocal_Del(t *testing.T) { + ctx := context.Background() + + cache := NewCacheLocal(time.Second * 10) + err := cache.Set(ctx, "del", 1, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + err = cache.Del(ctx, "del") + if err != nil { + t.Fatal(err) + } + var i int + err = cache.Get(ctx, "del", &i) + if err == nil { + t.Errorf("should not get deleted key") + } +} + +func TestCacheLocal_TypeMismatch(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Second * 10) + err := cache.Set(ctx, "mismatch", 123, time.Second*10, []int64{0}) + if err != nil { + t.Fatal(err) + } + var s string + err = cache.Get(ctx, "mismatch", &s) + if err == nil { + t.Errorf("should type mismatch") + } +} + +func TestCacheLocal_BatchGet(t *testing.T) { + ctx := context.Background() + cache := NewCacheLocal(time.Second * 10) + + cache.Set(ctx, "a", 1, time.Second*10, []int64{0}) + cache.Set(ctx, "b", 2, time.Second*10, []int64{0}) + cache.Set(ctx, "c", 3, time.Second*10, []int64{0}) + keys := []string{"a", "b", "x"} + m, err := BatchGetLocal[int](cache, ctx, keys) + if err != nil { + t.Fatal(err) + } + if len(m) != 2 || *m["a"] != 1 || *m["b"] != 2 { + t.Errorf("batch get failed: %+v", m) + } +} + +type TestStruct struct { + Name string `json:"name"` +} diff --git a/manager/event_manager.go b/manager/event_manager.go new file mode 100644 index 0000000..22b896e --- /dev/null +++ b/manager/event_manager.go @@ -0,0 +1,181 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "time" + + "gorm.io/gorm" +) + +// 公共的事件管理器 +// +// 作用 +// 1. 对外提供统一的缓存操作接口,比如根据ID获取、更新、删除等 +// 2. 缓存读取不到的时候会自动从DB加载数据,简化业务代码的实现 +// 3. 处理缓存穿透、缓存雪崩等问题,提升系统的稳定性和性能 + +type EventManager[T any] struct { + tx *gorm.DB + options *managerOptions + ttl time.Duration + manager *CacheManager[T] + idGetter func(*T) int64 +} + +func NewEventManager[T any](tx *gorm.DB, ttl time.Duration, manager *CacheManager[T], idGetter func(*T) int64, ops ...OptionFunc) *EventManager[T] { + options := defaultManagerOptions() + for _, op := range ops { + op(options) + } + return &EventManager[T]{ + tx: tx, + options: options, + ttl: ttl, + manager: manager, + idGetter: idGetter, + } +} + +func (l *EventManager[T]) DeleteById(ctx context.Context, id int64) error { + err := l.manager.Del(ctx, fmt.Sprintf("%d", id)) + if err != nil { + l.options.logger.Errorf(ctx, "EventManager DeleteById error: %v", err) + return err + } + l.options.logger.Infof(ctx, "EventManager DeleteById id:%d success", id) + return nil +} + +func (l *EventManager[T]) DeleteByKey(ctx context.Context, key string) error { + err := l.manager.Del(ctx, key) + if err != nil { + l.options.logger.Errorf(ctx, "EventManager DeleteByKey error: %v", err) + return err + } + l.options.logger.Infof(ctx, "EventManager DeleteByKey key:%s success", key) + return nil +} + +func (e *EventManager[T]) UpdateById(ctx context.Context, id int64) error { + _, err := e.UpdateByIdWithValue(ctx, id) + return err +} + +func (e *EventManager[T]) UpdateByIdWithValue(ctx context.Context, id int64) (*T, error) { + // 查询DB & 更新缓存 + data := new(T) + err := e.tx.WithContext(ctx).Where("id = ?", id).First(data).Error + if err != nil { + e.options.logger.Errorf(ctx, "EventManager UpdateByIdWithValue id:%d db query error: %v", id, err) + return nil, err + } + err = e.manager.Set(ctx, fmt.Sprintf("%d", id), data, e.ttl, []int64{id}) + if err != nil { + e.options.logger.Errorf(ctx, "EventManager cache set id:%d error: %v", id, err) + return nil, err + } + e.options.logger.Infof(ctx, "EventManager cache update id:%d", id) + return data, nil +} + +func (e *EventManager[T]) GetById(ctx context.Context, id int64) (*T, error) { + key := fmt.Sprintf("%d", id) + data, err := e.manager.Get(ctx, key) + if err != nil && !errors.Is(err, ErrCacheNil) { + e.options.logger.Errorf(ctx, "EventManager Get id:%d error: %v", id, err) + return nil, err + } + if data != nil { + e.options.logger.Infof(ctx, "EventManager get by cache id:%d", id) + return data, nil + } + + // TODO: 缓存雪崩问题,这里需要处理一下,比如设置一个兜底数据之类的 + + d, err := e.UpdateByIdWithValue(ctx, id) + if err != nil { + e.options.logger.Errorf(ctx, "EventManager UpdateByIdWithValue id:%d err:%v", id, err) + return nil, err + } + return d, nil +} +func (e *EventManager[T]) GetByIds(ctx context.Context, ids []int64) ([]*T, error) { + resp := make([]*T, 0, len(ids)) + + keys := make([]string, 0, len(ids)) + keyMap := make(map[string]struct{}, len(ids)) + for _, id := range ids { + key := fmt.Sprintf("%d", id) + keys = append(keys, key) + keyMap[key] = struct{}{} + } + + // global.Logger.Infof(ctx, "get by cache ids:%v keys:%+v", ids, keys) + + respMap, err := e.manager.BatchGet(ctx, keys) + if err != nil { + e.options.logger.Errorf(ctx, "EventManager cache batch get ids:%v error: %v", ids, err) + return nil, err + } + // e.options.logger.Infof(ctx, "get by cache ids:%v,respMap:%+v", ids, respMap) + dbs := make([]string, 0) + for key := range keyMap { + if val, ok := respMap[key]; ok { + resp = append(resp, val) + } else { + dbs = append(dbs, key) + } + } + + if len(dbs) > 0 { + dbQuerys := make([]*T, 0, len(dbs)) + db := e.tx.WithContext(ctx).Where("id in (?)", dbs) + err := db.Find(&dbQuerys).Error + if err != nil { + e.options.logger.Errorf(ctx, "EventManager db query ids:%v error: %v", dbs, err) + return nil, err + } + for _, v := range dbQuerys { + resp = append(resp, v) + + err = e.manager.Set(ctx, fmt.Sprintf("%d", e.idGetter(v)), v, e.ttl, []int64{e.idGetter(v)}) + if err != nil { + e.options.logger.Errorf(ctx, "EventManager cache set id:%d error: %v", e.idGetter(v), err) + } + } + } + + return resp, nil +} + +func (l *EventManager[T]) GetByFields(ctx context.Context, fields map[string]interface{}) ([]*T, error) { + resp, err := l.manager.GetByFields(ctx, fields, l.ttl) + if err == nil { + return resp, nil + } + return l.manager.UpdateByFields(ctx, fields, l.ttl) +} + +func (e *EventManager[T]) DeleteCacheById(ctx context.Context, id int64) error { + key := fmt.Sprintf("%d", id) + return e.manager.Del(ctx, key) +} + +func (e *EventManager[T]) DeleteCacheByIds(ctx context.Context, ids []int64) error { + keys := make([]string, 0, len(ids)) + for _, id := range ids { + keys = append(keys, fmt.Sprintf("%d", id)) + } + return e.manager.BatchDel(ctx, keys) +} + +func (e *EventManager[T]) DeleteCacheByFields(ctx context.Context, fields map[string]interface{}) error { + key := e.manager.fieldsToKey(fields) + return e.manager.Del(ctx, fmt.Sprintf("%v", key)) +} + +func (e *EventManager[T]) BatchDel(ctx context.Context, keys []string) error { + return e.manager.BatchDel(ctx, keys) +} diff --git a/manager/options.go b/manager/options.go new file mode 100644 index 0000000..0a5b0dd --- /dev/null +++ b/manager/options.go @@ -0,0 +1,39 @@ +package manager + +import ( + "context" + "log" +) + +type managerOptions struct { + logger ILogger +} + +func defaultManagerOptions() *managerOptions { + return &managerOptions{ + logger: &DefLogger{}, + } +} + +type OptionFunc func(*managerOptions) + +func WithLogger(log ILogger) OptionFunc { + return func(o *managerOptions) { + o.logger = log + } +} + +type ILogger interface { + Infof(ctx context.Context, format string, args ...any) + Errorf(ctx context.Context, format string, args ...any) +} + +type DefLogger struct{} + +func (l *DefLogger) Infof(ctx context.Context, format string, args ...any) { + log.Printf("[info]"+format, args...) +} + +func (l *DefLogger) Errorf(ctx context.Context, format string, args ...any) { + log.Printf("[error]"+format, args...) +}