Skip to content

Commit 0fafc07

Browse files
committed
fix: align strict structured key handling
1 parent cdb378c commit 0fafc07

3 files changed

Lines changed: 89 additions & 18 deletions

File tree

binding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ func tryBindNestedField(
696696
if rawMap, ok := entry.value.(map[string]any); ok {
697697
nestedData := make(map[string]mergedEntry)
698698
for k, v := range rawMap {
699-
nestedData[k] = mergedEntry{value: v, sourceName: entry.sourceName}
699+
nestedData[strings.ToLower(k)] = mergedEntry{value: v, sourceName: entry.sourceName}
700700
}
701701
return true, bindNested(fieldValue, nestedData, "", fieldPath)
702702
}

loader.go

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,13 @@ func (l *Loader[T]) loadInternal(ctx context.Context, store bool) (*T, *Provenan
143143

144144
// Step 3: In strict mode, detect unknown keys
145145
if l.strict {
146-
validKeys := collectValidKeys(rootType, "")
147-
validKeyTypes := collectValidKeyTypes(rootType, "")
148-
dynamicMapKeyPatterns := collectDynamicMapKeyPatterns(rootType, "")
146+
strictMetadataCache := make(map[strictValidationMetadataKey]strictValidationMetadata)
147+
strictMetadata := getStrictValidationMetadata(rootType, "", strictMetadataCache)
149148

150149
// Check for unknown keys
151150
var unknownKeyErrors []FieldError
152151
for key, entry := range mergedData {
153-
if !validKeys[key] && !matchesDynamicMapKeyPattern(key, dynamicMapKeyPatterns) {
152+
if !strictMetadata.validKeys[key] && !matchesDynamicMapKeyPattern(key, strictMetadata.dynamicMapKeyPatterns) {
154153
unknownKeyErrors = append(unknownKeyErrors, FieldError{
155154
FieldPath: key,
156155
Code: ErrCodeUnknownKey,
@@ -159,8 +158,8 @@ func (l *Loader[T]) loadInternal(ctx context.Context, store bool) (*T, *Provenan
159158
continue
160159
}
161160

162-
if fieldType, ok := validKeyTypes[key]; ok {
163-
unknownKeyErrors = append(unknownKeyErrors, collectStructuredUnknownKeyErrors(entry.value, fieldType, key)...)
161+
if fieldType, ok := strictMetadata.validKeyTypes[key]; ok {
162+
unknownKeyErrors = append(unknownKeyErrors, collectStructuredUnknownKeyErrors(entry.value, fieldType, key, strictMetadataCache)...)
164163
}
165164
}
166165

@@ -355,6 +354,43 @@ func collectValidKeyTypes(t reflect.Type, prefix string) map[string]reflect.Type
355354
return validKeyTypes
356355
}
357356

357+
type strictValidationMetadata struct {
358+
validKeys map[string]bool
359+
validKeyTypes map[string]reflect.Type
360+
dynamicMapKeyPatterns []string
361+
}
362+
363+
type strictValidationMetadataKey struct {
364+
targetType reflect.Type
365+
prefix string
366+
}
367+
368+
func getStrictValidationMetadata(
369+
targetType reflect.Type,
370+
prefix string,
371+
cache map[strictValidationMetadataKey]strictValidationMetadata,
372+
) strictValidationMetadata {
373+
for targetType.Kind() == reflect.Ptr {
374+
targetType = targetType.Elem()
375+
}
376+
377+
cacheKey := strictValidationMetadataKey{
378+
targetType: targetType,
379+
prefix: prefix,
380+
}
381+
if metadata, ok := cache[cacheKey]; ok {
382+
return metadata
383+
}
384+
385+
metadata := strictValidationMetadata{
386+
validKeys: collectValidKeys(targetType, prefix),
387+
validKeyTypes: collectValidKeyTypes(targetType, prefix),
388+
dynamicMapKeyPatterns: collectDynamicMapKeyPatterns(targetType, prefix),
389+
}
390+
cache[cacheKey] = metadata
391+
return metadata
392+
}
393+
358394
func collectDynamicMapKeyPatterns(t reflect.Type, prefix string) []string {
359395
patternSet := make(map[string]struct{})
360396

@@ -414,7 +450,12 @@ func collectDynamicMapKeyPatterns(t reflect.Type, prefix string) []string {
414450
return patterns
415451
}
416452

417-
func collectStructuredUnknownKeyErrors(rawValue any, targetType reflect.Type, keyPath string) []FieldError {
453+
func collectStructuredUnknownKeyErrors(
454+
rawValue any,
455+
targetType reflect.Type,
456+
keyPath string,
457+
metadataCache map[strictValidationMetadataKey]strictValidationMetadata,
458+
) []FieldError {
418459
if rawValue == nil {
419460
return nil
420461
}
@@ -432,7 +473,7 @@ func collectStructuredUnknownKeyErrors(rawValue any, targetType reflect.Type, ke
432473
return nil
433474
}
434475

435-
return collectUnknownStructMapKeys(rawMap, targetType, keyPath)
476+
return collectUnknownStructMapKeys(rawMap, targetType, keyPath, metadataCache)
436477

437478
case reflect.Slice, reflect.Array:
438479
elemType := unwrapStrictType(targetType.Elem())
@@ -452,7 +493,7 @@ func collectStructuredUnknownKeyErrors(rawValue any, targetType reflect.Type, ke
452493
continue
453494
}
454495

455-
fieldErrors = append(fieldErrors, collectUnknownStructMapKeys(elemMap, elemType, fmt.Sprintf("%s.%d", keyPath, i))...)
496+
fieldErrors = append(fieldErrors, collectUnknownStructMapKeys(elemMap, elemType, fmt.Sprintf("%s.%d", keyPath, i), metadataCache)...)
456497
}
457498
return fieldErrors
458499

@@ -482,23 +523,26 @@ func collectStructuredUnknownKeyErrors(rawValue any, targetType reflect.Type, ke
482523
continue
483524
}
484525

485-
fieldErrors = append(fieldErrors, collectUnknownStructMapKeys(elemMap, elemType, keyPath+"."+rawKey.String())...)
526+
fieldErrors = append(fieldErrors, collectUnknownStructMapKeys(elemMap, elemType, keyPath+"."+strings.ToLower(rawKey.String()), metadataCache)...)
486527
}
487528
return fieldErrors
488529
}
489530

490531
return nil
491532
}
492533

493-
func collectUnknownStructMapKeys(rawMap map[string]any, targetType reflect.Type, keyPath string) []FieldError {
494-
validKeys := collectValidKeys(targetType, "")
495-
validKeyTypes := collectValidKeyTypes(targetType, "")
496-
dynamicMapKeyPatterns := collectDynamicMapKeyPatterns(targetType, "")
534+
func collectUnknownStructMapKeys(
535+
rawMap map[string]any,
536+
targetType reflect.Type,
537+
keyPath string,
538+
metadataCache map[strictValidationMetadataKey]strictValidationMetadata,
539+
) []FieldError {
540+
strictMetadata := getStrictValidationMetadata(targetType, "", metadataCache)
497541

498542
var fieldErrors []FieldError
499543
for rawKey, rawValue := range rawMap {
500544
normalizedKey := strings.ToLower(rawKey)
501-
if !validKeys[normalizedKey] && !matchesDynamicMapKeyPattern(normalizedKey, dynamicMapKeyPatterns) {
545+
if !strictMetadata.validKeys[normalizedKey] && !matchesDynamicMapKeyPattern(normalizedKey, strictMetadata.dynamicMapKeyPatterns) {
502546
fieldErrors = append(fieldErrors, FieldError{
503547
FieldPath: keyPath + "." + normalizedKey,
504548
Code: ErrCodeUnknownKey,
@@ -507,8 +551,8 @@ func collectUnknownStructMapKeys(rawMap map[string]any, targetType reflect.Type,
507551
continue
508552
}
509553

510-
if nestedType, ok := validKeyTypes[normalizedKey]; ok {
511-
fieldErrors = append(fieldErrors, collectStructuredUnknownKeyErrors(rawValue, nestedType, keyPath+"."+normalizedKey)...)
554+
if nestedType, ok := strictMetadata.validKeyTypes[normalizedKey]; ok {
555+
fieldErrors = append(fieldErrors, collectStructuredUnknownKeyErrors(rawValue, nestedType, keyPath+"."+normalizedKey, metadataCache)...)
512556
}
513557
}
514558

loader_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,33 @@ func TestLoad_NestedStruct(t *testing.T) {
914914
if cfg.Database.Port != 5432 {
915915
t.Errorf("expected Database.Port=5432, got %d", cfg.Database.Port)
916916
}
917+
918+
t.Run("direct nested map normalizes inner keys like strict validation", func(t *testing.T) {
919+
type ConfigWithDirectMap struct {
920+
Database Database
921+
}
922+
923+
source := &mockSource{
924+
data: map[string]any{
925+
"database": map[string]any{
926+
"Host": "localhost",
927+
"Port": 5432,
928+
},
929+
},
930+
}
931+
932+
cfg, err := NewLoader[ConfigWithDirectMap]().WithSource(source).Load(context.Background())
933+
if err != nil {
934+
t.Fatalf("Load failed: %v", err)
935+
}
936+
937+
if cfg.Database.Host != "localhost" {
938+
t.Errorf("expected Database.Host=localhost, got %s", cfg.Database.Host)
939+
}
940+
if cfg.Database.Port != 5432 {
941+
t.Errorf("expected Database.Port=5432, got %d", cfg.Database.Port)
942+
}
943+
})
917944
}
918945

919946
func TestLoad_NestedCollections(t *testing.T) {

0 commit comments

Comments
 (0)