Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 24 additions & 30 deletions docparse/jsonschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,24 @@ func setTags(name, fName string, p *Schema, tags []string) error {
return nil
}

// extractGenericIdent resolves the type expression on the left-hand side of a
// generic instantiation (e.g. Foo[T] or pkg.Foo[T]) to its identifier and
// package name.
func extractGenericIdent(x ast.Expr, curPkg string) (*ast.Ident, string, error) {
switch x := x.(type) {
case *ast.Ident:
return x, curPkg, nil
case *ast.SelectorExpr:
pkgSel, ok := x.X.(*ast.Ident)
if !ok {
return nil, "", fmt.Errorf("unknown generic type selector: %T", x.X)
}
return x.Sel, pkgSel.Name, nil
default:
return nil, "", fmt.Errorf("unknown generic type: %T", x)
}
}

// Convert a struct field to JSON schema.
func fieldToSchema(
prog *Program,
Expand Down Expand Up @@ -459,21 +477,9 @@ start:

// Generic types
case *ast.IndexExpr:
var genericsPkg string
var genericsIdent *ast.Ident
switch x := typ.X.(type) {
case *ast.Ident:
genericsIdent = x
genericsPkg = ref.Package
case *ast.SelectorExpr:
pkgSel, ok := x.X.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("unknown generic type selector: %T", x.X)
}
genericsIdent = x.Sel
genericsPkg = pkgSel.Name
default:
return nil, fmt.Errorf("unknown generic type: %T", typ.X)
genericsIdent, genericsPkg, err := extractGenericIdent(typ.X, ref.Package)
if err != nil {
return nil, err
}
if mapped, err := genericMapTypeLookup(prog, &p, genericsPkg, genericsIdent.Name, ref, typ.Index); err != nil {
return nil, err
Expand All @@ -486,21 +492,9 @@ start:
return &p, nil

case *ast.IndexListExpr:
var genericsPkg string
var genericsIdent *ast.Ident
switch x := typ.X.(type) {
case *ast.Ident:
genericsIdent = x
genericsPkg = ref.Package
case *ast.SelectorExpr:
pkgSel, ok := x.X.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("unknown generic type selector: %T", x.X)
}
genericsIdent = x.Sel
genericsPkg = pkgSel.Name
default:
return nil, fmt.Errorf("unknown generic type: %T", typ.X)
genericsIdent, genericsPkg, err := extractGenericIdent(typ.X, ref.Package)
if err != nil {
return nil, err
}
if mapped, err := genericMapTypeLookup(prog, &p, genericsPkg, genericsIdent.Name, ref, typ.Indices...); err != nil {
return nil, err
Expand Down
Loading