package emit
import (
"fmt"
"go/format"
"go/token"
"go/types"
"slices"
"strconv"
"strings"
"go.pact.im/x/plumb/internal/discover"
"go.pact.im/x/plumb/internal/gotypes"
"go.pact.im/x/plumb/internal/solve"
)
type Result struct {
Source string
Report string
}
func File (importPath , packageName string , pkgs []*discover .Package , plans []*solve .Plan , dest *solve .DestInfo ) *Result {
needErrors := slices .ContainsFunc (plans , planNeedsErrors )
allLifted := liftedNamesOf (plans )
rec := newRecording (importPath )
for _ , pl := range plans {
recordPlanPackages (pl , rec , dest , allLifted )
}
pkgList := rec .recordedPackages ()
setNames := map [string ]bool {}
for _ , pl := range plans {
setNames [pl .Name ] = true
}
aliasByPath , errorsAlias , importLines := assignAliases (pkgList , needErrors , importPath , allLifted , setNames , dest )
q := newFinal (importPath , aliasByPath )
var funcs []string
for _ , pl := range plans {
funcs = append (funcs , renderPlan (pl , q , errorsAlias , dest , allLifted ))
}
src := assembleFile (packageName , importLines , funcs )
formatted , err := format .Source ([]byte (src ))
if err != nil {
panic (fmt .Sprintf ("plumb: generated source did not format: %v\n---\n%s" , err , src ))
}
report := buildReport (importPath , pkgs , plans , importLines , aliasByPath )
return &Result {Source : string (formatted ), Report : report }
}
func recordPlanPackages (pl *solve .Plan , q *qualifier , dest *solve .DestInfo , lifted map [string ]bool ) {
_ = renderPlan (pl , q , "" , dest , lifted )
}
func liftedNamesOf (plans []*solve .Plan ) map [string ]bool {
names := map [string ]bool {}
for _ , pl := range plans {
for _ , tp := range pl .Lifted {
names [tp .Obj ().Name ()] = true
}
}
return names
}
func assignAliases (pkgs []*types .Package , needErrors bool , destPath string , lifted , setNames map [string ]bool , dest *solve .DestInfo ) (map [string ]string , string , []string ) {
taken := reservedIdents (dest , lifted )
for n := range setNames {
taken [n ] = true
}
type imp struct {
path string
name string
alias string
}
var imps []imp
seen := map [string ]bool {}
add := func (path , name string ) {
if path == destPath || seen [path ] {
return
}
seen [path ] = true
imps = append (imps , imp {path : path , name : name })
}
for _ , p := range pkgs {
add (p .Path (), p .Name ())
}
if needErrors {
add ("errors" , "errors" )
}
slices .SortFunc (imps , func (a , b imp ) int { return strings .Compare (a .path , b .path ) })
aliasByPath := map [string ]string {}
for i := range imps {
base := imps [i ].name
if base == "" {
base = "pkg"
}
name := base
for n := 2 ; taken [name ] || token .IsKeyword (name ); n ++ {
name = base + strconv .Itoa (n )
}
taken [name ] = true
imps [i ].alias = name
aliasByPath [imps [i ].path ] = name
}
var lines []string
for _ , m := range imps {
if m .alias == m .name {
lines = append (lines , fmt .Sprintf ("%q" , m .path ))
} else {
lines = append (lines , fmt .Sprintf ("%s %q" , m .alias , m .path ))
}
}
errorsAlias := ""
if needErrors {
errorsAlias = aliasByPath ["errors" ]
}
return aliasByPath , errorsAlias , lines
}
func renderPlan (pl *solve .Plan , q *qualifier , errorsAlias string , dest *solve .DestInfo , lifted map [string ]bool ) string {
alloc := newAllocator (dest , q , lifted )
var localOf gotypes .Map [types .Type , string ]
setLocal := func (t types .Type , name string ) { localOf .Set (t , name ) }
local := func (t types .Type ) string {
n , ok := localOf .At (t )
if !ok {
panic (fmt .Sprintf ("plumb: emit: no local bound for %s" , gotypes .TypeName (t )))
}
return n
}
var resultDecls []string
outNames := make ([]string , len (pl .Outputs ))
for i , o := range pl .Outputs {
n := alloc .alloc (baseName (o ))
outNames [i ] = n
resultDecls = append (resultDecls , n +" " +q .typeString (o ))
}
cleanupName := ""
if pl .AnyCleanup {
cleanupName = alloc .alloc ("cleanup" )
ct := "func()"
if pl .CleanupFailable {
ct = "func() error"
}
resultDecls = append (resultDecls , cleanupName +" " +ct )
}
errName := ""
if pl .Fallible {
errName = alloc .alloc ("err" )
resultDecls = append (resultDecls , errName +" error" )
}
var paramDecls []string
for _ , in := range pl .Inputs {
base := baseName (in .Type )
if h := in .Name ; h != "" && h != "_" {
base = lowerCamel (h )
}
n := alloc .alloc (base )
setLocal (in .Type , n )
paramDecls = append (paramDecls , n +" " +q .typeString (in .Type ))
}
var body []string
var acquired []cleanupRef
errLocal := ""
errDeclared := false
for _ , in := range pl .Order {
args := make ([]string , len (in .Inputs ))
for i , ref := range pl .Args [in ] {
args [i ] = coerceExpr (local (ref .SrcType ), ref .Coerce )
}
rhs := renderInstance (in , q , args )
var lhs []string
newVar := false
hasErr := false
var newCleanups []cleanupRef
for _ , r := range in .Results {
switch r .Kind {
case solve .ResultKindValue :
vn := alloc .alloc (baseName (r .Typ ))
setLocal (r .Typ , vn )
lhs = append (lhs , vn )
newVar = true
case solve .ResultKindCleanup :
cn := alloc .alloc (cleanupBaseName (in .Prov .Fn ))
lhs = append (lhs , cn )
newVar = true
newCleanups = append (newCleanups , cleanupRef {name : cn , failable : r .Failable })
case solve .ResultKindError :
if errLocal == "" {
errLocal = alloc .alloc ("e" )
}
lhs = append (lhs , errLocal )
hasErr = true
default :
panic (fmt .Sprintf ("plumb: unhandled result kind %d" , r .Kind ))
}
}
switch {
case len (in .Results ) == 0 :
body = append (body , rhs )
default :
op := ":="
if !newVar && errDeclared {
op = "="
}
body = append (body , strings .Join (lhs , ", " )+" " +op +" " +rhs )
}
if hasErr {
errDeclared = true
stmts , errExpr := renderUnwind (acquired , errLocal , pl , errorsAlias , alloc )
block := []string {"if " + errLocal + " != nil {" }
block = append (block , stmts ...)
block = append (block , errName +" = " +errExpr , "return" , "}" )
body = append (body , block ...)
}
acquired = append (acquired , newCleanups ...)
}
for i , o := range pl .Outputs {
body = append (body , outNames [i ]+" = " +local (o ))
}
if pl .AnyCleanup {
body = append (body , cleanupName +" = " +renderAggregate (acquired , pl , errorsAlias , alloc ))
}
if len (resultDecls ) > 0 {
body = append (body , "return" )
}
var b strings .Builder
b .WriteString ("func " )
b .WriteString (pl .Name )
b .WriteString (renderLiftedHeader (pl .Lifted , q .typeString ))
b .WriteByte ('(' )
b .WriteString (strings .Join (paramDecls , ", " ))
b .WriteByte (')' )
if len (resultDecls ) > 0 {
b .WriteString (" (" )
b .WriteString (strings .Join (resultDecls , ", " ))
b .WriteByte (')' )
}
b .WriteString (" {\n" )
for _ , line := range body {
b .WriteString (line )
b .WriteByte ('\n' )
}
b .WriteString ("}" )
return b .String ()
}
type cleanupRef struct {
name string
failable bool
}
func renderUnwind (acquired []cleanupRef , errLocal string , pl *solve .Plan , errorsAlias string , alloc *allocator ) (stmts []string , errExpr string ) {
rev := reverseCleanups (acquired )
if !pl .CleanupFailable || !anyFailable (rev ) {
for _ , c := range rev {
stmts = append (stmts , c .name +"()" )
}
return stmts , errLocal
}
cs , errLocals := cleanupStmts (rev , alloc )
return cs , joinErrs (errorsAlias , append ([]string {errLocal }, errLocals ...))
}
func anyFailable (cs []cleanupRef ) bool {
for _ , c := range cs {
if c .failable {
return true
}
}
return false
}
func renderAggregate (acquired []cleanupRef , pl *solve .Plan , errorsAlias string , alloc *allocator ) string {
rev := reverseCleanups (acquired )
if !pl .CleanupFailable {
var b strings .Builder
b .WriteString ("func() {\n" )
for _ , c := range rev {
b .WriteString (c .name )
b .WriteString ("()\n" )
}
b .WriteString ("}" )
return b .String ()
}
cs , errLocals := cleanupStmts (rev , alloc )
var b strings .Builder
b .WriteString ("func() error {\n" )
for _ , s := range cs {
b .WriteString (s )
b .WriteByte ('\n' )
}
b .WriteString ("return " )
b .WriteString (joinErrs (errorsAlias , errLocals ))
b .WriteString ("\n}" )
return b .String ()
}
func cleanupStmts (rev []cleanupRef , alloc *allocator ) (stmts , errLocals []string ) {
ba := alloc .branch ()
for _ , c := range rev {
if c .failable {
errl := ba .alloc ("cleanupErr" )
stmts = append (stmts , errl +" := " +c .name +"()" )
errLocals = append (errLocals , errl )
} else {
stmts = append (stmts , c .name +"()" )
}
}
return stmts , errLocals
}
func joinErrs (errorsAlias string , errs []string ) string {
if len (errs ) == 1 {
return errs [0 ]
}
return errorsAlias + ".Join(" + strings .Join (errs , ", " ) + ")"
}
func planNeedsErrors (pl *solve .Plan ) bool {
if !pl .CleanupFailable {
return false
}
failable := 0
for _ , in := range pl .Order {
for _ , r := range in .Results {
if r .Kind == solve .ResultKindCleanup && r .Failable {
failable ++
}
}
}
if failable >= 2 {
return true
}
acquired := 0
for _ , in := range pl .Order {
if acquired >= 1 && instanceFallible (in ) {
return true
}
for _ , r := range in .Results {
if r .Kind == solve .ResultKindCleanup && r .Failable {
acquired ++
}
}
}
return false
}
func instanceFallible (in *solve .Instance ) bool {
for _ , r := range in .Results {
if r .Kind == solve .ResultKindError {
return true
}
}
return false
}
func reverseCleanups (in []cleanupRef ) []cleanupRef {
out := slices .Clone (in )
slices .Reverse (out )
return out
}
func coerceExpr (local string , c solve .Coerce ) string {
switch c {
case solve .CoerceNone :
return local
case solve .CoerceToPtr :
return "&" + local
case solve .CoerceToVal :
return "*" + local
default :
panic (fmt .Sprintf ("plumb: unhandled coercion %d" , c ))
}
}
func renderLiftedHeader (lifted []*types .TypeParam , qual func (types .Type ) string ) string {
if len (lifted ) == 0 {
return ""
}
var parts []string
for _ , tp := range lifted {
c := "any"
if !gotypes .ConstraintCollapsesToAny (tp .Constraint ()) {
c = qual (tp .Constraint ())
}
parts = append (parts , tp .Obj ().Name ()+" " +c )
}
return "[" + strings .Join (parts , ", " ) + "]"
}
func assembleFile (pkgName string , importLines , funcs []string ) string {
var b strings .Builder
b .WriteString ("// Code generated by plumb. DO NOT EDIT.\n\n" )
b .WriteString ("package " )
b .WriteString (pkgName )
b .WriteString ("\n\n" )
if len (importLines ) > 0 {
b .WriteString ("import (\n" )
for _ , l := range importLines {
b .WriteString (l )
b .WriteByte ('\n' )
}
b .WriteString (")\n\n" )
}
b .WriteString (strings .Join (funcs , "\n\n" ))
b .WriteString ("\n" )
return b .String ()
}