package pgx
import (
)
func ( [][]any) CopyFromSource {
return ©FromRows{rows: , idx: -1}
}
type copyFromRows struct {
rows [][]any
idx int
}
func ( *copyFromRows) () bool {
.idx++
return .idx < len(.rows)
}
func ( *copyFromRows) () ([]any, error) {
return .rows[.idx], nil
}
func ( *copyFromRows) () error {
return nil
}
func ( int, func(int) ([]any, error)) CopyFromSource {
return ©FromSlice{next: , idx: -1, len: }
}
type copyFromSlice struct {
next func(int) ([]any, error)
idx int
len int
err error
}
func ( *copyFromSlice) () bool {
.idx++
return .idx < .len
}
func ( *copyFromSlice) () ([]any, error) {
, := .next(.idx)
if != nil {
.err =
}
return ,
}
func ( *copyFromSlice) () error {
return .err
}
func ( func() ( []any, error)) CopyFromSource {
return ©FromFunc{next: }
}
type copyFromFunc struct {
next func() ([]any, error)
valueRow []any
err error
}
func ( *copyFromFunc) () bool {
.valueRow, .err = .next()
return .valueRow != nil && .err == nil
}
func ( *copyFromFunc) () ([]any, error) {
return .valueRow, .err
}
func ( *copyFromFunc) () error {
return .err
}
type CopyFromSource interface {
Next() bool
Values() ([]any, error)
Err() error
}
type copyFrom struct {
conn *Conn
tableName Identifier
columnNames []string
rowSrc CopyFromSource
readerErrChan chan error
mode QueryExecMode
}
func ( *copyFrom) ( context.Context) (int64, error) {
if .conn.copyFromTracer != nil {
= .conn.copyFromTracer.TraceCopyFromStart(, .conn, TraceCopyFromStartData{
TableName: .tableName,
ColumnNames: .columnNames,
})
}
:= .tableName.Sanitize()
:= &bytes.Buffer{}
for , := range .columnNames {
if != 0 {
.WriteString(", ")
}
.WriteString(quoteIdentifier())
}
:= .String()
var *pgconn.StatementDescription
switch .mode {
case QueryExecModeExec, QueryExecModeSimpleProtocol:
.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var error
, = .conn.getStatementDescription(
,
.mode,
fmt.Sprintf("select %s from %s", , ),
)
if != nil {
return 0, fmt.Errorf("statement description failed: %w", )
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", .mode)
}
, := io.Pipe()
:= make(chan struct{})
go func() {
defer close()
:= .conn.wbuf
= append(, "PGCOPY\n\377\r\n\000"...)
= pgio.AppendInt32(, 0)
= pgio.AppendInt32(, 0)
:= true
for {
var error
, , = .buildCopyBuf(, )
if != nil {
.CloseWithError()
return
}
if .rowSrc.Err() != nil {
.CloseWithError(.rowSrc.Err())
return
}
if len() > 0 {
_, = .Write()
if != nil {
.Close()
return
}
}
= [:0]
}
.Close()
}()
, := .conn.pgConn.CopyFrom(, , fmt.Sprintf("copy %s ( %s ) from stdin binary;", , ))
.Close()
<-
if .conn.copyFromTracer != nil {
.conn.copyFromTracer.TraceCopyFromEnd(, .conn, TraceCopyFromEndData{
CommandTag: ,
Err: ,
})
}
return .RowsAffected(),
}
func ( *copyFrom) ( []byte, *pgconn.StatementDescription) (bool, []byte, error) {
const = 65536 - 5
:= 0
:= 0
for .rowSrc.Next() {
= len()
, := .rowSrc.Values()
if != nil {
return false, nil,
}
if len() != len(.columnNames) {
return false, nil, fmt.Errorf("expected %d values, got %d values", len(.columnNames), len())
}
= pgio.AppendInt16(, int16(len(.columnNames)))
for , := range {
, = encodeCopyValue(.conn.typeMap, , .Fields[].DataTypeOID, )
if != nil {
return false, nil,
}
}
:= len() -
if > {
=
}
if len() > - {
return true, , nil
}
}
return false, , nil
}
func ( *Conn) ( context.Context, Identifier, []string, CopyFromSource) (int64, error) {
:= ©From{
conn: ,
tableName: ,
columnNames: ,
rowSrc: ,
readerErrChan: make(chan error),
mode: .config.DefaultQueryExecMode,
}
return .run()
}