优化对decimal.Decimal的支持

This commit is contained in:
Yun
2025-09-20 23:13:52 +08:00
parent 02152b44bf
commit 898bdf7f38
+204 -19
View File
@@ -67,7 +67,7 @@ func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[stri
// 获取结构体类型信息
t := v.Type()
fieldMap := buildFieldMap(t)
fmt.Printf("字段映射2: %+v %+v\n", fieldMap,fieldMap["data"])
fmt.Printf("字段映射2: %+v %+v\n", fieldMap, fieldMap["data"])
for mapKey, mapValue := range updateMap {
fieldInfo, exists := fieldMap[mapKey]
@@ -197,6 +197,11 @@ func setSliceElementValue(elemValue reflect.Value, item interface{}, elemType re
return setSliceElementValue(elemValue.Elem(), item, elemType.Elem())
}
// 检查是否实现了 UnmarshalJSON 方法
if hasUnmarshalJSON(elemType) {
return setUnmarshalJSONValue(elemValue, item)
}
// 根据元素类型进行转换
switch elemType.Kind() {
case reflect.String:
@@ -384,6 +389,11 @@ func isBasicStructType(t reflect.Type) bool {
t = t.Elem()
}
// 检查是否为 decimal.Decimal
if isDecimalType(t) {
return true
}
// 这里可以添加更多需要排除的基本结构体类型
if t.PkgPath() == "time" && t.Name() == "Time" {
return true
@@ -398,6 +408,85 @@ func isBasicStructType(t reflect.Type) bool {
return false
}
// 判断是否为 decimal.Decimal 类型
func isDecimalType(t reflect.Type) bool {
return t.PkgPath() == "github.com/shopspring/decimal" && t.Name() == "Decimal"
}
// 检查类型是否实现了 UnmarshalJSON 方法
func hasUnmarshalJSON(t reflect.Type) bool {
// 处理指针类型
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// 检查是否实现了 json.Unmarshaler 接口
unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
return t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType)
}
// 检查类型是否实现了 UnmarshalText 方法
func hasUnmarshalText(t reflect.Type) bool {
// 处理指针类型
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// 检查是否实现了 encoding.TextUnmarshaler 接口
textUnmarshalerType := reflect.TypeOf((*interface {
UnmarshalText([]byte) error
})(nil)).Elem()
return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType)
}
// 使用 UnmarshalJSON 方法设置值
func setUnmarshalJSONValue(field reflect.Value, value interface{}) error {
jsonBytes, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("序列化值失败: %w", err)
}
// 获取字段的地址(确保我们可以调用指针方法)
var fieldAddr reflect.Value
if field.CanAddr() {
fieldAddr = field.Addr()
} else {
// 对于不可寻址的字段,创建一个临时变量
temp := reflect.New(field.Type())
temp.Elem().Set(field)
fieldAddr = temp
}
// 调用 UnmarshalJSON 方法
if unmarshaler, ok := fieldAddr.Interface().(json.Unmarshaler); ok {
return unmarshaler.UnmarshalJSON(jsonBytes)
}
return fmt.Errorf("类型 %s 未实现 json.Unmarshaler", field.Type())
}
// 使用 UnmarshalText 方法设置值
func setUnmarshalTextValue(field reflect.Value, value string) error {
// 获取字段的地址
var fieldAddr reflect.Value
if field.CanAddr() {
fieldAddr = field.Addr()
} else {
temp := reflect.New(field.Type())
temp.Elem().Set(field)
fieldAddr = temp
}
// 调用 UnmarshalText 方法
if unmarshaler, ok := fieldAddr.Interface().(interface {
UnmarshalText([]byte) error
}); ok {
return unmarshaler.UnmarshalText([]byte(value))
}
return fmt.Errorf("类型 %s 未实现 UnmarshalText", field.Type())
}
// 处理嵌套结构体(支持JSON解析)
// processNestedStruct 处理嵌套结构体
func processNestedStruct(field reflect.Value, value string, parentKey string) (map[string]ChangeInfo, error) {
@@ -420,25 +509,106 @@ func processNestedStruct(field reflect.Value, value string, parentKey string) (m
return nil, fmt.Errorf("无效的结构体字段")
}
// 尝试解析JSON字符串到map
var nestedMap map[string]string
if err := json.Unmarshal([]byte(value), &nestedMap); err != nil {
return nil, fmt.Errorf("嵌套结构体值必须是JSON格式: %w", err)
// 首先尝试解析JSON对象(用于真正的嵌套结构体)
var nestedMap map[string]interface{}
if err := json.Unmarshal([]byte(value), &nestedMap); err == nil {
// 成功解析为JSON对象,说明是真正的嵌套结构体
stringMap := make(map[string]string, len(nestedMap))
for k, v := range nestedMap {
str, err := cast.ToStringE(v)
if err != nil {
return nil, fmt.Errorf("转换嵌套字段 %s 的值失败: %w", k, err)
}
stringMap[k] = str
}
// 递归处理嵌套结构体
nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), stringMap)
if err != nil {
return nil, err
}
// 为嵌套字段的变更记录添加前缀
for key, change := range nestedChanges {
fullKey := parentKey + "." + key
changeMap[fullKey] = change
}
return changeMap, nil
}
// 递归处理嵌套结构体
nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), nestedMap)
if err != nil {
return nil, err
// 如果解析JSON对象失败,尝试直接设置整个结构体的值
// 这可能是一个基本结构体类型(如decimal.Decimal)或者实现了Unmarshaler接口的类型
if hasUnmarshalJSON(structValue.Type()) {
// 对于实现了UnmarshalJSON的类型,直接使用JSON解析
err := json.Unmarshal([]byte(value), structValue.Addr().Interface())
if err != nil {
return nil, fmt.Errorf("解析嵌套结构体值失败: %w", err)
}
// 记录变更
oldValueStr, _ := cast.ToStringE(field.Interface())
newValueStr := value
changeMap[parentKey] = ChangeInfo{
Old: oldValueStr,
New: newValueStr,
Val: structValue.Interface(),
}
return changeMap, nil
}
// 为嵌套字段的变更记录添加前缀
for key, change := range nestedChanges {
fullKey := parentKey + "." + key
changeMap[fullKey] = change
// 对于其他基本结构体类型,尝试直接设置
if isBasicStructType(structValue.Type()) {
// 保存旧值
oldValueStr, _ := cast.ToStringE(field.Interface())
// 尝试设置新值
err := setBasicStructValue(structValue, value)
if err != nil {
return nil, fmt.Errorf("设置基本结构体值失败: %w", err)
}
// 记录变更
newValueStr, _ := cast.ToStringE(field.Interface())
changeMap[parentKey] = ChangeInfo{
Old: oldValueStr,
New: newValueStr,
Val: structValue.Interface(),
}
return changeMap, nil
}
return changeMap, nil
return nil, fmt.Errorf("嵌套结构体值必须是有效的JSON格式或基本结构体类型")
// // 尝试解析JSON字符串到map
// var nestedMap map[string]string
// if err := json.Unmarshal([]byte(value), &nestedMap); err != nil {
// return nil, fmt.Errorf("嵌套结构体值必须是JSON格式: %w", err)
// }
// // 递归处理嵌套结构体
// nestedChanges, err := AttactToStruct(structValue.Addr().Interface(), nestedMap)
// if err != nil {
// return nil, err
// }
// // 为嵌套字段的变更记录添加前缀
// for key, change := range nestedChanges {
// fullKey := parentKey + "." + key
// changeMap[fullKey] = change
// }
// return changeMap, nil
}
// 设置基本结构体的值
func setBasicStructValue(field reflect.Value, value string) error {
// 检查是否实现了UnmarshalText
if hasUnmarshalText(field.Type()) {
return setUnmarshalTextValue(field, value)
}
// 对于其他基本结构体类型,尝试JSON解析
return json.Unmarshal([]byte(value), field.Addr().Interface())
}
// 检测是否为自定义类型
@@ -466,6 +636,24 @@ func isCustomType(t reflect.Type) bool {
func setFieldValue(field reflect.Value, value string) (interface{}, error) {
kind := field.Kind()
// 优先检查是否实现了 UnmarshalJSON 方法
if hasUnmarshalJSON(field.Type()) {
err := setUnmarshalJSONValue(field, value)
if err != nil {
return nil, fmt.Errorf("UnmarshalJSON 失败: %w", err)
}
return field.Interface(), nil
}
// 其次检查是否实现了 UnmarshalText 方法
if hasUnmarshalText(field.Type()) {
err := setUnmarshalTextValue(field, value)
if err != nil {
return nil, fmt.Errorf("UnmarshalText 失败: %w", err)
}
return field.Interface(), nil
}
// 检测是否为自定义类型(有包路径的类型)
if isCustomType(field.Type()) {
return setCustomTypeValue(field, value)
@@ -566,16 +754,13 @@ func convertMap(field reflect.Value, value string) (interface{}, error) {
// 处理自定义类型
func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) {
// 检查是否实现了TextUnmarshaler接口
if unmarshaler, ok := field.Addr().Interface().(interface {
UnmarshalText([]byte) error
}); ok {
err := unmarshaler.UnmarshalText([]byte(value))
if hasUnmarshalText(field.Type()) {
err := setUnmarshalTextValue(field, value)
if err != nil {
return nil, err
}
return field.Interface(), nil
}
// 对于其他自定义类型,我们需要获取其基础类型并进行转换
baseType := getBaseType(field.Type())
if baseType == nil {