From e57caf3080b3ac4ba5982267f4150effb4a0fd14 Mon Sep 17 00:00:00 2001 From: Yun Date: Sun, 21 Sep 2025 12:06:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8B=86=E5=BC=80=E6=96=87=E4=BB=B6=E5=A4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- entity.go | 77 +++++++ field_mapper.go | 78 +++++++ structx.go | 580 +----------------------------------------------- structx_test.go | 3 + utils.go | 324 +++++++++++++++++++++++++++ value_setter.go | 109 +++++++++ 6 files changed, 599 insertions(+), 572 deletions(-) create mode 100644 entity.go create mode 100644 field_mapper.go create mode 100644 utils.go create mode 100644 value_setter.go diff --git a/entity.go b/entity.go new file mode 100644 index 0000000..f473786 --- /dev/null +++ b/entity.go @@ -0,0 +1,77 @@ +package structx + +import ( + "reflect" + "sync" +) + +// ChangeInfo 变更信息 +type ChangeInfo struct { + Old string `json:"old"` + New string `json:"new"` + Val any `json:"val"` +} + +// FieldInfo 字段信息 +type FieldInfo struct { + Index []int + Name string + IsPtr bool + FieldType reflect.Type + IsSlice bool + IsArray bool +} + +// converterFunc 类型转换函数 +type converterFunc func(reflect.Value, string) (interface{}, error) + +var ( + typeConverters = map[reflect.Kind]converterFunc{ + reflect.Bool: convertBool, + reflect.Int: convertInt[int], + reflect.Int8: convertInt[int8], + reflect.Int16: convertInt[int16], + reflect.Int32: convertInt[int32], + reflect.Int64: convertInt[int64], + reflect.Uint: convertUint[uint], + reflect.Uint8: convertUint[uint8], + reflect.Uint16: convertUint[uint16], + reflect.Uint32: convertUint[uint32], + reflect.Uint64: convertUint[uint64], + reflect.Float32: convertFloat[float32], + reflect.Float64: convertFloat[float64], + reflect.String: convertString, + reflect.Slice: convertSlice, + reflect.Array: convertArray, + reflect.Map: convertMap, + } + + typeInfoCache = make(map[reflect.Type]map[string]FieldInfo) + cacheMutex = &sync.RWMutex{} + + basicStructTypes = map[string]bool{ + "time.Time": true, + "github.com/shopspring/decimal.Decimal": true, + "sql.NullString": true, + "sql.NullInt64": true, + "sql.NullBool": true, + "sql.NullFloat64": true, + } + + basicKinds = map[reflect.Kind]bool{ + reflect.String: true, + reflect.Int: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + reflect.Uint: true, + reflect.Uint8: true, + reflect.Uint16: true, + reflect.Uint32: true, + reflect.Uint64: true, + reflect.Float32: true, + reflect.Float64: true, + reflect.Bool: true, + } +) diff --git a/field_mapper.go b/field_mapper.go new file mode 100644 index 0000000..7bd292d --- /dev/null +++ b/field_mapper.go @@ -0,0 +1,78 @@ +package structx + +import ( + "reflect" +) + +// FieldMapper 字段映射器接口 +type FieldMapper interface { + GetFieldMap(t reflect.Type) map[string]FieldInfo +} + +// defaultFieldMapper 默认字段映射器 +type defaultFieldMapper struct{} + +func (dm *defaultFieldMapper) GetFieldMap(t reflect.Type) map[string]FieldInfo { + cacheMutex.RLock() + if cached, exists := typeInfoCache[t]; exists { + cacheMutex.RUnlock() + return cached + } + cacheMutex.RUnlock() + + fieldMap := make(map[string]FieldInfo) + dm.buildFieldMapRecursive(t, []int{}, fieldMap, "") + + cacheMutex.Lock() + typeInfoCache[t] = fieldMap + cacheMutex.Unlock() + + return fieldMap +} + +func (dm *defaultFieldMapper) buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]FieldInfo, prefix string) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + + currentIndex := append(index, i) + jsonTag := getJSONTagName(field) + + fullKey := jsonTag + if prefix != "" { + fullKey = prefix + "." + jsonTag + } + + fieldType := field.Type + isPtr := fieldType.Kind() == reflect.Ptr + actualType := fieldType + if isPtr { + actualType = fieldType.Elem() + } + + if actualType.Kind() == reflect.Slice || actualType.Kind() == reflect.Array { + fieldMap[fullKey] = FieldInfo{ + Index: currentIndex, + Name: field.Name, + IsPtr: isPtr, + FieldType: fieldType, + IsSlice: actualType.Kind() == reflect.Slice, + IsArray: actualType.Kind() == reflect.Array, + } + continue + } + + if actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { + dm.buildFieldMapRecursive(actualType, currentIndex, fieldMap, fullKey) + } + + fieldMap[fullKey] = FieldInfo{ + Index: currentIndex, + Name: field.Name, + IsPtr: isPtr, + FieldType: fieldType, + } + } +} diff --git a/structx.go b/structx.go index 29418f0..9beb5a8 100644 --- a/structx.go +++ b/structx.go @@ -4,93 +4,17 @@ import ( "encoding/json" "fmt" "reflect" - "strconv" - "strings" - "sync" - "code.yun.ink/pkg/convx" "github.com/spf13/cast" ) -type ChangeInfo struct { - Old string `json:"old"` - New string `json:"new"` - Val interface{} `json:"val"` +// 全局函数(保持向后兼容) +func AttactToStructAny(structxx any, updateMap map[string]any) (map[string]ChangeInfo, error) { + return NewStructProcessor().AttactToStructAny(structxx, updateMap) } -// FieldMapper 字段映射器接口 -type FieldMapper interface { - GetFieldMap(t reflect.Type) map[string]FieldInfo -} - -// ValueSetter 值设置器接口 -type ValueSetter interface { - SetFieldValue(field reflect.Value, value string) (interface{}, error) - SetSliceElementValue(elemValue reflect.Value, item interface{}, elemType reflect.Type) error -} - -// converterFunc 类型转换函数 -type converterFunc func(reflect.Value, string) (interface{}, error) - -var ( - typeConverters = map[reflect.Kind]converterFunc{ - reflect.Bool: convertBool, - reflect.Int: convertInt[int], - reflect.Int8: convertInt[int8], - reflect.Int16: convertInt[int16], - reflect.Int32: convertInt[int32], - reflect.Int64: convertInt[int64], - reflect.Uint: convertUint[uint], - reflect.Uint8: convertUint[uint8], - reflect.Uint16: convertUint[uint16], - reflect.Uint32: convertUint[uint32], - reflect.Uint64: convertUint[uint64], - reflect.Float32: convertFloat[float32], - reflect.Float64: convertFloat[float64], - reflect.String: convertString, - reflect.Slice: convertSlice, - reflect.Array: convertArray, - reflect.Map: convertMap, - } - - typeInfoCache = make(map[reflect.Type]map[string]FieldInfo) - cacheMutex = &sync.RWMutex{} - - basicStructTypes = map[string]bool{ - "time.Time": true, - "github.com/shopspring/decimal.Decimal": true, - "sql.NullString": true, - "sql.NullInt64": true, - "sql.NullBool": true, - "sql.NullFloat64": true, - } - - basicKinds = map[reflect.Kind]bool{ - reflect.String: true, - reflect.Int: true, - reflect.Int8: true, - reflect.Int16: true, - reflect.Int32: true, - reflect.Int64: true, - reflect.Uint: true, - reflect.Uint8: true, - reflect.Uint16: true, - reflect.Uint32: true, - reflect.Uint64: true, - reflect.Float32: true, - reflect.Float64: true, - reflect.Bool: true, - } -) - -// FieldInfo 字段信息 -type FieldInfo struct { - Index []int - Name string - IsPtr bool - FieldType reflect.Type - IsSlice bool - IsArray bool +func AttactToStruct(structxx any, updateMap map[string]string) (map[string]ChangeInfo, error) { + return NewStructProcessor().AttactToStruct(structxx, updateMap) } // StructProcessor 结构体处理器 @@ -108,7 +32,7 @@ func NewStructProcessor() *StructProcessor { } // AttactToStructAny 将 map[string]interface{} 转换为字符串映射并调用 AttactToStruct -func (sp *StructProcessor) AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) (map[string]ChangeInfo, error) { +func (sp *StructProcessor) AttactToStructAny(structxx any, updateMap map[string]any) (map[string]ChangeInfo, error) { stringMap := make(map[string]string, len(updateMap)) for k, v := range updateMap { str, err := cast.ToStringE(v) @@ -121,7 +45,7 @@ func (sp *StructProcessor) AttactToStructAny(structxx interface{}, updateMap map } // AttactToStruct 将映射数据赋值到结构体中 -func (sp *StructProcessor) AttactToStruct(structxx interface{}, updateMap map[string]string) (map[string]ChangeInfo, error) { +func (sp *StructProcessor) AttactToStruct(structxx any, updateMap map[string]string) (map[string]ChangeInfo, error) { v := reflect.ValueOf(structxx) if v.Kind() != reflect.Ptr || v.IsNil() { return nil, fmt.Errorf("structxx 需要是非空指针") @@ -302,7 +226,7 @@ func (sp *StructProcessor) processNestedStruct(field reflect.Value, value string // 尝试直接设置值 if hasUnmarshalJSON(structValue.Type()) || isBasicStructType(structValue.Type()) { oldValueStr, _ := cast.ToStringE(field.Interface()) - + if err := setBasicStructValue(structValue, value); err != nil { return nil, fmt.Errorf("设置基本结构体值失败: %w", err) } @@ -318,491 +242,3 @@ func (sp *StructProcessor) processNestedStruct(field reflect.Value, value string return nil, fmt.Errorf("嵌套结构体值必须是有效的JSON格式") } - -// defaultFieldMapper 默认字段映射器 -type defaultFieldMapper struct{} - -func (dm *defaultFieldMapper) GetFieldMap(t reflect.Type) map[string]FieldInfo { - cacheMutex.RLock() - if cached, exists := typeInfoCache[t]; exists { - cacheMutex.RUnlock() - return cached - } - cacheMutex.RUnlock() - - fieldMap := make(map[string]FieldInfo) - dm.buildFieldMapRecursive(t, []int{}, fieldMap, "") - - cacheMutex.Lock() - typeInfoCache[t] = fieldMap - cacheMutex.Unlock() - - return fieldMap -} - -func (dm *defaultFieldMapper) buildFieldMapRecursive(t reflect.Type, index []int, fieldMap map[string]FieldInfo, prefix string) { - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - if !field.IsExported() { - continue - } - - currentIndex := append(index, i) - jsonTag := getJSONTagName(field) - - fullKey := jsonTag - if prefix != "" { - fullKey = prefix + "." + jsonTag - } - - fieldType := field.Type - isPtr := fieldType.Kind() == reflect.Ptr - actualType := fieldType - if isPtr { - actualType = fieldType.Elem() - } - - if actualType.Kind() == reflect.Slice || actualType.Kind() == reflect.Array { - fieldMap[fullKey] = FieldInfo{ - Index: currentIndex, - Name: field.Name, - IsPtr: isPtr, - FieldType: fieldType, - IsSlice: actualType.Kind() == reflect.Slice, - IsArray: actualType.Kind() == reflect.Array, - } - continue - } - - if actualType.Kind() == reflect.Struct && !isBasicStructType(actualType) { - dm.buildFieldMapRecursive(actualType, currentIndex, fieldMap, fullKey) - } - - fieldMap[fullKey] = FieldInfo{ - Index: currentIndex, - Name: field.Name, - IsPtr: isPtr, - FieldType: fieldType, - } - } -} - -// defaultValueSetter 默认值设置器 -type defaultValueSetter struct{} - -func (ds *defaultValueSetter) SetFieldValue(field reflect.Value, value string) (interface{}, error) { - fieldType := field.Type() - - if hasUnmarshalJSON(fieldType) { - if err := setUnmarshalJSONValue(field, value); err != nil { - return nil, fmt.Errorf("UnmarshalJSON失败: %w", err) - } - return field.Interface(), nil - } - - if hasUnmarshalText(fieldType) { - if err := setUnmarshalTextValue(field, value); err != nil { - return nil, fmt.Errorf("UnmarshalText失败: %w", err) - } - return field.Interface(), nil - } - - if fieldType.Kind() == reflect.Ptr { - return setPointerFieldValue(field, value) - } - - if isTypeAlias(fieldType) { - return setTypeAliasValue(field, value) - } - - if isCustomStructType(fieldType) { - return setCustomTypeValue(field, value) - } - - converter, exists := typeConverters[fieldType.Kind()] - if !exists { - return nil, fmt.Errorf("不支持的类型: %s", fieldType.Kind()) - } - - result, err := converter(field, value) - if err != nil { - return nil, err - } - - field.Set(reflect.ValueOf(result)) - return result, nil -} - -func (ds *defaultValueSetter) SetSliceElementValue(elemValue reflect.Value, item interface{}, elemType reflect.Type) error { - if elemType.Kind() == reflect.Ptr { - if elemValue.IsNil() { - elemValue.Set(reflect.New(elemType.Elem())) - } - return ds.SetSliceElementValue(elemValue.Elem(), item, elemType.Elem()) - } - - if hasUnmarshalJSON(elemType) { - return setUnmarshalJSONValue(elemValue, item) - } - - switch elemType.Kind() { - case reflect.String: - elemValue.SetString(convertToString(item)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - num, err := convertToFloat64(item) - if err != nil { - return fmt.Errorf("无法转换为整型: %w", err) - } - elemValue.SetInt(int64(num)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - num, err := convertToFloat64(item) - if err != nil { - return fmt.Errorf("无法转换为无符号整型: %w", err) - } - elemValue.SetUint(uint64(num)) - case reflect.Float32, reflect.Float64: - num, err := convertToFloat64(item) - if err != nil { - return fmt.Errorf("无法转换为浮点型: %w", err) - } - elemValue.SetFloat(num) - case reflect.Bool: - if b, ok := item.(bool); ok { - elemValue.SetBool(b) - } else { - return fmt.Errorf("无法转换为布尔型") - } - case reflect.Struct: - if isBasicStructType(elemType) { - return setBasicStructElement(elemValue, item) - } - return setStructElement(elemValue, item) - default: - return fmt.Errorf("不支持的切片元素类型: %s", elemType.Kind()) - } - - return nil -} - -// 工具函数 -func getJSONTagName(field reflect.StructField) string { - jsonTag := field.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - return field.Name - } - return strings.Split(jsonTag, ",")[0] -} - -func isBasicStructType(t reflect.Type) bool { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return basicStructTypes[t.String()] -} - -func hasUnmarshalJSON(t reflect.Type) bool { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() - return t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType) -} - -func hasUnmarshalText(t reflect.Type) bool { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - textUnmarshalerType := reflect.TypeOf((*interface { - UnmarshalText([]byte) error - })(nil)).Elem() - return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) -} - -func setUnmarshalJSONValue(field reflect.Value, value interface{}) error { - jsonBytes, err := json.Marshal(value) - if err != nil { - return err - } - - var fieldAddr reflect.Value - if field.CanAddr() { - fieldAddr = field.Addr() - } else { - temp := reflect.New(field.Type()) - temp.Elem().Set(field) - fieldAddr = temp - } - - if unmarshaler, ok := fieldAddr.Interface().(json.Unmarshaler); ok { - return unmarshaler.UnmarshalJSON(jsonBytes) - } - return fmt.Errorf("类型未实现Unmarshaler") -} - -func setUnmarshalTextValue(field reflect.Value, value string) error { - var fieldAddr reflect.Value - if field.CanAddr() { - fieldAddr = field.Addr() - } else { - temp := reflect.New(field.Type()) - temp.Elem().Set(field) - fieldAddr = temp - } - - if unmarshaler, ok := fieldAddr.Interface().(interface { - UnmarshalText([]byte) error - }); ok { - return unmarshaler.UnmarshalText([]byte(value)) - } - return fmt.Errorf("类型未实现UnmarshalText") -} - -func setBasicStructValue(field reflect.Value, value string) error { - if hasUnmarshalText(field.Type()) { - return setUnmarshalTextValue(field, value) - } - return json.Unmarshal([]byte(value), field.Addr().Interface()) -} - -func setPointerFieldValue(field reflect.Value, value string) (interface{}, error) { - if field.Kind() != reflect.Ptr { - return nil, fmt.Errorf("期望指针类型") - } - if field.IsNil() { - field.Set(reflect.New(field.Type().Elem())) - } - - - return new(defaultValueSetter).SetFieldValue(field.Elem(), value) -} - -func isTypeAlias(t reflect.Type) bool { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t.PkgPath() != "" && basicKinds[t.Kind()] -} - -func setTypeAliasValue(field reflect.Value, value string) (interface{}, error) { - baseType := getBaseTypeFromAlias(field.Type()) - if baseType == nil { - return nil, fmt.Errorf("无法获取基础类型") - } - - converter, exists := typeConverters[baseType.Kind()] - if !exists { - return nil, fmt.Errorf("不支持的基础类型") - } - - baseValue := reflect.New(baseType).Elem() - result, err := converter(baseValue, value) - if err != nil { - return nil, err - } - - convertedValue, err := convertToTypeAlias(field.Type(), result) - if err != nil { - return nil, err - } - - field.Set(reflect.ValueOf(convertedValue)) - return convertedValue, nil -} - -func isCustomStructType(t reflect.Type) bool { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t.PkgPath() != "" && !isBasicStructType(t) && !isTypeAlias(t) && t.Kind() == reflect.Struct -} - -func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { - if hasUnmarshalText(field.Type()) { - if err := setUnmarshalTextValue(field, value); err != nil { - return nil, err - } - return field.Interface(), nil - } - - var jsonData interface{} - if err := json.Unmarshal([]byte(value), &jsonData); err != nil { - return nil, err - } - - jsonBytes, err := json.Marshal(jsonData) - if err != nil { - return nil, err - } - - if err := json.Unmarshal(jsonBytes, field.Addr().Interface()); err != nil { - return nil, err - } - - return field.Interface(), nil -} - -func setBasicStructElement(elemValue reflect.Value, item interface{}) error { - return setStructElement(elemValue, item) -} - -func setStructElement(elemValue reflect.Value, item interface{}) error { - jsonBytes, err := json.Marshal(item) - if err != nil { - return err - } - return json.Unmarshal(jsonBytes, elemValue.Addr().Interface()) -} - -func convertToString(item interface{}) string { - if str, ok := item.(string); ok { - return str - } - return fmt.Sprintf("%v", item) -} - -func convertToFloat64(item interface{}) (float64, error) { - switch v := item.(type) { - case float64: - return v, nil - case int, int32, int64: - return float64(reflect.ValueOf(v).Int()), nil - case uint, uint32, uint64: - return float64(reflect.ValueOf(v).Uint()), nil - case float32: - return float64(v), nil - default: - return 0, fmt.Errorf("无法转换为数字") - } -} - -func getBaseTypeFromAlias(aliasType reflect.Type) reflect.Type { - if aliasType.Kind() == reflect.Ptr { - aliasType = aliasType.Elem() - } - - switch aliasType.Kind() { - case reflect.String: - return reflect.TypeOf("") - case reflect.Int: - return reflect.TypeOf(int(0)) - case reflect.Int8: - return reflect.TypeOf(int8(0)) - case reflect.Int16: - return reflect.TypeOf(int16(0)) - case reflect.Int32: - return reflect.TypeOf(int32(0)) - case reflect.Int64: - return reflect.TypeOf(int64(0)) - case reflect.Uint: - return reflect.TypeOf(uint(0)) - case reflect.Uint8: - return reflect.TypeOf(uint8(0)) - case reflect.Uint16: - return reflect.TypeOf(uint16(0)) - case reflect.Uint32: - return reflect.TypeOf(uint32(0)) - case reflect.Uint64: - return reflect.TypeOf(uint64(0)) - case reflect.Float32: - return reflect.TypeOf(float32(0)) - case reflect.Float64: - return reflect.TypeOf(float64(0)) - case reflect.Bool: - return reflect.TypeOf(false) - default: - return nil - } -} - -func convertToTypeAlias(aliasType reflect.Type, value interface{}) (interface{}, error) { - if aliasType.Kind() == reflect.Ptr { - elemType := aliasType.Elem() - newValue := reflect.New(elemType) - elemValue := newValue.Elem() - - converted, err := convertValueToType(value, elemType) - if err != nil { - return nil, err - } - elemValue.Set(reflect.ValueOf(converted)) - return newValue.Interface(), nil - } - - return convertValueToType(value, aliasType) -} - -func convertValueToType(value interface{}, targetType reflect.Type) (interface{}, error) { - valueType := reflect.TypeOf(value) - if valueType.AssignableTo(targetType) { - return value, nil - } - if valueType.ConvertibleTo(targetType) { - return reflect.ValueOf(value).Convert(targetType).Interface(), nil - } - return nil, fmt.Errorf("无法转换类型") -} - -// 类型转换函数 -func convertBool(field reflect.Value, value string) (interface{}, error) { - return convx.ToBool(value) -} - -func convertInt[T int | int8 | int16 | int32 | int64](field reflect.Value, value string) (interface{}, error) { - bits := field.Type().Bits() - intVal, err := strconv.ParseInt(value, 10, bits) - if err != nil { - return nil, err - } - return T(intVal), nil -} - -func convertUint[T uint | uint8 | uint16 | uint32 | uint64](field reflect.Value, value string) (interface{}, error) { - bits := field.Type().Bits() - uintVal, err := strconv.ParseUint(value, 10, bits) - if err != nil { - return nil, err - } - return T(uintVal), nil -} - -func convertFloat[T float32 | float64](field reflect.Value, value string) (interface{}, error) { - bits := field.Type().Bits() - floatVal, err := strconv.ParseFloat(value, bits) - if err != nil { - return nil, err - } - return T(floatVal), nil -} - -func convertString(field reflect.Value, value string) (interface{}, error) { - return value, nil -} - -func convertSlice(field reflect.Value, value string) (interface{}, error) { - var result []interface{} - if err := json.Unmarshal([]byte(value), &result); err != nil { - return nil, err - } - return result, nil -} - -func convertArray(field reflect.Value, value string) (interface{}, error) { - var result []interface{} - if err := json.Unmarshal([]byte(value), &result); err != nil { - return nil, err - } - return result, nil -} - -func convertMap(field reflect.Value, value string) (interface{}, error) { - return nil, fmt.Errorf("map转换未实现") -} - -// 全局函数(保持向后兼容) -func AttactToStructAny(structxx interface{}, updateMap map[string]interface{}) (map[string]ChangeInfo, error) { - return NewStructProcessor().AttactToStructAny(structxx, updateMap) -} - -func AttactToStruct(structxx interface{}, updateMap map[string]string) (map[string]ChangeInfo, error) { - return NewStructProcessor().AttactToStruct(structxx, updateMap) -} \ No newline at end of file diff --git a/structx_test.go b/structx_test.go index e5ba815..c8177fc 100644 --- a/structx_test.go +++ b/structx_test.go @@ -225,6 +225,9 @@ func TestAttactToStruct_NestedStruct(t *testing.T) { input: map[string]string{ "basic.name": "Alice", "comment": "partial", + "amount": "123.45", + "amount2": "123.45", // 测试指针字段 + "timestamp": "2024-01-01T15:04:05Z", }, expected: NestedStruct{ Basic: BasicStruct{ diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..6e446e3 --- /dev/null +++ b/utils.go @@ -0,0 +1,324 @@ +package structx + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + + "code.yun.ink/pkg/convx" +) + +// 工具函数 +func getJSONTagName(field reflect.StructField) string { + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + return field.Name + } + return strings.Split(jsonTag, ",")[0] +} + +func isBasicStructType(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return basicStructTypes[t.String()] +} + +func hasUnmarshalJSON(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + return t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType) +} + +func hasUnmarshalText(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + textUnmarshalerType := reflect.TypeOf((*interface { + UnmarshalText([]byte) error + })(nil)).Elem() + return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) +} + +func setUnmarshalJSONValue(field reflect.Value, value interface{}) error { + jsonBytes, err := json.Marshal(value) + if err != nil { + return err + } + + var fieldAddr reflect.Value + if field.CanAddr() { + fieldAddr = field.Addr() + } else { + temp := reflect.New(field.Type()) + temp.Elem().Set(field) + fieldAddr = temp + } + + if unmarshaler, ok := fieldAddr.Interface().(json.Unmarshaler); ok { + return unmarshaler.UnmarshalJSON(jsonBytes) + } + return fmt.Errorf("类型未实现Unmarshaler") +} + +func setUnmarshalTextValue(field reflect.Value, value string) error { + var fieldAddr reflect.Value + if field.CanAddr() { + fieldAddr = field.Addr() + } else { + temp := reflect.New(field.Type()) + temp.Elem().Set(field) + fieldAddr = temp + } + + if unmarshaler, ok := fieldAddr.Interface().(interface { + UnmarshalText([]byte) error + }); ok { + return unmarshaler.UnmarshalText([]byte(value)) + } + return fmt.Errorf("类型未实现UnmarshalText") +} + +func setBasicStructValue(field reflect.Value, value string) error { + if hasUnmarshalText(field.Type()) { + return setUnmarshalTextValue(field, value) + } + return json.Unmarshal([]byte(value), field.Addr().Interface()) +} + +func setPointerFieldValue(field reflect.Value, value string) (interface{}, error) { + if field.Kind() != reflect.Ptr { + return nil, fmt.Errorf("期望指针类型") + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + + return new(defaultValueSetter).SetFieldValue(field.Elem(), value) +} + +func isTypeAlias(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() != "" && basicKinds[t.Kind()] +} + +func setTypeAliasValue(field reflect.Value, value string) (interface{}, error) { + baseType := getBaseTypeFromAlias(field.Type()) + if baseType == nil { + return nil, fmt.Errorf("无法获取基础类型") + } + + converter, exists := typeConverters[baseType.Kind()] + if !exists { + return nil, fmt.Errorf("不支持的基础类型") + } + + baseValue := reflect.New(baseType).Elem() + result, err := converter(baseValue, value) + if err != nil { + return nil, err + } + + convertedValue, err := convertToTypeAlias(field.Type(), result) + if err != nil { + return nil, err + } + + field.Set(reflect.ValueOf(convertedValue)) + return convertedValue, nil +} + +func isCustomStructType(t reflect.Type) bool { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() != "" && !isBasicStructType(t) && !isTypeAlias(t) && t.Kind() == reflect.Struct +} + +func setCustomTypeValue(field reflect.Value, value string) (interface{}, error) { + if hasUnmarshalText(field.Type()) { + if err := setUnmarshalTextValue(field, value); err != nil { + return nil, err + } + return field.Interface(), nil + } + + var jsonData interface{} + if err := json.Unmarshal([]byte(value), &jsonData); err != nil { + return nil, err + } + + jsonBytes, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + if err := json.Unmarshal(jsonBytes, field.Addr().Interface()); err != nil { + return nil, err + } + + return field.Interface(), nil +} + +func setBasicStructElement(elemValue reflect.Value, item interface{}) error { + return setStructElement(elemValue, item) +} + +func setStructElement(elemValue reflect.Value, item interface{}) error { + jsonBytes, err := json.Marshal(item) + if err != nil { + return err + } + return json.Unmarshal(jsonBytes, elemValue.Addr().Interface()) +} + +func convertToString(item interface{}) string { + if str, ok := item.(string); ok { + return str + } + return fmt.Sprintf("%v", item) +} + +func convertToFloat64(item interface{}) (float64, error) { + switch v := item.(type) { + case float64: + return v, nil + case int, int32, int64: + return float64(reflect.ValueOf(v).Int()), nil + case uint, uint32, uint64: + return float64(reflect.ValueOf(v).Uint()), nil + case float32: + return float64(v), nil + default: + return 0, fmt.Errorf("无法转换为数字") + } +} + +func getBaseTypeFromAlias(aliasType reflect.Type) reflect.Type { + if aliasType.Kind() == reflect.Ptr { + aliasType = aliasType.Elem() + } + + switch aliasType.Kind() { + case reflect.String: + return reflect.TypeOf("") + case reflect.Int: + return reflect.TypeOf(int(0)) + case reflect.Int8: + return reflect.TypeOf(int8(0)) + case reflect.Int16: + return reflect.TypeOf(int16(0)) + case reflect.Int32: + return reflect.TypeOf(int32(0)) + case reflect.Int64: + return reflect.TypeOf(int64(0)) + case reflect.Uint: + return reflect.TypeOf(uint(0)) + case reflect.Uint8: + return reflect.TypeOf(uint8(0)) + case reflect.Uint16: + return reflect.TypeOf(uint16(0)) + case reflect.Uint32: + return reflect.TypeOf(uint32(0)) + case reflect.Uint64: + return reflect.TypeOf(uint64(0)) + case reflect.Float32: + return reflect.TypeOf(float32(0)) + case reflect.Float64: + return reflect.TypeOf(float64(0)) + case reflect.Bool: + return reflect.TypeOf(false) + default: + return nil + } +} + +func convertToTypeAlias(aliasType reflect.Type, value interface{}) (interface{}, error) { + if aliasType.Kind() == reflect.Ptr { + elemType := aliasType.Elem() + newValue := reflect.New(elemType) + elemValue := newValue.Elem() + + converted, err := convertValueToType(value, elemType) + if err != nil { + return nil, err + } + elemValue.Set(reflect.ValueOf(converted)) + return newValue.Interface(), nil + } + + return convertValueToType(value, aliasType) +} + +func convertValueToType(value interface{}, targetType reflect.Type) (interface{}, error) { + valueType := reflect.TypeOf(value) + if valueType.AssignableTo(targetType) { + return value, nil + } + if valueType.ConvertibleTo(targetType) { + return reflect.ValueOf(value).Convert(targetType).Interface(), nil + } + return nil, fmt.Errorf("无法转换类型") +} + +// 类型转换函数 +func convertBool(field reflect.Value, value string) (interface{}, error) { + return convx.ToBool(value) +} + +func convertInt[T int | int8 | int16 | int32 | int64](field reflect.Value, value string) (interface{}, error) { + bits := field.Type().Bits() + intVal, err := strconv.ParseInt(value, 10, bits) + if err != nil { + return nil, err + } + return T(intVal), nil +} + +func convertUint[T uint | uint8 | uint16 | uint32 | uint64](field reflect.Value, value string) (interface{}, error) { + bits := field.Type().Bits() + uintVal, err := strconv.ParseUint(value, 10, bits) + if err != nil { + return nil, err + } + return T(uintVal), nil +} + +func convertFloat[T float32 | float64](field reflect.Value, value string) (interface{}, error) { + bits := field.Type().Bits() + floatVal, err := strconv.ParseFloat(value, bits) + if err != nil { + return nil, err + } + return T(floatVal), nil +} + +func convertString(field reflect.Value, value string) (interface{}, error) { + return value, nil +} + +func convertSlice(field reflect.Value, value string) (interface{}, error) { + var result []interface{} + if err := json.Unmarshal([]byte(value), &result); err != nil { + return nil, err + } + return result, nil +} + +func convertArray(field reflect.Value, value string) (interface{}, error) { + var result []interface{} + if err := json.Unmarshal([]byte(value), &result); err != nil { + return nil, err + } + return result, nil +} + +func convertMap(field reflect.Value, value string) (interface{}, error) { + return nil, fmt.Errorf("map转换未实现") +} diff --git a/value_setter.go b/value_setter.go new file mode 100644 index 0000000..586aa97 --- /dev/null +++ b/value_setter.go @@ -0,0 +1,109 @@ +package structx + +import ( + "fmt" + "reflect" +) + +// ValueSetter 值设置器接口 +type ValueSetter interface { + SetFieldValue(field reflect.Value, value string) (interface{}, error) + SetSliceElementValue(elemValue reflect.Value, item interface{}, elemType reflect.Type) error +} + +// defaultValueSetter 默认值设置器 +type defaultValueSetter struct{} + +func (ds *defaultValueSetter) SetFieldValue(field reflect.Value, value string) (interface{}, error) { + fieldType := field.Type() + + if hasUnmarshalJSON(fieldType) { + if err := setUnmarshalJSONValue(field, value); err != nil { + return nil, fmt.Errorf("UnmarshalJSON失败: %w", err) + } + return field.Interface(), nil + } + + if hasUnmarshalText(fieldType) { + if err := setUnmarshalTextValue(field, value); err != nil { + return nil, fmt.Errorf("UnmarshalText失败: %w", err) + } + return field.Interface(), nil + } + + if fieldType.Kind() == reflect.Ptr { + return setPointerFieldValue(field, value) + } + + if isTypeAlias(fieldType) { + return setTypeAliasValue(field, value) + } + + if isCustomStructType(fieldType) { + return setCustomTypeValue(field, value) + } + + converter, exists := typeConverters[fieldType.Kind()] + if !exists { + return nil, fmt.Errorf("不支持的类型: %s", fieldType.Kind()) + } + + result, err := converter(field, value) + if err != nil { + return nil, err + } + + field.Set(reflect.ValueOf(result)) + return result, nil +} + +func (ds *defaultValueSetter) SetSliceElementValue(elemValue reflect.Value, item interface{}, elemType reflect.Type) error { + if elemType.Kind() == reflect.Ptr { + if elemValue.IsNil() { + elemValue.Set(reflect.New(elemType.Elem())) + } + return ds.SetSliceElementValue(elemValue.Elem(), item, elemType.Elem()) + } + + if hasUnmarshalJSON(elemType) { + return setUnmarshalJSONValue(elemValue, item) + } + + switch elemType.Kind() { + case reflect.String: + elemValue.SetString(convertToString(item)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + num, err := convertToFloat64(item) + if err != nil { + return fmt.Errorf("无法转换为整型: %w", err) + } + elemValue.SetInt(int64(num)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + num, err := convertToFloat64(item) + if err != nil { + return fmt.Errorf("无法转换为无符号整型: %w", err) + } + elemValue.SetUint(uint64(num)) + case reflect.Float32, reflect.Float64: + num, err := convertToFloat64(item) + if err != nil { + return fmt.Errorf("无法转换为浮点型: %w", err) + } + elemValue.SetFloat(num) + case reflect.Bool: + if b, ok := item.(bool); ok { + elemValue.SetBool(b) + } else { + return fmt.Errorf("无法转换为布尔型") + } + case reflect.Struct: + if isBasicStructType(elemType) { + return setBasicStructElement(elemValue, item) + } + return setStructElement(elemValue, item) + default: + return fmt.Errorf("不支持的切片元素类型: %s", elemType.Kind()) + } + + return nil +}