From 614ac34883042e58a057d729bd8e5800159408cc Mon Sep 17 00:00:00 2001 From: Yun Date: Wed, 27 May 2026 22:02:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=95=B0=E6=8D=AE=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- covert.go | 68 ++- entity.go | 7 +- field_mapper.go | 27 +- structx.go | 55 ++- structx_test.go | 1225 +++++++++++++++++++++++++++++++++++++++++++++-- utils.go | 44 +- value_setter.go | 30 +- 7 files changed, 1352 insertions(+), 104 deletions(-) diff --git a/covert.go b/covert.go index 2a0b573..747d3d3 100644 --- a/covert.go +++ b/covert.go @@ -31,15 +31,20 @@ var ( reflect.Slice: convertSlice, reflect.Array: convertArray, reflect.Map: convertMap, + reflect.Interface: convertInterface, } ) // 转换为字符串 func convertToString(item interface{}) string { - if str, ok := item.(string); ok { - return str + switch v := item.(type) { + case string: + return v + case json.Number: + return string(v) + default: + return fmt.Sprintf("%v", item) } - return fmt.Sprintf("%v", item) } // 转换为数字(float64) @@ -47,10 +52,28 @@ func convertToFloat64(item interface{}) (float64, error) { switch v := item.(type) { case float64: return v, nil - case int, int32, int64: - return float64(reflect.ValueOf(v).Int()), nil - case uint, uint32, uint64: - return float64(reflect.ValueOf(v).Uint()), nil + case json.Number: + return v.Float64() + case int: + return float64(v), nil + case int8: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil case float32: return float64(v), nil default: @@ -125,25 +148,40 @@ func convertString(field reflect.Value, value string) (interface{}, error) { } func convertSlice(field reflect.Value, value string) (interface{}, error) { - var result []interface{} - if err := json.Unmarshal([]byte(value), &result); err != nil { + sliceType := field.Type() + result := reflect.New(sliceType).Elem() + if err := json.Unmarshal([]byte(value), result.Addr().Interface()); err != nil { return nil, err } - return result, nil + return result.Interface(), nil } func convertArray(field reflect.Value, value string) (interface{}, error) { - var result []interface{} + arrayType := field.Type() + result := reflect.New(arrayType).Elem() + if err := json.Unmarshal([]byte(value), result.Addr().Interface()); err != nil { + return nil, err + } + return result.Interface(), nil +} + +func convertMap(field reflect.Value, value string) (interface{}, error) { + mapType := field.Type() + result := reflect.New(mapType).Elem() + if err := json.Unmarshal([]byte(value), result.Addr().Interface()); err != nil { + return nil, err + } + return result.Interface(), nil +} + +func convertInterface(field reflect.Value, value string) (interface{}, error) { + var result interface{} if err := json.Unmarshal([]byte(value), &result); err != nil { return nil, err } return result, nil } -func convertMap(field reflect.Value, value string) (interface{}, error) { - return nil, fmt.Errorf("map转换未实现") -} - func getBaseTypeFromAlias(aliasType reflect.Type) reflect.Type { if aliasType.Kind() == reflect.Ptr { aliasType = aliasType.Elem() diff --git a/entity.go b/entity.go index 62104a2..9501805 100644 --- a/entity.go +++ b/entity.go @@ -6,9 +6,9 @@ import ( // ChangeInfo 变更信息 type ChangeInfo struct { - Old string `json:"old"` - New string `json:"new"` - Val any `json:"val"` + Old any `json:"old"` + New any `json:"new"` + Val any `json:"val"` } // FieldInfo 字段信息 @@ -24,6 +24,7 @@ type FieldInfo struct { var ( basicStructTypes = map[string]bool{ "time.Time": true, + "time.Duration": true, "github.com/shopspring/decimal.Decimal": true, "sql.NullString": true, "sql.NullInt64": true, diff --git a/field_mapper.go b/field_mapper.go index 7b1816a..adfac17 100644 --- a/field_mapper.go +++ b/field_mapper.go @@ -28,10 +28,14 @@ func (dm *defaultFieldMapper) GetFieldMap(t reflect.Type) map[string]FieldInfo { } cacheMutex.RUnlock() + cacheMutex.Lock() + // 双重检查:避免两个goroutine同时构建 + if cached, exists := typeInfoCache[t]; exists { + cacheMutex.Unlock() + return cached + } fieldMap := make(map[string]FieldInfo) dm.buildFieldMapRecursive(t, []int{}, fieldMap, "") - - cacheMutex.Lock() typeInfoCache[t] = fieldMap cacheMutex.Unlock() @@ -40,6 +44,16 @@ func (dm *defaultFieldMapper) GetFieldMap(t reflect.Type) map[string]FieldInfo { // 递归构建字段映射表 func (dm *defaultFieldMapper) buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]FieldInfo, prefix string) { + dm.buildFieldMapRecursiveWithVisited(t, index, fieldMap, prefix, make(map[reflect.Type]bool)) +} + +func (dm *defaultFieldMapper) buildFieldMapRecursiveWithVisited(t reflect.Type, index []int, fieldMap map[string]FieldInfo, prefix string, stack map[reflect.Type]bool) { + if stack[t] { + return + } + stack[t] = true + defer delete(stack, t) + for i := 0; i < t.NumField(); i++ { field := t.Field(i) if !field.IsExported() { @@ -58,10 +72,15 @@ func (dm *defaultFieldMapper) buildFieldMapRecursive(t reflect.Type, index []int isPtr := fieldType.Kind() == reflect.Ptr actualType := fieldType if isPtr { - // 解引用 actualType = fieldType.Elem() } + // 嵌入式结构体字段提升——其字段注册到父级别 + if field.Anonymous && actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { + dm.buildFieldMapRecursiveWithVisited(actualType, currentIndex, fieldMap, prefix, stack) + continue + } + if actualType.Kind() == reflect.Slice || actualType.Kind() == reflect.Array { fieldMap[fullKey] = FieldInfo{ Index: currentIndex, @@ -75,7 +94,7 @@ func (dm *defaultFieldMapper) buildFieldMapRecursive(t reflect.Type, index []int } if actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { - dm.buildFieldMapRecursive(actualType, currentIndex, fieldMap, fullKey) + dm.buildFieldMapRecursiveWithVisited(actualType, currentIndex, fieldMap, fullKey, stack) } fieldMap[fullKey] = FieldInfo{ diff --git a/structx.go b/structx.go index 785e756..053b1f8 100644 --- a/structx.go +++ b/structx.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "github.com/spf13/cast" ) @@ -120,16 +121,15 @@ func (sp *StructProcessor) AttactToStruct(structxx any, updateMap map[string]str continue } - oldValueStr, _ := cast.ToStringE(field.Interface()) + oldValue := field.Interface() newValue, err := sp.valueSetter.SetFieldValue(field, mapValue) if err != nil { return nil, fmt.Errorf("设置字段 %s 的值失败: %w", mapKey, err) } - newValueStr, _ := cast.ToStringE(newValue) changeMap[mapKey] = ChangeInfo{ - Old: oldValueStr, - New: newValueStr, + Old: oldValue, + New: newValue, Val: newValue, } } @@ -148,15 +148,14 @@ func (sp *StructProcessor) getFieldByPath(v reflect.Value, index []int) (reflect return reflect.Value{}, fmt.Errorf("字段索引 %d 无效", index[0]) } - // 处理指针 - for field.Kind() == reflect.Ptr { - if field.IsNil() { - field.Set(reflect.New(field.Type().Elem())) - } - field = field.Elem() - } - if len(index) > 1 { + // 只有中间节点才自动解引用指针 + for field.Kind() == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + } return sp.getFieldByPath(field, index[1:]) } @@ -166,7 +165,9 @@ func (sp *StructProcessor) getFieldByPath(v reflect.Value, index []int) (reflect // processSliceOrArrayField 处理切片和数组字段 func (sp *StructProcessor) processSliceOrArrayField(field reflect.Value, value string, fieldKey string) error { var jsonData []interface{} - if err := json.Unmarshal([]byte(value), &jsonData); err != nil { + decoder := json.NewDecoder(strings.NewReader(value)) + decoder.UseNumber() + if err := decoder.Decode(&jsonData); err != nil { return fmt.Errorf("字段 %s 的值必须是JSON数组格式: %w", fieldKey, err) } @@ -215,7 +216,9 @@ func (sp *StructProcessor) processNestedStruct(field reflect.Value, value string // 尝试解析为JSON对象 var nestedMap map[string]interface{} - if err := json.Unmarshal([]byte(value), &nestedMap); err == nil { + decoder := json.NewDecoder(strings.NewReader(value)) + decoder.UseNumber() + if err := decoder.Decode(&nestedMap); err == nil { stringMap := make(map[string]string, len(nestedMap)) for k, v := range nestedMap { str, err := cast.ToStringE(v) @@ -237,17 +240,31 @@ func (sp *StructProcessor) processNestedStruct(field reflect.Value, value string } // 尝试直接设置值 - if hasUnmarshalJSON(structValue.Type()) || isBasicStructType(structValue.Type()) { - oldValueStr, _ := cast.ToStringE(field.Interface()) + if hasUnmarshalJSON(structValue.Type()) { + oldValue := field.Interface() + + if err := setUnmarshalJSONValue(structValue, value); err != nil { + return nil, fmt.Errorf("设置UnmarshalJSON值失败: %w", err) + } + + changeMap[parentKey] = ChangeInfo{ + Old: oldValue, + New: structValue.Interface(), + Val: structValue.Interface(), + } + return changeMap, nil + } + + if isBasicStructType(structValue.Type()) { + oldValue := field.Interface() if err := setBasicStructValue(structValue, value); err != nil { return nil, fmt.Errorf("设置基本结构体值失败: %w", err) } - newValueStr, _ := cast.ToStringE(field.Interface()) changeMap[parentKey] = ChangeInfo{ - Old: oldValueStr, - New: newValueStr, + Old: oldValue, + New: structValue.Interface(), Val: structValue.Interface(), } return changeMap, nil diff --git a/structx_test.go b/structx_test.go index e7e6184..1678166 100644 --- a/structx_test.go +++ b/structx_test.go @@ -1,8 +1,11 @@ package structx_test import ( + "encoding/json" "fmt" + "math" "strings" + "sync" "testing" "time" @@ -200,26 +203,26 @@ func TestAttactToStruct_NestedStruct(t *testing.T) { expected NestedStruct wantErr bool }{ - // { - // name: "嵌套结构体赋值", - // input: map[string]string{ - // "basic.name": "John", - // "basic.age": "30", - // "basic.salary": "50000.0", - // "basic.is_active": "true", - // "comment": "test comment", - // }, - // expected: NestedStruct{ - // Basic: BasicStruct{ - // Name: "John", - // Age: 30, - // Salary: 50000.0, - // IsActive: true, - // }, - // Comment: "test comment", - // }, - // wantErr: false, - // }, + { + name: "嵌套结构体赋值", + input: map[string]string{ + "basic.name": "John", + "basic.age": "30", + "basic.salary": "50000.0", + "basic.is_active": "true", + "comment": "test comment", + }, + expected: NestedStruct{ + Basic: BasicStruct{ + Name: "John", + Age: 30, + Salary: 50000.0, + IsActive: true, + }, + Comment: "test comment", + }, + wantErr: false, + }, { name: "部分嵌套字段", input: map[string]string{ @@ -280,25 +283,25 @@ func TestAttactToStruct_PointerNested(t *testing.T) { expected PointerStruct wantErr bool }{ - // { - // name: "指针嵌套结构体", - // input: map[string]string{ - // "basic.name": "John", - // "basic.age": "30", - // "basic.is_active": "true", - // "enabled": "true", - // }, + { + name: "指针嵌套结构体", + input: map[string]string{ + "basic.name": "John", + "basic.age": "30", + "basic.is_active": "true", + "enabled": "true", + }, - // expected: PointerStruct{ - // Basic: &BasicStruct{ - // Name: "John", - // Age: 30, - // IsActive: true, - // }, - // Enabled: true, - // }, - // wantErr: false, - // }, + expected: PointerStruct{ + Basic: &BasicStruct{ + Name: "John", + Age: 30, + IsActive: true, + }, + Enabled: true, + }, + wantErr: false, + }, { name: "空指针初始化", input: map[string]string{ @@ -548,11 +551,11 @@ func TestAttactToStruct_ChangeInfoValidation(t *testing.T) { } // 验证旧值和新值 - if change.Old == "" { - t.Errorf("字段 %s 的旧值不应为空", field) + if change.Old == nil { + t.Errorf("字段 %s 的旧值不应为 nil", field) } - if change.New == "" { - t.Errorf("字段 %s 的新值不应为空", field) + if change.New == nil { + t.Errorf("字段 %s 的新值不应为 nil", field) } // 验证值类型正确 @@ -870,4 +873,1142 @@ func TestAttactToStruct_MixedNested(t *testing.T) { if len(changes) != 5 { // 5个字段的变更 t.Errorf("期望 5 个变更记录, 得到 %d", len(changes)) } +} + +// ===== 新增综合测试 ===== + +// 全类型边界测试结构体 +type AllKindsStruct struct { + Str string `json:"str"` + I int `json:"i"` + I8 int8 `json:"i8"` + I16 int16 `json:"i16"` + I32 int32 `json:"i32"` + I64 int64 `json:"i64"` + Ui uint `json:"ui"` + Ui8 uint8 `json:"ui8"` + Ui16 uint16 `json:"ui16"` + Ui32 uint32 `json:"ui32"` + Ui64 uint64 `json:"ui64"` + F32 float32 `json:"f32"` + F64 float64 `json:"f64"` + B bool `json:"b"` +} + +// 测试所有基本类型及边界值 +func TestAttactToStruct_AllBasicKinds(t *testing.T) { + tests := []struct { + name string + input map[string]string + wantErr bool + check func(*testing.T, AllKindsStruct) + }{ + { + name: "所有类型正常值", + input: map[string]string{ + "str": "hello", "i": "42", "i8": "127", "i16": "32767", + "i32": "2147483647", "i64": "9223372036854775807", + "ui": "42", "ui8": "255", "ui16": "65535", + "ui32": "4294967295", "ui64": "18446744073709551615", + "f32": "3.14", "f64": "3.141592653589793", "b": "true", + }, + check: func(t *testing.T, s AllKindsStruct) { + if s.Str != "hello" || s.I != 42 || s.I8 != 127 || s.I16 != 32767 || s.I32 != 2147483647 || s.I64 != 9223372036854775807 { + t.Errorf("int类型不匹配: %+v", s) + } + if s.Ui != 42 || s.Ui8 != 255 || s.Ui16 != 65535 || s.Ui32 != 4294967295 || s.Ui64 != 18446744073709551615 { + t.Errorf("uint类型不匹配: %+v", s) + } + if math.Abs(float64(s.F32)-3.14) > 0.001 || math.Abs(s.F64-3.141592653589793) > 1e-14 { + t.Errorf("float类型不匹配: %+v", s) + } + if s.B != true { + t.Errorf("bool不匹配: %+v", s) + } + }, + }, + { + name: "int8溢出错误", + input: map[string]string{ + "i8": "128", + }, + wantErr: true, + }, + { + name: "uint8溢出错误", + input: map[string]string{ + "ui16": "70000", + }, + wantErr: true, + }, + { + name: "负整数", + input: map[string]string{ + "i": "-1", "i8": "-128", "i64": "-9223372036854775808", + }, + check: func(t *testing.T, s AllKindsStruct) { + if s.I != -1 || s.I8 != -128 || s.I64 != -9223372036854775808 { + t.Errorf("负数不匹配: %+v", s) + } + }, + }, + { + name: "uint负数错误", + input: map[string]string{ + "ui": "-1", + }, + wantErr: true, + }, + { + name: "浮点数科学计数法", + input: map[string]string{ + "f64": "1e10", + "f32": "1e10", + }, + check: func(t *testing.T, s AllKindsStruct) { + if math.Abs(s.F64-1e10) > 1 { + t.Errorf("科学计数法不匹配: %+v", s) + } + }, + }, + { + name: "布尔值多种表达", + input: map[string]string{ + "b": "true", + }, + check: func(t *testing.T, s AllKindsStruct) { + if s.B != true { + t.Errorf("true不匹配") + } + }, + }, + { + name: "字符串含特殊字符", + input: map[string]string{ + "str": "hello\nworld\tunicodë", + }, + check: func(t *testing.T, s AllKindsStruct) { + if s.Str != "hello\nworld\tunicodë" { + t.Errorf("特殊字符串不匹配: %q", s.Str) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s AllKindsStruct + _, err := structx.AttactToStruct(&s, tt.input) + if tt.wantErr { + if err == nil { + t.Error("期望错误但得到nil") + } + return + } + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if tt.check != nil { + tt.check(t, s) + } + }) + } +} + +// ===== Map 字段测试 ===== + +type MapStruct struct { + Metadata map[string]string `json:"metadata"` + Counts map[string]int `json:"counts"` + Empty map[string]string `json:"empty"` +} + +func TestAttactToStruct_Maps(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected MapStruct + wantErr bool + }{ + { + name: "map[string]string", + input: map[string]string{ + "metadata": `{"key1":"val1","key2":"val2"}`, + }, + expected: MapStruct{ + Metadata: map[string]string{"key1": "val1", "key2": "val2"}, + }, + }, + { + name: "map[string]int", + input: map[string]string{ + "counts": `{"a":1,"b":2,"c":3}`, + }, + expected: MapStruct{ + Counts: map[string]int{"a": 1, "b": 2, "c": 3}, + }, + }, + { + name: "空map", + input: map[string]string{ + "metadata": `{}`, + }, + expected: MapStruct{ + Metadata: map[string]string{}, + }, + }, + { + name: "无效JSON错误", + input: map[string]string{ + "metadata": `not-json`, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s MapStruct + _, err := structx.AttactToStruct(&s, tt.input) + if tt.wantErr { + if err == nil { + t.Error("期望错误但得到nil") + } + return + } + if err != nil { + t.Fatalf("意外错误: %v", err) + } + for k, v := range tt.expected.Metadata { + if s.Metadata[k] != v { + t.Errorf("Metadata[%s] 期望 %s, 得到 %s", k, v, s.Metadata[k]) + } + } + for k, v := range tt.expected.Counts { + if s.Counts[k] != v { + t.Errorf("Counts[%s] 期望 %d, 得到 %d", k, v, s.Counts[k]) + } + } + }) + } +} + +// ===== 切片/数组字段测试 ===== + +type SliceArrayStruct struct { + Tags []string `json:"tags"` + Scores []int64 `json:"scores"` + Empty []string `json:"empty"` + Floats []float64 `json:"floats"` + Coords [3]int `json:"coords"` +} + +type SliceStructElemStruct struct { + Items []BasicStruct `json:"items"` + PtrItems []*BasicStruct `json:"ptr_items"` + Times []time.Time `json:"times"` +} + +func TestAttactToStruct_SlicesArrays(t *testing.T) { + t.Run("字符串切片", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "tags": `["go","json","test"]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Tags) != 3 || s.Tags[0] != "go" || s.Tags[1] != "json" || s.Tags[2] != "test" { + t.Fatalf("期望 [go json test], 得到 %v", s.Tags) + } + }) + + t.Run("整数切片大数值精度", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "scores": `[9007199254740993,9223372036854775807]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Scores) != 2 || s.Scores[0] != 9007199254740993 || s.Scores[1] != 9223372036854775807 { + t.Fatalf("大整数精度丢失: %v", s.Scores) + } + }) + + t.Run("浮点数切片", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "floats": `[3.14,2.718,1.0]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Floats) != 3 || math.Abs(s.Floats[0]-3.14) > 0.001 { + t.Fatalf("浮点数切片不匹配: %v", s.Floats) + } + }) + + t.Run("空切片", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "empty": `[]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Empty) != 0 { + t.Fatalf("期望空切片, 得到 %v", s.Empty) + } + }) + + t.Run("固定数组正常", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "coords": `[1,2,3]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Coords != [3]int{1, 2, 3} { + t.Fatalf("期望 [1 2 3], 得到 %v", s.Coords) + } + }) + + t.Run("固定数组长度不匹配错误", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "coords": `[1,2]`, + }) + if err == nil { + t.Fatal("期望数组长度不匹配错误但得到nil") + } + }) + + t.Run("切片元素为结构体", func(t *testing.T) { + var s SliceStructElemStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "items": `[{"name":"Alice","age":30,"salary":1000.5,"is_active":true,"count":1}]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Items) != 1 || s.Items[0].Name != "Alice" || s.Items[0].Age != 30 { + t.Fatalf("结构体切片不匹配: %+v", s.Items) + } + }) + + t.Run("切片元素为结构体指针", func(t *testing.T) { + var s SliceStructElemStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "ptr_items": `[{"name":"Bob","age":25}]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.PtrItems) != 1 || s.PtrItems[0] == nil || s.PtrItems[0].Name != "Bob" { + t.Fatalf("指针结构体切片不匹配: %+v", s.PtrItems) + } + }) + + t.Run("切片元素为time.Time", func(t *testing.T) { + var s SliceStructElemStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "times": `["2024-01-01T00:00:00Z","2025-06-15T12:30:00Z"]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Times) != 2 { + t.Fatalf("期望2个时间, 得到 %d", len(s.Times)) + } + expected, _ := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z") + if !s.Times[0].Equal(expected) { + t.Fatalf("时间不匹配: %v", s.Times[0]) + } + }) + + t.Run("无效JSON数组错误", func(t *testing.T) { + var s SliceArrayStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "tags": `not-an-array`, + }) + if err == nil { + t.Fatal("期望JSON数组解析错误但得到nil") + } + }) +} + +// ===== json.Unmarshaler 接口测试 ===== + +type CustomJSONType struct { + Value string +} + +func (c *CustomJSONType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + c.Value = "json_" + s + return nil +} + +type JSONUnmarshalerStruct struct { + Data CustomJSONType `json:"data"` +} + +func TestAttactToStruct_JSONUnmarshaler(t *testing.T) { + t.Run("json.Unmarshaler字符串值", func(t *testing.T) { + var s JSONUnmarshalerStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "data": "hello", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Data.Value != "json_hello" { + t.Fatalf("期望 json_hello, 得到 %s", s.Data.Value) + } + }) + + t.Run("json.Unmarshaler JSON字符串值", func(t *testing.T) { + var s JSONUnmarshalerStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "data": `"quoted"`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Data.Value != "json_quoted" { + t.Fatalf("期望 json_quoted, 得到 %s", s.Data.Value) + } + }) +} + +// ===== time.Duration 测试 ===== + +type DurationStruct struct { + Timeout time.Duration `json:"timeout"` + Delay time.Duration `json:"delay"` +} + +func TestAttactToStruct_Duration(t *testing.T) { + t.Run("Duration字符串解析", func(t *testing.T) { + var s DurationStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "timeout": "5s", + "delay": "100ms", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Timeout != 5*time.Second { + t.Fatalf("期望 5s, 得到 %v", s.Timeout) + } + if s.Delay != 100*time.Millisecond { + t.Fatalf("期望 100ms, 得到 %v", s.Delay) + } + }) + + t.Run("Duration无效格式错误", func(t *testing.T) { + var s DurationStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "timeout": "invalid", + }) + if err == nil { + t.Fatal("期望Duration解析错误但得到nil") + } + }) +} + +// ===== 嵌套结构体JSON对象字符串测试 ===== + +type JSONNestedParent struct { + Basic BasicStruct `json:"basic"` + Name string `json:"name"` +} + +func TestAttactToStruct_NestedJSONObject(t *testing.T) { + t.Run("JSON对象字符串赋值嵌套结构体", func(t *testing.T) { + var s JSONNestedParent + _, err := structx.AttactToStruct(&s, map[string]string{ + "basic": `{"name":"John","age":"30","salary":"50000.5","is_active":"true","count":"5"}`, + "name": "parent", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Basic.Name != "John" || s.Basic.Age != 30 || s.Basic.Salary != 50000.5 || !s.Basic.IsActive || s.Basic.Count != 5 { + t.Fatalf("嵌套结构体不匹配: %+v", s.Basic) + } + if s.Name != "parent" { + t.Fatalf("父字段不匹配: %s", s.Name) + } + }) + + t.Run("无效JSON对象错误", func(t *testing.T) { + var s JSONNestedParent + _, err := structx.AttactToStruct(&s, map[string]string{ + "basic": "not-json-object", + }) + if err == nil { + t.Fatal("期望无效JSON对象错误但得到nil") + } + }) +} + +// ===== 覆盖已有值测试 ===== + +type OverwriteStruct struct { + Name string `json:"name"` + Age int `json:"age"` + Salary float64 `json:"salary"` + Active bool `json:"active"` +} + +func TestAttactToStruct_OverwriteValues(t *testing.T) { + t.Run("覆盖已有字符串值", func(t *testing.T) { + s := OverwriteStruct{Name: "OldName", Age: 10, Salary: 100.0, Active: false} + changes, err := structx.AttactToStruct(&s, map[string]string{ + "name": "NewName", + "age": "20", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Name != "NewName" || s.Age != 20 || s.Salary != 100.0 || s.Active != false { + t.Fatalf("覆盖不匹配: %+v", s) + } + // 验证ChangeInfo记录旧值 + if changes["name"].Old != "OldName" || changes["name"].New != "NewName" { + t.Fatalf("ChangeInfo.name 不匹配: %+v", changes["name"]) + } + if changes["age"].Old != 10 || changes["age"].New != 20 { + t.Fatalf("ChangeInfo.age 不匹配: %+v", changes["age"]) + } + }) + + t.Run("覆盖嵌套结构体", func(t *testing.T) { + s := NestedStruct{ + Basic: BasicStruct{Name: "Old", Age: 1, Salary: 1.0, IsActive: false, Count: 0}, + } + _, err := structx.AttactToStruct(&s, map[string]string{ + "basic.name": "New", + "basic.age": "99", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Basic.Name != "New" || s.Basic.Age != 99 || s.Basic.Salary != 1.0 { + t.Fatalf("嵌套覆盖不匹配: %+v", s.Basic) + } + }) +} + +// ===== 边界与错误场景测试 ===== + +type SimpleStruct struct { + Name string `json:"name"` +} + +func TestAttactToStruct_EmptyEdgeCases(t *testing.T) { + t.Run("空输入map", func(t *testing.T) { + var s BasicStruct + changes, err := structx.AttactToStruct(&s, map[string]string{}) + if err != nil { + t.Fatalf("空map不应错误: %v", err) + } + if len(changes) != 0 { + t.Fatalf("空map应无变更记录, 得到 %d", len(changes)) + } + }) + + t.Run("AllowUnknownFields选项", func(t *testing.T) { + s := structx.NewStructProcessor(structx.AllowUnknownFields()) + var v SimpleStruct + _, err := s.AttactToStruct(&v, map[string]string{ + "name": "hello", + "unknown_field": "ignored", + "another_unknown": "ignored", + }) + if err != nil { + t.Fatalf("AllowUnknownFields不应错误: %v", err) + } + if v.Name != "hello" { + t.Fatalf("期望 hello, 得到 %s", v.Name) + } + }) + + t.Run("默认不允许未知字段", func(t *testing.T) { + var v SimpleStruct + _, err := structx.AttactToStruct(&v, map[string]string{ + "name": "hello", + "bad_field": "error", + }) + if err == nil { + t.Fatal("期望未知字段错误但得到nil") + } + }) +} + +// ===== 全指针字段结构体测试 ===== + +type AllPointerStruct struct { + Name *string `json:"name"` + Age *int `json:"age"` +} + +func TestAttactToStruct_AllPointerFields(t *testing.T) { + t.Run("全指针字段赋值", func(t *testing.T) { + var s AllPointerStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "name": "Alice", + "age": "30", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Name == nil || *s.Name != "Alice" { + t.Fatalf("name指针不匹配: %v", s.Name) + } + if s.Age == nil || *s.Age != 30 { + t.Fatalf("age指针不匹配: %v", s.Age) + } + }) +} + +// ===== 同一结构体类型多次出现测试 ===== + +type SameTypeReusedStruct struct { + First BasicStruct `json:"first"` + Second BasicStruct `json:"second"` +} + +func TestAttactToStruct_SameTypeReused(t *testing.T) { + var s SameTypeReusedStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "first.name": "First", + "first.age": "10", + "second.name": "Second", + "second.age": "20", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.First.Name != "First" || s.First.Age != 10 { + t.Fatalf("First不匹配: %+v", s.First) + } + if s.Second.Name != "Second" || s.Second.Age != 20 { + t.Fatalf("Second不匹配: %+v", s.Second) + } +} + +// ===== ChangeInfo.Val 类型验证 ===== + +type TypedValStruct struct { + Name string `json:"name"` + Age int `json:"age"` + Salary float64 `json:"salary"` + Active bool `json:"active"` + Count uint `json:"count"` +} + +func TestAttactToStruct_ChangeInfo_ValTypes(t *testing.T) { + var s TypedValStruct + changes, err := structx.AttactToStruct(&s, map[string]string{ + "name": "test", + "age": "25", + "salary": "1000.5", + "active": "true", + "count": "42", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + + typeCheck := func(field string, val interface{}, expectedType string) { + if val == nil { + t.Fatalf("字段 %s 的Val为nil", field) + } + actualType := fmt.Sprintf("%T", val) + if actualType != expectedType { + t.Errorf("字段 %s 的Val类型期望 %s, 得到 %s", field, expectedType, actualType) + } + } + + typeCheck("name", changes["name"].Val, "string") + typeCheck("age", changes["age"].Val, "int") + typeCheck("salary", changes["salary"].Val, "float64") + typeCheck("active", changes["active"].Val, "bool") + typeCheck("count", changes["count"].Val, "uint") +} + +// ===== AttactToStructAny 综合测试 ===== + +func TestAttactToStructAny_Comprehensive(t *testing.T) { + t.Run("各种类型混合输入", func(t *testing.T) { + var s BasicStruct + _, err := structx.AttactToStructAny(&s, map[string]interface{}{ + "name": "John", + "age": 30, + "salary": 50000.50, + "is_active": true, + "count": uint(100), + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Name != "John" || s.Age != 30 || s.Salary != 50000.50 || !s.IsActive || s.Count != 100 { + t.Fatalf("值不匹配: %+v", s) + } + }) + + t.Run("int8/int16等小类型", func(t *testing.T) { + var s AllKindsStruct + _, err := structx.AttactToStructAny(&s, map[string]interface{}{ + "str": "hello", "i": 42, "i8": int8(127), "i16": int16(32767), + "i32": int32(100), "i64": int64(999), + "ui": uint(1), "ui8": uint8(255), "ui16": uint16(1), + "ui32": uint32(1), "ui64": uint64(1), + "f32": float32(1.5), "f64": float64(3.14), "b": true, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.I8 != 127 || s.I16 != 32767 || s.Ui8 != 255 { + t.Fatalf("小类型不匹配: %+v", s) + } + }) + + t.Run("嵌套map输入", func(t *testing.T) { + var s PointerNestedStruct + _, err := structx.AttactToStructAny(&s, map[string]interface{}{ + "basic_ptr.name": "FromAny", + "direct.name": "DirectAny", + "value": "val", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.BasicPtr == nil || s.BasicPtr.Name != "FromAny" { + t.Fatalf("BasicPtr不匹配: %+v", s.BasicPtr) + } + if s.Direct.Name != "DirectAny" { + t.Fatalf("Direct不匹配: %s", s.Direct.Name) + } + }) +} + +// ===== 并发访问测试 ===== + +func TestAttactToStruct_ConcurrentAccess(t *testing.T) { + var wg sync.WaitGroup + errChan := make(chan error, 20) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + var s BasicStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "name": fmt.Sprintf("goroutine_%d", idx), + "age": fmt.Sprintf("%d", 20+idx), + "salary": "1000.0", + "is_active": "true", + "count": "1", + }) + if err != nil { + errChan <- err + } + }(i) + } + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + var s NestedStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "basic.name": fmt.Sprintf("nested_%d", idx), + "basic.age": fmt.Sprintf("%d", idx), + "comment": "concurrent", + "amount": "1.23", + "amount2": "4.56", + "timestamp": "2024-01-01T00:00:00Z", + }) + if err != nil { + errChan <- err + } + }(i) + } + + wg.Wait() + close(errChan) + + for err := range errChan { + t.Errorf("并发访问错误: %v", err) + } +} + +// ===== decimal.Decimal 指针字段测试 ===== + +func TestAttactToStruct_DecimalPointer(t *testing.T) { + t.Run("*decimal.Decimal初始化与赋值", func(t *testing.T) { + var s NestedStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "amount": "123.45", + "amount2": "678.90", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if !s.Amount.Equal(decimal.RequireFromString("123.45")) { + t.Fatalf("amount期望 123.45, 得到 %s", s.Amount) + } + if s.Amount2 == nil { + t.Fatal("amount2不应为nil") + } + if !s.Amount2.Equal(decimal.RequireFromString("678.90")) { + t.Fatalf("amount2期望 678.90, 得到 %s", s.Amount2) + } + }) +} + +// ===== time.Time 字段测试 ===== + +func TestAttactToStruct_Time(t *testing.T) { + t.Run("time.Time赋值", func(t *testing.T) { + var s NestedStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "timestamp": "2024-06-15T10:30:00Z", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + expected, _ := time.Parse(time.RFC3339, "2024-06-15T10:30:00Z") + if !s.Timestamp.Equal(expected) { + t.Fatalf("期望 %v, 得到 %v", expected, s.Timestamp) + } + }) + + t.Run("time.Time RFC3339Nano格式", func(t *testing.T) { + var s NestedStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "timestamp": "2024-01-01T15:04:05.999999999Z", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + expected, _ := time.Parse(time.RFC3339Nano, "2024-01-01T15:04:05.999999999Z") + if !s.Timestamp.Equal(expected) { + t.Fatalf("期望 %v, 得到 %v", expected, s.Timestamp) + } + }) +} + +// ===== 自定义类型别名各基础类型全覆盖测试 ===== + +type CustomBool bool +type CustomFloat64 float64 +type CustomUint64 uint64 + +type FullAliasStruct struct { + Str CustomString `json:"str"` + I CustomInt `json:"i"` + Ui64 CustomUint64 `json:"ui64"` + F64 CustomFloat64 `json:"f64"` + B CustomBool `json:"b"` +} + +func TestAttactToStruct_FullAliasTypes(t *testing.T) { + var s FullAliasStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "str": "alias", + "i": "42", + "ui64": "18446744073709551615", + "f64": "3.14159", + "b": "true", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if string(s.Str) != "alias" || int(s.I) != 42 || uint64(s.Ui64) != 18446744073709551615 { + t.Fatalf("别名不匹配: str=%s i=%d ui64=%d", s.Str, s.I, s.Ui64) + } + if float64(s.F64) != 3.14159 || bool(s.B) != true { + t.Fatalf("别名不匹配: f64=%f b=%v", s.F64, s.B) + } +} + +// ===== ComplexStruct 全字段覆盖测试 ===== + +func TestAttactToStruct_ComplexStruct(t *testing.T) { + var s ComplexStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "basic.name": "Complex", + "basic.age": "50", + "nested.comment": "deep", + "nested.amount": "99.99", + "nested.timestamp": "2024-12-25T00:00:00Z", + "custom.id": "cid_001", + "custom.version": "3", + "custom.email": "c@test.com", + "timestamp": "2024-01-01T00:00:00Z", + "metadata": `{"env":"prod","region":"us-east-1"}`, + "unmarshaler": "raw_data", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Basic.Name != "Complex" || s.Basic.Age != 50 { + t.Errorf("Basic不匹配: %+v", s.Basic) + } + if s.Nested == nil { + t.Fatal("Nested不应为nil") + } + if s.Nested.Comment != "deep" { + t.Errorf("Nested.Comment不匹配: %s", s.Nested.Comment) + } + if s.Custom.ID != CustomString("cid_001") || s.Custom.Version != CustomInt(3) { + t.Errorf("Custom不匹配: %+v", s.Custom) + } + if s.Metadata["env"] != "prod" || s.Metadata["region"] != "us-east-1" { + t.Errorf("Metadata不匹配: %+v", s.Metadata) + } + if string(s.Unmarshaler) != "custom_raw_data" { + t.Errorf("Unmarshaler不匹配: %s", s.Unmarshaler) + } +} + +// ===== processNestedStruct UseNumber 精度测试 ===== + +func TestAttactToStruct_NestedJSONObject_LargeNumber(t *testing.T) { + // 验证嵌套JSON对象中的大整数不会因float64丢失精度 + var s JSONNestedParent + _, err := structx.AttactToStruct(&s, map[string]string{ + "basic": `{"name":"test","age":9007199254740993}`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + // age是int类型,9007199254740993 > 2^53,如果经过float64会变成9007199254740992 + if s.Basic.Age != 9007199254740993 { + t.Fatalf("大数精度丢失: 期望 9007199254740993, 得到 %d", s.Basic.Age) + } +} + +// ===== SetSliceElementValue bool 字符串值测试 ===== + +type SliceBoolStruct struct { + Flags []bool `json:"flags"` +} + +func TestAttactToStruct_SliceBoolString(t *testing.T) { + t.Run("布尔值字符串true/false", func(t *testing.T) { + var s SliceBoolStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "flags": `["true","false","true"]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Flags) != 3 || s.Flags[0] != true || s.Flags[1] != false || s.Flags[2] != true { + t.Fatalf("布尔切片不匹配: %v", s.Flags) + } + }) + + t.Run("布尔值数字1/0", func(t *testing.T) { + var s SliceBoolStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "flags": `[true,false]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if len(s.Flags) != 2 || s.Flags[0] != true || s.Flags[1] != false { + t.Fatalf("布尔切片不匹配: %v", s.Flags) + } + }) +} + +// ===== 嵌入式结构体测试 ===== + +type EmbeddedBase struct { + BaseName string `json:"base_name"` + BaseAge int `json:"base_age"` +} + +type EmbeddedStruct struct { + EmbeddedBase + OwnField string `json:"own_field"` +} + +func TestAttactToStruct_EmbeddedStruct(t *testing.T) { + t.Run("嵌入式字段直接访问", func(t *testing.T) { + var s EmbeddedStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "base_name": "embedded", + "base_age": "25", + "own_field": "own", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.BaseName != "embedded" || s.BaseAge != 25 || s.OwnField != "own" { + t.Fatalf("嵌入式结构体不匹配: %+v", s) + } + }) + + t.Run("嵌入式指针结构体字段提升", func(t *testing.T) { + type Inner struct { + Val string `json:"val"` + } + type Outer struct { + *Inner + Name string `json:"name"` + } + + var s Outer + _, err := structx.AttactToStruct(&s, map[string]string{ + "val": "inner_value", + "name": "outer_name", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Inner == nil { + t.Fatal("Inner不应为nil") + } + if s.Val != "inner_value" || s.Name != "outer_name" { + t.Fatalf("指针嵌入不匹配: Val=%s, Name=%s", s.Val, s.Name) + } + }) +} + +// ===== interface{} 字段测试 ===== + +type InterfaceStruct struct { + Data interface{} `json:"data"` + Name string `json:"name"` +} + +func TestAttactToStruct_InterfaceField(t *testing.T) { + t.Run("interface{}字符串值", func(t *testing.T) { + var s InterfaceStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "data": `"hello"`, + "name": "test", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Name != "test" { + t.Fatalf("Name不匹配: %s", s.Name) + } + }) + + t.Run("interface{}数字值", func(t *testing.T) { + var s InterfaceStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "data": "12345", + "name": "num", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Name != "num" { + t.Fatalf("Name不匹配: %s", s.Name) + } + }) + + t.Run("interface{}对象值", func(t *testing.T) { + var s InterfaceStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "data": `{"key":"value"}`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + result, ok := s.Data.(map[string]interface{}) + if !ok { + t.Fatalf("期望map[string]interface{}, 得到 %T", s.Data) + } + if result["key"] != "value" { + t.Fatalf("data.key不匹配: %v", result["key"]) + } + }) +} + +// ===== *[]Type 指针指向切片测试 ===== + +type PtrSliceStruct struct { + Tags *[]string `json:"tags"` +} + +func TestAttactToStruct_PtrToSlice(t *testing.T) { + var s PtrSliceStruct + _, err := structx.AttactToStruct(&s, map[string]string{ + "tags": `["a","b","c"]`, + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if s.Tags == nil { + t.Fatal("Tags不应为nil") + } + if len(*s.Tags) != 3 || (*s.Tags)[0] != "a" { + t.Fatalf("Tags不匹配: %v", *s.Tags) + } +} + +// ===== ChangeInfo Old/New 类型验证(any) ===== + +func TestAttactToStruct_ChangeInfo_OldNewTypes(t *testing.T) { + s := OverwriteStruct{Name: "old", Age: 10, Salary: 1.5, Active: true} + changes, err := structx.AttactToStruct(&s, map[string]string{ + "name": "new", + "age": "20", + "salary": "2.5", + "active": "false", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + + // 验证Old是原始Go类型,而非字符串 + if _, ok := changes["name"].Old.(string); !ok { + t.Errorf("name.Old期望string, 得到 %T", changes["name"].Old) + } + if _, ok := changes["age"].Old.(int); !ok { + t.Errorf("age.Old期望int, 得到 %T", changes["age"].Old) + } + if _, ok := changes["salary"].Old.(float64); !ok { + t.Errorf("salary.Old期望float64, 得到 %T", changes["salary"].Old) + } + if _, ok := changes["active"].Old.(bool); !ok { + t.Errorf("active.Old期望bool, 得到 %T", changes["active"].Old) + } +} + +// ===== ChangeInfo Nested Old/New 验证 ===== + +func TestAttactToStruct_ChangeInfo_NestedOldNew(t *testing.T) { + var s NestedStruct + s.Basic = BasicStruct{Name: "old_name", Age: 10} + changes, err := structx.AttactToStruct(&s, map[string]string{ + "basic.name": "new_name", + "basic.age": "99", + }) + if err != nil { + t.Fatalf("意外错误: %v", err) + } + if changes["basic.name"].Old != "old_name" || changes["basic.name"].New != "new_name" { + t.Errorf("basic.name Old/New 不匹配: Old=%v, New=%v", + changes["basic.name"].Old, changes["basic.name"].New) + } + if changes["basic.age"].Old != 10 || changes["basic.age"].New != 99 { + t.Errorf("basic.age Old/New 不匹配: Old=%v, New=%v", + changes["basic.age"].Old, changes["basic.age"].New) + } } \ No newline at end of file diff --git a/utils.go b/utils.go index a308aa1..0bf5bdf 100644 --- a/utils.go +++ b/utils.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strings" + "time" ) // 工具函数 @@ -39,9 +40,21 @@ func hasUnmarshalText(t reflect.Type) bool { // 设置json.Unmarshaler接口的值 func setUnmarshalJSONValue(field reflect.Value, value interface{}) error { - jsonBytes, err := json.Marshal(value) - if err != nil { - return err + var jsonBytes []byte + if str, ok := value.(string); ok { + // 尝试直接将字符串作为JSON解析;如果是合法JSON则直接使用 + if json.Valid([]byte(str)) { + jsonBytes = []byte(str) + } else { + // 不是合法JSON,作为JSON字符串值处理(加引号) + jsonBytes = []byte(`"` + str + `"`) + } + } else { + var err error + jsonBytes, err = json.Marshal(value) + if err != nil { + return err + } } var fieldAddr reflect.Value @@ -91,6 +104,16 @@ func setBasicStructValue(field reflect.Value, value string) error { if hasUnmarshalText(field.Type()) { return setUnmarshalTextValue(field, value) } + + if field.Type().String() == "time.Duration" { + d, err := time.ParseDuration(value) + if err != nil { + return err + } + field.SetInt(int64(d)) + return nil + } + return json.Unmarshal([]byte(value), field.Addr().Interface()) } @@ -111,6 +134,9 @@ func isTypeAlias(t reflect.Type) bool { if t.Kind() == reflect.Ptr { t = t.Elem() } + if basicStructTypes[t.String()] { + return false + } return t.PkgPath() != "" && basicKinds[t.Kind()] } @@ -158,17 +184,7 @@ func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) return field.Interface(), nil } - var jsonData interface{} - if err := json.Unmarshal([]byte(value), &jsonData); err != nil { - return nil, err - } - - jsonBytes, err := json.Marshal(jsonData) - if err != nil { - return nil, err - } - - if err := json.Unmarshal(jsonBytes, field.Addr().Interface()); err != nil { + if err := json.Unmarshal([]byte(value), field.Addr().Interface()); err != nil { return nil, err } diff --git a/value_setter.go b/value_setter.go index 586aa97..31dd6a7 100644 --- a/value_setter.go +++ b/value_setter.go @@ -3,6 +3,7 @@ package structx import ( "fmt" "reflect" + "strconv" ) // ValueSetter 值设置器接口 @@ -35,6 +36,13 @@ func (ds *defaultValueSetter) SetFieldValue(field reflect.Value, value string) ( return setPointerFieldValue(field, value) } + if isBasicStructType(fieldType) { + if err := setBasicStructValue(field, value); err != nil { + return nil, err + } + return field.Interface(), nil + } + if isTypeAlias(fieldType) { return setTypeAliasValue(field, value) } @@ -73,28 +81,36 @@ func (ds *defaultValueSetter) SetSliceElementValue(elemValue reflect.Value, item case reflect.String: elemValue.SetString(convertToString(item)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - num, err := convertToFloat64(item) + itemStr := convertToString(item) + intVal, err := strconv.ParseInt(itemStr, 10, elemType.Bits()) if err != nil { return fmt.Errorf("无法转换为整型: %w", err) } - elemValue.SetInt(int64(num)) + elemValue.SetInt(intVal) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - num, err := convertToFloat64(item) + itemStr := convertToString(item) + uintVal, err := strconv.ParseUint(itemStr, 10, elemType.Bits()) if err != nil { return fmt.Errorf("无法转换为无符号整型: %w", err) } - elemValue.SetUint(uint64(num)) + elemValue.SetUint(uintVal) case reflect.Float32, reflect.Float64: - num, err := convertToFloat64(item) + itemStr := convertToString(item) + floatVal, err := strconv.ParseFloat(itemStr, elemType.Bits()) if err != nil { return fmt.Errorf("无法转换为浮点型: %w", err) } - elemValue.SetFloat(num) + elemValue.SetFloat(floatVal) case reflect.Bool: if b, ok := item.(bool); ok { elemValue.SetBool(b) } else { - return fmt.Errorf("无法转换为布尔型") + itemStr := convertToString(item) + b, err := strconv.ParseBool(itemStr) + if err != nil { + return fmt.Errorf("无法转换为布尔型: %w", err) + } + elemValue.SetBool(b) } case reflect.Struct: if isBasicStructType(elemType) {