diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 5f108868c7..3d06cc775e 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -181,13 +181,13 @@ func pluginQueries(r *compiler.Result) []*plugin.Query { } } out = append(out, &plugin.Query{ - Name: q.Name, - Cmd: q.Cmd, + Name: q.Metadata.Name, + Cmd: q.Metadata.Cmd, Text: q.SQL, - Comments: q.Comments, + Comments: q.Metadata.Comments, Columns: columns, Params: params, - Filename: q.Filename, + Filename: q.Metadata.Filename, InsertIntoTable: iit, }) } diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index 31aa3ec33c..4f79fc8e8b 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -545,7 +545,7 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { req := codeGenRequest(result, combo) cfg := vetConfig(req) for i, query := range req.Queries { - if result.Queries[i].Flags[QueryFlagSqlcVetDisable] { + if result.Queries[i].Metadata.Flags[QueryFlagSqlcVetDisable] { if debug.Active { log.Printf("Skipping vet rules for query: %s\n", query.Name) } diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index a6744fc6d2..9f4a5170ef 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -8,10 +8,10 @@ import ( "path/filepath" "strings" - "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/migrations" "github.com/sqlc-dev/sqlc/internal/multierr" "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/source" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" @@ -20,7 +20,7 @@ import ( // TODO: Rename this interface Engine type Parser interface { Parse(io.Reader) ([]ast.Statement, error) - CommentSyntax() metadata.CommentSyntax + CommentSyntax() source.CommentSyntax IsReservedKeyword(string) bool } @@ -90,14 +90,15 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { merr.Add(filename, src, loc, err) continue } - if query.Name != "" { - if _, exists := set[query.Name]; exists { - merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", query.Name)) + queryName := query.Metadata.Name + if queryName != "" { + if _, exists := set[queryName]; exists { + merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", queryName)) continue } - set[query.Name] = struct{}{} + set[queryName] = struct{}{} } - query.Filename = filepath.Base(filename) + query.Metadata.Filename = filepath.Base(filename) if query != nil { q = append(q, query) } diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 53e3043c7d..0b626e1081 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -43,14 +43,31 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, errors.New("missing semicolon at end of file") } - name, cmd, err := metadata.ParseQueryNameAndType(strings.TrimSpace(rawSQL), c.parser.CommentSyntax()) + name, cmd, err := metadata.ParseQueryNameAndType(rawSQL, metadata.CommentSyntax(c.parser.CommentSyntax())) if err != nil { return nil, err } + if err := validate.Cmd(raw.Stmt, name, cmd); err != nil { return nil, err } + md := metadata.Metadata{ + Name: name, + Cmd: cmd, + } + + // TODO eventually can use this for name and type/cmd parsing too + cleanedComments, err := source.CleanedComments(rawSQL, c.parser.CommentSyntax()) + if err != nil { + return nil, err + } + + md.Params, md.Flags, err = metadata.ParseParamsAndFlags(cleanedComments) + if err != nil { + return nil, err + } + var anlys *analysis if c.analyzer != nil { // TODO: Handle panics @@ -90,17 +107,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - flags, err := metadata.ParseQueryFlags(comments) - if err != nil { - return nil, err - } + md.Comments = comments return &Query{ RawStmt: raw, - Cmd: cmd, - Comments: comments, - Name: name, - Flags: flags, + Metadata: md, Params: anlys.Parameters, Columns: anlys.Columns, SQL: trimmed, diff --git a/internal/compiler/query.go b/internal/compiler/query.go index 117cf44813..df580c197b 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -1,6 +1,7 @@ package compiler import ( + "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/sql/ast" ) @@ -41,15 +42,9 @@ type Column struct { type Query struct { SQL string - Name string - Cmd string // TODO: Pick a better name. One of: one, many, exec, execrows, copyFrom - Flags map[string]bool + Metadata metadata.Metadata Columns []*Column Params []Parameter - Comments []string - - // XXX: Hack - Filename string // Needed for CopyFrom InsertIntoTable *ast.TableName diff --git a/internal/endtoend/testdata/comment_syntax/mysql/go/query.sql.go b/internal/endtoend/testdata/comment_syntax/mysql/go/query.sql.go index 7f4b916150..29909a8da2 100644 --- a/internal/endtoend/testdata/comment_syntax/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/comment_syntax/mysql/go/query.sql.go @@ -22,7 +22,6 @@ func (q *Queries) DoubleDash(ctx context.Context) (sql.NullString, error) { } const hash = `-- name: Hash :one -# name: Hash :one SELECT bar FROM foo LIMIT 1 ` diff --git a/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt b/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt index 6b0840fc37..06ec54327f 100644 --- a/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt +++ b/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt @@ -1,5 +1,5 @@ # package querytest -query.sql:1:1: invalid query comment: -- name: ListFoos +query.sql:1:1: missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos query.sql:5:1: invalid query comment: -- name: ListFoos :one :many query.sql:8:1: invalid query type: :two query.sql:11:1: query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause diff --git a/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt b/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt index 6b0840fc37..06ec54327f 100644 --- a/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt +++ b/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt @@ -1,5 +1,5 @@ # package querytest -query.sql:1:1: invalid query comment: -- name: ListFoos +query.sql:1:1: missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos query.sql:5:1: invalid query comment: -- name: ListFoos :one :many query.sql:8:1: invalid query type: :two query.sql:11:1: query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause diff --git a/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt b/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt index 6b0840fc37..06ec54327f 100644 --- a/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt +++ b/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt @@ -1,5 +1,5 @@ # package querytest -query.sql:1:1: invalid query comment: -- name: ListFoos +query.sql:1:1: missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos query.sql:5:1: invalid query comment: -- name: ListFoos :one :many query.sql:8:1: invalid query type: :two query.sql:11:1: query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause diff --git a/internal/engine/dolphin/parse.go b/internal/engine/dolphin/parse.go index 676362c448..22d3a1d224 100644 --- a/internal/engine/dolphin/parse.go +++ b/internal/engine/dolphin/parse.go @@ -10,7 +10,7 @@ import ( "github.com/pingcap/tidb/parser" _ "github.com/pingcap/tidb/parser/test_driver" - "github.com/sqlc-dev/sqlc/internal/metadata" + "github.com/sqlc-dev/sqlc/internal/source" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -86,8 +86,8 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { } // https://dev.mysql.com/doc/refman/8.0/en/comments.html -func (p *Parser) CommentSyntax() metadata.CommentSyntax { - return metadata.CommentSyntax{ +func (p *Parser) CommentSyntax() source.CommentSyntax { + return source.CommentSyntax{ Dash: true, SlashStar: true, Hash: true, diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index c1ac83381c..957a3073ae 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -12,7 +12,7 @@ import ( nodes "github.com/pganalyze/pg_query_go/v4" "github.com/pganalyze/pg_query_go/v4/parser" - "github.com/sqlc-dev/sqlc/internal/metadata" + "github.com/sqlc-dev/sqlc/internal/source" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -199,8 +199,8 @@ func normalizeErr(err error) error { } // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-COMMENTS -func (p *Parser) CommentSyntax() metadata.CommentSyntax { - return metadata.CommentSyntax{ +func (p *Parser) CommentSyntax() source.CommentSyntax { + return source.CommentSyntax{ Dash: true, SlashStar: true, } diff --git a/internal/engine/sqlite/parse.go b/internal/engine/sqlite/parse.go index bf0bacad9f..56005dd2ee 100644 --- a/internal/engine/sqlite/parse.go +++ b/internal/engine/sqlite/parse.go @@ -8,7 +8,7 @@ import ( "github.com/antlr/antlr4/runtime/Go/antlr/v4" "github.com/sqlc-dev/sqlc/internal/engine/sqlite/parser" - "github.com/sqlc-dev/sqlc/internal/metadata" + "github.com/sqlc-dev/sqlc/internal/source" "github.com/sqlc-dev/sqlc/internal/sql/ast" ) @@ -86,8 +86,8 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { return stmts, nil } -func (p *Parser) CommentSyntax() metadata.CommentSyntax { - return metadata.CommentSyntax{ +func (p *Parser) CommentSyntax() source.CommentSyntax { + return source.CommentSyntax{ Dash: true, Hash: false, SlashStar: true, diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 4176da1e2b..97ff36dbd2 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -1,15 +1,24 @@ package metadata import ( + "bufio" "fmt" "strings" "unicode" + + "github.com/sqlc-dev/sqlc/internal/source" ) -type CommentSyntax struct { - Dash bool - Hash bool - SlashStar bool +type CommentSyntax source.CommentSyntax + +type Metadata struct { + Name string + Cmd string + Comments []string + Params map[string]string + Flags map[string]bool + + Filename string } const ( @@ -83,7 +92,7 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string if prefix == "/*" { part = part[:len(part)-1] // removes the trailing "*/" element } - if len(part) == 2 { + if len(part) == 3 { return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: %s", line) } if len(part) != 4 { @@ -104,19 +113,39 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string return "", "", nil } -func ParseQueryFlags(comments []string) (map[string]bool, error) { +func ParseParamsAndFlags(comments []string) (map[string]string, map[string]bool, error) { + params := make(map[string]string) flags := make(map[string]bool) + for _, line := range comments { - cleanLine := strings.TrimPrefix(line, "--") - cleanLine = strings.TrimPrefix(cleanLine, "/*") - cleanLine = strings.TrimPrefix(cleanLine, "#") - cleanLine = strings.TrimSuffix(cleanLine, "*/") - cleanLine = strings.TrimSpace(cleanLine) - if strings.HasPrefix(cleanLine, "@") { - flagName := strings.SplitN(cleanLine, " ", 2)[0] - flags[flagName] = true + s := bufio.NewScanner(strings.NewReader(line)) + s.Split(bufio.ScanWords) + + s.Scan() + token := s.Text() + + if !strings.HasPrefix(token, "@") { continue } + + switch token { + case "@param": + s.Scan() + name := s.Text() + var rest []string + for s.Scan() { + paramToken := s.Text() + rest = append(rest, paramToken) + } + params[name] = strings.Join(rest, " ") + default: + flags[token] = true + } + + if s.Err() != nil { + return params, flags, s.Err() + } } - return flags, nil + + return params, flags, nil } diff --git a/internal/metadata/meta_test.go b/internal/metadata/meta_test.go index cbfcb6fba6..3c2be6d6de 100644 --- a/internal/metadata/meta_test.go +++ b/internal/metadata/meta_test.go @@ -32,34 +32,108 @@ func TestParseQueryNameAndType(t *testing.T) { } } - query := `-- name: CreateFoo :one` - queryName, queryType, err := ParseQueryNameAndType(query, CommentSyntax{Dash: true}) - if err != nil { - t.Errorf("expected valid metadata: %q", query) - } - if queryName != "CreateFoo" { - t.Errorf("incorrect queryName parsed: %q", query) - } - if queryType != CmdOne { - t.Errorf("incorrect queryType parsed: %q", query) + for query, cs := range map[string]CommentSyntax{ + `-- name: CreateFoo :one`: {Dash: true}, + `# name: CreateFoo :one`: {Hash: true}, + `/* name: CreateFoo :one */`: {SlashStar: true}, + } { + queryName, queryCmd, err := ParseQueryNameAndType(query, cs) + if err != nil { + t.Errorf("expected valid metadata: %q", query) + } + if queryName != "CreateFoo" { + t.Errorf("incorrect queryName parsed: (%q) %q", queryName, query) + } + if queryCmd != CmdOne { + t.Errorf("incorrect queryCmd parsed: (%q) %q", queryCmd, query) + } } } +func TestParseQueryParams(t *testing.T) { + for _, comments := range [][]string{ + { + " name: CreateFoo :one", + " @param foo_id UUID", + }, + { + " name: CreateFoo :one ", + " @param foo_id UUID ", + }, + { + " name: CreateFoo :one", + "@param foo_id UUID", + " invalid", + }, + { + " name: CreateFoo :one", + " @invalid", + " @param foo_id UUID", + }, + { + " name: GetFoos :many ", + " @param foo_id UUID ", + " @param @invalid UUID ", + }, + } { + params, _, err := ParseParamsAndFlags(comments) + if err != nil { + t.Errorf("expected comments to parse, got err: %s", err) + } + + pt, ok := params["foo_id"] + if !ok { + t.Errorf("expected param not found") + } + + if pt != "UUID" { + t.Error("unexpected param metadata:", pt) + } + + _, ok = params["invalid"] + if ok { + t.Errorf("unexpected param found") + } + } +} + func TestParseQueryFlags(t *testing.T) { for _, comments := range [][]string{ { - "-- name: CreateFoo :one", - "-- @flag-foo", + " name: CreateFoo :one", + " @flag-foo", + }, + { + " name: CreateFoo :one ", + "@flag-foo ", + }, + { + " name: CreateFoo :one", + " @flag-foo @flag-bar", + }, + { + " name: GetFoos :many", + " @param @flag-bar UUID", + " @flag-foo", + }, + { + " name: GetFoos :many", + " @flag-foo", + " @param @flag-bar UUID", }, } { - flags, err := ParseQueryFlags(comments) + _, flags, err := ParseParamsAndFlags(comments) if err != nil { - t.Errorf("expected query flags to parse, got error: %s", err) + t.Errorf("expected comments to parse, got err: %s", err) } if !flags["@flag-foo"] { t.Errorf("expected flag not found") } + + if flags["@flag-bar"] { + t.Errorf("unexpected flag found") + } } -} \ No newline at end of file +} diff --git a/internal/source/code.go b/internal/source/code.go index f34e3e3684..8b88a24136 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -15,6 +15,12 @@ type Edit struct { OldFunc func(string) int } +type CommentSyntax struct { + Dash bool + Hash bool + SlashStar bool +} + func LineNumber(source string, head int) (int, int) { // Calculate the true line and column number for a query, ignoring spaces var comment bool @@ -101,6 +107,9 @@ func StripComments(sql string) (string, []string, error) { if strings.HasPrefix(t, "/* name:") && strings.HasSuffix(t, "*/") { continue } + if strings.HasPrefix(t, "# name:") { + continue + } if strings.HasPrefix(t, "--") { comments = append(comments, strings.TrimPrefix(t, "--")) continue @@ -111,7 +120,46 @@ func StripComments(sql string) (string, []string, error) { comments = append(comments, t) continue } + if strings.HasPrefix(t, "#") { + comments = append(comments, strings.TrimPrefix(t, "#")) + continue + } lines = append(lines, t) } return strings.Join(lines, "\n"), comments, s.Err() } + +func CleanedComments(rawSQL string, cs CommentSyntax) ([]string, error) { + s := bufio.NewScanner(strings.NewReader(strings.TrimSpace(rawSQL))) + var comments []string + for s.Scan() { + line := s.Text() + var prefix string + if strings.HasPrefix(line, "--") { + if !cs.Dash { + continue + } + prefix = "--" + } + if strings.HasPrefix(line, "/*") { + if !cs.SlashStar { + continue + } + prefix = "/*" + } + if strings.HasPrefix(line, "#") { + if !cs.Hash { + continue + } + prefix = "#" + } + if prefix == "" { + continue + } + + rest := line[len(prefix):] + rest = strings.TrimSuffix(rest, "*/") + comments = append(comments, rest) + } + return comments, s.Err() +}