From 02152b44bf5f385e618257e90d49dd64b59eddb5 Mon Sep 17 00:00:00 2001 From: Yun Date: Sat, 20 Sep 2025 21:17:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=84=E7=90=86=E6=95=B0=E7=BB=84=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E4=BD=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- structx.go | 410 ++++++++++++++++++++++++++++++++---------------- structx_test.go | 9 +- 2 files changed, 283 insertions(+), 136 deletions(-) diff --git a/structx.go b/structx.go index de094af..b64ab4e 100644 --- a/structx.go +++ b/structx.go @@ -35,9 +35,9 @@ var typeConverters = map[reflect.Kind]converterFunc{ reflect.Float32: convertFloat[float32], reflect.Float64: convertFloat[float64], reflect.String: convertString, - // reflect.Slice: convertSlice, - // reflect.Array: convertArray, - // reflect.Map: convertMap, + reflect.Slice: convertSlice, + reflect.Array: convertArray, + reflect.Map: convertMap, } // AttactToStructAny 将 map[string]interface{} 类型的值附加到结构体中 @@ -67,11 +67,13 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri // 获取结构体类型信息 t := v.Type() fieldMap := buildFieldMap(t) + fmt.Printf("字段映射2: %+v %+v\n", fieldMap,fieldMap["data"]) for mapKey, mapValue := range updateMap { fieldInfo, exists := fieldMap[mapKey] + fmt.Printf("处理字段: %s, 信息: %+v\n", mapKey, fieldInfo) if !exists { - continue // 忽略不存在的字段 + return nil, fmt.Errorf("字段 %s 不存在", mapKey) } // 安全获取字段值,处理嵌套指针 @@ -87,14 +89,22 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri // 处理指针类型 if fieldInfo.IsPtr { if field.IsNil() { - // 创建新的指针实例 newValue := reflect.New(fieldInfo.FieldType.Elem()) field.Set(newValue) } field = field.Elem() } - // 处理嵌套结构体(包括指针解引用后的结构体) + // 处理切片和数组类型 + if fieldInfo.IsSlice || fieldInfo.IsArray { + err := processSliceOrArrayField(field, mapValue, mapKey) + if err != nil { + return nil, fmt.Errorf("处理切片/数组字段 %s 失败: %w", mapKey, err) + } + continue + } + + // 处理嵌套结构体 if field.Kind() == reflect.Struct && !isBasicStructType(field.Type()) { nestedChanges, err := processNestedStruct(field, mapValue, mapKey) if err != nil { @@ -136,6 +146,122 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri return changeMap, nil } +// 处理切片和数组字段 +func processSliceOrArrayField(field reflect.Value, value string, fieldKey string) error { + // 尝试解析JSON数组 + var jsonData []interface{} + if err := json.Unmarshal([]byte(value), &jsonData); err != nil { + return fmt.Errorf("字段 %s 的值必须是JSON数组格式: %w", fieldKey, err) + } + + fieldType := field.Type() + elemType := fieldType.Elem() + + // 创建新的切片或数组 + var newContainer reflect.Value + + if fieldType.Kind() == reflect.Slice { + // 切片:动态长度 + newContainer = reflect.MakeSlice(fieldType, len(jsonData), len(jsonData)) + } else if fieldType.Kind() == reflect.Array { + // 数组:固定长度 + if len(jsonData) != fieldType.Len() { + return fmt.Errorf("字段 %s 的数组长度不匹配: 期望 %d, 实际 %d", + fieldKey, fieldType.Len(), len(jsonData)) + } + newContainer = reflect.New(fieldType).Elem() + } else { + return fmt.Errorf("字段 %s 不是切片或数组类型", fieldKey) + } + + // 填充数据 + for i, item := range jsonData { + elemValue := newContainer.Index(i) + err := setSliceElementValue(elemValue, item, elemType) + if err != nil { + return fmt.Errorf("设置切片/数组元素失败: %w", err) + } + } + + field.Set(newContainer) + return nil +} + +// 设置切片/数组元素的值 +func setSliceElementValue(elemValue reflect.Value, item interface{}, elemType reflect.Type) error { + // 处理指针类型的元素 + if elemType.Kind() == reflect.Ptr { + if elemValue.IsNil() { + elemValue.Set(reflect.New(elemType.Elem())) + } + return setSliceElementValue(elemValue.Elem(), item, elemType.Elem()) + } + + // 根据元素类型进行转换 + switch elemType.Kind() { + case reflect.String: + if str, ok := item.(string); ok { + elemValue.SetString(str) + } else { + elemValue.SetString(fmt.Sprintf("%v", item)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if num, ok := item.(float64); ok { // JSON数字默认是float64 + elemValue.SetInt(int64(num)) + } else { + return fmt.Errorf("无法将 %v 转换为整型", item) + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if num, ok := item.(float64); ok { + elemValue.SetUint(uint64(num)) + } else { + return fmt.Errorf("无法将 %v 转换为无符号整型", item) + } + + case reflect.Float32, reflect.Float64: + if num, ok := item.(float64); ok { + elemValue.SetFloat(num) + } else { + return fmt.Errorf("无法将 %v 转换为浮点型", item) + } + + case reflect.Bool: + if b, ok := item.(bool); ok { + elemValue.SetBool(b) + } else { + return fmt.Errorf("无法将 %v 转换为布尔型", item) + } + + case reflect.Struct: + // 处理结构体元素 + if isBasicStructType(elemType) { + return setBasicStructElement(elemValue, item, elemType) + } + // 对于自定义结构体,需要JSON unmarshal + jsonBytes, err := json.Marshal(item) + if err != nil { + return err + } + return json.Unmarshal(jsonBytes, elemValue.Addr().Interface()) + + default: + return fmt.Errorf("不支持的切片元素类型: %s", elemType.Kind()) + } + + return nil +} + +// 设置基本结构体元素(如time.Time) +func setBasicStructElement(elemValue reflect.Value, item interface{}, elemType reflect.Type) error { + jsonBytes, err := json.Marshal(item) + if err != nil { + return err + } + return json.Unmarshal(jsonBytes, elemValue.Addr().Interface()) +} + // getFieldByIndexSafe 安全地通过索引路径获取字段,处理嵌套指针 func getFieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, error) { if len(index) == 0 { @@ -173,17 +299,21 @@ type fieldInfo struct { Name string IsPtr bool FieldType reflect.Type + IsSlice bool // 新增:标识是否为切片或数组 + IsArray bool // 新增:标识是否为数组 } func buildFieldMap(t reflect.Type) map[string]fieldInfo { fieldMap := make(map[string]fieldInfo) buildFieldMapRecursive(t, []int{}, fieldMap, "") + fmt.Printf("字段映射表: %+v\n", fieldMap) return fieldMap } func buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]fieldInfo, prefix string) { for i := 0; i < t.NumField(); i++ { field := t.Field(i) + fmt.Printf("处理字段 索引: %+v %+v\n", field, field.Type.Kind()) if !field.IsExported() { continue } @@ -205,12 +335,22 @@ func buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]fie actualType = fieldType.Elem() } + // 处理切片和数组类型 - 不进行递归展开 + if actualType.Kind() == reflect.Slice || actualType.Kind() == reflect.Array { + fieldMap[fullKey] = fieldInfo{ + Index: currentIndex, + Name: field.Name, + IsPtr: isPtr, + FieldType: fieldType, + IsSlice: true, + } + continue + } + // 如果是结构体且不是基本类型,递归处理 if actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { - // 对于指针或非指针的结构体都进行递归展开 buildFieldMapRecursive(actualType, currentIndex, fieldMap, fullKey) - // 同时记录当前字段的信息 fieldMap[fullKey] = fieldInfo{ Index: currentIndex, Name: field.Name, @@ -303,64 +443,64 @@ func processNestedStruct(field reflect.Value, value string, parentKey string) (m // 检测是否为自定义类型 func isCustomType(t reflect.Type) bool { - // 排除基本类型 - if t.PkgPath() == "" { - return false - } - - // 排除已知的基本结构体类型 - if isBasicStructType(t) { - return false - } - - // 排除接口类型 - if t.Kind() == reflect.Interface { - return false - } - - return true + // 排除基本类型 + if t.PkgPath() == "" { + return false + } + + // 排除已知的基本结构体类型 + if isBasicStructType(t) { + return false + } + + // 排除接口类型 + if t.Kind() == reflect.Interface { + return false + } + + return true } // 设置字段值 // 设置字段值 func setFieldValue(field reflect.Value, value string) (interface{}, error) { - kind := field.Kind() + kind := field.Kind() - // 检测是否为自定义类型(有包路径的类型) - if isCustomType(field.Type()) { - return setCustomTypeValue(field, value) - } + // 检测是否为自定义类型(有包路径的类型) + if isCustomType(field.Type()) { + return setCustomTypeValue(field, value) + } - // 处理指针类型的基础类型 - if kind == reflect.Ptr { - return setPointerFieldValue(field, value) - } + // 处理指针类型的基础类型 + if kind == reflect.Ptr { + return setPointerFieldValue(field, value) + } - converter, exists := typeConverters[kind] - if !exists { - return nil, fmt.Errorf("不支持的类型: %s", kind.String()) - } + converter, exists := typeConverters[kind] + if !exists { + return nil, fmt.Errorf("不支持的类型: %s", kind.String()) + } - result, err := converter(field, value) - if err != nil { - return nil, err - } + result, err := converter(field, value) + if err != nil { + return nil, err + } - field.Set(reflect.ValueOf(result)) - return result, nil + field.Set(reflect.ValueOf(result)) + return result, nil } // 处理指针类型的字段 func setPointerFieldValue(field reflect.Value, value string) (interface{}, error) { - if field.IsNil() { - // 创建新的指针实例 - elemType := field.Type().Elem() - newValue := reflect.New(elemType) - field.Set(newValue) - } - - // 递归处理指针指向的值 - return setFieldValue(field.Elem(), value) + if field.IsNil() { + // 创建新的指针实例 + elemType := field.Type().Elem() + newValue := reflect.New(elemType) + field.Set(newValue) + } + + // 递归处理指针指向的值 + return setFieldValue(field.Elem(), value) } // 类型转换函数 @@ -401,14 +541,20 @@ func convertString(field reflect.Value, value string) (interface{}, error) { func convertSlice(field reflect.Value, value string) (interface{}, error) { // 实现切片转换逻辑 - // elemType := field.Type().Elem() - // 根据元素类型进行解析 - return nil, fmt.Errorf("切片转换未实现") + var result []interface{} + if err := json.Unmarshal([]byte(value), &result); err != nil { + return nil, err + } + return result, nil } func convertArray(field reflect.Value, value string) (interface{}, error) { // 实现数组转换逻辑 - return nil, fmt.Errorf("数组转换未实现") + var result []interface{} + if err := json.Unmarshal([]byte(value), &result); err != nil { + return nil, err + } + return result, nil } func convertMap(field reflect.Value, value string) (interface{}, error) { @@ -419,92 +565,92 @@ func convertMap(field reflect.Value, value string) (interface{}, error) { // 处理自定义类型 // 处理自定义类型 func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { - // 检查是否实现了TextUnmarshaler接口 - if unmarshaler, ok := field.Addr().Interface().(interface { - UnmarshalText([]byte) error - }); ok { - err := unmarshaler.UnmarshalText([]byte(value)) - if err != nil { - return nil, err - } - return field.Interface(), nil - } + // 检查是否实现了TextUnmarshaler接口 + if unmarshaler, ok := field.Addr().Interface().(interface { + UnmarshalText([]byte) error + }); ok { + err := unmarshaler.UnmarshalText([]byte(value)) + if err != nil { + return nil, err + } + return field.Interface(), nil + } - // 对于其他自定义类型,我们需要获取其基础类型并进行转换 - baseType := getBaseType(field.Type()) - if baseType == nil { - return nil, fmt.Errorf("不支持的自定义类型: %s", field.Type().String()) - } + // 对于其他自定义类型,我们需要获取其基础类型并进行转换 + baseType := getBaseType(field.Type()) + if baseType == nil { + return nil, fmt.Errorf("不支持的自定义类型: %s", field.Type().String()) + } - // 创建基础类型的值并进行转换 - baseValue := reflect.New(baseType).Elem() - converter, exists := typeConverters[baseType.Kind()] - if !exists { - return nil, fmt.Errorf("不支持的基础类型: %s", baseType.Kind().String()) - } + // 创建基础类型的值并进行转换 + baseValue := reflect.New(baseType).Elem() + converter, exists := typeConverters[baseType.Kind()] + if !exists { + return nil, fmt.Errorf("不支持的基础类型: %s", baseType.Kind().String()) + } - result, err := converter(baseValue, value) - if err != nil { - return nil, err - } + result, err := converter(baseValue, value) + if err != nil { + return nil, err + } - // 将基础类型值转换回自定义类型 - convertedValue, err := convertToCustomType(field.Type(), result) - if err != nil { - return nil, err - } + // 将基础类型值转换回自定义类型 + convertedValue, err := convertToCustomType(field.Type(), result) + if err != nil { + return nil, err + } - field.Set(reflect.ValueOf(convertedValue)) - return convertedValue, nil + field.Set(reflect.ValueOf(convertedValue)) + return convertedValue, nil } // 获取自定义类型的基础类型 func getBaseType(t reflect.Type) reflect.Type { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - // 如果是自定义类型(有包路径),获取其底层类型 - if t.PkgPath() != "" { - return t - } - - return t + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 如果是自定义类型(有包路径),获取其底层类型 + if t.PkgPath() != "" { + return t + } + + return t } // 将基础类型值转换为自定义类型 func convertToCustomType(customType reflect.Type, value interface{}) (interface{}, error) { - valueType := reflect.TypeOf(value) - customBaseType := customType - - // 处理指针类型的自定义类型 - if customType.Kind() == reflect.Ptr { - customBaseType = customType.Elem() - // 创建新的指针实例 - newValue := reflect.New(customBaseType) - elemValue := newValue.Elem() - - // 尝试将值设置到元素 - if valueType.AssignableTo(customBaseType) { - elemValue.Set(reflect.ValueOf(value)) - } else if reflect.ValueOf(value).Type().ConvertibleTo(customBaseType) { - converted := reflect.ValueOf(value).Convert(customBaseType) - elemValue.Set(converted) - } else { - return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customBaseType) - } - - return newValue.Interface(), nil - } - - // 处理非指针类型的自定义类型 - if valueType.AssignableTo(customType) { - return value, nil - } - - if reflect.ValueOf(value).Type().ConvertibleTo(customType) { - return reflect.ValueOf(value).Convert(customType).Interface(), nil - } - - return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customType) -} \ No newline at end of file + valueType := reflect.TypeOf(value) + customBaseType := customType + + // 处理指针类型的自定义类型 + if customType.Kind() == reflect.Ptr { + customBaseType = customType.Elem() + // 创建新的指针实例 + newValue := reflect.New(customBaseType) + elemValue := newValue.Elem() + + // 尝试将值设置到元素 + if valueType.AssignableTo(customBaseType) { + elemValue.Set(reflect.ValueOf(value)) + } else if reflect.ValueOf(value).Type().ConvertibleTo(customBaseType) { + converted := reflect.ValueOf(value).Convert(customBaseType) + elemValue.Set(converted) + } else { + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customBaseType) + } + + return newValue.Interface(), nil + } + + // 处理非指针类型的自定义类型 + if valueType.AssignableTo(customType) { + return value, nil + } + + if reflect.ValueOf(value).Type().ConvertibleTo(customType) { + return reflect.ValueOf(value).Convert(customType).Interface(), nil + } + + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customType) +} diff --git a/structx_test.go b/structx_test.go index cc7895c..3b41821 100644 --- a/structx_test.go +++ b/structx_test.go @@ -463,10 +463,10 @@ func TestAttactToStruct_ErrorScenarios(t *testing.T) { expectedErr: "需要是非空指针", }, { - name: "不支持的类型", + name: "自定义类型", structPtr: &struct{ Data []string }{}, - input: map[string]string{"data": "test"}, - expectedErr: "不支持的类型", + input: map[string]string{"Data": `["test","test2"]`}, + expectedErr: "", }, { name: "无效的嵌套路径", @@ -475,7 +475,7 @@ func TestAttactToStruct_ErrorScenarios(t *testing.T) { "nonexistent.field": "value", "basic.name": "test", }, - expectedErr: "", // 应该忽略不存在的字段而不报错 + expectedErr: "字段 nonexistent.field 不存在", // 应该忽略不存在的字段而不报错 }, } @@ -592,6 +592,7 @@ func BenchmarkAttactToStruct_Nested(b *testing.B) { "basic.salary": "50000.0", "basic.is_active": "true", "comment": "test", + "amount": "500.0", } b.ResetTimer()