8000 Zzq/function contracts/general nonnil by lizard-boy · Pull Request #8 · greptileai/nilaway · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Zzq/function contracts/general nonnil #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions assertion/function/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,24 +304,43 @@ func duplicateFullTriggersFromContractedFunctionsToCallers(
if r == nil {
// should not happen since funcResults should contain all the functions including any
// contracted functions.
panic(fmt.Sprintf("Did not find the contracted function %s in funcResults", ctrtFunc.Id()))
panic(fmt.Sprintf("Did not find the contracted function %s in funcResults",
ctrtFunc.FullName()))
}
ctrts := funcContracts[ctrtFunc]
if ctrts == nil {
// should not happen since ctrtFunc is a contracted function
panic(fmt.Sprintf(
"Did not find the contracted function %s in funcContracts",
ctrtFunc.FullName()))
}
contract := ctrts[0]
ctrtParamIndex := contract.IndexOfNonnilIn()
ctrtRetIndex := contract.IndexOfNonnilOut()
for _, trigger := range r.triggers {
// If the full trigger has a FuncParam producer or a UseAsReturn consumer, then create
// a duplicated (possibly controlled) full trigger from it and add the created full
// trigger to every caller.
_, isParamProducer := trigger.Producer.Annotation.(annotation.FuncParam)
_, isReturnConsumer := trigger.Consumer.Annotation.(annotation.UseAsReturn)
if !isParamProducer && !isReturnConsumer {
// No need to duplicate the full trigger
p, isContractedParam := trigger.Producer.Annotation.(annotation.FuncParam)
if isContractedParam {
isContractedParam = ctrtParamIndex == p.TriggerIfNilable.Ann.(annotation.ParamAnnotationKey).ParamNum
}
c, isContractedReturn := trigger.Consumer.Annotation.(annotation.UseAsReturn)
if isContractedReturn {
isContractedReturn = ctrtRetIndex == c.TriggerIfNonNil.Ann.(annotation.RetAnnotationKey).RetNum
}
if !isContractedParam && !isContractedReturn {
// We only duplicate the full trigger if it is the right parameter which is the one
// with contract value NONNIL in a general nonnil->nonnil contract.
// TODO: However, this could be changed in the future if we support multiple
// contracts and/or other kinds of contract values.
continue
}
// Duplicate the full trigger in every caller
for caller, callExprs := range calls {
for _, callExpr := range callExprs {
dupTrigger := duplicateFullTrigger(trigger, ctrtFunc, callExpr, pass,
isParamProducer, isReturnConsumer)

ctrtParamIndex, isContractedParam, isContractedReturn)
// Store the duplicated full trigger
dupTriggers[caller] = append(dupTriggers[caller], dupTrigger)
}
Expand Down Expand Up @@ -349,11 +368,15 @@ func duplicateFullTrigger(
callee *types.Func,
callExpr *ast.CallExpr,
pass *analysis.Pass,
ctrtParamIndex int,
isParamProducer bool,
isReturnConsumer bool,
) annotation.FullTrigger {
// TODO: what if we have more than one parameter, planned in future revisions
argExpr := callExpr.Args[0]
// TODO: what if we have other kinds of contracts than a general nonnil->nonnil contract,
// planned in the future.

// Assume the contract is a general nonnil->nonnil contract
argExpr := callExpr.Args[ctrtParamIndex]
argLoc := util.PosToLocation(argExpr.Pos(), pass)

// Create the duplicated full trigger
Expand All @@ -379,7 +402,7 @@ func duplicateFullTrigger(
retLoc := util.PosToLocation(callExpr.Pos(), pass)
dupTrigger.Consumer = annotation.DuplicateReturnConsumer(trigger.Consumer, retLoc)
// Set up the site that controls the controlled full trigger to be created
c := annotation.NewCallSiteParamKey(callee, 0, argLoc)
c := annotation.NewCallSiteParamKey(callee, ctrtParamIndex, argLoc)
dupTrigger.Controller = &c
}

Expand Down Expand Up @@ -424,16 +447,15 @@ func findCallsToContractedFunctions(
return calls
}

// hasOnlyNonNilToNonNilContract returns whether the given function has only one contract that is
// nonnil->nonnil.
// hasOnlyNonNilToNonNilContract returns true if the given function has only a single contract
// and the contract is a general nonnil -> nonnil contract, i.e., the contract has only one
// nonnil in input and only one nonnil in output, and all the other values are any.
func hasOnlyNonNilToNonNilContract(funcContracts functioncontracts.Map, funcObj *types.Func) bool {
contracts, ok := funcContracts[funcObj]
if !ok || len(contracts) != 1 {
return false
}
ctr := contracts[0]
return len(ctr.Ins) == 1 && ctr.Ins[0] == functioncontracts.NonNil &&
len(ctr.Outs) == 1 && ctr.Outs[0] == functioncontracts.NonNil
return contracts[0].IsGeneralNonnnilToNonnil()
}

// analyzeFunc analyzes a given function declaration and emit generated triggers, or an error if
Expand Down
2 changes: 1 addition & 1 deletion assertion/function/assertiontree/parse_expr_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C

for i := 0; i < numResults; i++ {
var retKey annotation.Key
if r.HasContract(funcObj) {
if r.RetHasContract(funcObj, i) {
// Creates a new return site with location information at every call site for a
// function with contracts. The return site is unique at every call site, even with the
// same function called.
Expand Down
30 changes: 25 additions & 5 deletions assertion/function/assertiontree/root_assertion_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"go/types"

"go.uber.org/nilaway/annotation"
"go.uber.org/nilaway/assertion/function/functioncontracts"
"go.uber.org/nilaway/config"
"go.uber.org/nilaway/util"
"golang.org/x/tools/go/analysis"
Expand Down Expand Up @@ -54,10 +55,29 @@ func (r *RootAssertionNode) LocationOf(expr ast.Expr) token.Position {
return util.PosToLocation(expr.Pos(), r.Pass())
}

// HasContract returns if the given function has any contracts.
func (r *RootAssertionNode) HasContract(funcObj *types.Func) bool {
_, ok := r.functionContext.funcContracts[funcObj]
return ok
// getSingleNonnilToNonnilContract returns if the given function has only one contract that has only
// one nonnil in input and only one nonnil in output and all the other values are any, e.g.,
// contract(_,nonnil->nonnil,_) is OK, but contract(nonnil,nonnil->nonnil,_) is not.
func (r *RootAssertionNode) getSingleNonnilToNonnilContract(funcObj *types.Func) *functioncontracts.FunctionContract {
ctrts := r.functionContext.funcContracts[funcObj]
if ctrts == nil || len(ctrts) != 1 || !ctrts[0].IsGeneralNonnnilToNonnil() {
return nil
}
return ctrts[0]
}

// ParamHasContract returns if the given parameter of the given function has a non-nil contract
// value.
func (r *RootAssertionNode) ParamHasContract(funcObj *types.Func, i int) bool {
ctrt := r.getSingleNonnilToNonnilContract(funcObj)
return ctrt != nil && ctrt.Ins[i] == functioncontracts.NonNil
}

// RetHasContract returns if the given return value of the given function has a non-nil contract
// value.
func (r *RootAssertionNode) RetHasContract(funcObj *types.Func, i int) bool {
ctrt := r.getSingleNonnilToNonnilContract(funcObj)
return ctrt != nil && ctrt.Outs[i] == functioncontracts.NonNil
}

// MinimalString for a RootAssertionNode returns a minimal string representation of that root node
Expand Down Expand Up @@ -662,7 +682,7 @@ func (r *RootAssertionNode) AddComputation(expr ast.Expr) {
})
} else {
var paramKey annotation.Key
if r.HasContract(fdecl) {
if r.ParamHasContract(fdecl, i) {
// Creates a new param site with location information at every call site
// for a function with contracts. The param site is unique at every call
// site, even with the same function called.
Expand Down
154 changes: 153 additions & 1 deletion assertion/function/functioncontracts/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
package functioncontracts

import (
"context"
"fmt"
& F438 quot;go/ast"
"go/types"
"reflect"
"runtime/debug"
"sync"

"go.uber.org/nilaway/config"
"go.uber.org/nilaway/util"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/buildssa"
"golang.org/x/tools/go/ssa"
)

const _doc = "Read the contracts of each function in this package, returning the results."
Expand All @@ -44,7 +51,14 @@ var Analyzer = &analysis.Analyzer{
Doc: _doc,
Run: run,
ResultType: reflect.TypeOf((*Result)(nil)).Elem(),
Requires: []*analysis.Analyzer{config.Analyzer},
Requires: []*analysis.Analyzer{buildssa.Analyzer, config.Analyzer},
}

// functionResult is the struct that is received from the channel for each function.
type functionResult struct {
funcObj *types.Func
contracts []*FunctionContract
err error
}

func run(pass *analysis.Pass) (result interface{}, _ error) {
Expand Down Expand Up @@ -72,3 +86,141 @@ func run(pass *analysis.Pass) (result interface{}, _ error) {

return Result{FunctionContracts: collectFunctionContracts(pass)}, nil
}

// collectFunctionContracts collects all the function contracts and returns a map that associates
// every function with its contracts if it has any. We prefer to parse handwritten contracts from
// the comments at the top of each function. Only when there are no handwritten contracts there,
// do we try to automatically infer contracts.
func collectFunctionContracts(pass *analysis.Pass) Map {
// Collect ssa for every function.
conf := pass.ResultOf[config.Analyzer].(*config.Config)
ssaInput := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
ssaOfFunc := make(map[*types.Func]*ssa.Function, len(ssaInput.SrcFuncs))
for _, fnssa := range ssaInput.SrcFuncs {
if fnssa == nil {
// should be guaranteed to be non-nil; otherwise it would have paniced in the library
// https://cs.opensource.google/go/x/tools/+/refs/tags/v0.12.0:go/analysis/passes/buildssa/buildssa.go;l=99
continue
}
if funcObj, ok := fnssa.Object().(*types.Func); ok {
ssaOfFunc[funcObj] = fnssa
}
}

// Set up variables for synchronization and communication.
_, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
funcChan := make(chan functionResult)

m := Map{}
for _, file := range pass.Files {
if !conf.IsFileInScope(file) || !util.DocContainsFunctionContractsCheck(file.Doc) {
continue
}
for _, decl := range file.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
// Ignore any non-function declaration
// TODO: If we want to support contracts for anonymous functions (function
// literals) in the future, then we need to handle more types here.
continue
}
funcObj := pass.TypesInfo.ObjectOf(funcDecl.Name).(*types.Func)

// First, we try to parse the contracts from the comments at the top of the function.
// If there are any, we do not need to infer contracts for this function.
if parsedContracts := parseContracts(funcDecl.Doc); len(parsedContracts) != 0 {
m[funcObj] = parsedContracts
continue
}
// If we reach here, it means that there are no handwritten contracts for this
// function. We need to infer contracts for this function.
if funcDecl.Type.Params.NumFields() == 0 ||
funcDecl.Type.Results.NumFields() == 0 ||
allParamOrRetTypesBarNilness(funcObj) ||
funcObj.Type().(*types.Signature).Variadic() {
// We ignore any function without any parameters or return values since they cannot
// have any contracts.
// We ignore any function that has a parameter or return value with a type that
// cannot have nil as a valid value, e.g., an int.

// TODO: We ignore variadic parameters since they are not well handled when
// creating triggers. We will need to create a Always Nilable producer if no
// argument is passed to a site that is supposed to be a variadic parameter.
// We leave this as future work.
continue
}
fnssa, ok := ssaOfFunc[funcObj]
if !ok {
// For some reason, we cannot find the ssa for this function. We ignore this
// function.
continue
}
wg.Add(1)
// Infer contracts for a function that does not have any contracts specified.
go inferContractsToChannel(funcObj, fnssa, funcChan, &wg)
}
}

// Spawn another goroutine that will close the channel when all analyses are done. This makes
// sure the channel receive logic in the main thread (below) can properly terminate.
go func() {
wg.Wait()
close(funcChan)
}()

// Collect inferred contracts from the channel.
for r := range funcChan {
if len(r.contracts) != 0 {
ctrt := r.contracts[0]
if ctrt == nil {
continue
}
m[r.funcObj] = r.contracts
}
}
return m
}

// allParamOrRetTypesBarNilness checks if all parameter or return values of a function have types
// that cannot have nil as a valid value, e.g., all ints.
func allParamOrRetTypesBarNilness(funcObj *types.Func) bool {
params := funcObj.Type().(*types.Signature).Params()
results := funcObj.Type().(*types.Signature).Results()
return allInTupleTypesBarNilness(params) || allInTupleTypesBarNilness(results)
}

// allInTupleTypesBarNilness checks if all types in a tuple cannot have nil as a valid value, e.g.,
// all ints.
func allInTupleTypesBarNilness(vars *types.Tuple) bool {
for i := 0; i < vars.Len(); i++ {
if !util.TypeBarsNilness(vars.At(i).Type()) {
return false
}
}
return true
}

// inferContractsToChannel infers contracts for a function that does not have any contracts
// specified and sends the result to the channel.
func inferContractsToChannel(
funcObj *types.Func,
fnssa *ssa.Function,
fnChan chan functionResult,
wg *sync.WaitGroup,
) {
// As a last resort, convert the panics into errors and return.
defer func() {
if r := recover(); r != nil {
e := fmt.Errorf("INTERNAL PANIC: %s\n%s", r, string(debug.Stack()))
fnChan <- functionResult{err: e, funcObj: funcObj, contracts: []*FunctionContract{}}
}
}()
defer wg.Done()

fnChan <- functionResult{
funcObj: funcObj,
contracts: inferContracts(fnssa),
}
}
Loading
0