package pg
import (
)
var ErrTxDone = errors.New("pg: transaction has already been committed or rolled back")
type Tx struct {
db *baseDB
ctx context.Context
stmtsMu sync.Mutex
stmts []*Stmt
_closed int32
}
var _ orm.DB = (*Tx)(nil)
func ( *Tx) () context.Context {
return .ctx
}
func ( *baseDB) () (*Tx, error) {
return .BeginContext(.db.Context())
}
func ( *baseDB) ( context.Context) (*Tx, error) {
:= &Tx{
db: .withPool(pool.NewStickyConnPool(.pool)),
ctx: ,
}
:= .begin()
if != nil {
.close()
return nil,
}
return , nil
}
func ( *baseDB) ( context.Context, func(*Tx) error) error {
, := .BeginContext()
if != nil {
return
}
return .RunInTransaction(, )
}
func ( *Tx) () (*Tx, error) {
return , nil
}
func ( *Tx) ( context.Context, func(*Tx) error) error {
defer func() {
if := recover(); != nil {
if := .RollbackContext(); != nil {
internal.Logger.Printf(, "tx.Rollback panicked: %s", )
}
panic()
}
}()
if := (); != nil {
if := .RollbackContext(); != nil {
internal.Logger.Printf(, "tx.Rollback failed: %s", )
}
return
}
return .CommitContext()
}
func ( *Tx) ( context.Context, func(context.Context, *pool.Conn) error) error {
:= .db.withConn(, )
if .closed() && == pool.ErrClosed {
return ErrTxDone
}
return
}
func ( *Tx) ( *Stmt) *Stmt {
, := .Prepare(.q)
if != nil {
return &Stmt{stickyErr: }
}
return
}
func ( *Tx) ( string) (*Stmt, error) {
.stmtsMu.Lock()
defer .stmtsMu.Unlock()
:= .db.withPool(pool.NewStickyConnPool(.db.pool))
, := prepareStmt(, )
if != nil {
return nil,
}
.stmts = append(.stmts, )
return , nil
}
func ( *Tx) ( interface{}, ...interface{}) (Result, error) {
return .exec(.ctx, , ...)
}
func ( *Tx) ( context.Context, interface{}, ...interface{}) (Result, error) {
return .exec(, , ...)
}
func ( *Tx) ( context.Context, interface{}, ...interface{}) (Result, error) {
:= pool.GetWriteBuffer()
defer pool.PutWriteBuffer()
if := writeQueryMsg(, .db.fmter, , ...); != nil {
return nil,
}
, , := .db.beforeQuery(, , nil, , , .Query())
if != nil {
return nil,
}
var Result
:= .withConn(, func( context.Context, *pool.Conn) error {
, = .db.simpleQuery(, , )
return
})
if := .db.afterQuery(, , , ); != nil {
return nil,
}
return ,
}
func ( *Tx) ( interface{}, ...interface{}) (Result, error) {
return .execOne(.ctx, , ...)
}
func ( *Tx) ( context.Context, interface{}, ...interface{}) (Result, error) {
return .execOne(, , ...)
}
func ( *Tx) ( context.Context, interface{}, ...interface{}) (Result, error) {
, := .ExecContext(, , ...)
if != nil {
return nil,
}
if := internal.AssertOneRow(.RowsAffected()); != nil {
return nil,
}
return , nil
}
func ( *Tx) ( interface{}, interface{}, ...interface{}) (Result, error) {
return .query(.ctx, , , ...)
}
func ( *Tx) (
context.Context,
interface{},
interface{},
...interface{},
) (Result, error) {
return .query(, , , ...)
}
func ( *Tx) (
context.Context,
interface{},
interface{},
...interface{},
) (Result, error) {
:= pool.GetWriteBuffer()
defer pool.PutWriteBuffer()
if := writeQueryMsg(, .db.fmter, , ...); != nil {
return nil,
}
, , := .db.beforeQuery(, , , , , .Query())
if != nil {
return nil,
}
var *result
:= .withConn(, func( context.Context, *pool.Conn) error {
, = .db.simpleQueryData(, , , )
return
})
if := .db.afterQuery(, , , ); != nil {
return nil,
}
return ,
}
func ( *Tx) ( interface{}, interface{}, ...interface{}) (Result, error) {
return .queryOne(.ctx, , , ...)
}
func ( *Tx) (
context.Context,
interface{},
interface{},
...interface{},
) (Result, error) {
return .queryOne(, , , ...)
}
func ( *Tx) (
context.Context,
interface{},
interface{},
...interface{},
) (Result, error) {
, := orm.NewModel()
if != nil {
return nil,
}
, := .QueryContext(, , , ...)
if != nil {
return nil,
}
if := internal.AssertOneRow(.RowsAffected()); != nil {
return nil,
}
return , nil
}
func ( *Tx) ( ...interface{}) *Query {
return orm.NewQuery(, ...)
}
func ( *Tx) ( context.Context, ...interface{}) *Query {
return orm.NewQueryContext(, , ...)
}
func ( *Tx) ( io.Reader, interface{}, ...interface{}) ( Result, error) {
= .withConn(.ctx, func( context.Context, *pool.Conn) error {
, = .db.copyFrom(, , , , ...)
return
})
return ,
}
func ( *Tx) ( io.Writer, interface{}, ...interface{}) ( Result, error) {
= .withConn(.ctx, func( context.Context, *pool.Conn) error {
, = .db.copyTo(, , , , ...)
return
})
return ,
}
func ( *Tx) () orm.QueryFormatter {
return .db.Formatter()
}
func ( *Tx) ( context.Context) error {
var error
for := 0; <= .db.opt.MaxRetries; ++ {
if > 0 {
if := internal.Sleep(, .db.retryBackoff(-1)); != nil {
return
}
:= .db.pool.(*pool.StickyConnPool).Reset()
if != nil {
return
}
}
_, = .ExecContext(, "BEGIN")
if !.db.shouldRetry() {
break
}
}
return
}
func ( *Tx) () error {
return .CommitContext(.ctx)
}
func ( *Tx) ( context.Context) error {
, := .ExecContext(internal.UndoContext(), "COMMIT")
.close()
return
}
func ( *Tx) () error {
return .RollbackContext(.ctx)
}
func ( *Tx) ( context.Context) error {
, := .ExecContext(internal.UndoContext(), "ROLLBACK")
.close()
return
}
func ( *Tx) () error {
return .CloseContext(.ctx)
}
func ( *Tx) ( context.Context) error {
if .closed() {
return nil
}
return .RollbackContext()
}
func ( *Tx) () {
if !atomic.CompareAndSwapInt32(&._closed, 0, 1) {
return
}
.stmtsMu.Lock()
defer .stmtsMu.Unlock()
for , := range .stmts {
_ = .Close()
}
.stmts = nil
_ = .db.Close()
}
func ( *Tx) () bool {
return atomic.LoadInt32(&._closed) == 1
}