package migrations
import (
)
type Migration struct {
Version int64
UpTx bool
Up func(DB) error
DownTx bool
Down func(DB) error
}
func ( *Migration) () string {
return strconv.FormatInt(.Version, 10)
}
type Collection struct {
tableName string
sqlAutodiscoverDisabled bool
mu sync.Mutex
visitedDirs map[string]struct{}
migrations []*Migration
}
func ( ...*Migration) *Collection {
:= &Collection{
tableName: "gopg_migrations",
}
for , := range {
.addMigration()
}
return
}
func ( *Collection) ( string) *Collection {
.tableName =
return
}
func ( *Collection) () (string, string) {
if := strings.IndexByte(.tableName, '.'); >= 0 {
return .tableName[:], .tableName[+1:]
}
return "public", .tableName
}
func ( *Collection) ( bool) *Collection {
.sqlAutodiscoverDisabled =
return
}
func ( *Collection) ( ...func(DB) error) error {
return .register(false, ...)
}
func ( *Collection) ( ...func(DB) error) error {
return .register(true, ...)
}
func ( *Collection) ( bool, ...func(DB) error) error {
var , func(DB) error
switch len() {
case 0:
return errors.New("Register expects at least 1 arg")
case 1:
= [0]
case 2:
= [0]
= [1]
default:
return fmt.Errorf("Register expects at most 2 args, got %d", len())
}
:= migrationFile()
, := extractVersionGo()
if != nil {
return
}
if !.sqlAutodiscoverDisabled {
= .DiscoverSQLMigrations(filepath.Dir())
if != nil {
return
}
}
.addMigration(&Migration{
Version: ,
UpTx: ,
Up: ,
DownTx: ,
Down: ,
})
return nil
}
func () string {
const = 32
var []uintptr
:= runtime.Callers(1, [:])
:= runtime.CallersFrames([:])
for {
, := .Next()
if ! {
break
}
if !strings.Contains(.Function, "/go-pg/migrations") {
return .File
}
}
return ""
}
func ( *Collection) ( string) error {
, := filepath.Abs()
if != nil {
return
}
return .DiscoverSQLMigrationsFromFilesystem(osfilesystem{}, )
}
func ( *Collection) ( http.FileSystem, string) error {
if .isVisitedDir() {
return nil
}
, := .Open()
if os.IsNotExist() {
return nil
}
if != nil {
return
}
defer .Close()
if , := .Stat(); os.IsNotExist() {
return nil
}
var []*Migration
:= func( int64) *Migration {
for := range {
:= []
if .Version == {
return
}
}
= append(, &Migration{
Version: ,
})
return [len()-1]
}
, := .Readdir(-1)
if != nil {
return
}
sort.Slice(, func(, int) bool { return [].Name() < [].Name() })
for , := range {
if .IsDir() {
continue
}
:= .Name()
if !strings.HasSuffix(, ".sql") {
continue
}
:= strings.IndexByte(, '_')
if == -1 {
:= fmt.Errorf(
"file=%q must have name in format version_comment, e.g. 1_initial",
)
return
}
, := strconv.ParseInt([:], 10, 64)
if != nil {
return
}
:= ()
:= filepath.Join(, )
if strings.HasSuffix(, ".up.sql") {
if .Up != nil {
return fmt.Errorf("migration=%d already has Up func", )
}
.UpTx = strings.HasSuffix(, ".tx.up.sql")
.Up = newSQLMigration(, )
continue
}
if strings.HasSuffix(, ".down.sql") {
if .Down != nil {
return fmt.Errorf("migration=%d already has Down func", )
}
.DownTx = strings.HasSuffix(, ".tx.down.sql")
.Down = newSQLMigration(, )
continue
}
return fmt.Errorf(
"file=%q must have extension .up.sql or .down.sql", )
}
for , := range {
.addMigration()
}
return nil
}
func ( *Collection) ( string) bool {
.mu.Lock()
defer .mu.Unlock()
if , := .visitedDirs[]; {
return true
}
if .visitedDirs == nil {
.visitedDirs = make(map[string]struct{})
}
.visitedDirs[] = struct{}{}
return false
}
func ( http.FileSystem, string) func(DB) error {
return func( DB) error {
, := .Open()
if != nil {
return
}
defer .Close()
:= bufio.NewScanner()
var []byte
var []string
for .Scan() {
:= .Bytes()
const = "--gopg:"
if bytes.HasPrefix(, []byte()) {
= [len():]
if bytes.Equal(, []byte("split")) {
= append(, string())
= [:0]
continue
}
return fmt.Errorf("unknown gopg directive: %q", )
}
= append(, ...)
= append(, '\n')
}
if len() > 0 {
= append(, string())
}
if := .Err(); != nil {
return
}
if len() > 1 {
switch v := .(type) {
case *pg.DB:
:= .Conn()
defer .Close()
=
}
}
for , := range {
_, = .Exec()
if != nil {
return
}
}
return nil
}
}
func ( *Collection) ( *Migration) {
.mu.Lock()
defer .mu.Unlock()
for , := range .migrations {
if .Version > .Version {
.migrations = insert(.migrations, , )
return
}
}
.migrations = append(.migrations, )
}
func ( []*Migration, int, *Migration) []*Migration {
= append(, nil)
copy([+1:], [:])
[] =
return
}
func ( *Collection) ( ...func(DB) error) {
:= .Register(...)
if != nil {
panic()
}
}
func ( *Collection) ( ...func(DB) error) {
:= .RegisterTx(...)
if != nil {
panic()
}
}
func ( *Collection) () []*Migration {
if !.sqlAutodiscoverDisabled {
_ = .DiscoverSQLMigrations(filepath.Dir(migrationFile()))
, := os.Getwd()
if == nil {
_ = .DiscoverSQLMigrations()
}
}
.mu.Lock()
defer .mu.Unlock()
:= make([]*Migration, len(.migrations))
copy(, .migrations)
return
}
func ( *Collection) ( DB, ...string) (, int64, error) {
:= .Migrations()
= validateMigrations()
if != nil {
return
}
:= "up"
if len() > 0 {
= [0]
}
switch {
case "init":
= .createTable()
if != nil {
return
}
return
case "create":
if len() < 2 {
fmt.Println("please provide migration description")
return
}
var int64
if len() > 0 {
= [len()-1].Version
}
:= fmtMigrationFilename(+1, strings.Join([1:], "_"))
= createMigrationFile()
if != nil {
return
}
fmt.Println("created new migration", )
return
}
, := .tableExists()
if != nil {
return
}
if ! {
= fmt.Errorf("table %q does not exist; did you run init?", .tableName)
return
}
, , := .begin()
if != nil {
return
}
defer .Close()
=
=
switch {
case "version":
case "up":
:= int64(math.MaxInt64)
if len() > 1 {
, = strconv.ParseInt([1], 10, 64)
if != nil {
return
}
if > {
break
}
}
for , := range {
if .Version > {
break
}
if == nil {
, , = .begin()
if != nil {
return
}
}
if .Version <= {
continue
}
, = .runUp(, , )
if != nil {
return
}
= .Commit()
if != nil {
return
}
= nil
}
case "down":
, = .down(, , , )
if != nil {
return
}
case "reset":
for {
if == nil {
, , = .begin()
if != nil {
return
}
}
, = .down(, , , )
if != nil {
return
}
= .Commit()
if != nil {
return
}
= nil
if == {
break
}
=
}
case "set_version":
if len() < 2 {
= fmt.Errorf("set_version requires version as 2nd arg, e.g. set_version 42")
return
}
, = strconv.ParseInt([1], 10, 64)
if != nil {
return
}
= .SetVersion(, )
if != nil {
return
}
default:
= fmt.Errorf("unsupported command: %q", )
if != nil {
return
}
}
if != nil {
= .Commit()
}
return
}
func ( []*Migration) error {
:= make(map[int64]struct{}, len())
for , := range {
if , := [.Version]; {
return fmt.Errorf(
"there are multiple migrations with version=%d", .Version)
}
[.Version] = struct{}{}
}
return nil
}
func ( *Collection) ( DB, *pg.Tx, *Migration) (int64, error) {
if .UpTx {
=
}
return .run(, func() (int64, error) {
:= .Up()
if != nil {
return 0,
}
return .Version, nil
})
}
func ( *Collection) ( DB, *pg.Tx, *Migration) (int64, error) {
if .DownTx {
=
}
return .run(, func() (int64, error) {
if .Down != nil {
:= .Down()
if != nil {
return 0,
}
}
return .Version - 1, nil
})
}
func ( *Collection) (
*pg.Tx, func() (int64, error),
) ( int64, error) {
, = ()
if != nil {
return
}
= .SetVersion(, )
return
}
func ( *Collection) ( DB, *pg.Tx, []*Migration, int64) (int64, error) {
if == 0 {
return 0, nil
}
var *Migration
for := len() - 1; >= 0; -- {
:= []
if .Version <= {
=
break
}
}
if == nil {
return , nil
}
return .runDown(, , )
}
func ( *Collection) ( DB) (bool, error) {
, := .schemaTableName()
return .Model().
Table("information_schema.schemata").
Where("schema_name = '?'", pg.SafeQuery()).
Exists()
}
func ( *Collection) ( DB) (bool, error) {
, := .schemaTableName()
return .Model().
Table("pg_tables").
Where("schemaname = '?'", pg.SafeQuery()).
Where("tablename = '?'", pg.SafeQuery()).
Exists()
}
func ( *Collection) ( DB) (int64, error) {
var int64
, := .QueryOne(pg.Scan(&), `
SELECT version FROM ? ORDER BY id DESC LIMIT 1
`, pg.SafeQuery(.tableName))
if != nil {
if == pg.ErrNoRows {
return 0, nil
}
return 0,
}
return , nil
}
func ( *Collection) ( DB, int64) error {
, := .Exec(`
INSERT INTO ? (version, created_at) VALUES (?, now())
`, pg.SafeQuery(.tableName), )
return
}
func ( *Collection) ( DB) error {
, := .schemaExists()
if != nil {
return
}
if ! {
, := .schemaTableName()
, := .Exec(`CREATE SCHEMA IF NOT EXISTS ?`, pg.SafeQuery())
if != nil {
return
}
}
_, = .Exec(`
CREATE TABLE IF NOT EXISTS ? (
id serial,
version bigint,
created_at timestamptz
)
`, pg.SafeQuery(.tableName))
return
}
const (
cockroachdbErrorMatch = `at or near "lock"`
yugabytedbErrorMatch = `lock mode not supported yet`
)
func ( *Collection) ( DB) (*pg.Tx, int64, error) {
, := .Begin()
if != nil {
return nil, 0,
}
_, = .Exec("SET idle_in_transaction_session_timeout = 0")
if != nil {
_ = .Rollback()
, = .Begin()
if != nil {
return nil, 0,
}
}
_, = .Exec("LOCK TABLE ? IN EXCLUSIVE MODE", pg.SafeQuery(.tableName))
if != nil {
_ = .Rollback()
if !strings.Contains(.Error(), cockroachdbErrorMatch) && !strings.Contains(.Error(), yugabytedbErrorMatch) {
return nil, 0,
}
, = .Begin()
if != nil {
return nil, 0,
}
}
, := .Version()
if != nil {
_ = .Rollback()
return nil, 0,
}
return , , nil
}
func ( string) (int64, error) {
:= filepath.Base()
if !strings.HasSuffix(, ".go") {
return 0, fmt.Errorf("file=%q must have extension .go", )
}
:= strings.IndexByte(, '_')
if == -1 {
:= fmt.Errorf(
"file=%q must have name in format version_comment, e.g. 1_initial",
)
return 0,
}
, := strconv.ParseInt([:], 10, 64)
if != nil {
return 0,
}
return , nil
}
var migrationNameRE = regexp.MustCompile(`[^a-z0-9]+`)
func ( int64, string) string {
= strings.ToLower()
= migrationNameRE.ReplaceAllString(, "_")
return fmt.Sprintf("%d_%s.go", , )
}
func ( string) error {
, := os.Getwd()
if != nil {
return
}
= path.Join(, )
_, = os.Stat()
if !os.IsNotExist() {
return fmt.Errorf("file=%q already exists (%s)", , )
}
return ioutil.WriteFile(, migrationTemplate, 0o644)
}
var migrationTemplate = []byte(`package main
import (
"github.com/go-pg/migrations"
)
func init() {
migrations.MustRegisterTx(func(db migrations.DB) error {
_, err := db.Exec("")
return err
}, func(db migrations.DB) error {
_, err := db.Exec("")
return err
})
}
`)
type osfilesystem struct{}
func (osfilesystem) ( string) (http.File, error) {
return os.Open()
}