From d6f5d1a4e66fb93397d9f441ed632f7ab0962481 Mon Sep 17 00:00:00 2001 From: Yun Date: Wed, 27 Dec 2023 19:10:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- buildSql.go | 38 ++++++++++++++++++++++++++++++++++++++ gormx.go | 27 +++++++++++++++++++++++++++ logger.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+) create mode 100644 buildSql.go create mode 100644 gormx.go create mode 100644 logger.go diff --git a/buildSql.go b/buildSql.go new file mode 100644 index 0000000..27a5a5b --- /dev/null +++ b/buildSql.go @@ -0,0 +1,38 @@ +package gormx + +import ( + "context" + "fmt" + + "gorm.io/gorm" +) + +// 根据map组装gorm的where +func BuildWhere(ctx context.Context, m *gorm.DB, where map[string]interface{}) (*gorm.DB, error) { + for key, val := range where { + switch val.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + m = m.Where(fmt.Sprintf("%s = ?", key), val) + case string: + m = m.Where(fmt.Sprintf("%s = ?", key), val) + case float32, float64: + m = m.Where(fmt.Sprintf("%s = ?", key), val) + case bool: + m = m.Where(fmt.Sprintf("%s = ?", key), val) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64: + m = m.Where(fmt.Sprintf("%s in (?)", key), val) + case []string: + m = m.Where(fmt.Sprintf("%s in (?)", key), val) + case []float32, []float64: + m = m.Where(fmt.Sprintf("%s in (?)", key), val) + case []bool: + m = m.Where(fmt.Sprintf("%s in (?)", key), val) + case []interface{}: + m = m.Where(fmt.Sprintf("%s in (?)", key), val) + default: + // byte,rune,any,nil + return nil, fmt.Errorf("unknown type:%+v %T", val, val) + } + } + return m, nil +} diff --git a/gormx.go b/gormx.go new file mode 100644 index 0000000..1928c7d --- /dev/null +++ b/gormx.go @@ -0,0 +1,27 @@ +package gormx + +import ( + "context" + "fmt" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +func NewGorm(prefix, user, password, host, database string, port int) *gorm.DB { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%v)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, database) + + g, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: prefix, // 表名前缀,`Article` 的表名应该是 `it_articles` + SingularTable: true, // 使用单数表名,启用该选项,此时,`Article` 的表名应该是 `it_article` + }, + Logger: NewGormxLogger(context.TODO()), + }) + if err != nil { + panic(err) + } + + return g +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..b39c182 --- /dev/null +++ b/logger.go @@ -0,0 +1,49 @@ +package gormx + +import ( + "context" + "fmt" + "time" + + "gorm.io/gorm/logger" + + "github.com/zeromicro/go-zero/core/logx" +) + +// TODO:待进一步封装 + +type GormxLogger struct { + logx.Logger +} + +func NewGormxLogger(ctx context.Context) *GormxLogger { + return &GormxLogger{ + Logger: logx.WithContext(ctx), + } +} + +func (g *GormxLogger) LogMode(LogLevel logger.LogLevel) logger.Interface { + return g +} + +func (g *GormxLogger) Info(ctx context.Context, msg string, val ...interface{}) { + fmt.Println("info", msg, val) + g.Logger.Info(msg, val) +} + +func (g *GormxLogger) Warn(ctx context.Context, msg string, val ...interface{}) { + fmt.Println("warn", msg, val) + g.Logger.Info(msg, val) +} + +func (g *GormxLogger) Error(ctx context.Context, msg string, val ...interface{}) { + fmt.Println("error", msg, val) + g.Logger.Error(msg, val) +} + +func (g *GormxLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + + sql, rows := fc() + fmt.Printf("trace: begin:%+v, err:%+v, sql:%+v, rows:%+v\n", begin, err, sql, rows) + +}