package connpoolx import ( "context" "encoding/json" "fmt" "sync" "time" "xiaoxin-plus/pkg/langx" uuid "github.com/satori/go.uuid" "github.com/zeromicro/go-zero/core/logc" ) // 长连接的连接池管理 var connList = sync.Map{} type Conn struct { ctx context.Context cancel context.CancelFunc // 取消事件 stream StreamInterface // 连接 userInfo UserInfo // 用户信息 callback NotifyFunc mux sync.Mutex // 互斥锁 createTime time.Time // 连接事件 swapTime time.Time // 最后互相传递消息时间 isClose bool // 是否关闭 } type UserInfo struct { TenantId int64 `json:"tenant_id"` ClientId int64 `json:"client_id"` UserId int64 `json:"user_id"` Token string `json:"token"` TraceId string `json:"trace_id"` // WS的用户UID } type Message struct { MsgId string `json:"msg_id"` // 每条消息独立的UUID (自动赋值) Action string `json:"action"` // 操作 Code int `json:"code"` // 响应码 Message string `json:"message"` // 提示信息 Data string `json:"data"` // 数据信息 MsgTime int64 `json:"msg_time"` // 消息时间 ConnId string `json:"conn_id"` // 连接的uuid OriMsgId string `json:"ori_msg_id"` // 响应原来请求的uuid(对端传来) } // 事件 const ( ActionEvent string = "event" // 指令操作(要求完成某件事) ActionPing string = "ping" // 心跳 ActionTips string = "tips" // 提示消息 ActionMessage string = "message" // 普通消息 ActionLogin string = "login" // 登陆消息 ) type StreamInterface interface { Send(string) error // 发送 Read() (string, error) // 读取 Close() // 关闭 } // 回调方法 type NotifyFunc interface { Register(wc *Conn) error // 注册用户 Unregister(uid string) error // 取消注册 DealEvent(uid string, data string) (string, error) // 处理事件 DealRead(uid string, data string) error // 处理读取的消息 Heartbeat(uid string) // 心跳事件 BeforeSend(uid string, data string) error // 消息发送前回调(可能需要拦截) } func Accept(ctx context.Context, cancel context.CancelFunc, stream StreamInterface, callback NotifyFunc, user UserInfo) error { user.TraceId = uuid.NewV4().String() logc.Infof(ctx, "Accept: %+v", user) conn := &Conn{ ctx: ctx, cancel: cancel, stream: stream, createTime: time.Now(), swapTime: time.Now(), userInfo: user, callback: callback, } connList.Store(user.TraceId, conn) logc.Infof(ctx, "Accept store succ") // 发送消息 // go conn.write() logc.Infof(ctx, "Accept Register begin") err := conn.callback.Register(conn) if err != nil { conn.close() logc.Errorf(ctx, "Accept Register error: %v", err) return err } // 读取消息 go conn.read() // 登录响应 b := conn.buildMessage(ActionLogin, langx.GetCode(langx.Success), user.TraceId, langx.Success, "") conn.send(b) logc.Infof(ctx, "Accept finish") return nil } func (c *Conn) read() { for { select { case <-c.ctx.Done(): return default: message, err := c.stream.Read() // logc.Infof(c.ctx, "pool.read stream traceId:%+v message:%+v", c.userInfo.TraceId, message) if err != nil { c.close() logc.Errorf(c.ctx, "pool.read stream err:%+v", err) return } if message == "" { continue } c.mux.Lock() c.swapTime = time.Now() c.mux.Unlock() // 文字消息 if string(message) == "ping" { c.send("pong") continue } msg, err := c.parseMessage(message) // logc.Infof(c.ctx, "pool.read parseMessage: message:%+v msg:%+v err:%+v\n", message, msg, err) if err != nil { // 无法识别 str := c.buildMessage(ActionTips, langx.GetCode(langx.ErrorMsgUnparse), "ori_data:"+message, langx.ErrorMsgUnparse+" err:"+err.Error(), "") c.send(str) continue } if msg.Action == ActionEvent { data, err := c.callback.DealEvent(c.userInfo.TraceId, msg.Data) if err != nil { str := c.buildMessage(ActionTips, langx.GetCode(langx.ErrorMsgUnparse), "ori_data:"+message, "err:"+err.Error(), msg.MsgId) c.send(str) continue } else { str := c.buildMessage(ActionTips, langx.GetCode(langx.Success), data, langx.Success, "") c.send(str) continue } } else if msg.Action == ActionPing { if msg.Data == "pong" { str := c.buildMessage(ActionPing, langx.GetCode(langx.Success), "pong", langx.Success, "") c.send(str) continue } } else { // 其他消息 err := c.callback.DealRead(c.userInfo.TraceId, message) if err != nil { str := c.buildMessage(ActionTips, langx.GetCode(langx.Error), "ori_data:"+message, "方法错误", msg.MsgId) c.send(str) continue } } } } } func (c *Conn) Send(msg Message) error { str := c.buildMessage(msg.Action, msg.Code, msg.Data, msg.Message, msg.OriMsgId) return c.send(str) } func (c *Conn) send(data string) error { err := c.callback.BeforeSend(c.userInfo.TraceId, data) if err != nil { c.close() logc.Errorf(c.ctx, "Conn send BeforeSend error: %v", err) return err } err = c.stream.Send(data) if err != nil { c.close() logc.Errorf(c.ctx, "Conn send error: %v", err) return err } return nil } func (c *Conn) Close(msg *Message) { if msg != nil { str := c.buildMessage(msg.Action, msg.Code, msg.Data, msg.Message, msg.OriMsgId) c.send(str) } c.close() logc.Infof(c.ctx, "Conn Close: %+v", c.userInfo) } func (c *Conn) GetUserInfo() (uuid string, user UserInfo) { c.mux.Lock() defer c.mux.Unlock() return c.userInfo.TraceId, c.userInfo } func (c *Conn) close() { c.mux.Lock() c.isClose = true c.mux.Unlock() c.callback.Unregister(c.userInfo.TraceId) time.Sleep(time.Second * 1) c.stream.Close() logc.Infof(c.ctx, "Conn close: %+v", c.userInfo) c.cancel() } // 构造消息 func (c *Conn) buildMessage(action string, code int, data string, msg string, oriMsgId string) string { d := Message{ MsgId: uuid.NewV4().String(), Action: action, Code: code, Message: msg, Data: data, MsgTime: time.Now().UnixMilli(), ConnId: c.userInfo.TraceId, OriMsgId: oriMsgId, } b, _ := json.Marshal(d) return string(b) } // 消息转换 func (c *Conn) parseMessage(data string) (*Message, error) { msg := Message{} err := json.Unmarshal([]byte(data), &msg) if err != nil { return nil, err } fmt.Printf("parseMessage data:%+v msg:%+v\n", data, msg) return &msg, nil } func init() { go func() { for { Heartbeat(context.Background()) time.Sleep(time.Second * 5) } }() } func Heartbeat(ctx context.Context) bool { connList.Range(func(key, value interface{}) bool { go func(key interface{}, value interface{}) { uid, ok := key.(string) if !ok { return } c, ok := value.(*Conn) if !ok { return } c.mux.Lock() isClose := c.isClose swapTime := c.swapTime c.mux.Unlock() // 已关闭的链接就清除掉 if isClose { connList.Delete(uid) return } // 超过15s就关闭连接 if swapTime.Add(time.Minute * 2).Before(time.Now()) { c.close() logc.Infof(ctx, "Heartbeat close: %+v", c.userInfo) return } // 回调事件 c.callback.Heartbeat(uid) // // 15s内有发过消息不需要发心跳 // if swapTime.Add(time.Second * 60).After(time.Now()) { // return true // } // // 发送消息 // b, _ := buildMessage(ActionPing, lang.GetCode(lang.Success), "ping", lang.Success, "") // wc.sendMessage(SyncTypeAsync, b) }(key, value) return true }) return true }