From 0b7a1cf46d58b5bad0d44cdf99d122d153787d8f Mon Sep 17 00:00:00 2001 From: Yun Date: Sat, 20 Sep 2025 20:28:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=BB=93=E6=9E=84=E4=BD=93?= =?UTF-8?q?=E7=9A=84=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 10 +- go.sum | 18 +- structx.go | 626 +++++++++++++++++++++++++++--------- structx.go.bak | 184 +++++++++++ structx_test.go | 834 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1518 insertions(+), 154 deletions(-) create mode 100644 structx.go.bak diff --git a/go.mod b/go.mod index 77a7053..48693ba 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,11 @@ module code.yun.ink/pkg/structx -go 1.20 +go 1.21.0 -require code.yun.ink/pkg/convx v1.0.2 +toolchain go1.22.4 + +require ( + code.yun.ink/pkg/convx v1.0.3 + github.com/shopspring/decimal v1.4.0 + github.com/spf13/cast v1.10.0 +) diff --git a/go.sum b/go.sum index e0664bf..fd1b862 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,16 @@ -code.yun.ink/pkg/convx v1.0.2 h1:vkEcDQ8w9Kz2T/RMnYefkzyXQwqI9nnaWo+Z1jlS7IE= -code.yun.ink/pkg/convx v1.0.2/go.mod h1:6xqmUend1kwarRvJ0TQlfzzS4QCWdRrXQiUY/ggzYqo= +code.yun.ink/pkg/convx v1.0.3 h1:pH8dUOgsoaBYVQ3+4C2+uVua561nDxq6/GpaQ9wnCew= +code.yun.ink/pkg/convx v1.0.3/go.mod h1:6xqmUend1kwarRvJ0TQlfzzS4QCWdRrXQiUY/ggzYqo= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= diff --git a/structx.go b/structx.go index 2e0e184..de094af 100644 --- a/structx.go +++ b/structx.go @@ -1,12 +1,14 @@ package structx import ( + "encoding/json" "fmt" "reflect" "strconv" "strings" "code.yun.ink/pkg/convx" + "github.com/spf13/cast" ) type ChangeInfo struct { @@ -15,170 +17,494 @@ type ChangeInfo struct { Val interface{} `json:"val"` } -// map[string]interface类型的值附加到结构体中 -func AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) (map[string]ChangeInfo, error) { - m := map[string]string{} - for k, v := range updateMap { - str, err := convx.ToString(v) - if err != nil { - return nil, err - } - m[k] = str - } +// 统一的类型转换函数 +type converterFunc func(reflect.Value, string) (interface{}, error) - // fmt.Println("bef", m, structxx) - - r, err := AttactToStruct(structxx, m) - // fmt.Println("resp:", r, m, structxx) - return r, err +var typeConverters = map[reflect.Kind]converterFunc{ + reflect.Bool: convertBool, + reflect.Int: convertInt[int], + reflect.Int8: convertInt[int8], + reflect.Int16: convertInt[int16], + reflect.Int32: convertInt[int32], + reflect.Int64: convertInt[int64], + reflect.Uint: convertUint[uint], + reflect.Uint8: convertUint[uint8], + reflect.Uint16: convertUint[uint16], + reflect.Uint32: convertUint[uint32], + reflect.Uint64: convertUint[uint64], + reflect.Float32: convertFloat[float32], + reflect.Float64: convertFloat[float64], + reflect.String: convertString, + // reflect.Slice: convertSlice, + // reflect.Array: convertArray, + // reflect.Map: convertMap, } -// 这个方法将map的数据赋值到结构体中 +// AttactToStructAny 将 map[string]interface{} 类型的值附加到结构体中 +func AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) (map[string]ChangeInfo, error) { + stringMap := make(map[string]string, len(updateMap)) + for k, v := range updateMap { + str, err := cast.ToStringE(v) + if err != nil { + return nil, fmt.Errorf("转换键 %s 的值失败: %w", k, err) + } + stringMap[k] = str + } + + return AttactToStruct(structxx, stringMap) +} + +// AttactToStruct 将 map 的数据赋值到结构体中,支持嵌套结构体和指针 func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[string]ChangeInfo, error) { - // 将结构体指针转换为reflect.Value类型 v := reflect.ValueOf(structxx) - // 如果v不是指针或者v是零值,直接返回 if v.Kind() != reflect.Ptr || v.IsNil() { - return nil, fmt.Errorf("structxx 需要是指针") + return nil, fmt.Errorf("structxx 需要是非空指针") } changeMap := make(map[string]ChangeInfo) - - // 获取v指向的元素,也就是结构体本身 v = v.Elem() - // 遍历map的键和值 - for k, val := range updateMap { - // 根据json标签获取结构体中对应的字段 - field := v.FieldByNameFunc(func(name string) bool { - f, _ := v.Type().FieldByName(name) - // fmt.Printf("FieldByNameFunc:%+v\n", f) - js := f.Tag.Get("json") - s := strings.Split(js, ",") - if len(s) > 1 { - return s[0] == k + + // 获取结构体类型信息 + t := v.Type() + fieldMap := buildFieldMap(t) + + for mapKey, mapValue := range updateMap { + fieldInfo, exists := fieldMap[mapKey] + if !exists { + continue // 忽略不存在的字段 + } + + // 安全获取字段值,处理嵌套指针 + field, err := getFieldByIndexSafe(v, fieldInfo.Index) + if err != nil { + return nil, fmt.Errorf("获取字段 %s 失败: %w", mapKey, err) + } + + if !field.IsValid() { + continue + } + + // 处理指针类型 + if fieldInfo.IsPtr { + if field.IsNil() { + // 创建新的指针实例 + newValue := reflect.New(fieldInfo.FieldType.Elem()) + field.Set(newValue) } - return js == k - }) - // 如果字段存在且可写 - if field.IsValid() && field.CanSet() { - // 根据字段的类型进行类型转换和赋值 - switch field.Kind() { - case reflect.Bool: - oldb := field.Bool() - olds, _ := convx.ToString(oldb) + field = field.Elem() + } - b, _ := convx.ToBool(val) - field.SetBool(b) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: b} - case reflect.Int64: - // 将字符串转换为int64 - oldi := field.Int() - olds, _ := convx.ToString(oldi) - - i, _ := strconv.ParseInt(val, 10, 64) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.String: - olds := field.String() - - field.SetString(val) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: val} - case reflect.Int: - oldi := field.Int() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Int32: - oldi := field.Int() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Int16: - oldi := field.Int() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Int8: - oldi := field.Int() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Uint: - oldi := field.Uint() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Uint32: - oldi := field.Uint() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Uint16: - oldi := field.Uint() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Uint8: - oldi := field.Uint() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Uint64: - oldi := field.Uint() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToInt64(val) - field.SetInt(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Float32: - oldi := field.Float() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToFloat64(val) - field.SetFloat(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - case reflect.Float64: - oldi := field.Float() - olds, _ := convx.ToString(oldi) - - i, _ := convx.ToFloat64(val) - field.SetFloat(i) - - changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} - default: - fmt.Println("未知类型") - return nil, fmt.Errorf("未知类型:" + k) + // 处理嵌套结构体(包括指针解引用后的结构体) + if field.Kind() == reflect.Struct && !isBasicStructType(field.Type()) { + nestedChanges, err := processNestedStruct(field, mapValue, mapKey) + if err != nil { + return nil, fmt.Errorf("处理嵌套结构体字段 %s 失败: %w", mapKey, err) } + for nestedKey, change := range nestedChanges { + changeMap[nestedKey] = change + } + continue + } + + if !field.CanSet() { + continue + } + + // 处理基本类型 + oldValueStr, err := cast.ToStringE(field.Interface()) + if err != nil { + return nil, fmt.Errorf("获取字段 %s 的旧值失败: %w", mapKey, err) + } + + newValue, err := setFieldValue(field, mapValue) + if err != nil { + return nil, fmt.Errorf("设置字段 %s 的值失败: %w", mapKey, err) + } + + newValueStr, err := cast.ToStringE(newValue) + if err != nil { + newValueStr = fmt.Sprintf("%v", newValue) + } + + changeMap[mapKey] = ChangeInfo{ + Old: oldValueStr, + New: newValueStr, + Val: newValue, } } + return changeMap, nil } + +// getFieldByIndexSafe 安全地通过索引路径获取字段,处理嵌套指针 +func getFieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, error) { + if len(index) == 0 { + return v, nil + } + + // 获取当前层级的字段 + field := v.Field(index[0]) + if !field.IsValid() { + return reflect.Value{}, fmt.Errorf("字段索引 %d 无效", index[0]) + } + + // 递归解引用指针,直到遇到非指针类型 + for field.Kind() == reflect.Ptr { + if field.IsNil() { + // 创建新的指针实例 + elemType := field.Type().Elem() + newValue := reflect.New(elemType) + field.Set(newValue) + } + field = field.Elem() + } + + // 递归处理剩余的索引路径 + if len(index) > 1 { + return getFieldByIndexSafe(field, index[1:]) + } + + return field, nil +} + +// 构建字段映射表,支持嵌套结构体和指针类型 +type fieldInfo struct { + Index []int + Name string + IsPtr bool + FieldType reflect.Type +} + +func buildFieldMap(t reflect.Type) map[string]fieldInfo { + fieldMap := make(map[string]fieldInfo) + buildFieldMapRecursive(t, []int{}, fieldMap, "") + return fieldMap +} + +func buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]fieldInfo, prefix string) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + + currentIndex := append(index, i) + jsonTag := getJSONTagName(field) + + fullKey := jsonTag + if prefix != "" { + fullKey = prefix + "." + jsonTag + } + + fieldType := field.Type + isPtr := fieldType.Kind() == reflect.Ptr + + // 解引用指针类型以获取实际类型 + actualType := fieldType + if isPtr { + actualType = fieldType.Elem() + } + + // 如果是结构体且不是基本类型,递归处理 + if actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { + // 对于指针或非指针的结构体都进行递归展开 + buildFieldMapRecursive(actualType, currentIndex, fieldMap, fullKey) + + // 同时记录当前字段的信息 + fieldMap[fullKey] = fieldInfo{ + Index: currentIndex, + Name: field.Name, + IsPtr: isPtr, + FieldType: fieldType, + } + continue + } + + fieldMap[fullKey] = fieldInfo{ + Index: currentIndex, + Name: field.Name, + IsPtr: isPtr, + FieldType: fieldType, + } + } +} + +func getJSONTagName(field reflect.StructField) string { + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + return field.Name + } + return strings.Split(jsonTag, ",")[0] +} + +// 判断是否为基本结构体类型(如time.Time等) +func isBasicStructType(t reflect.Type) bool { + // 处理指针类型 + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 这里可以添加更多需要排除的基本结构体类型 + if t.PkgPath() == "time" && t.Name() == "Time" { + return true + } + + // 其他常见的基本结构体类型 + switch t.String() { + case "time.Time", "sql.NullString", "sql.NullInt64", "sql.NullBool", "sql.NullFloat64": + return true + } + + return false +} + +// 处理嵌套结构体(支持JSON解析) +// processNestedStruct 处理嵌套结构体 +func processNestedStruct(field reflect.Value, value string, parentKey string) (map[string]ChangeInfo, error) { + changeMap := make(map[string]ChangeInfo) + + // 确保我们处理的是可寻址的结构体值 + var structValue reflect.Value + if field.Kind() == reflect.Ptr { + if field.IsNil() { + // 创建新的指针实例 + newValue := reflect.New(field.Type().Elem()) + field.Set(newValue) + } + structValue = field.Elem() + } else { + structValue = field + } + + if !structValue.IsValid() || structValue.Kind() != reflect.Struct { + return nil, fmt.Errorf("无效的结构体字段") + } + + // 尝试解析JSON字符串到map + var nestedMap map[string]string + if err := json.Unmarshal([]byte(value), &nestedMap); err != nil { + return nil, fmt.Errorf("嵌套结构体值必须是JSON格式: %w", err) + } + + // 递归处理嵌套结构体 + nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), nestedMap) + if err != nil { + return nil, err + } + + // 为嵌套字段的变更记录添加前缀 + for key, change := range nestedChanges { + fullKey := parentKey + "." + key + changeMap[fullKey] = change + } + + return changeMap, nil +} + +// 检测是否为自定义类型 +func isCustomType(t reflect.Type) bool { + // 排除基本类型 + if t.PkgPath() == "" { + return false + } + + // 排除已知的基本结构体类型 + if isBasicStructType(t) { + return false + } + + // 排除接口类型 + if t.Kind() == reflect.Interface { + return false + } + + return true +} + +// 设置字段值 +// 设置字段值 +func setFieldValue(field reflect.Value, value string) (interface{}, error) { + kind := field.Kind() + + // 检测是否为自定义类型(有包路径的类型) + if isCustomType(field.Type()) { + return setCustomTypeValue(field, value) + } + + // 处理指针类型的基础类型 + if kind == reflect.Ptr { + return setPointerFieldValue(field, value) + } + + converter, exists := typeConverters[kind] + if !exists { + return nil, fmt.Errorf("不支持的类型: %s", kind.String()) + } + + result, err := converter(field, value) + if err != nil { + return nil, err + } + + field.Set(reflect.ValueOf(result)) + return result, nil +} + +// 处理指针类型的字段 +func setPointerFieldValue(field reflect.Value, value string) (interface{}, error) { + if field.IsNil() { + // 创建新的指针实例 + elemType := field.Type().Elem() + newValue := reflect.New(elemType) + field.Set(newValue) + } + + // 递归处理指针指向的值 + return setFieldValue(field.Elem(), value) +} + +// 类型转换函数 +func convertBool(field reflect.Value, value string) (interface{}, error) { + return convx.ToBool(value) +} + +func convertInt[T int | int8 | int16 | int32 | int64](field reflect.Value, value string) (interface{}, error) { + bits := field.Type().Bits() + intVal, err := strconv.ParseInt(value, 10, bits) + if err != nil { + return nil, err + } + return T(intVal), nil +} + +func convertUint[T uint | uint8 | uint16 | uint32 | uint64](field reflect.Value, value string) (interface{}, error) { + bits := field.Type().Bits() + uintVal, err := strconv.ParseUint(value, 10, bits) + if err != nil { + return nil, err + } + return T(uintVal), nil +} + +func convertFloat[T float32 | float64](field reflect.Value, value string) (interface{}, error) { + bits := field.Type().Bits() + floatVal, err := strconv.ParseFloat(value, bits) + if err != nil { + return nil, err + } + return T(floatVal), nil +} + +func convertString(field reflect.Value, value string) (interface{}, error) { + return value, nil +} + +func convertSlice(field reflect.Value, value string) (interface{}, error) { + // 实现切片转换逻辑 + // elemType := field.Type().Elem() + // 根据元素类型进行解析 + return nil, fmt.Errorf("切片转换未实现") +} + +func convertArray(field reflect.Value, value string) (interface{}, error) { + // 实现数组转换逻辑 + return nil, fmt.Errorf("数组转换未实现") +} + +func convertMap(field reflect.Value, value string) (interface{}, error) { + // 实现map转换逻辑 + return nil, fmt.Errorf("map转换未实现") +} + +// 处理自定义类型 +// 处理自定义类型 +func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { + // 检查是否实现了TextUnmarshaler接口 + if unmarshaler, ok := field.Addr().Interface().(interface { + UnmarshalText([]byte) error + }); ok { + err := unmarshaler.UnmarshalText([]byte(value)) + if err != nil { + return nil, err + } + return field.Interface(), nil + } + + // 对于其他自定义类型,我们需要获取其基础类型并进行转换 + baseType := getBaseType(field.Type()) + if baseType == nil { + return nil, fmt.Errorf("不支持的自定义类型: %s", field.Type().String()) + } + + // 创建基础类型的值并进行转换 + baseValue := reflect.New(baseType).Elem() + converter, exists := typeConverters[baseType.Kind()] + if !exists { + return nil, fmt.Errorf("不支持的基础类型: %s", baseType.Kind().String()) + } + + result, err := converter(baseValue, value) + if err != nil { + return nil, err + } + + // 将基础类型值转换回自定义类型 + convertedValue, err := convertToCustomType(field.Type(), result) + if err != nil { + return nil, err + } + + field.Set(reflect.ValueOf(convertedValue)) + return convertedValue, nil +} + +// 获取自定义类型的基础类型 +func getBaseType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 如果是自定义类型(有包路径),获取其底层类型 + if t.PkgPath() != "" { + return t + } + + return t +} + +// 将基础类型值转换为自定义类型 +func convertToCustomType(customType reflect.Type, value interface{}) (interface{}, error) { + valueType := reflect.TypeOf(value) + customBaseType := customType + + // 处理指针类型的自定义类型 + if customType.Kind() == reflect.Ptr { + customBaseType = customType.Elem() + // 创建新的指针实例 + newValue := reflect.New(customBaseType) + elemValue := newValue.Elem() + + // 尝试将值设置到元素 + if valueType.AssignableTo(customBaseType) { + elemValue.Set(reflect.ValueOf(value)) + } else if reflect.ValueOf(value).Type().ConvertibleTo(customBaseType) { + converted := reflect.ValueOf(value).Convert(customBaseType) + elemValue.Set(converted) + } else { + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customBaseType) + } + + return newValue.Interface(), nil + } + + // 处理非指针类型的自定义类型 + if valueType.AssignableTo(customType) { + return value, nil + } + + if reflect.ValueOf(value).Type().ConvertibleTo(customType) { + return reflect.ValueOf(value).Convert(customType).Interface(), nil + } + + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customType) +} \ No newline at end of file diff --git a/structx.go.bak b/structx.go.bak new file mode 100644 index 0000000..2e0e184 --- /dev/null +++ b/structx.go.bak @@ -0,0 +1,184 @@ +package structx + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "code.yun.ink/pkg/convx" +) + +type ChangeInfo struct { + Old string `json:"old"` + New string `json:"new"` + Val interface{} `json:"val"` +} + +// map[string]interface类型的值附加到结构体中 +func AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) (map[string]ChangeInfo, error) { + m := map[string]string{} + for k, v := range updateMap { + str, err := convx.ToString(v) + if err != nil { + return nil, err + } + m[k] = str + } + + // fmt.Println("bef", m, structxx) + + r, err := AttactToStruct(structxx, m) + // fmt.Println("resp:", r, m, structxx) + return r, err +} + +// 这个方法将map的数据赋值到结构体中 +func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[string]ChangeInfo, error) { + // 将结构体指针转换为reflect.Value类型 + v := reflect.ValueOf(structxx) + // 如果v不是指针或者v是零值,直接返回 + if v.Kind() != reflect.Ptr || v.IsNil() { + return nil, fmt.Errorf("structxx 需要是指针") + } + + changeMap := make(map[string]ChangeInfo) + + // 获取v指向的元素,也就是结构体本身 + v = v.Elem() + // 遍历map的键和值 + for k, val := range updateMap { + // 根据json标签获取结构体中对应的字段 + field := v.FieldByNameFunc(func(name string) bool { + f, _ := v.Type().FieldByName(name) + // fmt.Printf("FieldByNameFunc:%+v\n", f) + js := f.Tag.Get("json") + s := strings.Split(js, ",") + if len(s) > 1 { + return s[0] == k + } + return js == k + }) + // 如果字段存在且可写 + if field.IsValid() && field.CanSet() { + // 根据字段的类型进行类型转换和赋值 + switch field.Kind() { + case reflect.Bool: + oldb := field.Bool() + olds, _ := convx.ToString(oldb) + + b, _ := convx.ToBool(val) + field.SetBool(b) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: b} + case reflect.Int64: + // 将字符串转换为int64 + oldi := field.Int() + olds, _ := convx.ToString(oldi) + + i, _ := strconv.ParseInt(val, 10, 64) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.String: + olds := field.String() + + field.SetString(val) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: val} + case reflect.Int: + oldi := field.Int() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Int32: + oldi := field.Int() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Int16: + oldi := field.Int() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Int8: + oldi := field.Int() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Uint: + oldi := field.Uint() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Uint32: + oldi := field.Uint() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Uint16: + oldi := field.Uint() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Uint8: + oldi := field.Uint() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Uint64: + oldi := field.Uint() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToInt64(val) + field.SetInt(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Float32: + oldi := field.Float() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToFloat64(val) + field.SetFloat(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + case reflect.Float64: + oldi := field.Float() + olds, _ := convx.ToString(oldi) + + i, _ := convx.ToFloat64(val) + field.SetFloat(i) + + changeMap[k] = ChangeInfo{Old: olds, New: val, Val: i} + default: + fmt.Println("未知类型") + return nil, fmt.Errorf("未知类型:" + k) + } + } + } + return changeMap, nil +} diff --git a/structx_test.go b/structx_test.go index 7fdb11a..cc7895c 100644 --- a/structx_test.go +++ b/structx_test.go @@ -2,9 +2,12 @@ package structx_test import ( "fmt" + "strings" "testing" + "time" "code.yun.ink/pkg/structx" + "github.com/shopspring/decimal" ) func TestAttactToStructMap(t *testing.T) { @@ -29,3 +32,834 @@ type Data struct { IsMan bool `json:"is_man"` Addr string `json:"addr"` } + +// 基础测试结构体 +type BasicStruct struct { + Name string `json:"name"` + Age int `json:"age"` + Salary float64 `json:"salary"` + IsActive bool `json:"is_active"` + Count uint `json:"count"` +} + +// 嵌套结构体 +type NestedStruct struct { + Basic BasicStruct `json:"basic"` + Comment string `json:"comment"` + Amount decimal.Decimal `json:"amount"` +} + +// 指针嵌套结构体 +type PointerStruct struct { + Basic *BasicStruct `json:"basic"` + Enabled bool `json:"enabled"` +} + +// 多层嵌套结构体 +type MultiLevelStruct struct { + Nested NestedStruct `json:"nested"` + Pointer *PointerStruct `json:"pointer"` + Tags []string `json:"tags"` +} + +// 自定义类型 +type CustomString string +type CustomInt int + +type CustomTypeStruct struct { + ID CustomString `json:"id"` + Version CustomInt `json:"version"` + Email string `json:"email"` +} + +// 实现 TextUnmarshaler 接口的类型 +type CustomUnmarshaler string + +func (c *CustomUnmarshaler) UnmarshalText(text []byte) error { + *c = CustomUnmarshaler("custom_" + string(text)) + return nil +} + +type UnmarshalerStruct struct { + Data CustomUnmarshaler `json:"data"` + Name string `json:"name"` +} + +// 复杂嵌套结构体 +type ComplexStruct struct { + Basic BasicStruct `json:"basic"` + Nested *NestedStruct `json:"nested"` + Custom CustomTypeStruct `json:"custom"` + Timestamp time.Time `json:"timestamp"` + Metadata map[string]string `json:"metadata"` + Unmarshaler CustomUnmarshaler `json:"unmarshaler"` +} + +// 基础类型测试 +func TestAttactToStruct_BasicTypes(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected BasicStruct + wantErr bool + }{ + { + name: "所有基础类型", + input: map[string]string{ + "name": "John Doe", + "age": "30", + "salary": "50000.50", + "is_active": "true", + "count": "100", + }, + expected: BasicStruct{ + Name: "John Doe", + Age: 30, + Salary: 50000.50, + IsActive: true, + Count: 100, + }, + wantErr: false, + }, + { + name: "部分字段", + input: map[string]string{ + "name": "Alice", + "age": "25", + }, + expected: BasicStruct{ + Name: "Alice", + Age: 25, + }, + wantErr: false, + }, + { + name: "无效布尔值", + input: map[string]string{ + "is_active": "invalid", + }, + wantErr: true, + }, + { + name: "无效数字", + input: map[string]string{ + "age": "not_a_number", + }, + wantErr: true, + }, + { + name: "空字符串处理", + input: map[string]string{ + "name": "", + "age": "0", + }, + expected: BasicStruct{ + Name: "", + Age: 0, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s BasicStruct + changes, err := structx.AttactToStruct(&s, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + if s != tt.expected { + t.Errorf("期望 %+v, 得到 %+v", tt.expected, s) + } + + // 验证变更记录 + if len(changes) != len(tt.input) { + t.Errorf("期望 %d 个变更记录, 得到 %d", len(tt.input), len(changes)) + } + }) + } +} + +// 测试嵌套结构体 +func TestAttactToStruct_NestedStruct(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected NestedStruct + wantErr bool + }{ + { + name: "嵌套结构体赋值", + input: map[string]string{ + "basic.name": "John", + "basic.age": "30", + "basic.salary": "50000.0", + "basic.is_active": "true", + "comment": "test comment", + }, + expected: NestedStruct{ + Basic: BasicStruct{ + Name: "John", + Age: 30, + Salary: 50000.0, + IsActive: true, + }, + Comment: "test comment", + }, + wantErr: false, + }, + { + name: "部分嵌套字段", + input: map[string]string{ + "basic.name": "Alice", + "comment": "partial", + }, + expected: NestedStruct{ + Basic: BasicStruct{ + Name: "Alice", + }, + Comment: "partial", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s NestedStruct + changes, err := structx.AttactToStruct(&s, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + if s.Basic.Name != tt.expected.Basic.Name || + s.Basic.Age != tt.expected.Basic.Age || + s.Comment != tt.expected.Comment { + t.Errorf("期望 %+v, 得到 %+v", tt.expected, s) + } + + // 验证嵌套字段的变更记录 + for key := range tt.input { + if _, exists := changes[key]; !exists { + t.Errorf("缺少变更记录 for key: %s", key) + } + } + }) + } +} + +// 测试指针嵌套结构体 +func TestAttactToStruct_PointerNested(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected PointerStruct + wantErr bool + }{ + // { + // name: "指针嵌套结构体", + // input: map[string]string{ + // "basic.name": "John", + // "basic.age": "30", + // "basic.is_active": "true", + // "enabled": "true", + // }, + + // expected: PointerStruct{ + // Basic: &BasicStruct{ + // Name: "John", + // Age: 30, + // IsActive: true, + // }, + // Enabled: true, + // }, + // wantErr: false, + // }, + { + name: "空指针初始化", + input: map[string]string{ + "basic.name": "New User", + "enabled": "true", + }, + expected: PointerStruct{ + Basic: &BasicStruct{ + Name: "New User", + }, + Enabled: true, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s PointerStruct + changes, err := structx.AttactToStruct(&s, tt.input) + + _ = changes + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + if s.Basic == nil || s.Basic.Name != tt.expected.Basic.Name || + s.Enabled != tt.expected.Enabled { + t.Errorf("期望 %+v, 得到 %+v", tt.expected, s) + } + }) + } +} + +// 自定义类型测试 +func TestAttactToStruct_CustomTypes(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected CustomTypeStruct + wantErr bool + }{ + { + name: "自定义类型转换", + input: map[string]string{ + "id": "user_123", + "version": "2", + "email": "test@example.com", + }, + expected: CustomTypeStruct{ + ID: CustomString("user_123"), + Version: CustomInt(2), + Email: "test@example.com", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s CustomTypeStruct + changes, err := structx.AttactToStruct(&s, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + if string(s.ID) != string(tt.expected.ID) || + int(s.Version) != int(tt.expected.Version) || + s.Email != tt.expected.Email { + t.Errorf("期望 %+v, 得到 %+v", tt.expected, s) + } + + if len(changes) != len(tt.input) { + t.Errorf("期望 %d 变更记录, 得到 %d", len(tt.input), len(changes)) + } + }) + } +} + +// Unmarshaler 接口测试 +func TestAttactToStruct_TextUnmarshaler(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected UnmarshalerStruct + wantErr bool + }{ + { + name: "TextUnmarshaler 接口", + input: map[string]string{ + "data": "test_data", + "name": "John", + }, + expected: UnmarshalerStruct{ + Data: CustomUnmarshaler("custom_test_data"), + Name: "John", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s UnmarshalerStruct + changes, err := structx.AttactToStruct(&s, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + if string(s.Data) != string(tt.expected.Data) || s.Name != tt.expected.Name { + t.Errorf("期望 Data=%s, Name=%s; 得到 Data=%s, Name=%s", + tt.expected.Data, tt.expected.Name, s.Data, s.Name) + } + + if len(changes) != len(tt.input) { + t.Errorf("期望 %d 变更记录, 得到 %d", len(tt.input), len(changes)) + } + }) + } +} + +// 错误场景测试 +func TestAttactToStruct_ErrorScenarios(t *testing.T) { + tests := []struct { + name string + structPtr interface{} + input map[string]string + expectedErr string + }{ + { + name: "非指针参数", + structPtr: BasicStruct{}, + input: map[string]string{"name": "test"}, + expectedErr: "structxx 需要是非空指针", + }, + { + name: "空指针", + structPtr: (*BasicStruct)(nil), + input: map[string]string{"name": "test"}, + expectedErr: "需要是非空指针", + }, + { + name: "不支持的类型", + structPtr: &struct{ Data []string }{}, + input: map[string]string{"data": "test"}, + expectedErr: "不支持的类型", + }, + { + name: "无效的嵌套路径", + structPtr: &NestedStruct{}, + input: map[string]string{ + "nonexistent.field": "value", + "basic.name": "test", + }, + expectedErr: "", // 应该忽略不存在的字段而不报错 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch, err := structx.AttactToStruct(tt.structPtr, tt.input) + fmt.Printf("变更记录:%+v %+v\n", ch,tt.structPtr) + + if tt.expectedErr == "" && err != nil { + t.Errorf("不期望错误但得到: %v", err) + return + } + + if tt.expectedErr != "" { + if err == nil { + t.Errorf("期望错误包含 '%s', 但得到 nil", tt.expectedErr) + return + } + if !strings.Contains(err.Error(), tt.expectedErr) { + t.Errorf("期望错误包含 '%s', 但得到: %v", tt.expectedErr, err) + } + } + }) + } +} + +// 变更记录验证测试 +func TestAttactToStruct_ChangeInfoValidation(t *testing.T) { + tests := []struct { + name string + input map[string]string + expectedFields []string + }{ + { + name: "变更记录完整性", + input: map[string]string{ + "name": "New Name", + "age": "25", + "is_active": "true", + }, + expectedFields: []string{"name", "age", "is_active"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s BasicStruct + // 设置初始值 + s.Name = "Old Name" + s.Age = 30 + s.IsActive = false + + changes, err := structx.AttactToStruct(&s, tt.input) + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + // 验证所有期望的字段都有变更记录 + for _, field := range tt.expectedFields { + change, exists := changes[field] + if !exists { + t.Errorf("缺少字段 %s 的变更记录", field) + continue + } + + // 验证旧值和新值 + if change.Old == "" { + t.Errorf("字段 %s 的旧值不应为空", field) + } + if change.New == "" { + t.Errorf("字段 %s 的新值不应为空", field) + } + + // 验证值类型正确 + if change.Val == nil { + t.Errorf("字段 %s 的值不应为 nil", field) + } + } + + // 验证变更记录数量 + if len(changes) != len(tt.expectedFields) { + t.Errorf("期望 %d 个变更记录, 得到 %d", len(tt.expectedFields), len(changes)) + } + }) + } +} + +func BenchmarkAttactToStruct(b *testing.B) { + var s BasicStruct + input := map[string]string{ + "name": "benchmark", + "age": "30", + "salary": "50000.0", + "is_active": "true", + "count": "100", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := structx.AttactToStruct(&s, input) + if err != nil { + b.Fatalf("基准测试失败: %v", err) + } + } +} + +// 性能测试 +func BenchmarkAttactToStruct_Nested(b *testing.B) { + var s NestedStruct + input := map[string]string{ + "basic.name": "benchmark", + "basic.age": "30", + "basic.salary": "50000.0", + "basic.is_active": "true", + "comment": "test", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := structx.AttactToStruct(&s, input) + if err != nil { + b.Fatalf("基准测试失败: %v", err) + } + } +} + +// AttactToStructAny 测试 +func TestAttactToStructAny(t *testing.T) { + tests := []struct { + name string + input map[string]interface{} + expected BasicStruct + wantErr bool + }{ + { + name: "混合类型输入", + input: map[string]interface{}{ + "name": "John", // string + "age": 30, // int + "salary": 50000.50, // float64 + "is_active": true, // bool + "count": uint(100), // uint + }, + expected: BasicStruct{ + Name: "John", + Age: 30, + Salary: 50000.50, + IsActive: true, + Count: 100, + }, + wantErr: false, + }, + { + name: "无法转换的类型", + input: map[string]interface{}{ + "name": make(chan int), // 无法转换为字符串的类型 + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s BasicStruct + _, err := structx.AttactToStructAny(&s, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + if s != tt.expected { + t.Errorf("期望 %+v, 得到 %+v", tt.expected, s) + } + }) + } +} + +// 测试指针类型嵌套结构体 +type PointerNestedStruct struct { + BasicPtr *BasicStruct `json:"basic_ptr"` + Direct BasicStruct `json:"direct"` + Value string `json:"value"` +} + +func TestAttactToStruct_PointerNested2(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected PointerNestedStruct + wantErr bool + }{ + { + name: "指针类型嵌套结构体", + input: map[string]string{ + "basic_ptr.name": "Pointer John", + "basic_ptr.age": "35", + "direct.name": "Direct John", + "direct.age": "25", + "value": "test", + }, + expected: PointerNestedStruct{ + BasicPtr: &BasicStruct{ + Name: "Pointer John", + Age: 35, + }, + Direct: BasicStruct{ + Name: "Direct John", + Age: 25, + }, + Value: "test", + }, + wantErr: false, + }, + { + name: "空指针初始化", + input: map[string]string{ + "basic_ptr.name": "New User", + }, + expected: PointerNestedStruct{ + BasicPtr: &BasicStruct{ + Name: "New User", + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s PointerNestedStruct + changes, err := structx.AttactToStruct(&s, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("期望错误,但得到 nil") + } + return + } + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + // 验证指针类型字段 + if s.BasicPtr == nil { + t.Error("BasicPtr 不应该为 nil") + } else if s.BasicPtr.Name != tt.expected.BasicPtr.Name { + t.Errorf("BasicPtr.Name 期望 %s, 得到 %s", tt.expected.BasicPtr.Name, s.BasicPtr.Name) + } + + // 验证直接嵌套字段 + if s.Direct.Name != tt.expected.Direct.Name { + t.Errorf("Direct.Name 期望 %s, 得到 %s", tt.expected.Direct.Name, s.Direct.Name) + } + + // 验证变更记录 + expectedChangeCount := len(tt.input) + if len(changes) != expectedChangeCount { + t.Errorf("期望 %d 个变更记录, 得到 %d", expectedChangeCount, len(changes)) + } + }) + } +} + +// 测试多层指针嵌套 +type MultiLevelPointerStruct struct { + Level1 *Level1Struct `json:"level1"` +} + +type Level1Struct struct { + Level2 *Level2Struct `json:"level2"` + Name string `json:"name"` +} + +type Level2Struct struct { + Level3 *Level3Struct `json:"level3"` + Value int `json:"value"` +} + +type Level3Struct struct { + FinalValue string `json:"final_value"` +} + +func TestAttactToStruct_MultiLevelPointer(t *testing.T) { + input := map[string]string{ + "level1.name": "Top Level", + "level1.level2.value": "100", + "level1.level2.level3.final_value": "end_value", + } + + var s MultiLevelPointerStruct + changes, err := structx.AttactToStruct(&s, input) + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + // 验证多层指针嵌套 + if s.Level1 == nil { + t.Error("Level1 不应该为 nil") + } else if s.Level1.Name != "Top Level" { + t.Errorf("Level1.Name 期望 Top Level, 得到 %s", s.Level1.Name) + } + + if s.Level1.Level2 == nil { + t.Error("Level2 不应该为 nil") + } else if s.Level1.Level2.Value != 100 { + t.Errorf("Level2.Value 期望 100, 得到 %d", s.Level1.Level2.Value) + } + + if s.Level1.Level2.Level3 == nil { + t.Error("Level3 不应该为 nil") + } else if s.Level1.Level2.Level3.FinalValue != "end_value" { + t.Errorf("Level3.FinalValue 期望 end_value, 得到 %s", s.Level1.Level2.Level3.FinalValue) + } + + // 验证变更记录 + if len(changes) != 3 { + t.Errorf("期望 3 个变更记录, 得到 %d", len(changes)) + } +} + +// 测试混合指针和值类型嵌套 +type MixedNestedStruct struct { + PtrField *BasicStruct `json:"ptr_field"` + ValueField BasicStruct `json:"value_field"` + Simple string `json:"simple"` +} + +func TestAttactToStruct_MixedNested(t *testing.T) { + input := map[string]string{ + "ptr_field.name": "Pointer Name", + "ptr_field.age": "40", + "value_field.name": "Value Name", + "value_field.age": "30", + "simple": "simple_value", + } + + var s MixedNestedStruct + changes, err := structx.AttactToStruct(&s, input) + + if err != nil { + t.Errorf("意外的错误: %v", err) + return + } + + // 验证指针字段 + if s.PtrField == nil { + t.Error("PtrField 不应该为 nil") + } else { + if s.PtrField.Name != "Pointer Name" { + t.Errorf("PtrField.Name 期望 Pointer Name, 得到 %s", s.PtrField.Name) + } + if s.PtrField.Age != 40 { + t.Errorf("PtrField.Age 期望 40, 得到 %d", s.PtrField.Age) + } + } + + // 验证值字段 + if s.ValueField.Name != "Value Name" { + t.Errorf("ValueField.Name 期望 Value Name, 得到 %s", s.ValueField.Name) + } + if s.ValueField.Age != 30 { + t.Errorf("ValueField.Age 期望 30, 得到 %d", s.ValueField.Age) + } + + // 验证简单字段 + if s.Simple != "simple_value" { + t.Errorf("Simple 期望 simple_value, 得到 %s", s.Simple) + } + + // 验证变更记录 + if len(changes) != 5 { // 5个字段的变更 + t.Errorf("期望 5 个变更记录, 得到 %d", len(changes)) + } +} \ No newline at end of file