package pgx
import (
)
type TxIsoLevel string
const (
Serializable TxIsoLevel = "serializable"
RepeatableRead TxIsoLevel = "repeatable read"
ReadCommitted TxIsoLevel = "read committed"
ReadUncommitted TxIsoLevel = "read uncommitted"
)
type TxAccessMode string
const (
ReadWrite TxAccessMode = "read write"
ReadOnly TxAccessMode = "read only"
)
type TxDeferrableMode string
const (
Deferrable TxDeferrableMode = "deferrable"
NotDeferrable TxDeferrableMode = "not deferrable"
)
type TxOptions struct {
IsoLevel TxIsoLevel
AccessMode TxAccessMode
DeferrableMode TxDeferrableMode
BeginQuery string
CommitQuery string
}
var emptyTxOptions TxOptions
func ( TxOptions) () string {
if == emptyTxOptions {
return "begin"
}
if .BeginQuery != "" {
return .BeginQuery
}
var strings.Builder
.Grow(64)
.WriteString("begin")
if .IsoLevel != "" {
.WriteString(" isolation level ")
.WriteString(string(.IsoLevel))
}
if .AccessMode != "" {
.WriteByte(' ')
.WriteString(string(.AccessMode))
}
if .DeferrableMode != "" {
.WriteByte(' ')
.WriteString(string(.DeferrableMode))
}
return .String()
}
var ErrTxClosed = errors.New("tx is closed")
var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
func ( *Conn) ( context.Context) (Tx, error) {
return .BeginTx(, TxOptions{})
}
func ( *Conn) ( context.Context, TxOptions) (Tx, error) {
, := .Exec(, .beginSQL())
if != nil {
.die()
return nil,
}
return &dbTx{
conn: ,
commitQuery: .CommitQuery,
}, nil
}
type Tx interface {
Begin(ctx context.Context) (Tx, error)
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error)
SendBatch(ctx context.Context, b *Batch) BatchResults
LargeObjects() LargeObjects
Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error)
Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error)
Query(ctx context.Context, sql string, args ...any) (Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) Row
Conn() *Conn
}
type dbTx struct {
conn *Conn
savepointNum int64
closed bool
commitQuery string
}
func ( *dbTx) ( context.Context) (Tx, error) {
if .closed {
return nil, ErrTxClosed
}
.savepointNum++
, := .conn.Exec(, "savepoint sp_"+strconv.FormatInt(.savepointNum, 10))
if != nil {
return nil,
}
return &dbSimulatedNestedTx{tx: , savepointNum: .savepointNum}, nil
}
func ( *dbTx) ( context.Context) error {
if .closed {
return ErrTxClosed
}
:= "commit"
if .commitQuery != "" {
= .commitQuery
}
, := .conn.Exec(, )
.closed = true
if != nil {
if .conn.PgConn().TxStatus() != 'I' {
_ = .conn.Close()
}
return
}
if .String() == "ROLLBACK" {
return ErrTxCommitRollback
}
return nil
}
func ( *dbTx) ( context.Context) error {
if .closed {
return ErrTxClosed
}
, := .conn.Exec(, "rollback")
.closed = true
if != nil {
.conn.die()
return
}
return nil
}
func ( *dbTx) ( context.Context, string, ...any) ( pgconn.CommandTag, error) {
if .closed {
return pgconn.CommandTag{}, ErrTxClosed
}
return .conn.Exec(, , ...)
}
func ( *dbTx) ( context.Context, , string) (*pgconn.StatementDescription, error) {
if .closed {
return nil, ErrTxClosed
}
return .conn.Prepare(, , )
}
func ( *dbTx) ( context.Context, string, ...any) (Rows, error) {
if .closed {
:= ErrTxClosed
return &baseRows{closed: true, err: },
}
return .conn.Query(, , ...)
}
func ( *dbTx) ( context.Context, string, ...any) Row {
, := .Query(, , ...)
return (*connRow)(.(*baseRows))
}
func ( *dbTx) ( context.Context, Identifier, []string, CopyFromSource) (int64, error) {
if .closed {
return 0, ErrTxClosed
}
return .conn.CopyFrom(, , , )
}
func ( *dbTx) ( context.Context, *Batch) BatchResults {
if .closed {
return &batchResults{err: ErrTxClosed}
}
return .conn.SendBatch(, )
}
func ( *dbTx) () LargeObjects {
return LargeObjects{tx: }
}
func ( *dbTx) () *Conn {
return .conn
}
type dbSimulatedNestedTx struct {
tx Tx
savepointNum int64
closed bool
}
func ( *dbSimulatedNestedTx) ( context.Context) (Tx, error) {
if .closed {
return nil, ErrTxClosed
}
return .tx.Begin()
}
func ( *dbSimulatedNestedTx) ( context.Context) error {
if .closed {
return ErrTxClosed
}
, := .Exec(, "release savepoint sp_"+strconv.FormatInt(.savepointNum, 10))
.closed = true
return
}
func ( *dbSimulatedNestedTx) ( context.Context) error {
if .closed {
return ErrTxClosed
}
, := .Exec(, "rollback to savepoint sp_"+strconv.FormatInt(.savepointNum, 10))
.closed = true
return
}
func ( *dbSimulatedNestedTx) ( context.Context, string, ...any) ( pgconn.CommandTag, error) {
if .closed {
return pgconn.CommandTag{}, ErrTxClosed
}
return .tx.Exec(, , ...)
}
func ( *dbSimulatedNestedTx) ( context.Context, , string) (*pgconn.StatementDescription, error) {
if .closed {
return nil, ErrTxClosed
}
return .tx.Prepare(, , )
}
func ( *dbSimulatedNestedTx) ( context.Context, string, ...any) (Rows, error) {
if .closed {
:= ErrTxClosed
return &baseRows{closed: true, err: },
}
return .tx.Query(, , ...)
}
func ( *dbSimulatedNestedTx) ( context.Context, string, ...any) Row {
, := .Query(, , ...)
return (*connRow)(.(*baseRows))
}
func ( *dbSimulatedNestedTx) ( context.Context, Identifier, []string, CopyFromSource) (int64, error) {
if .closed {
return 0, ErrTxClosed
}
return .tx.CopyFrom(, , , )
}
func ( *dbSimulatedNestedTx) ( context.Context, *Batch) BatchResults {
if .closed {
return &batchResults{err: ErrTxClosed}
}
return .tx.SendBatch(, )
}
func ( *dbSimulatedNestedTx) () LargeObjects {
return LargeObjects{tx: }
}
func ( *dbSimulatedNestedTx) () *Conn {
return .tx.Conn()
}
func (
context.Context,
interface {
( context.Context) (Tx, error)
},
func(Tx) error,
) ( error) {
var Tx
, = .()
if != nil {
return
}
return beginFuncExec(, , )
}
func (
context.Context,
interface {
( context.Context, TxOptions) (Tx, error)
},
TxOptions,
func(Tx) error,
) ( error) {
var Tx
, = .(, )
if != nil {
return
}
return beginFuncExec(, , )
}
func ( context.Context, Tx, func(Tx) error) ( error) {
defer func() {
:= .Rollback()
if != nil && !errors.Is(, ErrTxClosed) {
=
}
}()
:= ()
if != nil {
_ = .Rollback()
return
}
return .Commit()
}