diff --git a/structx.go b/structx.go index 87f17e2..eba7f68 100644 --- a/structx.go +++ b/structx.go @@ -40,6 +40,9 @@ var typeConverters = map[reflect.Kind]converterFunc{ reflect.Map: convertMap, } +// 缓存类型信息以避免重复反射操作 +var typeInfoCache = make(map[reflect.Type]map[string]fieldInfo) + // AttactToStructAny 将 map[string]interface{} 类型的值附加到结构体中 func AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) (map[string]ChangeInfo, error) { stringMap := make(map[string]string, len(updateMap)) @@ -54,6 +57,7 @@ func AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) ( return AttactToStruct(structxx, stringMap) } + // AttactToStruct 将 map 的数据赋值到结构体中,支持嵌套结构体和指针 func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[string]ChangeInfo, error) { v := reflect.ValueOf(structxx) @@ -67,37 +71,54 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri // 获取结构体类型信息 t := v.Type() fieldMap := buildFieldMap(t) - fmt.Printf("字段映射2: %+v %+v\n", fieldMap, fieldMap["data"]) for mapKey, mapValue := range updateMap { fieldInfo, exists := fieldMap[mapKey] - fmt.Printf("处理字段: %s, 信息: %+v\n", mapKey, fieldInfo) if !exists { return nil, fmt.Errorf("字段 %s 不存在", mapKey) } - // 安全获取字段值,处理嵌套指针 - field, err := getFieldByIndexSafe(v, fieldInfo.Index) - if err != nil { - return nil, fmt.Errorf("获取字段 %s 失败: %w", mapKey, err) + // 获取字段值,正确处理指针 + field := v + for i, idx := range fieldInfo.Index { + // 确保当前字段是结构体类型(不是指针) + if field.Kind() == reflect.Ptr { + if field.IsNil() { + newValue := reflect.New(field.Type().Elem()) + field.Set(newValue) + } + field = field.Elem() + } + + if field.Kind() != reflect.Struct { + return nil, fmt.Errorf("字段索引 %v 不是结构体类型", fieldInfo.Index[:i+1]) + } + + field = field.Field(idx) + if !field.IsValid() { + return nil, fmt.Errorf("字段索引 %v 无效", fieldInfo.Index[:i+1]) + } } - if !field.IsValid() { - continue - } - - // 处理指针类型 - if fieldInfo.IsPtr { + // 现在 field 是最终的字段值,可能是指针或非指针 + var actualField reflect.Value + if field.Kind() == reflect.Ptr { if field.IsNil() { - newValue := reflect.New(fieldInfo.FieldType.Elem()) + newValue := reflect.New(field.Type().Elem()) field.Set(newValue) } - field = field.Elem() + actualField = field.Elem() + } else { + actualField = field + } + + if !actualField.IsValid() { + continue } // 处理切片和数组类型 if fieldInfo.IsSlice || fieldInfo.IsArray { - err := processSliceOrArrayField(field, mapValue, mapKey) + err := processSliceOrArrayField(actualField, mapValue, mapKey) if err != nil { return nil, fmt.Errorf("处理切片/数组字段 %s 失败: %w", mapKey, err) } @@ -105,8 +126,8 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri } // 处理嵌套结构体 - if field.Kind() == reflect.Struct && !isBasicStructType(field.Type()) { - nestedChanges, err := processNestedStruct(field, mapValue, mapKey) + if actualField.Kind() == reflect.Struct && !isBasicStructType(actualField.Type()) { + nestedChanges, err := processNestedStruct(actualField, mapValue, mapKey) if err != nil { return nil, fmt.Errorf("处理嵌套结构体字段 %s 失败: %w", mapKey, err) } @@ -116,17 +137,17 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri continue } - if !field.CanSet() { + if !actualField.CanSet() { continue } // 处理基本类型 - oldValueStr, err := cast.ToStringE(field.Interface()) + oldValueStr, err := cast.ToStringE(actualField.Interface()) if err != nil { - return nil, fmt.Errorf("获取字段 %s 的旧值失败: %w", mapKey, err) + oldValueStr = fmt.Sprintf("%v", actualField.Interface()) } - newValue, err := setFieldValue(field, mapValue) + newValue, err := setFieldValue(actualField, mapValue) if err != nil { return nil, fmt.Errorf("设置字段 %s 的值失败: %w", mapKey, err) } @@ -146,6 +167,7 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri return changeMap, nil } + // 处理切片和数组字段 func processSliceOrArrayField(field reflect.Value, value string, fieldKey string) error { // 尝试解析JSON数组 @@ -161,10 +183,8 @@ func processSliceOrArrayField(field reflect.Value, value string, fieldKey string var newContainer reflect.Value if fieldType.Kind() == reflect.Slice { - // 切片:动态长度 newContainer = reflect.MakeSlice(fieldType, len(jsonData), len(jsonData)) } else if fieldType.Kind() == reflect.Array { - // 数组:固定长度 if len(jsonData) != fieldType.Len() { return fmt.Errorf("字段 %s 的数组长度不匹配: 期望 %d, 实际 %d", fieldKey, fieldType.Len(), len(jsonData)) @@ -205,51 +225,41 @@ func setSliceElementValue(elemValue reflect.Value, item interface{}, elemType re // 根据元素类型进行转换 switch elemType.Kind() { case reflect.String: - if str, ok := item.(string); ok { - elemValue.SetString(str) - } else { - elemValue.SetString(fmt.Sprintf("%v", item)) - } + elemValue.SetString(convertToString(item)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if num, ok := item.(float64); ok { // JSON数字默认是float64 - elemValue.SetInt(int64(num)) - } else { - return fmt.Errorf("无法将 %v 转换为整型", item) + num, err := convertToFloat64(item) + if err != nil { + return fmt.Errorf("无法将 %v 转换为整型: %w", item, err) } + elemValue.SetInt(int64(num)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if num, ok := item.(float64); ok { - elemValue.SetUint(uint64(num)) - } else { - return fmt.Errorf("无法将 %v 转换为无符号整型", item) + num, err := convertToFloat64(item) + if err != nil { + return fmt.Errorf("无法将 %v 转换为无符号整型: %w", item, err) } + elemValue.SetUint(uint64(num)) case reflect.Float32, reflect.Float64: - if num, ok := item.(float64); ok { - elemValue.SetFloat(num) - } else { - return fmt.Errorf("无法将 %v 转换为浮点型", item) + num, err := convertToFloat64(item) + if err != nil { + return fmt.Errorf("无法将 %v 转换为浮点型: %w", item, err) } + elemValue.SetFloat(num) case reflect.Bool: - if b, ok := item.(bool); ok { - elemValue.SetBool(b) - } else { + b, ok := item.(bool) + if !ok { return fmt.Errorf("无法将 %v 转换为布尔型", item) } + elemValue.SetBool(b) case reflect.Struct: - // 处理结构体元素 if isBasicStructType(elemType) { return setBasicStructElement(elemValue, item, elemType) } - // 对于自定义结构体,需要JSON unmarshal - jsonBytes, err := json.Marshal(item) - if err != nil { - return err - } - return json.Unmarshal(jsonBytes, elemValue.Addr().Interface()) + return setStructElement(elemValue, item) default: return fmt.Errorf("不支持的切片元素类型: %s", elemType.Kind()) @@ -258,8 +268,31 @@ func setSliceElementValue(elemValue reflect.Value, item interface{}, elemType re return nil } -// 设置基本结构体元素(如time.Time) -func setBasicStructElement(elemValue reflect.Value, item interface{}, elemType reflect.Type) error { +// 辅助转换函数 +func convertToString(item interface{}) string { + if str, ok := item.(string); ok { + return str + } + return fmt.Sprintf("%v", item) +} + +func convertToFloat64(item interface{}) (float64, error) { + switch v := item.(type) { + case float64: + return v, nil + case int, int32, int64: + return float64(reflect.ValueOf(v).Int()), nil + case uint, uint32, uint64: + return float64(reflect.ValueOf(v).Uint()), nil + case float32: + return float64(v), nil + default: + return 0, fmt.Errorf("无法转换为数字") + } +} + +// 设置结构体元素 +func setStructElement(elemValue reflect.Value, item interface{}) error { jsonBytes, err := json.Marshal(item) if err != nil { return err @@ -267,22 +300,25 @@ func setBasicStructElement(elemValue reflect.Value, item interface{}, elemType r return json.Unmarshal(jsonBytes, elemValue.Addr().Interface()) } +// 设置基本结构体元素 +func setBasicStructElement(elemValue reflect.Value, item interface{}, elemType reflect.Type) error { + return setStructElement(elemValue, item) +} + // getFieldByIndexSafe 安全地通过索引路径获取字段,处理嵌套指针 func getFieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, error) { if len(index) == 0 { return v, nil } - // 获取当前层级的字段 field := v.Field(index[0]) if !field.IsValid() { return reflect.Value{}, fmt.Errorf("字段索引 %d 无效", index[0]) } - // 递归解引用指针,直到遇到非指针类型 + // 递归解引用指针 for field.Kind() == reflect.Ptr { if field.IsNil() { - // 创建新的指针实例 elemType := field.Type().Elem() newValue := reflect.New(elemType) field.Set(newValue) @@ -290,7 +326,6 @@ func getFieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, error) { field = field.Elem() } - // 递归处理剩余的索引路径 if len(index) > 1 { return getFieldByIndexSafe(field, index[1:]) } @@ -298,27 +333,33 @@ func getFieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, error) { return field, nil } -// 构建字段映射表,支持嵌套结构体和指针类型 +// 构建字段映射表 type fieldInfo struct { Index []int Name string IsPtr bool FieldType reflect.Type - IsSlice bool // 新增:标识是否为切片或数组 - IsArray bool // 新增:标识是否为数组 + IsSlice bool + IsArray bool } func buildFieldMap(t reflect.Type) map[string]fieldInfo { + // 检查缓存 + if cached, exists := typeInfoCache[t]; exists { + return cached + } + fieldMap := make(map[string]fieldInfo) buildFieldMapRecursive(t, []int{}, fieldMap, "") - fmt.Printf("字段映射表: %+v\n", fieldMap) + + // 缓存结果 + typeInfoCache[t] = fieldMap return fieldMap } func buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]fieldInfo, prefix string) { for i := 0; i < t.NumField(); i++ { field := t.Field(i) - fmt.Printf("处理字段 索引: %+v %+v\n", field, field.Type.Kind()) if !field.IsExported() { continue } @@ -333,36 +374,27 @@ func buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]fie fieldType := field.Type isPtr := fieldType.Kind() == reflect.Ptr - - // 解引用指针类型以获取实际类型 actualType := fieldType if isPtr { actualType = fieldType.Elem() } - // 处理切片和数组类型 - 不进行递归展开 + // 处理切片和数组类型 if actualType.Kind() == reflect.Slice || actualType.Kind() == reflect.Array { fieldMap[fullKey] = fieldInfo{ Index: currentIndex, Name: field.Name, IsPtr: isPtr, FieldType: fieldType, - IsSlice: true, + IsSlice: actualType.Kind() == reflect.Slice, + IsArray: actualType.Kind() == reflect.Array, } continue } - // 如果是结构体且不是基本类型,递归处理 + // 处理嵌套结构体 if actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { buildFieldMapRecursive(actualType, currentIndex, fieldMap, fullKey) - - fieldMap[fullKey] = fieldInfo{ - Index: currentIndex, - Name: field.Name, - IsPtr: isPtr, - FieldType: fieldType, - } - continue } fieldMap[fullKey] = fieldInfo{ @@ -382,57 +414,41 @@ func getJSONTagName(field reflect.StructField) string { return strings.Split(jsonTag, ",")[0] } -// 判断是否为基本结构体类型(如time.Time等) +// 判断是否为基本结构体类型 func isBasicStructType(t reflect.Type) bool { - // 处理指针类型 if t.Kind() == reflect.Ptr { t = t.Elem() } - // 检查是否为 decimal.Decimal - if isDecimalType(t) { + // 检查常见的基本结构体类型 + switch { + case t.PkgPath() == "time" && t.Name() == "Time": return true - } - - // 这里可以添加更多需要排除的基本结构体类型 - if t.PkgPath() == "time" && t.Name() == "Time" { + case t.PkgPath() == "github.com/shopspring/decimal" && t.Name() == "Decimal": return true - } - - // 其他常见的基本结构体类型 - switch t.String() { - case "time.Time", "sql.NullString", "sql.NullInt64", "sql.NullBool", "sql.NullFloat64": + case strings.HasPrefix(t.String(), "sql.Null"): return true } return false } -// 判断是否为 decimal.Decimal 类型 -func isDecimalType(t reflect.Type) bool { - return t.PkgPath() == "github.com/shopspring/decimal" && t.Name() == "Decimal" -} - // 检查类型是否实现了 UnmarshalJSON 方法 func hasUnmarshalJSON(t reflect.Type) bool { - // 处理指针类型 if t.Kind() == reflect.Ptr { t = t.Elem() } - // 检查是否实现了 json.Unmarshaler 接口 unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() return t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType) } // 检查类型是否实现了 UnmarshalText 方法 func hasUnmarshalText(t reflect.Type) bool { - // 处理指针类型 if t.Kind() == reflect.Ptr { t = t.Elem() } - // 检查是否实现了 encoding.TextUnmarshaler 接口 textUnmarshalerType := reflect.TypeOf((*interface { UnmarshalText([]byte) error })(nil)).Elem() @@ -446,18 +462,15 @@ func setUnmarshalJSONValue(field reflect.Value, value interface{}) error { return fmt.Errorf("序列化值失败: %w", err) } - // 获取字段的地址(确保我们可以调用指针方法) var fieldAddr reflect.Value if field.CanAddr() { fieldAddr = field.Addr() } else { - // 对于不可寻址的字段,创建一个临时变量 temp := reflect.New(field.Type()) temp.Elem().Set(field) fieldAddr = temp } - // 调用 UnmarshalJSON 方法 if unmarshaler, ok := fieldAddr.Interface().(json.Unmarshaler); ok { return unmarshaler.UnmarshalJSON(jsonBytes) } @@ -467,7 +480,6 @@ func setUnmarshalJSONValue(field reflect.Value, value interface{}) error { // 使用 UnmarshalText 方法设置值 func setUnmarshalTextValue(field reflect.Value, value string) error { - // 获取字段的地址 var fieldAddr reflect.Value if field.CanAddr() { fieldAddr = field.Addr() @@ -477,7 +489,6 @@ func setUnmarshalTextValue(field reflect.Value, value string) error { fieldAddr = temp } - // 调用 UnmarshalText 方法 if unmarshaler, ok := fieldAddr.Interface().(interface { UnmarshalText([]byte) error }); ok { @@ -487,16 +498,13 @@ func setUnmarshalTextValue(field reflect.Value, value string) error { return fmt.Errorf("类型 %s 未实现 UnmarshalText", field.Type()) } -// 处理嵌套结构体(支持JSON解析) -// processNestedStruct 处理嵌套结构体 +// 处理嵌套结构体 func processNestedStruct(field reflect.Value, value string, parentKey string) (map[string]ChangeInfo, error) { changeMap := make(map[string]ChangeInfo) - // 确保我们处理的是可寻址的结构体值 var structValue reflect.Value if field.Kind() == reflect.Ptr { if field.IsNil() { - // 创建新的指针实例 newValue := reflect.New(field.Type().Elem()) field.Set(newValue) } @@ -509,10 +517,9 @@ func processNestedStruct(field reflect.Value, value string, parentKey string) (m return nil, fmt.Errorf("无效的结构体字段") } - // 首先尝试解析为JSON对象(用于真正的嵌套结构体) + // 尝试解析为JSON对象 var nestedMap map[string]interface{} if err := json.Unmarshal([]byte(value), &nestedMap); err == nil { - // 成功解析为JSON对象,说明是真正的嵌套结构体 stringMap := make(map[string]string, len(nestedMap)) for k, v := range nestedMap { str, err := cast.ToStringE(v) @@ -522,52 +529,26 @@ func processNestedStruct(field reflect.Value, value string, parentKey string) (m stringMap[k] = str } - // 递归处理嵌套结构体 nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), stringMap) if err != nil { return nil, err } - // 为嵌套字段的变更记录添加前缀 for key, change := range nestedChanges { - fullKey := parentKey + "." + key - changeMap[fullKey] = change + changeMap[parentKey+"."+key] = change } return changeMap, nil } - // 如果解析JSON对象失败,尝试直接设置整个结构体的值 - // 这可能是一个基本结构体类型(如decimal.Decimal)或者实现了Unmarshaler接口的类型 - if hasUnmarshalJSON(structValue.Type()) { - // 对于实现了UnmarshalJSON的类型,直接使用JSON解析 - err := json.Unmarshal([]byte(value), structValue.Addr().Interface()) - if err != nil { - return nil, fmt.Errorf("解析嵌套结构体值失败: %w", err) - } - - // 记录变更 - oldValueStr, _ := cast.ToStringE(field.Interface()) - newValueStr := value - changeMap[parentKey] = ChangeInfo{ - Old: oldValueStr, - New: newValueStr, - Val: structValue.Interface(), - } - return changeMap, nil - } - - // 对于其他基本结构体类型,尝试直接设置 - if isBasicStructType(structValue.Type()) { - // 保存旧值 + // 尝试直接设置结构体值 + if hasUnmarshalJSON(structValue.Type()) || isBasicStructType(structValue.Type()) { oldValueStr, _ := cast.ToStringE(field.Interface()) - // 尝试设置新值 err := setBasicStructValue(structValue, value) if err != nil { return nil, fmt.Errorf("设置基本结构体值失败: %w", err) } - // 记录变更 newValueStr, _ := cast.ToStringE(field.Interface()) changeMap[parentKey] = ChangeInfo{ Old: oldValueStr, @@ -578,39 +559,214 @@ func processNestedStruct(field reflect.Value, value string, parentKey string) (m } return nil, fmt.Errorf("嵌套结构体值必须是有效的JSON格式或基本结构体类型") - - // // 尝试解析JSON字符串到map - // var nestedMap map[string]string - // if err := json.Unmarshal([]byte(value), &nestedMap); err != nil { - // return nil, fmt.Errorf("嵌套结构体值必须是JSON格式: %w", err) - // } - - // // 递归处理嵌套结构体 - // nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), nestedMap) - // if err != nil { - // return nil, err - // } - - // // 为嵌套字段的变更记录添加前缀 - // for key, change := range nestedChanges { - // fullKey := parentKey + "." + key - // changeMap[fullKey] = change - // } - - // return changeMap, nil } // 设置基本结构体的值 func setBasicStructValue(field reflect.Value, value string) error { - // 检查是否实现了UnmarshalText if hasUnmarshalText(field.Type()) { return setUnmarshalTextValue(field, value) } - - // 对于其他基本结构体类型,尝试JSON解析 return json.Unmarshal([]byte(value), field.Addr().Interface()) } +// 设置字段值 +func setFieldValue(field reflect.Value, value string) (interface{}, error) { + fieldType := field.Type() + + // 优先检查接口实现 + if hasUnmarshalJSON(fieldType) { + err := setUnmarshalJSONValue(field, value) + if err != nil { + return nil, fmt.Errorf("UnmarshalJSON 失败: %w", err) + } + return field.Interface(), nil + } + + if hasUnmarshalText(fieldType) { + err := setUnmarshalTextValue(field, value) + if err != nil { + return nil, fmt.Errorf("UnmarshalText 失败: %w", err) + } + return field.Interface(), nil + } + + // 处理指针类型 + if fieldType.Kind() == reflect.Ptr { + return setPointerFieldValue(field, value) + } + + // 处理类型别名(如 type CustomString string) + if isTypeAlias(fieldType) { + return setTypeAliasValue(field, value) + } + + // 处理自定义类型(有包路径的结构体类型) + if isCustomStructType(fieldType) { + return setCustomTypeValue(field, value) + } + + // 处理基本类型 + converter, exists := typeConverters[fieldType.Kind()] + if !exists { + return nil, fmt.Errorf("不支持的类型: %s", fieldType.Kind().String()) + } + + result, err := converter(field, value) + if err != nil { + return nil, err + } + + field.Set(reflect.ValueOf(result)) + return result, nil +} + +// 设置类型别名的值 +func setTypeAliasValue(field reflect.Value, value string) (interface{}, error) { + fieldType := field.Type() + + // 获取基础类型 + baseType := getBaseTypeFromAlias(fieldType) + if baseType == nil { + return nil, fmt.Errorf("无法获取类型别名的基础类型: %s", fieldType.String()) + } + + // 使用基础类型的转换器 + converter, exists := typeConverters[baseType.Kind()] + if !exists { + return nil, fmt.Errorf("不支持的基础类型: %s", baseType.Kind().String()) + } + + // 创建基础类型的临时值 + baseValue := reflect.New(baseType).Elem() + result, err := converter(baseValue, value) + if err != nil { + return nil, err + } + + // 将基础类型值转换为类型别名 + convertedValue, err := convertToTypeAlias(fieldType, result) + if err != nil { + return nil, err + } + + field.Set(reflect.ValueOf(convertedValue)) + return convertedValue, nil +} + +// 将基础类型值转换为类型别名 +func convertToTypeAlias(aliasType reflect.Type, value interface{}) (interface{}, error) { + // valueType := reflect.TypeOf(value) + + // 处理指针类型的类型别名 + if aliasType.Kind() == reflect.Ptr { + elemType := aliasType.Elem() + + // 创建新的指针实例 + newValue := reflect.New(elemType) + elemValue := newValue.Elem() + + // 转换值 + converted, err := convertValueToType(value, elemType) + if err != nil { + return nil, err + } + elemValue.Set(reflect.ValueOf(converted)) + + return newValue.Interface(), nil + } + + // 处理非指针类型的类型别名 + return convertValueToType(value, aliasType) +} + +// 将值转换为指定类型 +func convertValueToType(value interface{}, targetType reflect.Type) (interface{}, error) { + valueType := reflect.TypeOf(value) + + if valueType.AssignableTo(targetType) { + return value, nil + } + + if valueType.ConvertibleTo(targetType) { + return reflect.ValueOf(value).Convert(targetType).Interface(), nil + } + + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, targetType) +} + +// 检查是否为自定义结构体类型 +func isCustomStructType(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 排除基本类型 + if t.PkgPath() == "" { + return false + } + + // 排除已知的基本结构体类型 + if isBasicStructType(t) { + return false + } + + // 排除类型别名 + if isTypeAlias(t) { + return false + } + + // 排除接口类型 + if t.Kind() == reflect.Interface { + return false + } + + // 必须是结构体类型 + return t.Kind() == reflect.Struct +} + + + +// 获取类型别名的基础类型 +func getBaseTypeFromAlias(aliasType reflect.Type) reflect.Type { + if aliasType.Kind() == reflect.Ptr { + aliasType = aliasType.Elem() + } + + // 根据类型别名的种类返回对应的基础类型 + switch aliasType.Kind() { + case reflect.String: + return reflect.TypeOf("") + case reflect.Int: + return reflect.TypeOf(int(0)) + case reflect.Int8: + return reflect.TypeOf(int8(0)) + case reflect.Int16: + return reflect.TypeOf(int16(0)) + case reflect.Int32: + return reflect.TypeOf(int32(0)) + case reflect.Int64: + return reflect.TypeOf(int64(0)) + case reflect.Uint: + return reflect.TypeOf(uint(0)) + case reflect.Uint8: + return reflect.TypeOf(uint8(0)) + case reflect.Uint16: + return reflect.TypeOf(uint16(0)) + case reflect.Uint32: + return reflect.TypeOf(uint32(0)) + case reflect.Uint64: + return reflect.TypeOf(uint64(0)) + case reflect.Float32: + return reflect.TypeOf(float32(0)) + case reflect.Float64: + return reflect.TypeOf(float64(0)) + case reflect.Bool: + return reflect.TypeOf(false) + default: + return nil + } +} + // 检测是否为自定义类型 func isCustomType(t reflect.Type) bool { // 排除基本类型 @@ -628,70 +784,278 @@ func isCustomType(t reflect.Type) bool { return false } + // 排除基础类型的别名 + if isBasicTypeAlias(t) { + return false + } + return true } -// 设置字段值 -// 设置字段值 -func setFieldValue(field reflect.Value, value string) (interface{}, error) { - kind := field.Kind() - - // 优先检查是否实现了 UnmarshalJSON 方法 - if hasUnmarshalJSON(field.Type()) { - err := setUnmarshalJSONValue(field, value) - if err != nil { - return nil, fmt.Errorf("UnmarshalJSON 失败: %w", err) - } - return field.Interface(), nil +// 检查是否为类型别名(如 type CustomString string) +func isTypeAlias(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() } - // 其次检查是否实现了 UnmarshalText 方法 + // 类型别名有包路径且是基本类型 + basicKinds := map[reflect.Kind]bool{ + reflect.String: true, + reflect.Int: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + reflect.Uint: true, + reflect.Uint8: true, + reflect.Uint16: true, + reflect.Uint32: true, + reflect.Uint64: true, + reflect.Float32: true, + reflect.Float64: true, + reflect.Bool: true, + } + + return t.PkgPath() != "" && basicKinds[t.Kind()] +} + +// 检查是否为基本类型的别名(如 type CustomString string) +func isBasicTypeAlias(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 检查是否为基本类型的别名 + basicKinds := map[reflect.Kind]bool{ + reflect.String: true, + reflect.Int: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + reflect.Uint: true, + reflect.Uint8: true, + reflect.Uint16: true, + reflect.Uint32: true, + reflect.Uint64: true, + reflect.Float32: true, + reflect.Float64: true, + reflect.Bool: true, + } + + return basicKinds[t.Kind()] +} + +// 处理自定义类型 +func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { + // 检查是否实现了TextUnmarshaler接口 if hasUnmarshalText(field.Type()) { err := setUnmarshalTextValue(field, value) if err != nil { - return nil, fmt.Errorf("UnmarshalText 失败: %w", err) + return nil, err } return field.Interface(), nil } - // 检测是否为自定义类型(有包路径的类型) - if isCustomType(field.Type()) { - return setCustomTypeValue(field, value) + // 对于其他自定义结构体类型,尝试JSON解析 + var jsonData interface{} + if err := json.Unmarshal([]byte(value), &jsonData); err != nil { + return nil, fmt.Errorf("无法解析JSON: %w", err) } - // 处理指针类型的基础类型 - if kind == reflect.Ptr { - return setPointerFieldValue(field, value) - } - - converter, exists := typeConverters[kind] - if !exists { - return nil, fmt.Errorf("不支持的类型: %s", kind.String()) - } - - result, err := converter(field, value) + jsonBytes, err := json.Marshal(jsonData) if err != nil { return nil, err } - field.Set(reflect.ValueOf(result)) - return result, nil + if err := json.Unmarshal(jsonBytes, field.Addr().Interface()); err != nil { + return nil, err + } + + return field.Interface(), nil + + // // 对于其他自定义类型,我们需要获取其基础类型并进行转换 + // baseType := getBaseType(field.Type()) + // if baseType == nil { + // return nil, fmt.Errorf("不支持的自定义类型: %s", field.Type().String()) + // } + + // // 创建基础类型的值并进行转换 + // baseValue := reflect.New(baseType).Elem() + + // // 使用对应的类型转换器 + // converter, exists := typeConverters[baseType.Kind()] + // if !exists { + // return nil, fmt.Errorf("不支持的基础类型: %s", baseType.Kind().String()) + // } + + // result, err := converter(baseValue, value) + // if err != nil { + // return nil, err + // } + + // // 将基础类型值转换回自定义类型 + // convertedValue, err := convertToCustomType(field.Type(), result) + // if err != nil { + // return nil, err + // } + + // field.Set(reflect.ValueOf(convertedValue)) + // return convertedValue, nil } +// 将基础类型值转换为自定义类型 +func convertToCustomType(customType reflect.Type, value interface{}) (interface{}, error) { + valueType := reflect.TypeOf(value) + + // 处理指针类型的自定义类型 + if customType.Kind() == reflect.Ptr { + elemType := customType.Elem() + + // 创建新的指针实例 + newValue := reflect.New(elemType) + elemValue := newValue.Elem() + + // 尝试将值设置到元素 + if valueType.AssignableTo(elemType) { + elemValue.Set(reflect.ValueOf(value)) + } else if reflect.ValueOf(value).Type().ConvertibleTo(elemType) { + converted := reflect.ValueOf(value).Convert(elemType) + elemValue.Set(converted) + } else { + // 对于自定义类型,尝试通过字符串转换 + if str, ok := value.(string); ok && elemType.Kind() == reflect.String { + elemValue.SetString(str) + } else { + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, elemType) + } + } + + return newValue.Interface(), nil + } + + // 处理非指针类型的自定义类型 + if valueType.AssignableTo(customType) { + return value, nil + } + + if reflect.ValueOf(value).Type().ConvertibleTo(customType) { + return reflect.ValueOf(value).Convert(customType).Interface(), nil + } + + // 对于字符串到自定义字符串类型的转换 + if str, ok := value.(string); ok && customType.Kind() == reflect.String { + return reflect.ValueOf(str).Convert(customType).Interface(), nil + } + + return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customType) +} + +// 获取自定义类型的基础类型 +func getBaseType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 如果是自定义类型(有包路径),获取其底层类型 + if t.PkgPath() != "" { + // 对于类型别名,我们需要获取其底层的基本类型 + switch t.Kind() { + case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool: + // 这是基本类型的别名,返回对应的基本类型 + switch t.Kind() { + case reflect.String: + return reflect.TypeOf("") + case reflect.Int: + return reflect.TypeOf(int(0)) + case reflect.Int8: + return reflect.TypeOf(int8(0)) + case reflect.Int16: + return reflect.TypeOf(int16(0)) + case reflect.Int32: + return reflect.TypeOf(int32(0)) + case reflect.Int64: + return reflect.TypeOf(int64(0)) + case reflect.Uint: + return reflect.TypeOf(uint(0)) + case reflect.Uint8: + return reflect.TypeOf(uint8(0)) + case reflect.Uint16: + return reflect.TypeOf(uint16(0)) + case reflect.Uint32: + return reflect.TypeOf(uint32(0)) + case reflect.Uint64: + return reflect.TypeOf(uint64(0)) + case reflect.Float32: + return reflect.TypeOf(float32(0)) + case reflect.Float64: + return reflect.TypeOf(float64(0)) + case reflect.Bool: + return reflect.TypeOf(false) + } + default: + // 对于其他自定义类型,尝试获取其底层类型 + if t.Kind() == reflect.Struct { + // 如果是结构体,检查是否有可转换的基础类型 + if field := getBaseTypeField(t); field.Name != "" { + return field.Type + } + } + } + } + + return nil +} + +// 获取结构体中可能的基础类型字段 +func getBaseTypeField(t reflect.Type) reflect.StructField { + if t.Kind() != reflect.Struct || t.NumField() != 1 { + return reflect.StructField{} + } + + field := t.Field(0) + basicKinds := map[reflect.Kind]bool{ + reflect.String: true, + reflect.Int: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + reflect.Uint: true, + reflect.Uint8: true, + reflect.Uint16: true, + reflect.Uint32: true, + reflect.Uint64: true, + reflect.Float32: true, + reflect.Float64: true, + reflect.Bool: true, + } + + if basicKinds[field.Type.Kind()] && field.IsExported() { + return field + } + + return reflect.StructField{} +} + + // 处理指针类型的字段 func setPointerFieldValue(field reflect.Value, value string) (interface{}, error) { + if field.Kind() != reflect.Ptr { + return nil, fmt.Errorf("setPointerFieldValue: 期望指针类型,得到 %s", field.Kind()) + } + if field.IsNil() { - // 创建新的指针实例 elemType := field.Type().Elem() newValue := reflect.New(elemType) field.Set(newValue) } - - // 递归处理指针指向的值 return setFieldValue(field.Elem(), value) } -// 类型转换函数 +// 类型转换函数(保持不变) func convertBool(field reflect.Value, value string) (interface{}, error) { return convx.ToBool(value) } @@ -728,7 +1092,6 @@ func convertString(field reflect.Value, value string) (interface{}, error) { } func convertSlice(field reflect.Value, value string) (interface{}, error) { - // 实现切片转换逻辑 var result []interface{} if err := json.Unmarshal([]byte(value), &result); err != nil { return nil, err @@ -737,7 +1100,6 @@ func convertSlice(field reflect.Value, value string) (interface{}, error) { } func convertArray(field reflect.Value, value string) (interface{}, error) { - // 实现数组转换逻辑 var result []interface{} if err := json.Unmarshal([]byte(value), &result); err != nil { return nil, err @@ -746,96 +1108,5 @@ func convertArray(field reflect.Value, value string) (interface{}, error) { } func convertMap(field reflect.Value, value string) (interface{}, error) { - // 实现map转换逻辑 return nil, fmt.Errorf("map转换未实现") -} - -// 处理自定义类型 -// 处理自定义类型 -func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { - // 检查是否实现了TextUnmarshaler接口 - if hasUnmarshalText(field.Type()) { - err := setUnmarshalTextValue(field, value) - if err != nil { - return nil, err - } - return field.Interface(), nil - } - // 对于其他自定义类型,我们需要获取其基础类型并进行转换 - baseType := getBaseType(field.Type()) - if baseType == nil { - return nil, fmt.Errorf("不支持的自定义类型: %s", field.Type().String()) - } - - // 创建基础类型的值并进行转换 - baseValue := reflect.New(baseType).Elem() - converter, exists := typeConverters[baseType.Kind()] - if !exists { - return nil, fmt.Errorf("不支持的基础类型: %s", baseType.Kind().String()) - } - - result, err := converter(baseValue, value) - if err != nil { - return nil, err - } - - // 将基础类型值转换回自定义类型 - convertedValue, err := convertToCustomType(field.Type(), result) - if err != nil { - return nil, err - } - - field.Set(reflect.ValueOf(convertedValue)) - return convertedValue, nil -} - -// 获取自定义类型的基础类型 -func getBaseType(t reflect.Type) reflect.Type { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - // 如果是自定义类型(有包路径),获取其底层类型 - if t.PkgPath() != "" { - return t - } - - return t -} - -// 将基础类型值转换为自定义类型 -func convertToCustomType(customType reflect.Type, value interface{}) (interface{}, error) { - valueType := reflect.TypeOf(value) - customBaseType := customType - - // 处理指针类型的自定义类型 - if customType.Kind() == reflect.Ptr { - customBaseType = customType.Elem() - // 创建新的指针实例 - newValue := reflect.New(customBaseType) - elemValue := newValue.Elem() - - // 尝试将值设置到元素 - if valueType.AssignableTo(customBaseType) { - elemValue.Set(reflect.ValueOf(value)) - } else if reflect.ValueOf(value).Type().ConvertibleTo(customBaseType) { - converted := reflect.ValueOf(value).Convert(customBaseType) - elemValue.Set(converted) - } else { - return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customBaseType) - } - - return newValue.Interface(), nil - } - - // 处理非指针类型的自定义类型 - if valueType.AssignableTo(customType) { - return value, nil - } - - if reflect.ValueOf(value).Type().ConvertibleTo(customType) { - return reflect.ValueOf(value).Convert(customType).Interface(), nil - } - - return nil, fmt.Errorf("无法将 %v 转换为 %v", valueType, customType) -} +} \ No newline at end of file diff --git a/structx_test.go b/structx_test.go index 3b41821..e5ba815 100644 --- a/structx_test.go +++ b/structx_test.go @@ -47,6 +47,8 @@ type NestedStruct struct { Basic BasicStruct `json:"basic"` Comment string `json:"comment"` Amount decimal.Decimal `json:"amount"` + Amount2 *decimal.Decimal `json:"amount2"` + Timestamp time.Time `json:"timestamp"` } // 指针嵌套结构体 @@ -593,6 +595,8 @@ func BenchmarkAttactToStruct_Nested(b *testing.B) { "basic.is_active": "true", "comment": "test", "amount": "500.0", + "amount2": "250.0", + "timestamp": "2024-01-01T12:00:00Z", } b.ResetTimer()