diff --git a/docs/reference/query-annotations.md b/docs/reference/query-annotations.md index 4fabe05aae..492335705d 100644 --- a/docs/reference/query-annotations.md +++ b/docs/reference/query-annotations.md @@ -113,6 +113,25 @@ func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { } ``` +## `:first` + +The generated method will return a pointer to a single record via +[QueryRowContext](https://golang.org/pkg/database/sql/#DB.QueryRowContext). +If there are no results, it returns `nil, nil`. + +```sql +-- name: GetAuthor :first +SELECT * FROM authors +WHERE id = $1 LIMIT 1; +``` + +```go +func (q *Queries) GetAuthor(ctx context.Context, id int64) (*Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + // ... +} +``` + ## `:batchexec` __NOTE: This command only works with PostgreSQL using the `pgx/v4` and `pgx/v5` drivers and outputting Go code.__ diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7df56a0a41..a7324c4283 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -74,6 +74,12 @@ func (t *tmplCtx) codegenQueryMethod(q Query) string { } return db + ".QueryRowContext" + case ":first": + if t.EmitPreparedQueries { + return "q.queryRow" + } + return db + ".QueryRowContext" + case ":many": if t.EmitPreparedQueries { return "q.query" @@ -91,6 +97,8 @@ func (t *tmplCtx) codegenQueryMethod(q Query) string { func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) { switch q.Cmd { case ":one": + fallthrough + case ":first": return "row :=", nil case ":many": return "rows, err :=", nil diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index ccca4f603c..17741d7c7a 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -342,6 +342,25 @@ func (i *importer) queryImports(filename string) fileImports { } return false }) + usesFirst := false + for _, q := range gq { + if q.Cmd == metadata.CmdFirst { + usesFirst = true + break + } + } + if usesFirst { + std["errors"] = struct{}{} + sqlpkg := parseDriver(i.Options.SqlPackage) + switch sqlpkg { + case opts.SQLDriverPGXV4: + pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} + case opts.SQLDriverPGXV5: + pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{} + default: + std["database/sql"] = struct{}{} + } + } sliceScan := func() bool { for _, q := range gq { diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 3b4fb2fa1a..52da1c59d5 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -12,6 +12,7 @@ import ( type QueryValue struct { Emit bool EmitPointer bool + ForcePointer bool Name string DBName string // The name of the field in the database. Only set if Struct==nil. Struct *Struct @@ -32,7 +33,7 @@ func (v QueryValue) IsStruct() bool { } func (v QueryValue) IsPointer() bool { - return v.EmitPointer && v.Struct != nil + return v.ForcePointer || (v.EmitPointer && v.Struct != nil) } func (v QueryValue) isEmpty() bool { @@ -270,7 +271,7 @@ type Query struct { } func (q Query) hasRetType() bool { - scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || + scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdFirst || q.Cmd == metadata.CmdMany || q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne return scanned && !q.Ret.isEmpty() } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 0820488f9d..daa9492ff5 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -322,6 +322,9 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] EmitPointer: options.EmitResultStructPointers, } } + if query.Cmd == metadata.CmdFirst { + gq.Ret.ForcePointer = true + } qs = append(qs, gq) } @@ -334,6 +337,7 @@ var cmdReturnsData = map[string]struct{}{ metadata.CmdBatchOne: {}, metadata.CmdMany: {}, metadata.CmdOne: {}, + metadata.CmdFirst: {}, } func putOutColumns(query *plugin.Query) bool { diff --git a/internal/codegen/golang/result_test.go b/internal/codegen/golang/result_test.go index 0c58525ec3..49ad232bce 100644 --- a/internal/codegen/golang/result_test.go +++ b/internal/codegen/golang/result_test.go @@ -36,6 +36,10 @@ func TestPutOutColumns_ForZeroColumns(t *testing.T) { cmd: metadata.CmdOne, want: true, }, + { + cmd: metadata.CmdFirst, + want: true, + }, { cmd: metadata.CmdCopyFrom, want: false, diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl index cf7cd36cb9..dcdcb25499 100644 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl @@ -11,6 +11,15 @@ {{end -}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {{- end}} + {{- if and (eq .Cmd ":first") ($dbtxParam) }} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":first" }} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) + {{- end}} {{- if and (eq .Cmd ":many") ($dbtxParam) }} {{range .Comments}}//{{.}} {{end -}} diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 59a88c880a..9afdee68a8 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -44,6 +44,31 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{- end}} return {{.Ret.ReturnName}}, err } +{{else if eq .Cmd ":first"}} +{{range .Comments}}//{{.}} +{{end -}} +{{- if $.EmitMethodsWithDBArgument -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- else -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- end}} + {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} + var {{.Ret.Name}} {{.Ret.Type}} + {{- end}} + err := row.Scan({{.Ret.Scan}}) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + {{- if $.WrapErrors}} + err = fmt.Errorf("query {{.MethodName}}: %w", err) + {{- end}} + return nil, err + } + return {{.Ret.ReturnName}}, nil +} {{end}} {{if eq .Cmd ":many"}} diff --git a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl b/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl index 3cbefe6df4..771c5b0d61 100644 --- a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl @@ -11,6 +11,15 @@ {{end -}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {{- end}} + {{- if and (eq .Cmd ":first") ($dbtxParam) }} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":first"}} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) + {{- end}} {{- if and (eq .Cmd ":many") ($dbtxParam) }} {{range .Comments}}//{{.}} {{end -}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 1e7f4e22a4..6579028ab6 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -35,6 +35,26 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- end}} return {{.Ret.ReturnName}}, err } +{{else if eq .Cmd ":first"}} +{{range .Comments}}//{{.}} +{{end -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- template "queryCodeStdExec" . }} + {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} + var {{.Ret.Name}} {{.Ret.Type}} + {{- end}} + err := row.Scan({{.Ret.Scan}}) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + {{- if $.WrapErrors}} + err = fmt.Errorf("query {{.MethodName}}: %w", err) + {{- end}} + return nil, err + } + return {{.Ret.ReturnName}}, nil +} {{end}} {{if eq .Cmd ":many"}} diff --git a/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt b/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt index 8c745b7e3b..9839311013 100644 --- a/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt +++ b/internal/endtoend/testdata/invalid_queries_foo/pgx/v4/stderr.txt @@ -1,5 +1,5 @@ # package querytest -query.sql:1:1: missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos +query.sql:1:1: missing query type [':one', ':first', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos query.sql:5:1: invalid query comment: -- name: ListFoos :one :many query.sql:8:1: invalid query type: :two query.sql:11:1: query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause diff --git a/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt b/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt index 8c745b7e3b..9839311013 100644 --- a/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt +++ b/internal/endtoend/testdata/invalid_queries_foo/pgx/v5/stderr.txt @@ -1,5 +1,5 @@ # package querytest -query.sql:1:1: missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos +query.sql:1:1: missing query type [':one', ':first', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos query.sql:5:1: invalid query comment: -- name: ListFoos :one :many query.sql:8:1: invalid query type: :two query.sql:11:1: query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause diff --git a/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt b/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt index 06ec54327f..3d4dd39be4 100644 --- a/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt +++ b/internal/endtoend/testdata/invalid_queries_foo/stdlib/stderr.txt @@ -1,7 +1,7 @@ # package querytest -query.sql:1:1: missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos +query.sql:1:1: missing query type [':one', ':first', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: -- name: ListFoos query.sql:5:1: invalid query comment: -- name: ListFoos :one :many query.sql:8:1: invalid query type: :two query.sql:11:1: query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause query.sql:14:1: query "UpdateFoo" specifies parameter ":one" without containing a RETURNING clause -query.sql:17:1: query "InsertFoo" specifies parameter ":one" without containing a RETURNING clause \ No newline at end of file +query.sql:17:1: query "InsertFoo" specifies parameter ":one" without containing a RETURNING clause diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 8f63624d2c..d3ef5b8333 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -33,6 +33,7 @@ const ( CmdExecLastId = ":execlastid" CmdMany = ":many" CmdOne = ":one" + CmdFirst = ":first" CmdCopyFrom = ":copyfrom" CmdBatchExec = ":batchexec" CmdBatchMany = ":batchmany" @@ -98,7 +99,7 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string part = part[:len(part)-1] // removes the trailing "*/" element } if len(part) == 3 { - return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: %s", line) + return "", "", fmt.Errorf("missing query type [':one', ':first', ':many', ':exec', ':execrows', ':execlastid', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: %s", line) } if len(part) != 4 { return "", "", fmt.Errorf("invalid query comment: %s", line) @@ -106,7 +107,7 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string queryName := part[2] queryType := strings.TrimSpace(part[3]) switch queryType { - case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne: + case CmdOne, CmdFirst, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne: default: return "", "", fmt.Errorf("invalid query type: %s", queryType) } diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index 66e849de6c..48a24f5986 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -71,7 +71,7 @@ func Cmd(n ast.Node, name, cmd string) error { return err } } - if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne) { + if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdFirst || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne) { return nil } var list *ast.List