diff --git a/structx.go b/structx.go index b64ab4e..87f17e2 100644 --- a/structx.go +++ b/structx.go @@ -67,7 +67,7 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri // 获取结构体类型信息 t := v.Type() fieldMap := buildFieldMap(t) - fmt.Printf("字段映射2: %+v %+v\n", fieldMap,fieldMap["data"]) + fmt.Printf("字段映射2: %+v %+v\n", fieldMap, fieldMap["data"]) for mapKey, mapValue := range updateMap { fieldInfo, exists := fieldMap[mapKey] @@ -197,6 +197,11 @@ func setSliceElementValue(elemValue reflect.Value, item interface{}, elemType re return setSliceElementValue(elemValue.Elem(), item, elemType.Elem()) } + // 检查是否实现了 UnmarshalJSON 方法 + if hasUnmarshalJSON(elemType) { + return setUnmarshalJSONValue(elemValue, item) + } + // 根据元素类型进行转换 switch elemType.Kind() { case reflect.String: @@ -384,6 +389,11 @@ func isBasicStructType(t reflect.Type) bool { t = t.Elem() } + // 检查是否为 decimal.Decimal + if isDecimalType(t) { + return true + } + // 这里可以添加更多需要排除的基本结构体类型 if t.PkgPath() == "time" && t.Name() == "Time" { return true @@ -398,6 +408,85 @@ func isBasicStructType(t reflect.Type) bool { return false } +// 判断是否为 decimal.Decimal 类型 +func isDecimalType(t reflect.Type) bool { + return t.PkgPath() == "github.com/shopspring/decimal" && t.Name() == "Decimal" +} + +// 检查类型是否实现了 UnmarshalJSON 方法 +func hasUnmarshalJSON(t reflect.Type) bool { + // 处理指针类型 + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 检查是否实现了 json.Unmarshaler 接口 + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + return t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType) +} + +// 检查类型是否实现了 UnmarshalText 方法 +func hasUnmarshalText(t reflect.Type) bool { + // 处理指针类型 + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // 检查是否实现了 encoding.TextUnmarshaler 接口 + textUnmarshalerType := reflect.TypeOf((*interface { + UnmarshalText([]byte) error + })(nil)).Elem() + return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) +} + +// 使用 UnmarshalJSON 方法设置值 +func setUnmarshalJSONValue(field reflect.Value, value interface{}) error { + jsonBytes, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("序列化值失败: %w", err) + } + + // 获取字段的地址(确保我们可以调用指针方法) + var fieldAddr reflect.Value + if field.CanAddr() { + fieldAddr = field.Addr() + } else { + // 对于不可寻址的字段,创建一个临时变量 + temp := reflect.New(field.Type()) + temp.Elem().Set(field) + fieldAddr = temp + } + + // 调用 UnmarshalJSON 方法 + if unmarshaler, ok := fieldAddr.Interface().(json.Unmarshaler); ok { + return unmarshaler.UnmarshalJSON(jsonBytes) + } + + return fmt.Errorf("类型 %s 未实现 json.Unmarshaler", field.Type()) +} + +// 使用 UnmarshalText 方法设置值 +func setUnmarshalTextValue(field reflect.Value, value string) error { + // 获取字段的地址 + var fieldAddr reflect.Value + if field.CanAddr() { + fieldAddr = field.Addr() + } else { + temp := reflect.New(field.Type()) + temp.Elem().Set(field) + fieldAddr = temp + } + + // 调用 UnmarshalText 方法 + if unmarshaler, ok := fieldAddr.Interface().(interface { + UnmarshalText([]byte) error + }); ok { + return unmarshaler.UnmarshalText([]byte(value)) + } + + return fmt.Errorf("类型 %s 未实现 UnmarshalText", field.Type()) +} + // 处理嵌套结构体(支持JSON解析) // processNestedStruct 处理嵌套结构体 func processNestedStruct(field reflect.Value, value string, parentKey string) (map[string]ChangeInfo, error) { @@ -420,25 +509,106 @@ func processNestedStruct(field reflect.Value, value string, parentKey string) (m return nil, fmt.Errorf("无效的结构体字段") } - // 尝试解析JSON字符串到map - var nestedMap map[string]string - if err := json.Unmarshal([]byte(value), &nestedMap); err != nil { - return nil, fmt.Errorf("嵌套结构体值必须是JSON格式: %w", err) + // 首先尝试解析为JSON对象(用于真正的嵌套结构体) + var nestedMap map[string]interface{} + if err := json.Unmarshal([]byte(value), &nestedMap); err == nil { + // 成功解析为JSON对象,说明是真正的嵌套结构体 + stringMap := make(map[string]string, len(nestedMap)) + for k, v := range nestedMap { + str, err := cast.ToStringE(v) + if err != nil { + return nil, fmt.Errorf("转换嵌套字段 %s 的值失败: %w", k, err) + } + stringMap[k] = str + } + + // 递归处理嵌套结构体 + nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), stringMap) + if err != nil { + return nil, err + } + + // 为嵌套字段的变更记录添加前缀 + for key, change := range nestedChanges { + fullKey := parentKey + "." + key + changeMap[fullKey] = change + } + return changeMap, nil } - // 递归处理嵌套结构体 - nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), nestedMap) - if err != nil { - return nil, err + // 如果解析JSON对象失败,尝试直接设置整个结构体的值 + // 这可能是一个基本结构体类型(如decimal.Decimal)或者实现了Unmarshaler接口的类型 + if hasUnmarshalJSON(structValue.Type()) { + // 对于实现了UnmarshalJSON的类型,直接使用JSON解析 + err := json.Unmarshal([]byte(value), structValue.Addr().Interface()) + if err != nil { + return nil, fmt.Errorf("解析嵌套结构体值失败: %w", err) + } + + // 记录变更 + oldValueStr, _ := cast.ToStringE(field.Interface()) + newValueStr := value + changeMap[parentKey] = ChangeInfo{ + Old: oldValueStr, + New: newValueStr, + Val: structValue.Interface(), + } + return changeMap, nil } - // 为嵌套字段的变更记录添加前缀 - for key, change := range nestedChanges { - fullKey := parentKey + "." + key - changeMap[fullKey] = change + // 对于其他基本结构体类型,尝试直接设置 + if isBasicStructType(structValue.Type()) { + // 保存旧值 + oldValueStr, _ := cast.ToStringE(field.Interface()) + + // 尝试设置新值 + err := setBasicStructValue(structValue, value) + if err != nil { + return nil, fmt.Errorf("设置基本结构体值失败: %w", err) + } + + // 记录变更 + newValueStr, _ := cast.ToStringE(field.Interface()) + changeMap[parentKey] = ChangeInfo{ + Old: oldValueStr, + New: newValueStr, + Val: structValue.Interface(), + } + return changeMap, nil } - return changeMap, nil + return nil, fmt.Errorf("嵌套结构体值必须是有效的JSON格式或基本结构体类型") + + // // 尝试解析JSON字符串到map + // var nestedMap map[string]string + // if err := json.Unmarshal([]byte(value), &nestedMap); err != nil { + // return nil, fmt.Errorf("嵌套结构体值必须是JSON格式: %w", err) + // } + + // // 递归处理嵌套结构体 + // nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), nestedMap) + // if err != nil { + // return nil, err + // } + + // // 为嵌套字段的变更记录添加前缀 + // for key, change := range nestedChanges { + // fullKey := parentKey + "." + key + // changeMap[fullKey] = change + // } + + // return changeMap, nil +} + +// 设置基本结构体的值 +func setBasicStructValue(field reflect.Value, value string) error { + // 检查是否实现了UnmarshalText + if hasUnmarshalText(field.Type()) { + return setUnmarshalTextValue(field, value) + } + + // 对于其他基本结构体类型,尝试JSON解析 + return json.Unmarshal([]byte(value), field.Addr().Interface()) } // 检测是否为自定义类型 @@ -466,6 +636,24 @@ func isCustomType(t reflect.Type) bool { func setFieldValue(field reflect.Value, value string) (interface{}, error) { kind := field.Kind() + // 优先检查是否实现了 UnmarshalJSON 方法 + if hasUnmarshalJSON(field.Type()) { + err := setUnmarshalJSONValue(field, value) + if err != nil { + return nil, fmt.Errorf("UnmarshalJSON 失败: %w", err) + } + return field.Interface(), nil + } + + // 其次检查是否实现了 UnmarshalText 方法 + if hasUnmarshalText(field.Type()) { + err := setUnmarshalTextValue(field, value) + if err != nil { + return nil, fmt.Errorf("UnmarshalText 失败: %w", err) + } + return field.Interface(), nil + } + // 检测是否为自定义类型(有包路径的类型) if isCustomType(field.Type()) { return setCustomTypeValue(field, value) @@ -566,16 +754,13 @@ func convertMap(field reflect.Value, value string) (interface{}, error) { // 处理自定义类型 func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { // 检查是否实现了TextUnmarshaler接口 - if unmarshaler, ok := field.Addr().Interface().(interface { - UnmarshalText([]byte) error - }); ok { - err := unmarshaler.UnmarshalText([]byte(value)) + if hasUnmarshalText(field.Type()) { + err := setUnmarshalTextValue(field, value) if err != nil { return nil, err } return field.Interface(), nil } - // 对于其他自定义类型,我们需要获取其基础类型并进行转换 baseType := getBaseType(field.Type()) if baseType == nil {