From 4a131290d61047f74ae9c9cd36ce430854374860 Mon Sep 17 00:00:00 2001 From: Yun Date: Wed, 17 Sep 2025 19:12:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9A=E6=97=B6=E5=99=A8=E6=94=B9=E7=89=88?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- locks.go | 4 +- locks_test.go | 2 +- lockx.go | 222 ++++++++++++++++++++++++++++++++++---------------- lockx_test.go | 15 ++-- options.go | 40 +++++---- 5 files changed, 189 insertions(+), 94 deletions(-) diff --git a/locks.go b/locks.go index 3652e90..f34d019 100644 --- a/locks.go +++ b/locks.go @@ -20,9 +20,9 @@ func Init(ctx context.Context, redis redis.UniversalClient, opts ...Option) erro // 新起一个锁对象 // 先Init后New再Lock -func New(ctx context.Context, uniqueKey string) (*globalLock, error) { +func New(ctx context.Context, uniqueKey string) (*GlobalLock, error) { if redisConn == nil { return nil, fmt.Errorf("redis client is nil") } - return NewGlobalLock(ctx, redisConn, uniqueKey), nil + return NewGlobalLock(ctx, redisConn, uniqueKey, globalOpts...) } diff --git a/locks_test.go b/locks_test.go index 726c340..6c30feb 100644 --- a/locks_test.go +++ b/locks_test.go @@ -27,7 +27,7 @@ func TestSimpleLock(t *testing.T) { t.Log(err) return } - if l.Lock() { + if b, _ := l.Lock(); b { fmt.Println("lock success") l.Unlock() } diff --git a/lockx.go b/lockx.go index 0acda8a..090eab1 100644 --- a/lockx.go +++ b/lockx.go @@ -2,6 +2,8 @@ package lockx import ( "context" + "fmt" + "sync" "time" "github.com/go-redis/redis/v8" @@ -9,98 +11,158 @@ import ( ) // 全局锁 -type globalLock struct { - redis redis.UniversalClient - ctx context.Context - cancel context.CancelFunc - uniqueKey string - value string +type GlobalLock struct { + redis redis.UniversalClient + ctx context.Context + cancel context.CancelFunc + uniqueKey string + value string + isClosed bool + closeLock sync.RWMutex + options *option + stopRefresh chan struct{} + wg sync.WaitGroup } -func NewGlobalLock(ctx context.Context, red redis.UniversalClient, uniqueKey string) *globalLock { - ctx, cancel := context.WithTimeout(ctx, opt.lockTimeout) - - u, _ := uuid.NewV7() - - return &globalLock{ - redis: red, - ctx: ctx, - cancel: cancel, - uniqueKey: uniqueKey, - value: u.String(), +func NewGlobalLock(ctx context.Context, red redis.UniversalClient, uniqueKey string, opts ...Option) (*GlobalLock, error) { + options := defaultOption() + for _, opt := range opts { + opt(options) } + + ctx, cancel := context.WithTimeout(ctx, options.lockTimeout) + + u, err := uuid.NewV7() + if err != nil { + cancel() + return nil, fmt.Errorf("failed to generate UUID: %w", err) + } + + return &GlobalLock{ + redis: red, + ctx: ctx, + cancel: cancel, + uniqueKey: uniqueKey, + value: u.String(), + stopRefresh: make(chan struct{}), + options: options, + }, nil } // 获取上下文 -func (l *globalLock) GetCtx() context.Context { +func (l *GlobalLock) GetCtx() context.Context { return l.ctx } // 获取锁 -func (g *globalLock) Lock() bool { - +func (g *GlobalLock) Lock() (bool, error) { script := ` - local token = redis.call('get',KEYS[1]) - if token == false - then - return redis.call('set',KEYS[1],ARGV[1],'EX',ARGV[2]) + if redis.call('set', KEYS[1], ARGV[1], 'NX', 'EX', ARGV[2]) then + return 'OK' + else + local current_val = redis.call('get', KEYS[1]) + if current_val == ARGV[1] then + redis.call('expire', KEYS[1], ARGV[2]) + return 'OK' + else + return 'ERROR' + end end - return 'ERROR' ` - resp, err := g.redis.Eval(g.ctx, script, []string{g.uniqueKey}, g.value, 5).Result() - if resp != "OK" { - opt.logger.Errorf(g.ctx, "global lock err resp:%+v err:%+v uniKey:%+v value:%+v", resp, err, g.uniqueKey, g.value) + resp, err := g.redis.Eval(g.ctx, script, []string{g.uniqueKey}, g.value, int(g.options.Expiry.Seconds())).Result() + if err != nil { + if g.options.logger != nil { + g.options.logger.Errorf(g.ctx, "global lock failed: %v, key: %s, value: %s", err, g.uniqueKey, g.value) + } + return false, err } + if resp == "OK" { - g.refresh() - return true + g.startRefresh() + return true, nil } - return false + + return false, nil } // 尝试获取锁 -func (g *globalLock) Try(limitTimes int) bool { - for i := 0; i < limitTimes; i++ { - if g.Lock() { - return true +func (g *GlobalLock) Try() (bool, error) { + for i := 0; i < g.options.MaxRetryTimes; i++ { + success, err := g.Lock() + if err != nil { + return false, err + } + if success { + return true, nil + } + + select { + case <-time.After(g.options.RetryInterval): + continue + case <-g.ctx.Done(): + return false, g.ctx.Err() } - time.Sleep(time.Millisecond * 100) } - return false + return false, nil } // 删除锁 -func (g *globalLock) Unlock() bool { +func (g *GlobalLock) Unlock() error { + g.closeLock.Lock() + defer g.closeLock.Unlock() + + if g.isClosed { + return nil + } + g.isClosed = true + + // 停止刷新goroutine + close(g.stopRefresh) + g.wg.Wait() script := ` - local token = redis.call('get',KEYS[1]) - if token == ARGV[1] - then - redis.call('del',KEYS[1]) - return 'OK' + if redis.call('get', KEYS[1]) == ARGV[1] then + return redis.call('del', KEYS[1]) + else + return 0 end - return 'ERROR' ` resp, err := g.redis.Eval(g.ctx, script, []string{g.uniqueKey}, g.value).Result() - if resp != "OK" { - opt.logger.Errorf(g.ctx, "global Unlock err resp:%+v err:%+v uniKey:%+v value:%+v", resp, err, g.uniqueKey, g.value) + if err != nil { + if g.options.logger != nil { + g.options.logger.Infof(g.ctx, "global unlock may have failed: %v, key: %s, value: %s", err, g.uniqueKey, g.value) + } + // 即使删除失败也继续执行,因为锁可能会自动过期 } + g.cancel() - return true + + if delCount, ok := resp.(int64); ok && delCount == 1 { + return nil + } + return fmt.Errorf("lock was already released or owned by another client") } -// 刷新锁 -func (g *globalLock) refresh() { +// 启动刷新goroutine +func (g *GlobalLock) startRefresh() { + g.wg.Add(1) go func() { - t := time.NewTicker(time.Second) + defer g.wg.Done() + + ticker := time.NewTicker(g.options.RefreshPeriod) + defer ticker.Stop() + for { select { - case <-t.C: - g.refreshExec() + case <-ticker.C: + if !g.refreshExec() { + return + } + case <-g.stopRefresh: + return case <-g.ctx.Done(): - t.Stop() g.Unlock() return } @@ -108,25 +170,45 @@ func (g *globalLock) refresh() { }() } -func (g *globalLock) refreshExec() bool { - script := ` - local token = redis.call('get',KEYS[1]) - if token == ARGV[1] - then - redis.call('set',KEYS[1],ARGV[1],'EX',ARGV[2]) - return 'OK' - end - return 'ERROR' - ` +// 执行刷新操作 +func (g *GlobalLock) refreshExec() bool { + g.closeLock.RLock() + defer g.closeLock.RUnlock() - resp, err := g.redis.Eval(g.ctx, script, []string{g.uniqueKey}, g.value, 5).Result() - if err != nil { - opt.logger.Errorf(g.ctx, "global refresh err resp:%+v err:%+v uniKey:%+v value:%+v", resp, err, g.uniqueKey, g.value) - } - if resp == "ERROR" { - opt.logger.Errorf(g.ctx, "global refresh err resp:%+v err:%+v uniKey:%+v value:%+v", resp, err, g.uniqueKey, g.value) - g.Unlock() + if g.isClosed { return false } + + script := ` + if redis.call('get', KEYS[1]) == ARGV[1] then + redis.call('expire', KEYS[1], ARGV[2]) + return 1 + else + return 0 + end + ` + + resp, err := g.redis.Eval(g.ctx, script, []string{g.uniqueKey}, g.value, int(g.options.Expiry.Seconds())).Result() + if err != nil { + if g.options.logger != nil { + g.options.logger.Errorf(g.ctx, "global refresh failed: %v, key: %s, value: %s", err, g.uniqueKey, g.value) + } + return false + } + + if refreshed, ok := resp.(int64); !ok || refreshed != 1 { + if g.options.logger != nil { + g.options.logger.Errorf(g.ctx, "global refresh failed, lock may be lost, key: %s, value: %s", g.uniqueKey, g.value) + } + return false + } + return true } + +// 检查锁是否已关闭 +func (g *GlobalLock) IsClosed() bool { + g.closeLock.RLock() + defer g.closeLock.RUnlock() + return g.isClosed +} diff --git a/lockx_test.go b/lockx_test.go index a223205..eb9927b 100644 --- a/lockx_test.go +++ b/lockx_test.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" "testing" + "time" "github.com/go-redis/redis/v8" "github.com/yuninks/lockx" @@ -43,18 +44,22 @@ func TestLockx(t *testing.T) { wg := sync.WaitGroup{} - for i := 0; i < 10000; i++ { + for i := 0; i < 20; i++ { wg.Add(1) go func(i int) { defer wg.Done() - lock := lockx.NewGlobalLock(ctx, client, "lockx:test") - defer lock.Unlock() - if !lock.Lock() { + lock, _ := lockx.NewGlobalLock(ctx, client, "lockx:test") + if b, _ := lock.Lock(); !b { fmt.Println("lock error", i) return } - fmt.Println("ssss", i) + defer lock.Unlock() + + fmt.Println("ssss2", i) + + time.Sleep(time.Second * 2) }(i) + time.Sleep(time.Second) } wg.Wait() diff --git a/options.go b/options.go index 2bf41f9..f9eeee5 100644 --- a/options.go +++ b/options.go @@ -7,47 +7,55 @@ import ( ) type option struct { - lockTimeout time.Duration // 锁的超时时间 - logger Logger // 日志 + lockTimeout time.Duration // 锁的超时时间 + Expiry time.Duration // 单次刷新有效时间 + MaxRetryTimes int // 尝试次数 + RetryInterval time.Duration // 尝试间隔 + RefreshPeriod time.Duration // 刷新间隔 + logger Logger // 日志 } func defaultOption() *option { return &option{ - lockTimeout: time.Minute * 60, - logger: &print{}, + lockTimeout: time.Minute * 60, + Expiry: 5 * time.Second, + RefreshPeriod: 1 * time.Second, + MaxRetryTimes: 3, + RetryInterval: 100 * time.Millisecond, + logger: &print{}, } } -var opt *option - -func init() { - opt = defaultOption() -} +var globalOpts []Option // 设置 func InitOption(opts ...Option) { - for _, app := range opts { - app(opt) - } + globalOpts = opts } type Option func(*option) -func SetTimeout(t time.Duration) Option { +func WithTimeout(t time.Duration) Option { return func(o *option) { o.lockTimeout = t } } -func SetLogger(logger Logger) Option { +func WithLogger(logger Logger) Option { return func(o *option) { o.logger = logger } } +func WithExpiry(expiry time.Duration) Option { + return func(o *option) { + o.Expiry = expiry + } +} + type Logger interface { Errorf(ctx context.Context, format string, v ...any) - Printf(ctx context.Context, format string, v ...any) + Infof(ctx context.Context, format string, v ...any) } type print struct{} @@ -56,6 +64,6 @@ func (*print) Errorf(ctx context.Context, format string, v ...any) { log.Printf(format, v...) } -func (*print) Printf(ctx context.Context, format string, v ...any) { +func (*print) Infof(ctx context.Context, format string, v ...any) { log.Printf(format, v...) }