package pgx
import (
)
type QueuedQuery struct {
SQL string
Arguments []any
Fn batchItemFunc
sd *pgconn.StatementDescription
}
type batchItemFunc func(br BatchResults) error
func ( *QueuedQuery) ( func( Rows) error) {
.Fn = func( BatchResults) error {
, := .Query()
defer .Close()
:= ()
if != nil {
return
}
.Close()
return .Err()
}
}
func ( *QueuedQuery) ( func( Row) error) {
.Fn = func( BatchResults) error {
:= .QueryRow()
return ()
}
}
func ( *QueuedQuery) ( func( pgconn.CommandTag) error) {
.Fn = func( BatchResults) error {
, := .Exec()
if != nil {
return
}
return ()
}
}
type Batch struct {
QueuedQueries []*QueuedQuery
}
func ( *Batch) ( string, ...any) *QueuedQuery {
:= &QueuedQuery{
SQL: ,
Arguments: ,
}
.QueuedQueries = append(.QueuedQueries, )
return
}
func ( *Batch) () int {
return len(.QueuedQueries)
}
type BatchResults interface {
Exec() (pgconn.CommandTag, error)
Query() (Rows, error)
QueryRow() Row
Close() error
}
type batchResults struct {
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
qqIdx int
closed bool
endTraced bool
}
func ( *batchResults) () (pgconn.CommandTag, error) {
if .err != nil {
return pgconn.CommandTag{}, .err
}
if .closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
, , := .nextQueryAndArgs()
if !.mrr.NextResult() {
:= .mrr.Close()
if == nil {
= errors.New("no more results in batch")
}
if .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchQuery(.ctx, .conn, TraceBatchQueryData{
SQL: ,
Args: ,
Err: ,
})
}
return pgconn.CommandTag{},
}
, := .mrr.ResultReader().Close()
if != nil {
.err =
.mrr.Close()
}
if .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchQuery(.ctx, .conn, TraceBatchQueryData{
SQL: ,
Args: ,
CommandTag: ,
Err: .err,
})
}
return , .err
}
func ( *batchResults) () (Rows, error) {
, , := .nextQueryAndArgs()
if ! {
= "batch query"
}
if .err != nil {
return &baseRows{err: .err, closed: true}, .err
}
if .closed {
:= fmt.Errorf("batch already closed")
return &baseRows{err: , closed: true},
}
:= .conn.getRows(.ctx, , )
.batchTracer = .conn.batchTracer
if !.mrr.NextResult() {
.err = .mrr.Close()
if .err == nil {
.err = errors.New("no more results in batch")
}
.closed = true
if .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchQuery(.ctx, .conn, TraceBatchQueryData{
SQL: ,
Args: ,
Err: .err,
})
}
return , .err
}
.resultReader = .mrr.ResultReader()
return , nil
}
func ( *batchResults) () Row {
, := .Query()
return (*connRow)(.(*baseRows))
}
func ( *batchResults) () error {
defer func() {
if !.endTraced {
if .conn != nil && .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchEnd(.ctx, .conn, TraceBatchEndData{Err: .err})
}
.endTraced = true
}
invalidateCachesOnBatchResultsError(.conn, .b, .err)
}()
if .err != nil {
return .err
}
if .closed {
return nil
}
for .err == nil && !.closed && .b != nil && .qqIdx < len(.b.QueuedQueries) {
if .b.QueuedQueries[.qqIdx].Fn != nil {
:= .b.QueuedQueries[.qqIdx].Fn()
if != nil {
.err =
}
} else {
.Exec()
}
}
.closed = true
:= .mrr.Close()
if .err == nil {
.err =
}
return .err
}
func ( *batchResults) () error {
return .err
}
func ( *batchResults) () ( string, []any, bool) {
if .b != nil && .qqIdx < len(.b.QueuedQueries) {
:= .b.QueuedQueries[.qqIdx]
= .SQL
= .Arguments
= true
.qqIdx++
}
return , ,
}
type pipelineBatchResults struct {
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
qqIdx int
closed bool
endTraced bool
}
func ( *pipelineBatchResults) () (pgconn.CommandTag, error) {
if .err != nil {
return pgconn.CommandTag{}, .err
}
if .closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
if .lastRows != nil && .lastRows.err != nil {
.err = .lastRows.err
return pgconn.CommandTag{}, .err
}
, , := .nextQueryAndArgs()
if != nil {
return pgconn.CommandTag{},
}
, := .pipeline.GetResults()
if != nil {
.err =
return pgconn.CommandTag{}, .err
}
var pgconn.CommandTag
switch results := .(type) {
case *pgconn.ResultReader:
, .err = .Close()
default:
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", )
}
if .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchQuery(.ctx, .conn, TraceBatchQueryData{
SQL: ,
Args: ,
CommandTag: ,
Err: .err,
})
}
return , .err
}
func ( *pipelineBatchResults) () (Rows, error) {
if .err != nil {
return &baseRows{err: .err, closed: true}, .err
}
if .closed {
:= fmt.Errorf("batch already closed")
return &baseRows{err: , closed: true},
}
if .lastRows != nil && .lastRows.err != nil {
.err = .lastRows.err
return &baseRows{err: .err, closed: true}, .err
}
, , := .nextQueryAndArgs()
if != nil {
return &baseRows{err: , closed: true},
}
:= .conn.getRows(.ctx, , )
.batchTracer = .conn.batchTracer
.lastRows =
, := .pipeline.GetResults()
if != nil {
.err =
.err =
.closed = true
if .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchQuery(.ctx, .conn, TraceBatchQueryData{
SQL: ,
Args: ,
Err: ,
})
}
} else {
switch results := .(type) {
case *pgconn.ResultReader:
.resultReader =
default:
= fmt.Errorf("unexpected pipeline result: %T", )
.err =
.err =
.closed = true
}
}
return , .err
}
func ( *pipelineBatchResults) () Row {
, := .Query()
return (*connRow)(.(*baseRows))
}
func ( *pipelineBatchResults) () error {
defer func() {
if !.endTraced {
if .conn.batchTracer != nil {
.conn.batchTracer.TraceBatchEnd(.ctx, .conn, TraceBatchEndData{Err: .err})
}
.endTraced = true
}
invalidateCachesOnBatchResultsError(.conn, .b, .err)
}()
if .err == nil && .lastRows != nil && .lastRows.err != nil {
.err = .lastRows.err
}
if .closed {
return .err
}
for .err == nil && !.closed && .b != nil && .qqIdx < len(.b.QueuedQueries) {
if .b.QueuedQueries[.qqIdx].Fn != nil {
:= .b.QueuedQueries[.qqIdx].Fn()
if != nil {
.err =
}
} else {
.Exec()
}
}
.closed = true
:= .pipeline.Close()
if .err == nil {
.err =
}
return .err
}
func ( *pipelineBatchResults) () error {
return .err
}
func ( *pipelineBatchResults) () ( string, []any, error) {
if .b == nil {
return "", nil, errors.New("no reference to batch")
}
if .qqIdx >= len(.b.QueuedQueries) {
return "", nil, errors.New("no more results in batch")
}
:= .b.QueuedQueries[.qqIdx]
.qqIdx++
return .SQL, .Arguments, nil
}
type emptyBatchResults struct {
conn *Conn
closed bool
}
func ( *emptyBatchResults) () (pgconn.CommandTag, error) {
if .closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
return pgconn.CommandTag{}, errors.New("no more results in batch")
}
func ( *emptyBatchResults) () (Rows, error) {
if .closed {
:= fmt.Errorf("batch already closed")
return &baseRows{err: , closed: true},
}
:= .conn.getRows(context.Background(), "", nil)
.err = errors.New("no more results in batch")
.closed = true
return , .err
}
func ( *emptyBatchResults) () Row {
, := .Query()
return (*connRow)(.(*baseRows))
}
func ( *emptyBatchResults) () error {
.closed = true
return nil
}
func ( *Conn, *Batch, error) {
if != nil && != nil && != nil {
if := .statementCache; != nil {
for , := range .QueuedQueries {
.Invalidate(.SQL)
}
}
if := .descriptionCache; != nil {
for , := range .QueuedQueries {
.Invalidate(.SQL)
}
}
}
}
type ErrPreprocessingBatch struct {
step string
sql string
err error
}
func (, string, error) ErrPreprocessingBatch {
return ErrPreprocessingBatch{step: , sql: , err: }
}
func ( ErrPreprocessingBatch) () string {
return fmt.Sprintf("error preprocessing batch (%s): %v", .step, .err)
}
func ( ErrPreprocessingBatch) () error {
return .err
}
func ( ErrPreprocessingBatch) () string {
return .sql
}