Skip to content

Commit f239082

Browse files
committed
feat: support sqlc.embed
1 parent bfa71a9 commit f239082

File tree

1 file changed

+93
-9
lines changed

1 file changed

+93
-9
lines changed

‎internal/gen.go‎

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ type Field struct{
5353
Namestring
5454
TypepyType
5555
Commentstring
56+
// EmbedFields contains the embedded fields that require scanning.
57+
EmbedFields []Field
5658
}
5759

5860
typeStructstruct{
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node{
105107
call:=&pyast.Call{
106108
Func: v.Annotation(),
107109
}
108-
fori, f:=rangev.Struct.Fields{
109-
call.Keywords=append(call.Keywords, &pyast.Keyword{
110-
Arg: f.Name,
111-
Value: subscriptNode(
110+
rowIndex:=0// We need to keep track of the index in the row variable.
111+
for_, f:=rangev.Struct.Fields{
112+
113+
varvalueNode*pyast.Node
114+
// Check if we are using sqlc.embed, if so we need to create a new object.
115+
iflen(f.EmbedFields) >0{
116+
// We keep this separate so we can easily add all arguments.
117+
embed_call:=&pyast.Call{Func: f.Type.Annotation()}
118+
119+
// Now add all field Initializers for the embedded model that index into the original row.
120+
fori, embedField:=rangef.EmbedFields{
121+
embed_call.Keywords=append(embed_call.Keywords, &pyast.Keyword{
122+
Arg: embedField.Name,
123+
Value: subscriptNode(
124+
rowVar,
125+
constantInt(rowIndex+i),
126+
),
127+
})
128+
}
129+
130+
valueNode=&pyast.Node{
131+
Node: &pyast.Node_Call{
132+
Call: embed_call,
133+
},
134+
}
135+
136+
rowIndex+=len(f.EmbedFields)
137+
} else{
138+
valueNode=subscriptNode(
112139
rowVar,
113-
constantInt(i),
114-
),
115-
})
140+
constantInt(rowIndex),
141+
)
142+
rowIndex++
143+
}
144+
145+
call.Keywords=append(call.Keywords, &pyast.Keyword{Arg: f.Name, Value: valueNode})
116146
}
117147
return&pyast.Node{
118148
Node: &pyast.Node_Call{
@@ -336,6 +366,47 @@ func paramName(p *plugin.Parameter) string{
336366
typepyColumnstruct{
337367
idint32
338368
*plugin.Column
369+
embed*pyEmbed
370+
}
371+
372+
typepyEmbedstruct{
373+
modelTypestring
374+
modelNamestring
375+
fields []Field
376+
}
377+
378+
// Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
379+
// look through all the structs and attempt to find a matching one to embed
380+
// We need the name of the struct and its field names.
381+
funcnewGoEmbed(embed*plugin.Identifier, structs []Struct, defaultSchemastring) *pyEmbed{
382+
ifembed==nil{
383+
returnnil
384+
}
385+
386+
for_, s:=rangestructs{
387+
embedSchema:=defaultSchema
388+
ifembed.Schema!=""{
389+
embedSchema=embed.Schema
390+
}
391+
392+
// compare the other attributes
393+
ifembed.Catalog!=s.Table.Catalog||embed.Name!=s.Table.Name||embedSchema!=s.Table.Schema{
394+
continue
395+
}
396+
397+
fields:=make([]Field, len(s.Fields))
398+
fori, f:=ranges.Fields{
399+
fields[i] =f
400+
}
401+
402+
return&pyEmbed{
403+
modelType: s.Name,
404+
modelName: s.Name,
405+
fields: fields,
406+
}
407+
}
408+
409+
returnnil
339410
}
340411

341412
funccolumnsToStruct(req*plugin.CodeGenRequest, namestring, columns []pyColumn) *Struct{
@@ -359,10 +430,22 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn
359430
ifsuffix>0{
360431
fieldName=fmt.Sprintf("%s_%d", fieldName, suffix)
361432
}
362-
gs.Fields=append(gs.Fields, Field{
433+
434+
f:=Field{
363435
Name: fieldName,
364436
Type: makePyType(req, c.Column),
365-
})
437+
}
438+
439+
ifc.embed!=nil{
440+
f.Type=pyType{
441+
InnerType: "models."+modelName(c.embed.modelType, req.Settings),
442+
IsArray: false,
443+
IsNull: false,
444+
}
445+
f.EmbedFields=c.embed.fields
446+
}
447+
448+
gs.Fields=append(gs.Fields, f)
366449
seen[colName]++
367450
}
368451
return&gs
@@ -476,6 +559,7 @@ func buildQueries(conf Config, req *plugin.CodeGenRequest, structs []Struct) ([]
476559
columns=append(columns, pyColumn{
477560
id: int32(i),
478561
Column: c,
562+
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
479563
})
480564
}
481565
gs=columnsToStruct(req, query.Name+"Row", columns)

0 commit comments

Comments
(0)