Files
connpoolx/connpoolx.go
2026-05-03 22:37:59 +08:00

316 lines
7.6 KiB
Go

package connpoolx
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"xiaoxin-plus/pkg/langx"
uuid "github.com/satori/go.uuid"
"github.com/zeromicro/go-zero/core/logc"
)
// 长连接的连接池管理
var connList = sync.Map{}
type Conn struct {
ctx context.Context
cancel context.CancelFunc // 取消事件
stream StreamInterface // 连接
userInfo UserInfo // 用户信息
callback NotifyFunc
mux sync.Mutex // 互斥锁
createTime time.Time // 连接事件
swapTime time.Time // 最后互相传递消息时间
isClose bool // 是否关闭
}
type UserInfo struct {
TenantId int64 `json:"tenant_id"`
ClientId int64 `json:"client_id"`
UserId int64 `json:"user_id"`
Token string `json:"token"`
TraceId string `json:"trace_id"` // WS的用户UID
}
type Message struct {
MsgId string `json:"msg_id"` // 每条消息独立的UUID (自动赋值)
Action string `json:"action"` // 操作
Code int `json:"code"` // 响应码
Message string `json:"message"` // 提示信息
Data string `json:"data"` // 数据信息
MsgTime int64 `json:"msg_time"` // 消息时间
ConnId string `json:"conn_id"` // 连接的uuid
OriMsgId string `json:"ori_msg_id"` // 响应原来请求的uuid(对端传来)
}
// 事件
const (
ActionEvent string = "event" // 指令操作(要求完成某件事)
ActionPing string = "ping" // 心跳
ActionTips string = "tips" // 提示消息
ActionMessage string = "message" // 普通消息
ActionLogin string = "login" // 登陆消息
)
type StreamInterface interface {
Send(string) error // 发送
Read() (string, error) // 读取
Close() // 关闭
}
// 回调方法
type NotifyFunc interface {
Register(wc *Conn) error // 注册用户
Unregister(uid string) error // 取消注册
DealEvent(uid string, data string) (string, error) // 处理事件
DealRead(uid string, data string) error // 处理读取的消息
Heartbeat(uid string) // 心跳事件
BeforeSend(uid string, data string) error // 消息发送前回调(可能需要拦截)
}
func Accept(ctx context.Context, cancel context.CancelFunc, stream StreamInterface, callback NotifyFunc, user UserInfo) error {
user.TraceId = uuid.NewV4().String()
logc.Infof(ctx, "Accept: %+v", user)
conn := &Conn{
ctx: ctx,
cancel: cancel,
stream: stream,
createTime: time.Now(),
swapTime: time.Now(),
userInfo: user,
callback: callback,
}
connList.Store(user.TraceId, conn)
logc.Infof(ctx, "Accept store succ")
// 发送消息
// go conn.write()
logc.Infof(ctx, "Accept Register begin")
err := conn.callback.Register(conn)
if err != nil {
conn.close()
logc.Errorf(ctx, "Accept Register error: %v", err)
return err
}
// 读取消息
go conn.read()
// 登录响应
b := conn.buildMessage(ActionLogin, langx.GetCode(langx.Success), user.TraceId, langx.Success, "")
conn.send(b)
logc.Infof(ctx, "Accept finish")
return nil
}
func (c *Conn) read() {
for {
select {
case <-c.ctx.Done():
return
default:
message, err := c.stream.Read()
// logc.Infof(c.ctx, "pool.read stream traceId:%+v message:%+v", c.userInfo.TraceId, message)
if err != nil {
c.close()
logc.Errorf(c.ctx, "pool.read stream err:%+v", err)
return
}
if message == "" {
continue
}
c.mux.Lock()
c.swapTime = time.Now()
c.mux.Unlock()
// 文字消息
if string(message) == "ping" {
c.send("pong")
continue
}
msg, err := c.parseMessage(message)
// logc.Infof(c.ctx, "pool.read parseMessage: message:%+v msg:%+v err:%+v\n", message, msg, err)
if err != nil {
// 无法识别
str := c.buildMessage(ActionTips, langx.GetCode(langx.ErrorMsgUnparse), "ori_data:"+message, langx.ErrorMsgUnparse+" err:"+err.Error(), "")
c.send(str)
continue
}
if msg.Action == ActionEvent {
data, err := c.callback.DealEvent(c.userInfo.TraceId, msg.Data)
if err != nil {
str := c.buildMessage(ActionTips, langx.GetCode(langx.ErrorMsgUnparse), "ori_data:"+message, "err:"+err.Error(), msg.MsgId)
c.send(str)
continue
} else {
str := c.buildMessage(ActionTips, langx.GetCode(langx.Success), data, langx.Success, "")
c.send(str)
continue
}
} else if msg.Action == ActionPing {
if msg.Data == "pong" {
str := c.buildMessage(ActionPing, langx.GetCode(langx.Success), "pong", langx.Success, "")
c.send(str)
continue
}
} else {
// 其他消息
err := c.callback.DealRead(c.userInfo.TraceId, message)
if err != nil {
str := c.buildMessage(ActionTips, langx.GetCode(langx.Error), "ori_data:"+message, "方法错误", msg.MsgId)
c.send(str)
continue
}
}
}
}
}
func (c *Conn) Send(msg Message) error {
str := c.buildMessage(msg.Action, msg.Code, msg.Data, msg.Message, msg.OriMsgId)
return c.send(str)
}
func (c *Conn) send(data string) error {
err := c.callback.BeforeSend(c.userInfo.TraceId, data)
if err != nil {
c.close()
logc.Errorf(c.ctx, "Conn send BeforeSend error: %v", err)
return err
}
err = c.stream.Send(data)
if err != nil {
c.close()
logc.Errorf(c.ctx, "Conn send error: %v", err)
return err
}
return nil
}
func (c *Conn) Close(msg *Message) {
if msg != nil {
str := c.buildMessage(msg.Action, msg.Code, msg.Data, msg.Message, msg.OriMsgId)
c.send(str)
}
c.close()
logc.Infof(c.ctx, "Conn Close: %+v", c.userInfo)
}
func (c *Conn) GetUserInfo() (uuid string, user UserInfo) {
c.mux.Lock()
defer c.mux.Unlock()
return c.userInfo.TraceId, c.userInfo
}
func (c *Conn) close() {
c.mux.Lock()
c.isClose = true
c.mux.Unlock()
c.callback.Unregister(c.userInfo.TraceId)
time.Sleep(time.Second * 1)
c.stream.Close()
logc.Infof(c.ctx, "Conn close: %+v", c.userInfo)
c.cancel()
}
// 构造消息
func (c *Conn) buildMessage(action string, code int, data string, msg string, oriMsgId string) string {
d := Message{
MsgId: uuid.NewV4().String(),
Action: action,
Code: code,
Message: msg,
Data: data,
MsgTime: time.Now().UnixMilli(),
ConnId: c.userInfo.TraceId,
OriMsgId: oriMsgId,
}
b, _ := json.Marshal(d)
return string(b)
}
// 消息转换
func (c *Conn) parseMessage(data string) (*Message, error) {
msg := Message{}
err := json.Unmarshal([]byte(data), &msg)
if err != nil {
return nil, err
}
fmt.Printf("parseMessage data:%+v msg:%+v\n", data, msg)
return &msg, nil
}
func init() {
go func() {
for {
Heartbeat(context.Background())
time.Sleep(time.Second * 5)
}
}()
}
func Heartbeat(ctx context.Context) bool {
connList.Range(func(key, value interface{}) bool {
go func(key interface{}, value interface{}) {
uid, ok := key.(string)
if !ok {
return
}
c, ok := value.(*Conn)
if !ok {
return
}
c.mux.Lock()
isClose := c.isClose
swapTime := c.swapTime
c.mux.Unlock()
// 已关闭的链接就清除掉
if isClose {
connList.Delete(uid)
return
}
// 超过15s就关闭连接
if swapTime.Add(time.Minute * 2).Before(time.Now()) {
c.close()
logc.Infof(ctx, "Heartbeat close: %+v", c.userInfo)
return
}
// 回调事件
c.callback.Heartbeat(uid)
// // 15s内有发过消息不需要发心跳
// if swapTime.Add(time.Second * 60).After(time.Now()) {
// return true
// }
// // 发送消息
// b, _ := buildMessage(ActionPing, lang.GetCode(lang.Success), "ping", lang.Success, "")
// wc.sendMessage(SyncTypeAsync, b)
}(key, value)
return true
})
return true
}